mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-19 20:18:22 -05:00
Compare commits
40 Commits
make-old-w
...
claude/tes
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
edba0c5ca6 | ||
|
|
3f29f71dd6 | ||
|
|
00207eb4c9 | ||
|
|
4a52b7eca0 | ||
|
|
97847f59f7 | ||
|
|
22ca8955c5 | ||
|
|
43cbe2e011 | ||
|
|
a318832414 | ||
|
|
843c487500 | ||
|
|
47a3a5ef41 | ||
|
|
ec00aa951a | ||
|
|
36fb1ea004 | ||
|
|
a81ac150da | ||
|
|
49ee087496 | ||
|
|
fc25e008b3 | ||
|
|
b0855e8cf2 | ||
|
|
5e2146dd76 | ||
|
|
103a62c9da | ||
|
|
fc8434fb30 | ||
|
|
3ae08cd48e | ||
|
|
4db13837b9 | ||
|
|
df87867625 | ||
|
|
e503126170 | ||
|
|
7ee28197a3 | ||
|
|
818de26d24 | ||
|
|
cb08def96c | ||
|
|
ac2daee5f8 | ||
|
|
266e0d79d4 | ||
|
|
01f443190e | ||
|
|
bdba0033de | ||
|
|
b87c64ce38 | ||
|
|
003affca43 | ||
|
|
290d0d9a9b | ||
|
|
fba61c72ed | ||
|
|
79d45a15d0 | ||
|
|
66f0d97ca2 | ||
|
|
5894a8fcdf | ||
|
|
dff8efa35d | ||
|
|
e26822998f | ||
|
|
4a7bc006a8 |
37
.branchlet.json
Normal file
37
.branchlet.json
Normal file
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"worktreeCopyPatterns": [
|
||||
".env*",
|
||||
".vscode/**",
|
||||
".auth/**",
|
||||
".claude/**",
|
||||
"autogpt_platform/.env*",
|
||||
"autogpt_platform/backend/.env*",
|
||||
"autogpt_platform/frontend/.env*",
|
||||
"autogpt_platform/frontend/.auth/**",
|
||||
"autogpt_platform/db/docker/.env*"
|
||||
],
|
||||
"worktreeCopyIgnores": [
|
||||
"**/node_modules/**",
|
||||
"**/dist/**",
|
||||
"**/.git/**",
|
||||
"**/Thumbs.db",
|
||||
"**/.DS_Store",
|
||||
"**/.next/**",
|
||||
"**/__pycache__/**",
|
||||
"**/.ruff_cache/**",
|
||||
"**/.pytest_cache/**",
|
||||
"**/*.pyc",
|
||||
"**/playwright-report/**",
|
||||
"**/logs/**",
|
||||
"**/site/**"
|
||||
],
|
||||
"worktreePathTemplate": "$BASE_PATH.worktree",
|
||||
"postCreateCmd": [
|
||||
"cd autogpt_platform/autogpt_libs && poetry install",
|
||||
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
||||
"cd autogpt_platform/frontend && pnpm install",
|
||||
"cd docs && pip install -r requirements.txt"
|
||||
],
|
||||
"terminalCommand": "code .",
|
||||
"deleteBranchWithWorktree": false
|
||||
}
|
||||
@@ -16,6 +16,7 @@
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
|
||||
2
.github/workflows/classic-autogpt-ci.yml
vendored
2
.github/workflows/classic-autogpt-ci.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.12", "3.13", "3.14"]
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
|
||||
|
||||
13
.github/workflows/classic-autogpts-ci.yml
vendored
13
.github/workflows/classic-autogpts-ci.yml
vendored
@@ -11,6 +11,9 @@ on:
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- '!**/*.md'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
@@ -19,6 +22,9 @@ on:
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- '!**/*.md'
|
||||
|
||||
defaults:
|
||||
@@ -53,15 +59,10 @@ jobs:
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: ./classic/${{ matrix.agent-name }}/
|
||||
run: poetry install
|
||||
|
||||
- name: Run regression tests
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
poetry run serve &
|
||||
sleep 10 # Wait for server to start
|
||||
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
|
||||
poetry run agbenchmark --test=WriteFile
|
||||
env:
|
||||
|
||||
11
.github/workflows/classic-benchmark-ci.yml
vendored
11
.github/workflows/classic-benchmark-ci.yml
vendored
@@ -23,7 +23,7 @@ defaults:
|
||||
shell: bash
|
||||
|
||||
env:
|
||||
min-python-version: '3.12'
|
||||
min-python-version: '3.10'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.12", "3.13", "3.14"]
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
defaults:
|
||||
@@ -128,16 +128,11 @@ jobs:
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
|
||||
- name: Install agent dependencies
|
||||
working-directory: classic/${{ matrix.agent-name }}
|
||||
run: poetry install
|
||||
|
||||
- name: Run regression tests
|
||||
working-directory: classic
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
poetry run python -m forge &
|
||||
sleep 10 # Wait for server to start
|
||||
|
||||
set +e # Ignore non-zero exit codes and continue execution
|
||||
echo "Running the following command: poetry run agbenchmark --maintain --mock"
|
||||
|
||||
2
.github/workflows/classic-forge-ci.yml
vendored
2
.github/workflows/classic-forge-ci.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.12", "3.13", "3.14"]
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
|
||||
|
||||
60
.github/workflows/classic-frontend-ci.yml
vendored
Normal file
60
.github/workflows/classic-frontend-ci.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
name: Classic - Frontend CI/CD
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- dev
|
||||
- 'ci-test*' # This will match any branch that starts with "ci-test"
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
BUILD_BRANCH: ${{ format('classic-frontend-build/{0}', github.ref_name) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Flutter
|
||||
uses: subosito/flutter-action@v2
|
||||
with:
|
||||
flutter-version: '3.13.2'
|
||||
|
||||
- name: Build Flutter to Web
|
||||
run: |
|
||||
cd classic/frontend
|
||||
flutter build web --base-href /app/
|
||||
|
||||
# - name: Commit and Push to ${{ env.BUILD_BRANCH }}
|
||||
# if: github.event_name == 'push'
|
||||
# run: |
|
||||
# git config --local user.email "action@github.com"
|
||||
# git config --local user.name "GitHub Action"
|
||||
# git add classic/frontend/build/web
|
||||
# git checkout -B ${{ env.BUILD_BRANCH }}
|
||||
# git commit -m "Update frontend build to ${GITHUB_SHA:0:7}" -a
|
||||
# git push -f origin ${{ env.BUILD_BRANCH }}
|
||||
|
||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||
if: github.event_name == 'push'
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
add-paths: classic/frontend/build/web
|
||||
base: ${{ github.ref_name }}
|
||||
branch: ${{ env.BUILD_BRANCH }}
|
||||
delete-branch: true
|
||||
title: "Update frontend build in `${{ github.ref_name }}`"
|
||||
body: "This PR updates the frontend build based on commit ${{ github.sha }}."
|
||||
commit-message: "Update frontend build based on commit ${{ github.sha }}"
|
||||
4
.github/workflows/classic-python-checks.yml
vendored
4
.github/workflows/classic-python-checks.yml
vendored
@@ -59,7 +59,7 @@ jobs:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.12"
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -111,7 +111,7 @@ jobs:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.12"
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
|
||||
2
.github/workflows/claude-dependabot.yml
vendored
2
.github/workflows/claude-dependabot.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
2
.github/workflows/claude.yml
vendored
2
.github/workflows/claude.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
12
.github/workflows/copilot-setup-steps.yml
vendored
12
.github/workflows/copilot-setup-steps.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
@@ -108,6 +108,16 @@ jobs:
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
# Remove large unused tools to free disk space for Docker builds
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -134,7 +134,7 @@ jobs:
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
- id: supabase
|
||||
name: Start Supabase
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -3,7 +3,6 @@
|
||||
classic/original_autogpt/keys.py
|
||||
classic/original_autogpt/*.json
|
||||
auto_gpt_workspace/*
|
||||
.autogpt/
|
||||
*.mpeg
|
||||
.env
|
||||
# Root .env files
|
||||
@@ -178,5 +177,5 @@ autogpt_platform/backend/settings.py
|
||||
|
||||
*.ign.*
|
||||
.test-contents
|
||||
**/.claude/settings.local.json
|
||||
.claude/settings.local.json
|
||||
/autogpt_platform/backend/logs
|
||||
|
||||
@@ -12,6 +12,7 @@ reset-db:
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
@@ -33,6 +34,7 @@ init-env:
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
|
||||
@@ -48,7 +48,8 @@ RUN poetry install --no-ansi --no-root
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
RUN poetry run prisma generate
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
|
||||
@@ -489,7 +489,7 @@ async def update_agent_version_in_library(
|
||||
agent_graph_version: int,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Updates the agent version in the library if useGraphIsActiveVersion is True.
|
||||
Updates the agent version in the library for any agent owned by the user.
|
||||
|
||||
Args:
|
||||
user_id: Owner of the LibraryAgent.
|
||||
@@ -498,20 +498,31 @@ async def update_agent_version_in_library(
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's an error with the update.
|
||||
NotFoundError: If no library agent is found for this user and agent.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Updating agent version in library for user #{user_id}, "
|
||||
f"agent #{agent_graph_id} v{agent_graph_version}"
|
||||
)
|
||||
try:
|
||||
library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise(
|
||||
async with transaction() as tx:
|
||||
library_agent = await prisma.models.LibraryAgent.prisma(tx).find_first_or_raise(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": agent_graph_id,
|
||||
"useGraphIsActiveVersion": True,
|
||||
},
|
||||
)
|
||||
lib = await prisma.models.LibraryAgent.prisma().update(
|
||||
|
||||
# Delete any conflicting LibraryAgent for the target version
|
||||
await prisma.models.LibraryAgent.prisma(tx).delete_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": agent_graph_id,
|
||||
"agentGraphVersion": agent_graph_version,
|
||||
"id": {"not": library_agent.id},
|
||||
}
|
||||
)
|
||||
|
||||
lib = await prisma.models.LibraryAgent.prisma(tx).update(
|
||||
where={"id": library_agent.id},
|
||||
data={
|
||||
"AgentGraph": {
|
||||
@@ -525,13 +536,13 @@ async def update_agent_version_in_library(
|
||||
},
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
if lib is None:
|
||||
raise NotFoundError(f"Library agent {library_agent.id} not found")
|
||||
|
||||
return library_model.LibraryAgent.from_db(lib)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating agent version in library: {e}")
|
||||
raise DatabaseError("Failed to update agent version in library") from e
|
||||
if lib is None:
|
||||
raise NotFoundError(
|
||||
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
|
||||
)
|
||||
|
||||
return library_model.LibraryAgent.from_db(lib)
|
||||
|
||||
|
||||
async def update_library_agent(
|
||||
@@ -825,6 +836,7 @@ async def add_store_agent_to_library(
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": SafeJson(
|
||||
_initialize_graph_settings(graph_model).model_dump()
|
||||
),
|
||||
|
||||
@@ -48,6 +48,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
owner_user_id: str # ID of user who owns/created this agent graph
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -163,6 +164,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id=agent.id,
|
||||
graph_id=agent.agentGraphId,
|
||||
graph_version=agent.agentGraphVersion,
|
||||
owner_user_id=agent.userId,
|
||||
image_url=agent.imageUrl,
|
||||
creator_name=creator_name,
|
||||
creator_image_url=creator_image_url,
|
||||
|
||||
@@ -42,6 +42,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -64,6 +65,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-2",
|
||||
graph_id="test-agent-2",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
@@ -138,6 +140,7 @@ async def test_get_favorite_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Favorite Agent 1",
|
||||
description="Test Favorite Description 1",
|
||||
image_url=None,
|
||||
@@ -205,6 +208,7 @@ def test_add_agent_to_library_success(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
|
||||
@@ -614,6 +614,7 @@ async def get_store_submissions(
|
||||
submission_models = []
|
||||
for sub in submissions:
|
||||
submission_model = store_model.StoreSubmission(
|
||||
listing_id=sub.listing_id,
|
||||
agent_id=sub.agent_id,
|
||||
agent_version=sub.agent_version,
|
||||
name=sub.name,
|
||||
@@ -667,35 +668,48 @@ async def delete_store_submission(
|
||||
submission_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a store listing submission as the submitting user.
|
||||
Delete a store submission version as the submitting user.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user
|
||||
submission_id: ID of the submission to be deleted
|
||||
submission_id: StoreListingVersion ID to delete
|
||||
|
||||
Returns:
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
bool: True if successfully deleted
|
||||
"""
|
||||
logger.debug(f"Deleting store submission {submission_id} for user {user_id}")
|
||||
|
||||
try:
|
||||
# Verify the submission belongs to this user
|
||||
submission = await prisma.models.StoreListing.prisma().find_first(
|
||||
where={"agentGraphId": submission_id, "owningUserId": user_id}
|
||||
# Find the submission version with ownership check
|
||||
version = await prisma.models.StoreListingVersion.prisma().find_first(
|
||||
where={"id": submission_id}, include={"StoreListing": True}
|
||||
)
|
||||
|
||||
if not submission:
|
||||
logger.warning(f"Submission not found for user {user_id}: {submission_id}")
|
||||
raise store_exceptions.SubmissionNotFoundError(
|
||||
f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}"
|
||||
if (
|
||||
not version
|
||||
or not version.StoreListing
|
||||
or version.StoreListing.owningUserId != user_id
|
||||
):
|
||||
raise store_exceptions.SubmissionNotFoundError("Submission not found")
|
||||
|
||||
# Prevent deletion of approved submissions
|
||||
if version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
|
||||
raise store_exceptions.InvalidOperationError(
|
||||
"Cannot delete approved submissions"
|
||||
)
|
||||
|
||||
# Delete the submission
|
||||
await prisma.models.StoreListing.prisma().delete(where={"id": submission.id})
|
||||
|
||||
logger.debug(
|
||||
f"Successfully deleted submission {submission_id} for user {user_id}"
|
||||
# Delete the version
|
||||
await prisma.models.StoreListingVersion.prisma().delete(
|
||||
where={"id": version.id}
|
||||
)
|
||||
|
||||
# Clean up empty listing if this was the last version
|
||||
remaining = await prisma.models.StoreListingVersion.prisma().count(
|
||||
where={"storeListingId": version.storeListingId}
|
||||
)
|
||||
if remaining == 0:
|
||||
await prisma.models.StoreListing.prisma().delete(
|
||||
where={"id": version.storeListingId}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -759,9 +773,15 @@ async def create_store_submission(
|
||||
logger.warning(
|
||||
f"Agent not found for user {user_id}: {agent_id} v{agent_version}"
|
||||
)
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
# Provide more user-friendly error message when agent_id is empty
|
||||
if not agent_id or agent_id.strip() == "":
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
"No agent selected. Please select an agent before submitting to the store."
|
||||
)
|
||||
else:
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
|
||||
# Check if listing already exists for this agent
|
||||
existing_listing = await prisma.models.StoreListing.prisma().find_first(
|
||||
@@ -833,6 +853,7 @@ async def create_store_submission(
|
||||
logger.debug(f"Created store listing for agent {agent_id}")
|
||||
# Return submission details
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=listing.id,
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
@@ -944,81 +965,56 @@ async def edit_store_submission(
|
||||
# Currently we are not allowing user to update the agent associated with a submission
|
||||
# If we allow it in future, then we need a check here to verify the agent belongs to this user.
|
||||
|
||||
# Check if we can edit this submission
|
||||
if current_version.submissionStatus == prisma.enums.SubmissionStatus.REJECTED:
|
||||
# Only allow editing of PENDING submissions
|
||||
if current_version.submissionStatus != prisma.enums.SubmissionStatus.PENDING:
|
||||
raise store_exceptions.InvalidOperationError(
|
||||
"Cannot edit a rejected submission"
|
||||
)
|
||||
|
||||
# For APPROVED submissions, we need to create a new version
|
||||
if current_version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
|
||||
# Create a new version for the existing listing
|
||||
return await create_store_version(
|
||||
user_id=user_id,
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
store_listing_id=current_version.storeListingId,
|
||||
name=name,
|
||||
video_url=video_url,
|
||||
agent_output_demo_url=agent_output_demo_url,
|
||||
image_urls=image_urls,
|
||||
description=description,
|
||||
sub_heading=sub_heading,
|
||||
categories=categories,
|
||||
changes_summary=changes_summary,
|
||||
recommended_schedule_cron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
f"Cannot edit a {current_version.submissionStatus.value.lower()} submission. Only pending submissions can be edited."
|
||||
)
|
||||
|
||||
# For PENDING submissions, we can update the existing version
|
||||
elif current_version.submissionStatus == prisma.enums.SubmissionStatus.PENDING:
|
||||
# Update the existing version
|
||||
updated_version = await prisma.models.StoreListingVersion.prisma().update(
|
||||
where={"id": store_listing_version_id},
|
||||
data=prisma.types.StoreListingVersionUpdateInput(
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
|
||||
)
|
||||
|
||||
if not updated_version:
|
||||
raise DatabaseError("Failed to update store listing version")
|
||||
return store_model.StoreSubmission(
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
# Update the existing version
|
||||
updated_version = await prisma.models.StoreListingVersion.prisma().update(
|
||||
where={"id": store_listing_version_id},
|
||||
data=prisma.types.StoreListingVersionUpdateInput(
|
||||
name=name,
|
||||
sub_heading=sub_heading,
|
||||
slug=current_version.StoreListing.slug,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
image_urls=image_urls,
|
||||
date_submitted=updated_version.submittedAt or updated_version.createdAt,
|
||||
status=updated_version.submissionStatus,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
store_listing_version_id=updated_version.id,
|
||||
changes_summary=changes_summary,
|
||||
video_url=video_url,
|
||||
categories=categories,
|
||||
version=updated_version.version,
|
||||
)
|
||||
subHeading=sub_heading,
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
raise store_exceptions.InvalidOperationError(
|
||||
f"Cannot edit submission with status: {current_version.submissionStatus}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
|
||||
)
|
||||
|
||||
if not updated_version:
|
||||
raise DatabaseError("Failed to update store listing version")
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=current_version.StoreListing.id,
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
name=name,
|
||||
sub_heading=sub_heading,
|
||||
slug=current_version.StoreListing.slug,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
image_urls=image_urls,
|
||||
date_submitted=updated_version.submittedAt or updated_version.createdAt,
|
||||
status=updated_version.submissionStatus,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
store_listing_version_id=updated_version.id,
|
||||
changes_summary=changes_summary,
|
||||
video_url=video_url,
|
||||
categories=categories,
|
||||
version=updated_version.version,
|
||||
)
|
||||
|
||||
except (
|
||||
store_exceptions.SubmissionNotFoundError,
|
||||
@@ -1097,38 +1093,78 @@ async def create_store_version(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
|
||||
# Get the latest version number
|
||||
latest_version = listing.Versions[0] if listing.Versions else None
|
||||
|
||||
next_version = (latest_version.version + 1) if latest_version else 1
|
||||
|
||||
# Create a new version for the existing listing
|
||||
new_version = await prisma.models.StoreListingVersion.prisma().create(
|
||||
data=prisma.types.StoreListingVersionCreateInput(
|
||||
version=next_version,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(),
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
storeListingId=store_listing_id,
|
||||
# Check if there's already a PENDING submission for this agent (any version)
|
||||
existing_pending_submission = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_first(
|
||||
where=prisma.types.StoreListingVersionWhereInput(
|
||||
storeListingId=store_listing_id,
|
||||
agentGraphId=agent_id,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
isDeleted=False,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Handle existing pending submission and create new one atomically
|
||||
async with transaction() as tx:
|
||||
# Get the latest version number first
|
||||
latest_listing = await prisma.models.StoreListing.prisma(tx).find_first(
|
||||
where=prisma.types.StoreListingWhereInput(
|
||||
id=store_listing_id, owningUserId=user_id
|
||||
),
|
||||
include={"Versions": {"order_by": {"version": "desc"}, "take": 1}},
|
||||
)
|
||||
|
||||
if not latest_listing:
|
||||
raise store_exceptions.ListingNotFoundError(
|
||||
f"Store listing not found. User ID: {user_id}, Listing ID: {store_listing_id}"
|
||||
)
|
||||
|
||||
latest_version = (
|
||||
latest_listing.Versions[0] if latest_listing.Versions else None
|
||||
)
|
||||
next_version = (latest_version.version + 1) if latest_version else 1
|
||||
|
||||
# If there's an existing pending submission, delete it atomically before creating new one
|
||||
if existing_pending_submission:
|
||||
logger.info(
|
||||
f"Found existing PENDING submission for agent {agent_id} (was v{existing_pending_submission.agentGraphVersion}, now v{agent_version}), replacing existing submission instead of creating duplicate"
|
||||
)
|
||||
await prisma.models.StoreListingVersion.prisma(tx).delete(
|
||||
where={"id": existing_pending_submission.id}
|
||||
)
|
||||
logger.debug(
|
||||
f"Deleted existing pending submission {existing_pending_submission.id}"
|
||||
)
|
||||
|
||||
# Create a new version for the existing listing
|
||||
new_version = await prisma.models.StoreListingVersion.prisma(tx).create(
|
||||
data=prisma.types.StoreListingVersionCreateInput(
|
||||
version=next_version,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(),
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
storeListingId=store_listing_id,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Created new version for listing {store_listing_id} of agent {agent_id}"
|
||||
)
|
||||
# Return submission details
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=listing.id,
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
@@ -1708,15 +1744,12 @@ async def review_store_submission(
|
||||
|
||||
# Convert to Pydantic model for consistency
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=(submission.StoreListing.id if submission.StoreListing else ""),
|
||||
agent_id=submission.agentGraphId,
|
||||
agent_version=submission.agentGraphVersion,
|
||||
name=submission.name,
|
||||
sub_heading=submission.subHeading,
|
||||
slug=(
|
||||
submission.StoreListing.slug
|
||||
if hasattr(submission, "storeListing") and submission.StoreListing
|
||||
else ""
|
||||
),
|
||||
slug=(submission.StoreListing.slug if submission.StoreListing else ""),
|
||||
description=submission.description,
|
||||
instructions=submission.instructions,
|
||||
image_urls=submission.imageUrls or [],
|
||||
@@ -1818,9 +1851,7 @@ async def get_admin_listings_with_versions(
|
||||
where = prisma.types.StoreListingWhereInput(**where_dict)
|
||||
include = prisma.types.StoreListingInclude(
|
||||
Versions=prisma.types.FindManyStoreListingVersionArgsFromStoreListing(
|
||||
order_by=prisma.types._StoreListingVersion_version_OrderByInput(
|
||||
version="desc"
|
||||
)
|
||||
order_by={"version": "desc"}
|
||||
),
|
||||
OwningUser=True,
|
||||
)
|
||||
@@ -1845,6 +1876,7 @@ async def get_admin_listings_with_versions(
|
||||
# If we have versions, turn them into StoreSubmission models
|
||||
for version in listing.Versions or []:
|
||||
version_model = store_model.StoreSubmission(
|
||||
listing_id=listing.id,
|
||||
agent_id=version.agentGraphId,
|
||||
agent_version=version.agentGraphVersion,
|
||||
name=version.name,
|
||||
|
||||
@@ -110,6 +110,7 @@ class Profile(pydantic.BaseModel):
|
||||
|
||||
|
||||
class StoreSubmission(pydantic.BaseModel):
|
||||
listing_id: str
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
name: str
|
||||
@@ -164,8 +165,12 @@ class StoreListingsWithVersionsResponse(pydantic.BaseModel):
|
||||
|
||||
|
||||
class StoreSubmissionRequest(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
agent_id: str = pydantic.Field(
|
||||
..., min_length=1, description="Agent ID cannot be empty"
|
||||
)
|
||||
agent_version: int = pydantic.Field(
|
||||
..., gt=0, description="Agent version must be greater than 0"
|
||||
)
|
||||
slug: str
|
||||
name: str
|
||||
sub_heading: str
|
||||
|
||||
@@ -138,6 +138,7 @@ def test_creator_details():
|
||||
|
||||
def test_store_submission():
|
||||
submission = store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
@@ -159,6 +160,7 @@ def test_store_submissions_response():
|
||||
response = store_model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="listing123",
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
|
||||
@@ -521,6 +521,7 @@ def test_get_submissions_success(
|
||||
mocked_value = store_model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
store_model.StoreSubmission(
|
||||
listing_id="test-listing-id",
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
image_urls=["test.jpg"],
|
||||
|
||||
@@ -39,7 +39,7 @@ import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
@@ -113,7 +113,7 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
with launch_darkly_context():
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
@@ -49,7 +50,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -81,7 +82,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": LlmModel.GPT4O,
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
|
||||
@@ -6,6 +6,9 @@ import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
|
||||
from prisma.types import Serializable
|
||||
|
||||
from backend.sdk import (
|
||||
BaseWebhooksManager,
|
||||
@@ -84,7 +87,9 @@ class AirtableWebhookManager(BaseWebhooksManager):
|
||||
# update webhook config
|
||||
await update_webhook(
|
||||
webhook.id,
|
||||
config={"base_id": base_id, "cursor": response.cursor},
|
||||
config=cast(
|
||||
dict[str, Serializable], {"base_id": base_id, "cursor": response.cursor}
|
||||
),
|
||||
)
|
||||
|
||||
event_type = "notification"
|
||||
|
||||
@@ -182,13 +182,10 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
)
|
||||
# Ensure items is never None
|
||||
if items is None:
|
||||
# Handle missing key, null value, or valid list value
|
||||
if isinstance(first_result, dict):
|
||||
items = first_result.get("items") or []
|
||||
else:
|
||||
items = []
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
|
||||
2896
autogpt_platform/backend/backend/blocks/google/docs.py
Normal file
2896
autogpt_platform/backend/backend/blocks/google/docs.py
Normal file
File diff suppressed because it is too large
Load Diff
184
autogpt_platform/backend/backend/blocks/helpers/review.py
Normal file
184
autogpt_platform/backend/backend/blocks/helpers/review.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Shared helpers for Human-In-The-Loop (HITL) review functionality.
|
||||
Used by both the dedicated HumanInTheLoopBlock and blocks that require human review.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.execution import ExecutionContext, ExecutionStatus
|
||||
from backend.data.human_review import ReviewResult
|
||||
from backend.executor.manager import async_update_node_execution_status
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReviewDecision(BaseModel):
|
||||
"""Result of a review decision."""
|
||||
|
||||
should_proceed: bool
|
||||
message: str
|
||||
review_result: ReviewResult
|
||||
|
||||
|
||||
class HITLReviewHelper:
|
||||
"""Helper class for Human-In-The-Loop review operations."""
|
||||
|
||||
@staticmethod
|
||||
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
|
||||
"""Create or retrieve a human review from the database."""
|
||||
return await get_database_manager_async_client().get_or_create_human_review(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_node_execution_status(**kwargs) -> None:
|
||||
"""Update the execution status of a node."""
|
||||
await async_update_node_execution_status(
|
||||
db_client=get_database_manager_async_client(), **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_review_processed_status(
|
||||
node_exec_id: str, processed: bool
|
||||
) -> None:
|
||||
"""Update the processed status of a review."""
|
||||
return await get_database_manager_async_client().update_review_processed_status(
|
||||
node_exec_id, processed
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _handle_review_request(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
) -> Optional[ReviewResult]:
|
||||
"""
|
||||
Handle a review request for a block that requires human review.
|
||||
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
graph_version: Version of the graph
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
|
||||
Returns:
|
||||
ReviewResult if review is complete, None if waiting for human input
|
||||
|
||||
Raises:
|
||||
Exception: If review creation or status update fails
|
||||
"""
|
||||
# Skip review if safe mode is disabled - return auto-approved result
|
||||
if not execution_context.safe_mode:
|
||||
logger.info(
|
||||
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
|
||||
)
|
||||
return ReviewResult(
|
||||
data=input_data,
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="Auto-approved (safe mode disabled)",
|
||||
processed=True,
|
||||
node_exec_id=node_exec_id,
|
||||
)
|
||||
|
||||
result = await HITLReviewHelper.get_or_create_human_review(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
input_data=input_data,
|
||||
message=f"Review required for {block_name} execution",
|
||||
editable=editable,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
logger.info(
|
||||
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
|
||||
)
|
||||
await HITLReviewHelper.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
return None # Signal that execution should pause
|
||||
|
||||
# Mark review as processed if not already done
|
||||
if not result.processed:
|
||||
await HITLReviewHelper.update_review_processed_status(
|
||||
node_exec_id=node_exec_id, processed=True
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def handle_review_decision(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
) -> Optional[ReviewDecision]:
|
||||
"""
|
||||
Handle a review request and return the decision in a single call.
|
||||
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
graph_version: Version of the graph
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
|
||||
Returns:
|
||||
ReviewDecision if review is complete (approved/rejected),
|
||||
None if execution should pause (awaiting review)
|
||||
"""
|
||||
review_result = await HITLReviewHelper._handle_review_request(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=block_name,
|
||||
editable=editable,
|
||||
)
|
||||
|
||||
if review_result is None:
|
||||
# Still awaiting review - return None to pause execution
|
||||
return None
|
||||
|
||||
# Review is complete, determine outcome
|
||||
should_proceed = review_result.status == ReviewStatus.APPROVED
|
||||
message = review_result.message or (
|
||||
"Execution approved by reviewer"
|
||||
if should_proceed
|
||||
else "Execution rejected by reviewer"
|
||||
)
|
||||
|
||||
return ReviewDecision(
|
||||
should_proceed=should_proceed, message=message, review_result=review_result
|
||||
)
|
||||
@@ -3,6 +3,7 @@ from typing import Any
|
||||
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.blocks.helpers.review import HITLReviewHelper
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -11,11 +12,9 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext, ExecutionStatus
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.human_review import ReviewResult
|
||||
from backend.data.model import SchemaField
|
||||
from backend.executor.manager import async_update_node_execution_status
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -72,32 +71,26 @@ class HumanInTheLoopBlock(Block):
|
||||
("approved_data", {"name": "John Doe", "age": 30}),
|
||||
],
|
||||
test_mock={
|
||||
"get_or_create_human_review": lambda *_args, **_kwargs: ReviewResult(
|
||||
data={"name": "John Doe", "age": 30},
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="",
|
||||
processed=False,
|
||||
node_exec_id="test-node-exec-id",
|
||||
),
|
||||
"update_node_execution_status": lambda *_args, **_kwargs: None,
|
||||
"update_review_processed_status": lambda *_args, **_kwargs: None,
|
||||
"handle_review_decision": lambda **kwargs: type(
|
||||
"ReviewDecision",
|
||||
(),
|
||||
{
|
||||
"should_proceed": True,
|
||||
"message": "Test approval message",
|
||||
"review_result": ReviewResult(
|
||||
data={"name": "John Doe", "age": 30},
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="",
|
||||
processed=False,
|
||||
node_exec_id="test-node-exec-id",
|
||||
),
|
||||
},
|
||||
)(),
|
||||
},
|
||||
)
|
||||
|
||||
async def get_or_create_human_review(self, **kwargs):
|
||||
return await get_database_manager_async_client().get_or_create_human_review(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def update_node_execution_status(self, **kwargs):
|
||||
return await async_update_node_execution_status(
|
||||
db_client=get_database_manager_async_client(), **kwargs
|
||||
)
|
||||
|
||||
async def update_review_processed_status(self, node_exec_id: str, processed: bool):
|
||||
return await get_database_manager_async_client().update_review_processed_status(
|
||||
node_exec_id, processed
|
||||
)
|
||||
async def handle_review_decision(self, **kwargs):
|
||||
return await HITLReviewHelper.handle_review_decision(**kwargs)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -109,7 +102,7 @@ class HumanInTheLoopBlock(Block):
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
**_kwargs,
|
||||
) -> BlockOutput:
|
||||
if not execution_context.safe_mode:
|
||||
logger.info(
|
||||
@@ -119,48 +112,28 @@ class HumanInTheLoopBlock(Block):
|
||||
yield "review_message", "Auto-approved (safe mode disabled)"
|
||||
return
|
||||
|
||||
try:
|
||||
result = await self.get_or_create_human_review(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
input_data=input_data.data,
|
||||
message=input_data.name,
|
||||
editable=input_data.editable,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in HITL block for node {node_exec_id}: {str(e)}")
|
||||
raise
|
||||
decision = await self.handle_review_decision(
|
||||
input_data=input_data.data,
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=self.name,
|
||||
editable=input_data.editable,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
logger.info(
|
||||
f"HITL block pausing execution for node {node_exec_id} - awaiting human review"
|
||||
)
|
||||
try:
|
||||
await self.update_node_execution_status(
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.REVIEW,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update node status for HITL block {node_exec_id}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
if decision is None:
|
||||
return
|
||||
|
||||
if not result.processed:
|
||||
await self.update_review_processed_status(
|
||||
node_exec_id=node_exec_id, processed=True
|
||||
)
|
||||
status = decision.review_result.status
|
||||
if status == ReviewStatus.APPROVED:
|
||||
yield "approved_data", decision.review_result.data
|
||||
elif status == ReviewStatus.REJECTED:
|
||||
yield "rejected_data", decision.review_result.data
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected review status: {status}")
|
||||
|
||||
if result.status == ReviewStatus.APPROVED:
|
||||
yield "approved_data", result.data
|
||||
if result.message:
|
||||
yield "review_message", result.message
|
||||
|
||||
elif result.status == ReviewStatus.REJECTED:
|
||||
yield "rejected_data", result.data
|
||||
if result.message:
|
||||
yield "review_message", result.message
|
||||
if decision.message:
|
||||
yield "review_message", decision.message
|
||||
|
||||
@@ -92,8 +92,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
O1 = "o1"
|
||||
O1_MINI = "o1-mini"
|
||||
# GPT-5 models
|
||||
GPT5 = "gpt-5-2025-08-07"
|
||||
GPT5_2 = "gpt-5.2-2025-12-11"
|
||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||
GPT5 = "gpt-5-2025-08-07"
|
||||
GPT5_MINI = "gpt-5-mini-2025-08-07"
|
||||
GPT5_NANO = "gpt-5-nano-2025-08-07"
|
||||
GPT5_CHAT = "gpt-5-chat-latest"
|
||||
@@ -194,8 +195,9 @@ MODEL_METADATA = {
|
||||
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_2: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
|
||||
@@ -303,6 +305,8 @@ MODEL_METADATA = {
|
||||
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
|
||||
}
|
||||
|
||||
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
|
||||
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_METADATA:
|
||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||
@@ -790,7 +794,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -855,7 +859,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
||||
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
||||
test_input={
|
||||
"model": LlmModel.GPT4O,
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expected_format": {
|
||||
"key1": "value1",
|
||||
@@ -1221,7 +1225,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -1317,7 +1321,7 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for summarizing the text.",
|
||||
)
|
||||
focus: str = SchemaField(
|
||||
@@ -1534,7 +1538,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for the conversation.",
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
@@ -1572,7 +1576,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
},
|
||||
{"role": "user", "content": "Where was it played?"},
|
||||
],
|
||||
"model": LlmModel.GPT4O,
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -1635,7 +1639,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for generating the list.",
|
||||
advanced=True,
|
||||
)
|
||||
@@ -1692,7 +1696,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
||||
"fictional worlds."
|
||||
),
|
||||
"model": LlmModel.GPT4O,
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_retries": 3,
|
||||
"force_json_output": False,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -226,7 +226,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
model: llm.LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=llm.LlmModel.GPT4O,
|
||||
default=llm.DEFAULT_LLM_MODEL,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -391,8 +391,12 @@ class SmartDecisionMakerBlock(Block):
|
||||
"""
|
||||
block = sink_node.block
|
||||
|
||||
# Use custom name from node metadata if set, otherwise fall back to block.name
|
||||
custom_name = sink_node.metadata.get("customized_name")
|
||||
tool_name = custom_name if custom_name else block.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(block.name),
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"description": block.description,
|
||||
}
|
||||
sink_block_input_schema = block.input_schema
|
||||
@@ -489,8 +493,12 @@ class SmartDecisionMakerBlock(Block):
|
||||
f"Sink graph metadata not found: {graph_id} {graph_version}"
|
||||
)
|
||||
|
||||
# Use custom name from node metadata if set, otherwise fall back to graph name
|
||||
custom_name = sink_node.metadata.get("customized_name")
|
||||
tool_name = custom_name if custom_name else sink_graph_meta.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"description": sink_graph_meta.description,
|
||||
}
|
||||
|
||||
@@ -975,10 +983,28 @@ class SmartDecisionMakerBlock(Block):
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
execution_processor: "ExecutionProcessor",
|
||||
nodes_to_skip: set[str] | None = None,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
tool_functions = await self._create_tool_node_signatures(node_id)
|
||||
original_tool_count = len(tool_functions)
|
||||
|
||||
# Filter out tools for nodes that should be skipped (e.g., missing optional credentials)
|
||||
if nodes_to_skip:
|
||||
tool_functions = [
|
||||
tf
|
||||
for tf in tool_functions
|
||||
if tf.get("function", {}).get("_sink_node_id") not in nodes_to_skip
|
||||
]
|
||||
|
||||
# Only raise error if we had tools but they were all filtered out
|
||||
if original_tool_count > 0 and not tool_functions:
|
||||
raise ValueError(
|
||||
"No available tools to execute - all downstream nodes are unavailable "
|
||||
"(possibly due to missing optional credentials)"
|
||||
)
|
||||
|
||||
yield "tool_functions", json.dumps(tool_functions)
|
||||
|
||||
conversation_history = input_data.conversation_history or []
|
||||
|
||||
@@ -196,6 +196,15 @@ class TestXMLParserBlockSecurity:
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
|
||||
pass
|
||||
|
||||
async def test_rejects_text_outside_root(self):
|
||||
"""Ensure parser surfaces readable errors for invalid root text."""
|
||||
block = XMLParserBlock()
|
||||
invalid_xml = "<root><child>value</child></root> trailing"
|
||||
|
||||
with pytest.raises(ValueError, match="text outside the root element"):
|
||||
async for _ in block.run(XMLParserBlock.Input(input_xml=invalid_xml)):
|
||||
pass
|
||||
|
||||
|
||||
class TestStoreMediaFileSecurity:
|
||||
"""Test file storage security limits."""
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestLLMStatsTracking:
|
||||
|
||||
response = await llm.llm_call(
|
||||
credentials=llm.TEST_CREDENTIALS,
|
||||
llm_model=llm.LlmModel.GPT4O,
|
||||
llm_model=llm.DEFAULT_LLM_MODEL,
|
||||
prompt=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100,
|
||||
)
|
||||
@@ -65,7 +65,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.LlmModel.GPT4O,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore # type: ignore
|
||||
)
|
||||
|
||||
@@ -109,7 +109,7 @@ class TestLLMStatsTracking:
|
||||
# Run the block
|
||||
input_data = llm.AITextGeneratorBlock.Input(
|
||||
prompt="Generate text",
|
||||
model=llm.LlmModel.GPT4O,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -170,7 +170,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.LlmModel.GPT4O,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
@@ -228,7 +228,7 @@ class TestLLMStatsTracking:
|
||||
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text=long_text,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=100, # Small chunks
|
||||
chunk_overlap=10,
|
||||
@@ -299,7 +299,7 @@ class TestLLMStatsTracking:
|
||||
# Test with very short text (should only need 1 chunk + 1 final summary)
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="This is a short text.",
|
||||
model=llm.LlmModel.GPT4O,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=1000, # Large enough to avoid chunking
|
||||
)
|
||||
@@ -346,7 +346,7 @@ class TestLLMStatsTracking:
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
],
|
||||
model=llm.LlmModel.GPT4O,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -387,7 +387,7 @@ class TestLLMStatsTracking:
|
||||
# Run the block
|
||||
input_data = llm.AIListGeneratorBlock.Input(
|
||||
focus="test items",
|
||||
model=llm.LlmModel.GPT4O,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_retries=3,
|
||||
)
|
||||
@@ -469,7 +469,7 @@ class TestLLMStatsTracking:
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test",
|
||||
expected_format={"result": "desc"},
|
||||
model=llm.LlmModel.GPT4O,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -513,7 +513,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.LlmModel.GPT4O,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
style=llm.SummaryStyle.BULLET_POINTS,
|
||||
)
|
||||
@@ -558,7 +558,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.LlmModel.GPT4O,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
style=llm.SummaryStyle.BULLET_POINTS,
|
||||
max_tokens=1000,
|
||||
@@ -593,7 +593,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.LlmModel.GPT4O,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
@@ -623,7 +623,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.LlmModel.GPT4O,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
max_tokens=1000,
|
||||
)
|
||||
@@ -654,7 +654,7 @@ class TestAITextSummarizerValidation:
|
||||
# Create input data
|
||||
input_data = llm.AITextSummarizerBlock.Input(
|
||||
text="Some text to summarize",
|
||||
model=llm.LlmModel.GPT4O,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
Standalone tests for pin name sanitization that can run without full backend dependencies.
|
||||
|
||||
These tests verify the core sanitization logic independently of the full system.
|
||||
Run with: python -m pytest test_pin_sanitization_standalone.py -v
|
||||
Or simply: python test_pin_sanitization_standalone.py
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
# Simulate the exact cleanup function from SmartDecisionMakerBlock
|
||||
def cleanup(s: str) -> str:
|
||||
"""Clean up names for use as tool function names."""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
|
||||
|
||||
|
||||
# Simulate the key parts of parse_execution_output
|
||||
def simulate_tool_routing(
|
||||
emit_key: str,
|
||||
sink_node_id: str,
|
||||
sink_pin_name: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Simulate the routing comparison from parse_execution_output.
|
||||
|
||||
Returns True if routing would succeed, False otherwise.
|
||||
"""
|
||||
if not emit_key.startswith("tools_^_") or "_~_" not in emit_key:
|
||||
return False
|
||||
|
||||
# Extract routing info from emit key: tools_^_{node_id}_~_{field}
|
||||
selector = emit_key[8:] # Remove "tools_^_"
|
||||
target_node_id, target_input_pin = selector.split("_~_", 1)
|
||||
|
||||
# Current (buggy) comparison - direct string comparison
|
||||
return target_node_id == sink_node_id and target_input_pin == sink_pin_name
|
||||
|
||||
|
||||
def simulate_fixed_tool_routing(
|
||||
emit_key: str,
|
||||
sink_node_id: str,
|
||||
sink_pin_name: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Simulate the FIXED routing comparison.
|
||||
|
||||
The fix: sanitize sink_pin_name before comparison.
|
||||
"""
|
||||
if not emit_key.startswith("tools_^_") or "_~_" not in emit_key:
|
||||
return False
|
||||
|
||||
selector = emit_key[8:]
|
||||
target_node_id, target_input_pin = selector.split("_~_", 1)
|
||||
|
||||
# Fixed comparison - sanitize sink_pin_name
|
||||
return target_node_id == sink_node_id and target_input_pin == cleanup(sink_pin_name)
|
||||
|
||||
|
||||
class TestCleanupFunction:
|
||||
"""Tests for the cleanup function."""
|
||||
|
||||
def test_spaces_to_underscores(self):
|
||||
assert cleanup("Max Keyword Difficulty") == "max_keyword_difficulty"
|
||||
|
||||
def test_mixed_case_to_lowercase(self):
|
||||
assert cleanup("MaxKeywordDifficulty") == "maxkeyworddifficulty"
|
||||
|
||||
def test_special_chars_to_underscores(self):
|
||||
assert cleanup("field@name!") == "field_name_"
|
||||
assert cleanup("CPC ($)") == "cpc____"
|
||||
|
||||
def test_preserves_valid_chars(self):
|
||||
assert cleanup("valid_name-123") == "valid_name-123"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert cleanup("") == ""
|
||||
|
||||
def test_consecutive_spaces(self):
|
||||
assert cleanup("a b") == "a___b"
|
||||
|
||||
def test_unicode(self):
|
||||
assert cleanup("café") == "caf_"
|
||||
|
||||
|
||||
class TestCurrentRoutingBehavior:
|
||||
"""Tests demonstrating the current (buggy) routing behavior."""
|
||||
|
||||
def test_exact_match_works(self):
|
||||
"""When names match exactly, routing works."""
|
||||
emit_key = "tools_^_node-123_~_query"
|
||||
assert simulate_tool_routing(emit_key, "node-123", "query") is True
|
||||
|
||||
def test_spaces_cause_failure(self):
|
||||
"""When sink_pin has spaces, routing fails."""
|
||||
sanitized = cleanup("Max Keyword Difficulty")
|
||||
emit_key = f"tools_^_node-123_~_{sanitized}"
|
||||
assert simulate_tool_routing(emit_key, "node-123", "Max Keyword Difficulty") is False
|
||||
|
||||
def test_special_chars_cause_failure(self):
|
||||
"""When sink_pin has special chars, routing fails."""
|
||||
sanitized = cleanup("CPC ($)")
|
||||
emit_key = f"tools_^_node-123_~_{sanitized}"
|
||||
assert simulate_tool_routing(emit_key, "node-123", "CPC ($)") is False
|
||||
|
||||
|
||||
class TestFixedRoutingBehavior:
|
||||
"""Tests demonstrating the fixed routing behavior."""
|
||||
|
||||
def test_exact_match_still_works(self):
|
||||
"""When names match exactly, routing still works."""
|
||||
emit_key = "tools_^_node-123_~_query"
|
||||
assert simulate_fixed_tool_routing(emit_key, "node-123", "query") is True
|
||||
|
||||
def test_spaces_work_with_fix(self):
|
||||
"""With the fix, spaces in sink_pin work."""
|
||||
sanitized = cleanup("Max Keyword Difficulty")
|
||||
emit_key = f"tools_^_node-123_~_{sanitized}"
|
||||
assert simulate_fixed_tool_routing(emit_key, "node-123", "Max Keyword Difficulty") is True
|
||||
|
||||
def test_special_chars_work_with_fix(self):
|
||||
"""With the fix, special chars in sink_pin work."""
|
||||
sanitized = cleanup("CPC ($)")
|
||||
emit_key = f"tools_^_node-123_~_{sanitized}"
|
||||
assert simulate_fixed_tool_routing(emit_key, "node-123", "CPC ($)") is True
|
||||
|
||||
|
||||
class TestBugReproduction:
|
||||
"""Exact reproduction of the reported bug."""
|
||||
|
||||
def test_max_keyword_difficulty_bug(self):
|
||||
"""
|
||||
Reproduce the exact bug from the issue:
|
||||
|
||||
"For this agent specifically the input pin has space and unsanitized,
|
||||
the frontend somehow connect without sanitizing creating a link like:
|
||||
tools_^_767682f5-..._~_Max Keyword Difficulty
|
||||
but what's produced by backend is
|
||||
tools_^_767682f5-..._~_max_keyword_difficulty
|
||||
so the tool calls go into the void"
|
||||
"""
|
||||
node_id = "767682f5-fake-uuid"
|
||||
original_field = "Max Keyword Difficulty"
|
||||
sanitized_field = cleanup(original_field)
|
||||
|
||||
# What backend produces (emit key)
|
||||
emit_key = f"tools_^_{node_id}_~_{sanitized_field}"
|
||||
assert emit_key == f"tools_^_{node_id}_~_max_keyword_difficulty"
|
||||
|
||||
# What frontend link has (sink_pin_name)
|
||||
frontend_sink = original_field
|
||||
|
||||
# Current behavior: FAILS
|
||||
assert simulate_tool_routing(emit_key, node_id, frontend_sink) is False
|
||||
|
||||
# With fix: WORKS
|
||||
assert simulate_fixed_tool_routing(emit_key, node_id, frontend_sink) is True
|
||||
|
||||
|
||||
class TestCommonFieldNamePatterns:
|
||||
"""Test common field name patterns that could cause issues."""
|
||||
|
||||
FIELD_NAMES = [
|
||||
"Max Keyword Difficulty",
|
||||
"Search Volume (Monthly)",
|
||||
"CPC ($)",
|
||||
"User's Input",
|
||||
"Target URL",
|
||||
"API Response",
|
||||
"Query #1",
|
||||
"First Name",
|
||||
"Last Name",
|
||||
"Email Address",
|
||||
"Phone Number",
|
||||
"Total Cost ($)",
|
||||
"Discount (%)",
|
||||
"Created At",
|
||||
"Updated At",
|
||||
"Is Active",
|
||||
]
|
||||
|
||||
def test_current_behavior_fails_for_special_names(self):
|
||||
"""Current behavior fails for names with spaces/special chars."""
|
||||
failed = []
|
||||
for name in self.FIELD_NAMES:
|
||||
sanitized = cleanup(name)
|
||||
emit_key = f"tools_^_node_~_{sanitized}"
|
||||
if not simulate_tool_routing(emit_key, "node", name):
|
||||
failed.append(name)
|
||||
|
||||
# All names with spaces should fail
|
||||
names_with_spaces = [n for n in self.FIELD_NAMES if " " in n or any(c in n for c in "()$%#'")]
|
||||
assert set(failed) == set(names_with_spaces)
|
||||
|
||||
def test_fixed_behavior_works_for_all_names(self):
|
||||
"""Fixed behavior works for all names."""
|
||||
for name in self.FIELD_NAMES:
|
||||
sanitized = cleanup(name)
|
||||
emit_key = f"tools_^_node_~_{sanitized}"
|
||||
assert simulate_fixed_tool_routing(emit_key, "node", name) is True, f"Failed for: {name}"
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""Run all tests manually without pytest."""
|
||||
import traceback
|
||||
|
||||
test_classes = [
|
||||
TestCleanupFunction,
|
||||
TestCurrentRoutingBehavior,
|
||||
TestFixedRoutingBehavior,
|
||||
TestBugReproduction,
|
||||
TestCommonFieldNamePatterns,
|
||||
]
|
||||
|
||||
total = 0
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test_class in test_classes:
|
||||
print(f"\n{test_class.__name__}:")
|
||||
instance = test_class()
|
||||
for name in dir(instance):
|
||||
if name.startswith("test_"):
|
||||
total += 1
|
||||
try:
|
||||
getattr(instance, name)()
|
||||
print(f" ✓ {name}")
|
||||
passed += 1
|
||||
except AssertionError as e:
|
||||
print(f" ✗ {name}: {e}")
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
print(f" ✗ {name}: {e}")
|
||||
traceback.print_exc()
|
||||
failed += 1
|
||||
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Total: {total}, Passed: {passed}, Failed: {failed}")
|
||||
return failed == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
success = run_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -233,7 +233,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
# Create test input
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Should I continue with this task?",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -335,7 +335,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2, # Set retry to 2 for testing
|
||||
agent_mode_max_iterations=0,
|
||||
@@ -402,7 +402,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -462,7 +462,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -526,7 +526,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Search for keywords",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -648,7 +648,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
agent_mode_max_iterations=0,
|
||||
@@ -722,7 +722,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Simple prompt",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -778,7 +778,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Another test",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
@@ -931,7 +931,7 @@ async def test_smart_decision_maker_agent_mode():
|
||||
# Test agent mode with max_iterations = 3
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Complete this task using tools",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=3, # Enable agent mode with 3 max iterations
|
||||
)
|
||||
@@ -1020,7 +1020,7 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
# Test default behavior (traditional mode)
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.LlmModel.GPT4O,
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
agent_mode_max_iterations=0, # Traditional mode
|
||||
)
|
||||
@@ -1057,3 +1057,153 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
) # Should yield individual tool parameters
|
||||
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
|
||||
assert "conversations" in outputs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
|
||||
mock_node.block = StoreValueBlock()
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the customized name (cleaned up)
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "my_custom_tool_name" # Cleaned version
|
||||
assert result["function"]["_sink_node_id"] == "test-node-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.block = StoreValueBlock()
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the block's default name
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "storevalueblock" # Default block name cleaned
|
||||
assert result["function"]["_sink_node_id"] == "test-node-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-agent-node-id"
|
||||
mock_node.metadata = {"customized_name": "My Custom Agent"}
|
||||
mock_node.input_default = {
|
||||
"graph_id": "test-graph-id",
|
||||
"graph_version": 1,
|
||||
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
|
||||
}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "test_input"
|
||||
|
||||
# Mock the database client
|
||||
mock_graph_meta = MagicMock()
|
||||
mock_graph_meta.name = "Original Agent Name"
|
||||
mock_graph_meta.description = "Agent description"
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the customized name (cleaned up)
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "my_custom_agent" # Cleaned version
|
||||
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
|
||||
"""Test that agent node falls back to graph name when no customized_name."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-agent-node-id"
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.input_default = {
|
||||
"graph_id": "test-graph-id",
|
||||
"graph_version": 1,
|
||||
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
|
||||
}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "test_input"
|
||||
|
||||
# Mock the database client
|
||||
mock_graph_meta = MagicMock()
|
||||
mock_graph_meta.name = "Original Agent Name"
|
||||
mock_graph_meta.description = "Agent description"
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the graph's default name
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "original_agent_name" # Graph name cleaned
|
||||
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
|
||||
|
||||
@@ -0,0 +1,916 @@
|
||||
"""
|
||||
Tests for SmartDecisionMaker agent mode specific failure modes.
|
||||
|
||||
Covers failure modes:
|
||||
2. Silent Tool Failures in Agent Mode
|
||||
3. Unbounded Agent Mode Iterations
|
||||
10. Unbounded Agent Iterations
|
||||
12. Stale Credentials in Agent Mode
|
||||
13. Tool Signature Cache Invalidation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.smart_decision_maker import (
|
||||
SmartDecisionMakerBlock,
|
||||
ExecutionParams,
|
||||
ToolInfo,
|
||||
)
|
||||
|
||||
|
||||
class TestSilentToolFailuresInAgentMode:
|
||||
"""
|
||||
Tests for Failure Mode #2: Silent Tool Failures in Agent Mode
|
||||
|
||||
When tool execution fails in agent mode, the error is converted to a
|
||||
tool response and execution continues silently.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_failure_converted_to_response(self):
|
||||
"""
|
||||
Test that tool execution failures are silently converted to responses.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# First response: tool call
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = "call_1"
|
||||
mock_tool_call.function.name = "failing_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"param": "value"})
|
||||
|
||||
mock_response_1 = MagicMock()
|
||||
mock_response_1.response = None
|
||||
mock_response_1.tool_calls = [mock_tool_call]
|
||||
mock_response_1.prompt_tokens = 50
|
||||
mock_response_1.completion_tokens = 25
|
||||
mock_response_1.reasoning = None
|
||||
mock_response_1.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
}
|
||||
|
||||
# Second response: finish after seeing error
|
||||
mock_response_2 = MagicMock()
|
||||
mock_response_2.response = "I encountered an error"
|
||||
mock_response_2.tool_calls = []
|
||||
mock_response_2.prompt_tokens = 30
|
||||
mock_response_2.completion_tokens = 15
|
||||
mock_response_2.reasoning = None
|
||||
mock_response_2.raw_response = {"role": "assistant", "content": "I encountered an error"}
|
||||
|
||||
llm_call_count = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal llm_call_count
|
||||
llm_call_count += 1
|
||||
if llm_call_count == 1:
|
||||
return mock_response_1
|
||||
return mock_response_2
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "failing_tool",
|
||||
"_sink_node_id": "sink-node",
|
||||
"_field_mapping": {"param": "param"},
|
||||
"parameters": {
|
||||
"properties": {"param": {"type": "string"}},
|
||||
"required": ["param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Mock database client that will fail
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_node.side_effect = Exception("Database connection failed!")
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Do something",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=5,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# The execution completed (didn't crash)
|
||||
assert "finished" in outputs or "conversations" in outputs
|
||||
|
||||
# BUG: The tool failure was silent - user doesn't know what happened
|
||||
# The error was just logged and converted to a tool response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_failure_causes_infinite_retry_loop(self):
|
||||
"""
|
||||
Test scenario where LLM keeps calling the same failing tool.
|
||||
|
||||
If tool fails but LLM doesn't realize it, it may keep trying.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
call_count = 0
|
||||
max_calls = 10 # Limit for test
|
||||
|
||||
def create_tool_call_response():
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = f"call_{call_count}"
|
||||
mock_tool_call.function.name = "persistent_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"retry": call_count})
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = None
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": f"call_{call_count}"}]
|
||||
}
|
||||
return mock_response
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
if call_count >= max_calls:
|
||||
# Eventually finish to prevent actual infinite loop in test
|
||||
final = MagicMock()
|
||||
final.response = "Giving up"
|
||||
final.tool_calls = []
|
||||
final.prompt_tokens = 10
|
||||
final.completion_tokens = 5
|
||||
final.reasoning = None
|
||||
final.raw_response = {"role": "assistant", "content": "Giving up"}
|
||||
return final
|
||||
|
||||
return create_tool_call_response()
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "persistent_tool",
|
||||
"_sink_node_id": "sink-node",
|
||||
"_field_mapping": {"retry": "retry"},
|
||||
"parameters": {
|
||||
"properties": {"retry": {"type": "integer"}},
|
||||
"required": ["retry"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_node.side_effect = Exception("Always fails!")
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Keep trying",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=-1, # Infinite mode!
|
||||
)
|
||||
|
||||
# Use timeout to prevent actual infinite loop
|
||||
try:
|
||||
async with asyncio.timeout(5):
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
except asyncio.TimeoutError:
|
||||
pass # Expected if we hit infinite loop
|
||||
|
||||
# Document that many calls were made before we gave up
|
||||
assert call_count >= max_calls - 1, \
|
||||
f"Expected many retries, got {call_count}"
|
||||
|
||||
|
||||
class TestUnboundedAgentIterations:
|
||||
"""
|
||||
Tests for Failure Mode #3 and #10: Unbounded Agent Mode Iterations
|
||||
|
||||
With max_iterations = -1, the agent can run forever, consuming
|
||||
unlimited tokens and compute resources.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_infinite_mode_requires_llm_to_stop(self):
|
||||
"""
|
||||
Test that infinite mode (-1) only stops when LLM stops making tool calls.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
iterations = 0
|
||||
max_test_iterations = 20
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal iterations
|
||||
iterations += 1
|
||||
|
||||
if iterations >= max_test_iterations:
|
||||
# Stop to prevent actual infinite loop
|
||||
resp = MagicMock()
|
||||
resp.response = "Finally done"
|
||||
resp.tool_calls = []
|
||||
resp.prompt_tokens = 10
|
||||
resp.completion_tokens = 5
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": "Done"}
|
||||
return resp
|
||||
|
||||
# Keep making tool calls
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = f"call_{iterations}"
|
||||
tool_call.function.name = "counter_tool"
|
||||
tool_call.function.arguments = json.dumps({"count": iterations})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": f"call_{iterations}"}]
|
||||
}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "counter_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {"count": "count"},
|
||||
"parameters": {
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
"required": ["count"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.node_exec_id = "exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {"count": 1})
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {"result": "ok"}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Count forever",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=-1, # INFINITE MODE
|
||||
)
|
||||
|
||||
async with asyncio.timeout(10):
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# We ran many iterations before stopping
|
||||
assert iterations == max_test_iterations
|
||||
# BUG: No built-in safeguard against runaway iterations
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_limit_enforced(self):
|
||||
"""
|
||||
Test that max_iterations limit is properly enforced.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
iterations = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal iterations
|
||||
iterations += 1
|
||||
|
||||
# Always make tool calls (never finish voluntarily)
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = f"call_{iterations}"
|
||||
tool_call.function.name = "endless_tool"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": f"call_{iterations}"}]
|
||||
}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "endless_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.node_exec_id = "exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
|
||||
|
||||
MAX_ITERATIONS = 3
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Run forever",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=MAX_ITERATIONS,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Should have stopped at max iterations
|
||||
assert iterations == MAX_ITERATIONS
|
||||
assert "finished" in outputs
|
||||
assert "limit reached" in outputs["finished"].lower()
|
||||
|
||||
|
||||
class TestStaleCredentialsInAgentMode:
|
||||
"""
|
||||
Tests for Failure Mode #12: Stale Credentials in Agent Mode
|
||||
|
||||
Credentials are validated once at start but can expire during
|
||||
long-running agent mode executions.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credentials_not_revalidated_between_iterations(self):
|
||||
"""
|
||||
Test that credentials are used without revalidation in agent mode.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
credential_check_count = 0
|
||||
iteration = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal credential_check_count, iteration
|
||||
iteration += 1
|
||||
|
||||
# Simulate credential check (in real code this happens in llm_call)
|
||||
credential_check_count += 1
|
||||
|
||||
if iteration >= 3:
|
||||
resp = MagicMock()
|
||||
resp.response = "Done"
|
||||
resp.tool_calls = []
|
||||
resp.prompt_tokens = 10
|
||||
resp.completion_tokens = 5
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": "Done"}
|
||||
return resp
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = f"call_{iteration}"
|
||||
tool_call.function.name = "test_tool"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": f"call_{iteration}"}]
|
||||
}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.node_exec_id = "exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test credentials",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=5,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Credentials were checked on each LLM call but not refreshed
|
||||
# If they expired mid-execution, we'd get auth errors
|
||||
assert credential_check_count == iteration
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credential_expiration_mid_execution(self):
|
||||
"""
|
||||
Test what happens when credentials expire during agent mode.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
iteration = 0
|
||||
|
||||
async def mock_llm_call_with_expiration(**kwargs):
|
||||
nonlocal iteration
|
||||
iteration += 1
|
||||
|
||||
if iteration >= 3:
|
||||
# Simulate credential expiration
|
||||
raise Exception("401 Unauthorized: API key expired")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = f"call_{iteration}"
|
||||
tool_call.function.name = "test_tool"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": f"call_{iteration}"}]
|
||||
}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.node_exec_id = "exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call_with_expiration), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test credentials",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=10,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Should have an error output
|
||||
assert "error" in outputs
|
||||
assert "expired" in outputs["error"].lower() or "unauthorized" in outputs["error"].lower()
|
||||
|
||||
|
||||
class TestToolSignatureCacheInvalidation:
|
||||
"""
|
||||
Tests for Failure Mode #13: Tool Signature Cache Invalidation
|
||||
|
||||
Tool signatures are created once at the start of run() but the
|
||||
graph could change during agent mode execution.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signatures_created_once_at_start(self):
|
||||
"""
|
||||
Test that tool signatures are only created once, not refreshed.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
signature_creation_count = 0
|
||||
iteration = 0
|
||||
|
||||
original_create_signatures = block._create_tool_node_signatures
|
||||
|
||||
async def counting_create_signatures(node_id):
|
||||
nonlocal signature_creation_count
|
||||
signature_creation_count += 1
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tool_v1",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal iteration
|
||||
iteration += 1
|
||||
|
||||
if iteration >= 3:
|
||||
resp = MagicMock()
|
||||
resp.response = "Done"
|
||||
resp.tool_calls = []
|
||||
resp.prompt_tokens = 10
|
||||
resp.completion_tokens = 5
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": "Done"}
|
||||
return resp
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = f"call_{iteration}"
|
||||
tool_call.function.name = "tool_v1"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": f"call_{iteration}"}]
|
||||
}
|
||||
return resp
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.node_exec_id = "exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", side_effect=counting_create_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test signatures",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=5,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Signatures were only created once, even though we had multiple iterations
|
||||
assert signature_creation_count == 1
|
||||
assert iteration >= 3 # We had multiple iterations
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stale_signatures_cause_tool_mismatch(self):
|
||||
"""
|
||||
Test scenario where tool definitions change but agent uses stale signatures.
|
||||
"""
|
||||
# This documents the potential issue:
|
||||
# 1. Agent starts with tool_v1
|
||||
# 2. User modifies graph, tool becomes tool_v2
|
||||
# 3. Agent still thinks tool_v1 exists
|
||||
# 4. LLM calls tool_v1, but it no longer exists
|
||||
|
||||
# Since signatures are created once at start and never refreshed,
|
||||
# any changes to the graph during execution won't be reflected.
|
||||
|
||||
# This is more of a documentation test - the actual fix would
|
||||
# require either:
|
||||
# a) Refreshing signatures periodically
|
||||
# b) Locking the graph during execution
|
||||
# c) Checking tool existence before each call
|
||||
pass
|
||||
|
||||
|
||||
class TestAgentModeConversationManagement:
|
||||
"""Tests for conversation management in agent mode."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_grows_with_iterations(self):
|
||||
"""
|
||||
Test that conversation history grows correctly with each iteration.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
iteration = 0
|
||||
conversation_lengths = []
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal iteration
|
||||
iteration += 1
|
||||
|
||||
# Record conversation length at each call
|
||||
prompt = kwargs.get("prompt", [])
|
||||
conversation_lengths.append(len(prompt))
|
||||
|
||||
if iteration >= 3:
|
||||
resp = MagicMock()
|
||||
resp.response = "Done"
|
||||
resp.tool_calls = []
|
||||
resp.prompt_tokens = 10
|
||||
resp.completion_tokens = 5
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": "Done"}
|
||||
return resp
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = f"call_{iteration}"
|
||||
tool_call.function.name = "test_tool"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": f"call_{iteration}"}]
|
||||
}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.node_exec_id = "exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {"result": "ok"}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test conversation",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=5,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Conversation should grow with each iteration
|
||||
# Each iteration adds: assistant message + tool response
|
||||
assert len(conversation_lengths) == 3
|
||||
for i in range(1, len(conversation_lengths)):
|
||||
assert conversation_lengths[i] > conversation_lengths[i-1], \
|
||||
f"Conversation should grow: {conversation_lengths}"
|
||||
@@ -0,0 +1,525 @@
|
||||
"""
|
||||
Tests for SmartDecisionMaker concurrency issues and race conditions.
|
||||
|
||||
Covers failure modes:
|
||||
1. Conversation History Race Condition
|
||||
4. Concurrent Execution State Sharing
|
||||
7. Race in Pending Tool Calls
|
||||
11. Race in Pending Tool Call Retrieval
|
||||
14. Concurrent State Sharing
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
from collections import Counter
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.smart_decision_maker import (
|
||||
SmartDecisionMakerBlock,
|
||||
get_pending_tool_calls,
|
||||
_create_tool_response,
|
||||
_get_tool_requests,
|
||||
_get_tool_responses,
|
||||
)
|
||||
|
||||
|
||||
class TestConversationHistoryRaceCondition:
|
||||
"""
|
||||
Tests for Failure Mode #1: Conversation History Race Condition
|
||||
|
||||
When multiple executions share conversation history, concurrent
|
||||
modifications can cause data loss or corruption.
|
||||
"""
|
||||
|
||||
def test_get_pending_tool_calls_with_concurrent_modification(self):
|
||||
"""
|
||||
Test that concurrent modifications to conversation history
|
||||
can cause inconsistent pending tool call counts.
|
||||
"""
|
||||
# Shared conversation history
|
||||
conversation_history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "toolu_1"},
|
||||
{"type": "tool_use", "id": "toolu_2"},
|
||||
{"type": "tool_use", "id": "toolu_3"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def reader_thread():
|
||||
"""Repeatedly read pending calls."""
|
||||
for _ in range(100):
|
||||
try:
|
||||
pending = get_pending_tool_calls(conversation_history)
|
||||
results.append(len(pending))
|
||||
except Exception as e:
|
||||
errors.append(str(e))
|
||||
|
||||
def writer_thread():
|
||||
"""Modify conversation while readers are active."""
|
||||
for i in range(50):
|
||||
# Add a tool response
|
||||
conversation_history.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "tool_result", "tool_use_id": f"toolu_{(i % 3) + 1}"}]
|
||||
})
|
||||
# Remove it
|
||||
if len(conversation_history) > 1:
|
||||
conversation_history.pop()
|
||||
|
||||
# Run concurrent readers and writers
|
||||
threads = []
|
||||
for _ in range(3):
|
||||
threads.append(threading.Thread(target=reader_thread))
|
||||
threads.append(threading.Thread(target=writer_thread))
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# The issue: results may be inconsistent due to race conditions
|
||||
# In a correct implementation, we'd expect consistent results
|
||||
# Document that this CAN produce inconsistent results
|
||||
assert len(results) > 0, "Should have some results"
|
||||
# Note: This test documents the race condition exists
|
||||
# When fixed, all results should be consistent
|
||||
|
||||
def test_prompt_list_mutation_race(self):
|
||||
"""
|
||||
Test that mutating prompt list during iteration can cause issues.
|
||||
"""
|
||||
prompt = []
|
||||
errors = []
|
||||
|
||||
def appender():
|
||||
for i in range(100):
|
||||
prompt.append({"role": "user", "content": f"msg_{i}"})
|
||||
|
||||
def extender():
|
||||
for i in range(100):
|
||||
prompt.extend([{"role": "assistant", "content": f"resp_{i}"}])
|
||||
|
||||
def reader():
|
||||
for _ in range(100):
|
||||
try:
|
||||
# Iterate while others modify
|
||||
_ = [p for p in prompt if p.get("role") == "user"]
|
||||
except RuntimeError as e:
|
||||
# "dictionary changed size during iteration" or similar
|
||||
errors.append(str(e))
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=appender),
|
||||
threading.Thread(target=extender),
|
||||
threading.Thread(target=reader),
|
||||
]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Document that race conditions can occur
|
||||
# In production, this could cause silent data corruption
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_block_runs_share_state(self):
|
||||
"""
|
||||
Test that concurrent runs on same block instance can share state incorrectly.
|
||||
|
||||
This is Failure Mode #14: Concurrent State Sharing
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Track all outputs from all runs
|
||||
all_outputs = []
|
||||
lock = threading.Lock()
|
||||
|
||||
async def run_block(run_id: int):
|
||||
"""Run the block with a unique run_id."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = f"Response for run {run_id}"
|
||||
mock_response.tool_calls = [] # No tool calls, just finish
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": f"Run {run_id}"}
|
||||
|
||||
mock_tool_signatures = []
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt=f"Prompt for run {run_id}",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id=f"graph-{run_id}",
|
||||
node_id=f"node-{run_id}",
|
||||
graph_exec_id=f"exec-{run_id}",
|
||||
node_exec_id=f"node-exec-{run_id}",
|
||||
user_id=f"user-{run_id}",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
with lock:
|
||||
all_outputs.append((run_id, outputs))
|
||||
|
||||
# Run multiple concurrent executions
|
||||
tasks = [run_block(i) for i in range(5)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Verify each run got its own response (no cross-contamination)
|
||||
for run_id, outputs in all_outputs:
|
||||
if "finished" in outputs:
|
||||
assert f"run {run_id}" in outputs["finished"].lower() or outputs["finished"] == f"Response for run {run_id}", \
|
||||
f"Run {run_id} may have received contaminated response: {outputs}"
|
||||
|
||||
|
||||
class TestPendingToolCallRace:
|
||||
"""
|
||||
Tests for Failure Mode #7 and #11: Race in Pending Tool Calls
|
||||
|
||||
The get_pending_tool_calls function can race with modifications
|
||||
to the conversation history, causing StopIteration or incorrect counts.
|
||||
"""
|
||||
|
||||
def test_pending_tool_calls_counter_accuracy(self):
|
||||
"""Test that pending tool call counting is accurate."""
|
||||
conversation = [
|
||||
# Assistant makes 3 tool calls
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1"},
|
||||
{"type": "tool_use", "id": "call_2"},
|
||||
{"type": "tool_use", "id": "call_3"},
|
||||
]
|
||||
},
|
||||
# User provides 1 response
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "tool_use_id": "call_1"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
pending = get_pending_tool_calls(conversation)
|
||||
|
||||
# Should have 2 pending (call_2, call_3)
|
||||
assert len(pending) == 2
|
||||
assert "call_2" in pending
|
||||
assert "call_3" in pending
|
||||
assert pending["call_2"] == 1
|
||||
assert pending["call_3"] == 1
|
||||
|
||||
def test_pending_tool_calls_duplicate_responses(self):
|
||||
"""Test handling of duplicate tool responses."""
|
||||
conversation = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
},
|
||||
# Duplicate responses for same call
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "call_1"}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "call_1"}]
|
||||
}
|
||||
]
|
||||
|
||||
pending = get_pending_tool_calls(conversation)
|
||||
|
||||
# call_1 has count -1 (1 request - 2 responses)
|
||||
# Should not be in pending (count <= 0)
|
||||
assert "call_1" not in pending or pending.get("call_1", 0) <= 0
|
||||
|
||||
def test_empty_conversation_no_pending(self):
|
||||
"""Test that empty conversation has no pending calls."""
|
||||
assert get_pending_tool_calls([]) == {}
|
||||
assert get_pending_tool_calls(None) == {}
|
||||
|
||||
def test_next_iter_on_empty_dict_raises_stop_iteration(self):
|
||||
"""
|
||||
Document the StopIteration vulnerability.
|
||||
|
||||
If pending_tool_calls becomes empty between the check and
|
||||
next(iter(...)), StopIteration is raised.
|
||||
"""
|
||||
pending = {}
|
||||
|
||||
# This is the pattern used in smart_decision_maker.py:1019
|
||||
# if pending_tool_calls and ...:
|
||||
# first_call_id = next(iter(pending_tool_calls.keys()))
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
next(iter(pending.keys()))
|
||||
|
||||
# Safe pattern should be:
|
||||
# first_call_id = next(iter(pending_tool_calls.keys()), None)
|
||||
safe_result = next(iter(pending.keys()), None)
|
||||
assert safe_result is None
|
||||
|
||||
|
||||
class TestToolRequestResponseParsing:
|
||||
"""Tests for tool request/response parsing edge cases."""
|
||||
|
||||
def test_get_tool_requests_openai_format(self):
|
||||
"""Test parsing OpenAI format tool requests."""
|
||||
entry = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_abc123"},
|
||||
{"id": "call_def456"},
|
||||
]
|
||||
}
|
||||
|
||||
requests = _get_tool_requests(entry)
|
||||
assert requests == ["call_abc123", "call_def456"]
|
||||
|
||||
def test_get_tool_requests_anthropic_format(self):
|
||||
"""Test parsing Anthropic format tool requests."""
|
||||
entry = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "toolu_abc123"},
|
||||
{"type": "text", "text": "Let me call this tool"},
|
||||
{"type": "tool_use", "id": "toolu_def456"},
|
||||
]
|
||||
}
|
||||
|
||||
requests = _get_tool_requests(entry)
|
||||
assert requests == ["toolu_abc123", "toolu_def456"]
|
||||
|
||||
def test_get_tool_requests_non_assistant_role(self):
|
||||
"""Non-assistant roles should return empty list."""
|
||||
entry = {"role": "user", "tool_calls": [{"id": "call_123"}]}
|
||||
assert _get_tool_requests(entry) == []
|
||||
|
||||
def test_get_tool_responses_openai_format(self):
|
||||
"""Test parsing OpenAI format tool responses."""
|
||||
entry = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc123",
|
||||
"content": "Result"
|
||||
}
|
||||
|
||||
responses = _get_tool_responses(entry)
|
||||
assert responses == ["call_abc123"]
|
||||
|
||||
def test_get_tool_responses_anthropic_format(self):
|
||||
"""Test parsing Anthropic format tool responses."""
|
||||
entry = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "tool_use_id": "toolu_abc123"},
|
||||
{"type": "tool_result", "tool_use_id": "toolu_def456"},
|
||||
]
|
||||
}
|
||||
|
||||
responses = _get_tool_responses(entry)
|
||||
assert responses == ["toolu_abc123", "toolu_def456"]
|
||||
|
||||
def test_get_tool_responses_mixed_content(self):
|
||||
"""Test parsing responses with mixed content types."""
|
||||
entry = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Here are the results"},
|
||||
{"type": "tool_result", "tool_use_id": "toolu_123"},
|
||||
{"type": "image", "url": "http://example.com/img.png"},
|
||||
]
|
||||
}
|
||||
|
||||
responses = _get_tool_responses(entry)
|
||||
assert responses == ["toolu_123"]
|
||||
|
||||
|
||||
class TestConcurrentToolSignatureCreation:
|
||||
"""Tests for concurrent tool signature creation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_signature_creation_same_node(self):
|
||||
"""
|
||||
Test that concurrent signature creation for same node
|
||||
doesn't cause issues.
|
||||
"""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
mock_node.block.input_schema.get_field_schema = Mock(
|
||||
return_value={"type": "string", "description": "test"}
|
||||
)
|
||||
|
||||
mock_links = [
|
||||
Mock(sink_name="field1", sink_id="test-node", source_id="source"),
|
||||
Mock(sink_name="field2", sink_id="test-node", source_id="source"),
|
||||
]
|
||||
|
||||
# Run multiple concurrent signature creations
|
||||
tasks = [
|
||||
block._create_block_function_signature(mock_node, mock_links)
|
||||
for _ in range(10)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be identical
|
||||
first = results[0]
|
||||
for i, result in enumerate(results[1:], 1):
|
||||
assert result["function"]["name"] == first["function"]["name"], \
|
||||
f"Result {i} has different name"
|
||||
assert set(result["function"]["parameters"]["properties"].keys()) == \
|
||||
set(first["function"]["parameters"]["properties"].keys()), \
|
||||
f"Result {i} has different properties"
|
||||
|
||||
|
||||
class TestThreadSafetyOfCleanup:
|
||||
"""Tests for thread safety of cleanup function."""
|
||||
|
||||
def test_cleanup_is_thread_safe(self):
|
||||
"""
|
||||
Test that cleanup function is thread-safe.
|
||||
|
||||
Since it's a pure function with no shared state, it should be safe.
|
||||
"""
|
||||
results = {}
|
||||
lock = threading.Lock()
|
||||
|
||||
test_inputs = [
|
||||
"Max Keyword Difficulty",
|
||||
"Search Volume (Monthly)",
|
||||
"CPC ($)",
|
||||
"Target URL",
|
||||
]
|
||||
|
||||
def worker(input_str: str, thread_id: int):
|
||||
for _ in range(100):
|
||||
result = SmartDecisionMakerBlock.cleanup(input_str)
|
||||
with lock:
|
||||
key = f"{thread_id}_{input_str}"
|
||||
if key not in results:
|
||||
results[key] = set()
|
||||
results[key].add(result)
|
||||
|
||||
threads = []
|
||||
for i, input_str in enumerate(test_inputs):
|
||||
for j in range(3):
|
||||
t = threading.Thread(target=worker, args=(input_str, i * 3 + j))
|
||||
threads.append(t)
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Each input should produce exactly one unique output
|
||||
for key, values in results.items():
|
||||
assert len(values) == 1, f"Non-deterministic cleanup for {key}: {values}"
|
||||
|
||||
|
||||
class TestAsyncConcurrencyPatterns:
|
||||
"""Tests for async concurrency patterns in the block."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_async_runs_isolation(self):
|
||||
"""
|
||||
Test that multiple async runs are properly isolated.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
run_count = 5
|
||||
results = []
|
||||
|
||||
async def single_run(run_id: int):
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = f"Unique response {run_id}"
|
||||
mock_response.tool_calls = []
|
||||
mock_response.prompt_tokens = 10
|
||||
mock_response.completion_tokens = 5
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": f"Run {run_id}"}
|
||||
|
||||
# Add small random delay to increase chance of interleaving
|
||||
await asyncio.sleep(0.001 * (run_id % 3))
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=[]):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt=f"Prompt {run_id}",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id=f"g{run_id}",
|
||||
node_id=f"n{run_id}",
|
||||
graph_exec_id=f"e{run_id}",
|
||||
node_exec_id=f"ne{run_id}",
|
||||
user_id=f"u{run_id}",
|
||||
graph_version=1,
|
||||
execution_context=ExecutionContext(safe_mode=False),
|
||||
execution_processor=MagicMock(),
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
return run_id, outputs
|
||||
|
||||
# Run all concurrently
|
||||
tasks = [single_run(i) for i in range(run_count)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify isolation
|
||||
for run_id, outputs in results:
|
||||
if "finished" in outputs:
|
||||
assert str(run_id) in outputs["finished"], \
|
||||
f"Run {run_id} got wrong response: {outputs['finished']}"
|
||||
@@ -0,0 +1,667 @@
|
||||
"""
|
||||
Tests for SmartDecisionMaker conversation handling and corruption scenarios.
|
||||
|
||||
Covers failure modes:
|
||||
6. Conversation Corruption in Error Paths
|
||||
And related conversation management issues.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.smart_decision_maker import (
|
||||
SmartDecisionMakerBlock,
|
||||
get_pending_tool_calls,
|
||||
_create_tool_response,
|
||||
_combine_tool_responses,
|
||||
_convert_raw_response_to_dict,
|
||||
_get_tool_requests,
|
||||
_get_tool_responses,
|
||||
)
|
||||
|
||||
|
||||
class TestConversationCorruptionInErrorPaths:
|
||||
"""
|
||||
Tests for Failure Mode #6: Conversation Corruption in Error Paths
|
||||
|
||||
When there's a logic error (orphaned tool output), the code appends
|
||||
it as a "user" message instead of proper tool response format,
|
||||
violating LLM conversation structure.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orphaned_tool_output_creates_user_message(self):
|
||||
"""
|
||||
Test that orphaned tool output (no pending calls) creates wrong message type.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Response with no tool calls
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = "No tools needed"
|
||||
mock_response.tool_calls = []
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": "No tools needed"}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=[]):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
# Orphaned tool output - no pending calls but we have output
|
||||
last_tool_output={"result": "orphaned data"},
|
||||
conversation_history=[], # Empty - no pending calls
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Check the conversation for the orphaned output handling
|
||||
# The orphaned output is logged as error but may be added as user message
|
||||
# This is the BUG: should not add orphaned outputs to conversation
|
||||
|
||||
def test_create_tool_response_anthropic_format(self):
|
||||
"""Test that Anthropic format tool responses are created correctly."""
|
||||
response = _create_tool_response(
|
||||
"toolu_abc123",
|
||||
{"result": "success"}
|
||||
)
|
||||
|
||||
assert response["role"] == "user"
|
||||
assert response["type"] == "message"
|
||||
assert isinstance(response["content"], list)
|
||||
assert response["content"][0]["type"] == "tool_result"
|
||||
assert response["content"][0]["tool_use_id"] == "toolu_abc123"
|
||||
|
||||
def test_create_tool_response_openai_format(self):
|
||||
"""Test that OpenAI format tool responses are created correctly."""
|
||||
response = _create_tool_response(
|
||||
"call_abc123",
|
||||
{"result": "success"}
|
||||
)
|
||||
|
||||
assert response["role"] == "tool"
|
||||
assert response["tool_call_id"] == "call_abc123"
|
||||
assert "content" in response
|
||||
|
||||
def test_tool_response_with_string_content(self):
|
||||
"""Test tool response creation with string content."""
|
||||
response = _create_tool_response(
|
||||
"call_123",
|
||||
"Simple string result"
|
||||
)
|
||||
|
||||
assert response["content"] == "Simple string result"
|
||||
|
||||
def test_tool_response_with_complex_content(self):
|
||||
"""Test tool response creation with complex JSON content."""
|
||||
complex_data = {
|
||||
"nested": {"key": "value"},
|
||||
"list": [1, 2, 3],
|
||||
"null": None,
|
||||
}
|
||||
|
||||
response = _create_tool_response("call_123", complex_data)
|
||||
|
||||
# Content should be JSON string
|
||||
parsed = json.loads(response["content"])
|
||||
assert parsed == complex_data
|
||||
|
||||
|
||||
class TestCombineToolResponses:
|
||||
"""Tests for combining multiple tool responses."""
|
||||
|
||||
def test_combine_single_response_unchanged(self):
|
||||
"""Test that single response is returned unchanged."""
|
||||
responses = [
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "123"}]
|
||||
}
|
||||
]
|
||||
|
||||
result = _combine_tool_responses(responses)
|
||||
assert result == responses
|
||||
|
||||
def test_combine_multiple_anthropic_responses(self):
|
||||
"""Test combining multiple Anthropic responses."""
|
||||
responses = [
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "123", "content": "a"}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "456", "content": "b"}]
|
||||
},
|
||||
]
|
||||
|
||||
result = _combine_tool_responses(responses)
|
||||
|
||||
# Should be combined into single message
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
assert len(result[0]["content"]) == 2
|
||||
|
||||
def test_combine_mixed_responses(self):
|
||||
"""Test combining mixed Anthropic and OpenAI responses."""
|
||||
responses = [
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "123"}]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_456",
|
||||
"content": "openai result"
|
||||
},
|
||||
]
|
||||
|
||||
result = _combine_tool_responses(responses)
|
||||
|
||||
# Anthropic response combined, OpenAI kept separate
|
||||
assert len(result) == 2
|
||||
|
||||
def test_combine_empty_list(self):
|
||||
"""Test combining empty list."""
|
||||
result = _combine_tool_responses([])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestConversationHistoryValidation:
|
||||
"""Tests for conversation history validation."""
|
||||
|
||||
def test_pending_tool_calls_basic(self):
|
||||
"""Test basic pending tool call counting."""
|
||||
history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1"},
|
||||
{"type": "tool_use", "id": "call_2"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
pending = get_pending_tool_calls(history)
|
||||
|
||||
assert len(pending) == 2
|
||||
assert "call_1" in pending
|
||||
assert "call_2" in pending
|
||||
|
||||
def test_pending_tool_calls_with_responses(self):
|
||||
"""Test pending calls after some responses."""
|
||||
history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1"},
|
||||
{"type": "tool_use", "id": "call_2"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "tool_use_id": "call_1"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
pending = get_pending_tool_calls(history)
|
||||
|
||||
assert len(pending) == 1
|
||||
assert "call_2" in pending
|
||||
assert "call_1" not in pending
|
||||
|
||||
def test_pending_tool_calls_all_responded(self):
|
||||
"""Test when all tool calls have responses."""
|
||||
history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "call_1"}]
|
||||
}
|
||||
]
|
||||
|
||||
pending = get_pending_tool_calls(history)
|
||||
|
||||
assert len(pending) == 0
|
||||
|
||||
def test_pending_tool_calls_openai_format(self):
|
||||
"""Test pending calls with OpenAI format."""
|
||||
history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_1"},
|
||||
{"id": "call_2"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "result"
|
||||
}
|
||||
]
|
||||
|
||||
pending = get_pending_tool_calls(history)
|
||||
|
||||
assert len(pending) == 1
|
||||
assert "call_2" in pending
|
||||
|
||||
|
||||
class TestConversationUpdateBehavior:
|
||||
"""Tests for conversation update behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_includes_assistant_response(self):
|
||||
"""Test that assistant responses are added to conversation."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = "Final answer"
|
||||
mock_response.tool_calls = []
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": "Final answer"}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=[]):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# No conversations output when no tool calls (just finished)
|
||||
assert "finished" in outputs
|
||||
assert outputs["finished"] == "Final answer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_with_tool_calls(self):
|
||||
"""Test that tool calls are properly added to conversation."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"param": "value"})
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = None
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = "I'll use the test tool"
|
||||
mock_response.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_1"}]
|
||||
}
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {"param": "param"},
|
||||
"parameters": {
|
||||
"properties": {"param": {"type": "string"}},
|
||||
"required": ["param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Should have conversations output
|
||||
assert "conversations" in outputs
|
||||
|
||||
# Conversation should include the assistant message
|
||||
conversations = outputs["conversations"]
|
||||
has_assistant = any(
|
||||
msg.get("role") == "assistant"
|
||||
for msg in conversations
|
||||
)
|
||||
assert has_assistant
|
||||
|
||||
|
||||
class TestConversationHistoryPreservation:
|
||||
"""Tests for conversation history preservation across calls."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_existing_history_preserved(self):
|
||||
"""Test that existing conversation history is preserved."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
existing_history = [
|
||||
{"role": "user", "content": "Previous message 1"},
|
||||
{"role": "assistant", "content": "Previous response 1"},
|
||||
{"role": "user", "content": "Previous message 2"},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = "New response"
|
||||
mock_response.tool_calls = []
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": "New response"}
|
||||
|
||||
captured_prompt = []
|
||||
|
||||
async def capture_llm_call(**kwargs):
|
||||
captured_prompt.extend(kwargs.get("prompt", []))
|
||||
return mock_response
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=capture_llm_call):
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=[]):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="New message",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
conversation_history=existing_history,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
pass
|
||||
|
||||
# Existing history should be in the prompt
|
||||
assert len(captured_prompt) >= len(existing_history)
|
||||
|
||||
|
||||
class TestRawResponseConversion:
|
||||
"""Tests for raw response to dict conversion."""
|
||||
|
||||
def test_string_response(self):
|
||||
"""Test conversion of string response."""
|
||||
result = _convert_raw_response_to_dict("Hello world")
|
||||
|
||||
assert result == {"role": "assistant", "content": "Hello world"}
|
||||
|
||||
def test_dict_response(self):
|
||||
"""Test that dict response is passed through."""
|
||||
original = {"role": "assistant", "content": "test", "extra": "data"}
|
||||
result = _convert_raw_response_to_dict(original)
|
||||
|
||||
assert result == original
|
||||
|
||||
def test_object_response(self):
|
||||
"""Test conversion of object response."""
|
||||
mock_obj = MagicMock()
|
||||
|
||||
with patch("backend.blocks.smart_decision_maker.json.to_dict") as mock_to_dict:
|
||||
mock_to_dict.return_value = {"role": "assistant", "content": "converted"}
|
||||
result = _convert_raw_response_to_dict(mock_obj)
|
||||
|
||||
mock_to_dict.assert_called_once_with(mock_obj)
|
||||
assert result["role"] == "assistant"
|
||||
|
||||
|
||||
class TestConversationMessageStructure:
|
||||
"""Tests for correct conversation message structure."""
|
||||
|
||||
def test_system_message_not_duplicated(self):
|
||||
"""Test that system messages are not duplicated."""
|
||||
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
|
||||
|
||||
# Existing system message in history
|
||||
existing_history = [
|
||||
{"role": "system", "content": f"{MAIN_OBJECTIVE_PREFIX}Existing system prompt"},
|
||||
]
|
||||
|
||||
# The block should not add another system message
|
||||
# This is verified by checking the prompt passed to LLM
|
||||
|
||||
def test_user_message_not_duplicated(self):
|
||||
"""Test that user messages are not duplicated."""
|
||||
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
|
||||
|
||||
# Existing user message with MAIN_OBJECTIVE_PREFIX
|
||||
existing_history = [
|
||||
{"role": "user", "content": f"{MAIN_OBJECTIVE_PREFIX}Existing user prompt"},
|
||||
]
|
||||
|
||||
# The block should not add another user message with same prefix
|
||||
# This is verified by checking the prompt passed to LLM
|
||||
|
||||
def test_tool_response_after_tool_call(self):
|
||||
"""Test that tool responses come after tool calls."""
|
||||
# Valid conversation structure
|
||||
valid_history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "call_1"}]
|
||||
}
|
||||
]
|
||||
|
||||
# This should be valid - tool result follows tool use
|
||||
pending = get_pending_tool_calls(valid_history)
|
||||
assert len(pending) == 0
|
||||
|
||||
def test_orphaned_tool_response_detected(self):
|
||||
"""Test detection of orphaned tool responses."""
|
||||
# Invalid: tool response without matching tool call
|
||||
invalid_history = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "orphan_call"}]
|
||||
}
|
||||
]
|
||||
|
||||
pending = get_pending_tool_calls(invalid_history)
|
||||
|
||||
# Orphan response creates negative count
|
||||
# Should have count -1 for orphan_call
|
||||
# But it's filtered out (count <= 0)
|
||||
assert "orphan_call" not in pending
|
||||
|
||||
|
||||
class TestValidationErrorInConversation:
|
||||
"""Tests for validation error handling in conversation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_error_feedback_not_in_final_conversation(self):
|
||||
"""
|
||||
Test that validation error feedback is not in final conversation output.
|
||||
|
||||
When retrying due to validation errors, the error feedback should
|
||||
only be used for the retry prompt, not persisted in final conversation.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
if call_count == 1:
|
||||
# First call: invalid tool call
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"wrong": "param"})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [mock_tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": None}
|
||||
return resp
|
||||
else:
|
||||
# Second call: finish
|
||||
resp = MagicMock()
|
||||
resp.response = "Done"
|
||||
resp.tool_calls = []
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": "Done"}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {"correct": "correct"},
|
||||
"parameters": {
|
||||
"properties": {"correct": {"type": "string"}},
|
||||
"required": ["correct"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call):
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
retry=3,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Should have finished successfully after retry
|
||||
assert "finished" in outputs
|
||||
|
||||
# Note: In traditional mode (agent_mode_max_iterations=0),
|
||||
# conversations are only output when there are tool calls
|
||||
# After the retry succeeds with no tool calls, we just get "finished"
|
||||
@@ -0,0 +1,671 @@
|
||||
"""
|
||||
Tests for SmartDecisionMaker data integrity failure modes.
|
||||
|
||||
Covers failure modes:
|
||||
6. Conversation Corruption in Error Paths
|
||||
7. Field Name Collision Not Detected
|
||||
8. No Type Validation in Dynamic Field Merging
|
||||
9. Unhandled Field Mapping Keys
|
||||
16. Silent Value Loss in Output Routing
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
|
||||
|
||||
class TestFieldNameCollisionDetection:
|
||||
"""
|
||||
Tests for Failure Mode #7: Field Name Collision Not Detected
|
||||
|
||||
When multiple field names sanitize to the same value,
|
||||
the last one silently overwrites previous mappings.
|
||||
"""
|
||||
|
||||
def test_different_names_same_sanitized_result(self):
|
||||
"""Test that different names can produce the same sanitized result."""
|
||||
cleanup = SmartDecisionMakerBlock.cleanup
|
||||
|
||||
# All these sanitize to "test_field"
|
||||
variants = [
|
||||
"test_field",
|
||||
"Test Field",
|
||||
"test field",
|
||||
"TEST_FIELD",
|
||||
"Test_Field",
|
||||
"test-field", # Note: hyphen is preserved, this is different
|
||||
]
|
||||
|
||||
sanitized = [cleanup(v) for v in variants]
|
||||
|
||||
# Count unique sanitized values
|
||||
unique = set(sanitized)
|
||||
# Most should collide (except hyphenated one)
|
||||
assert len(unique) < len(variants), \
|
||||
f"Expected collisions, got {unique}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collision_last_one_wins(self):
|
||||
"""Test that in case of collision, the last field mapping wins."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
mock_node.block.input_schema.get_field_schema = Mock(
|
||||
return_value={"type": "string", "description": "test"}
|
||||
)
|
||||
|
||||
# Two fields that sanitize to the same name
|
||||
mock_links = [
|
||||
Mock(sink_name="Test Field", sink_id="test-node", source_id="source"),
|
||||
Mock(sink_name="test field", sink_id="test-node", source_id="source"),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
field_mapping = signature["function"]["_field_mapping"]
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
|
||||
# Only one property (collision)
|
||||
assert len(properties) == 1
|
||||
assert "test_field" in properties
|
||||
|
||||
# The mapping has only the last one
|
||||
# This is the BUG: first field's mapping is lost
|
||||
assert field_mapping["test_field"] in ["Test Field", "test field"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collision_causes_data_loss(self):
|
||||
"""
|
||||
Test that field collision can cause actual data loss.
|
||||
|
||||
Scenario:
|
||||
1. Two fields "Field A" and "field a" both map to "field_a"
|
||||
2. LLM provides value for "field_a"
|
||||
3. Only one original field gets the value
|
||||
4. The other field's expected input is lost
|
||||
"""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Simulate processing tool calls with collision
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"field_a": "value_for_both" # LLM uses sanitized name
|
||||
})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
# Tool definition with collision in field mapping
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"field_a": {"type": "string"},
|
||||
},
|
||||
"required": ["field_a"],
|
||||
},
|
||||
"_sink_node_id": "sink",
|
||||
# BUG: Only one original name is stored
|
||||
# "Field A" was overwritten by "field a"
|
||||
"_field_mapping": {"field_a": "field a"},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
assert len(processed) == 1
|
||||
input_data = processed[0].input_data
|
||||
|
||||
# Only "field a" gets the value
|
||||
assert "field a" in input_data
|
||||
assert input_data["field a"] == "value_for_both"
|
||||
|
||||
# "Field A" is completely lost!
|
||||
assert "Field A" not in input_data
|
||||
|
||||
|
||||
class TestUnhandledFieldMappingKeys:
|
||||
"""
|
||||
Tests for Failure Mode #9: Unhandled Field Mapping Keys
|
||||
|
||||
When field_mapping is missing a key, the code falls back to
|
||||
the clean name, which may not be what the sink expects.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_field_mapping_falls_back_to_clean_name(self):
|
||||
"""Test that missing field mapping falls back to clean name."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"unmapped_field": "value"
|
||||
})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
# Tool definition with incomplete field mapping
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"unmapped_field": {"type": "string"},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {}, # Empty! No mapping for unmapped_field
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
assert len(processed) == 1
|
||||
input_data = processed[0].input_data
|
||||
|
||||
# Falls back to clean name (which IS the key since it's already clean)
|
||||
assert "unmapped_field" in input_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_field_mapping(self):
|
||||
"""Test behavior with partial field mapping."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"mapped_field": "value1",
|
||||
"unmapped_field": "value2",
|
||||
})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"mapped_field": {"type": "string"},
|
||||
"unmapped_field": {"type": "string"},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
"_sink_node_id": "sink",
|
||||
# Only one field is mapped
|
||||
"_field_mapping": {
|
||||
"mapped_field": "Original Mapped Field",
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
assert len(processed) == 1
|
||||
input_data = processed[0].input_data
|
||||
|
||||
# Mapped field uses original name
|
||||
assert "Original Mapped Field" in input_data
|
||||
# Unmapped field uses clean name (fallback)
|
||||
assert "unmapped_field" in input_data
|
||||
|
||||
|
||||
class TestSilentValueLossInRouting:
|
||||
"""
|
||||
Tests for Failure Mode #16: Silent Value Loss in Output Routing
|
||||
|
||||
When routing fails in parse_execution_output, it returns None
|
||||
without any logging or indication of why it failed.
|
||||
"""
|
||||
|
||||
def test_routing_mismatch_returns_none_silently(self):
|
||||
"""Test that routing mismatch returns None without error."""
|
||||
from backend.data.dynamic_fields import parse_execution_output
|
||||
|
||||
output_item = ("tools_^_node-123_~_sanitized_name", "important_value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="Original Name", # Doesn't match sanitized_name
|
||||
)
|
||||
|
||||
# Silently returns None
|
||||
assert result is None
|
||||
# No way to distinguish "value is None" from "routing failed"
|
||||
|
||||
def test_wrong_node_id_returns_none(self):
|
||||
"""Test that wrong node ID returns None."""
|
||||
from backend.data.dynamic_fields import parse_execution_output
|
||||
|
||||
output_item = ("tools_^_node-123_~_field", "value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="different-node", # Wrong node
|
||||
sink_pin_name="field",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_wrong_selector_returns_none(self):
|
||||
"""Test that wrong selector returns None."""
|
||||
from backend.data.dynamic_fields import parse_execution_output
|
||||
|
||||
output_item = ("tools_^_node-123_~_field", "value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="different_selector", # Wrong selector
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="field",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_cannot_distinguish_none_value_from_routing_failure(self):
|
||||
"""
|
||||
Test that None as actual value is indistinguishable from routing failure.
|
||||
"""
|
||||
from backend.data.dynamic_fields import parse_execution_output
|
||||
|
||||
# Case 1: Actual None value
|
||||
output_with_none = ("field_name", None)
|
||||
result1 = parse_execution_output(
|
||||
output_with_none,
|
||||
link_output_selector="field_name",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
# Case 2: Routing failure
|
||||
output_mismatched = ("field_name", "value")
|
||||
result2 = parse_execution_output(
|
||||
output_mismatched,
|
||||
link_output_selector="different_field",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
# Both return None - cannot distinguish!
|
||||
assert result1 is None
|
||||
assert result2 is None
|
||||
|
||||
|
||||
class TestProcessToolCallsInputData:
|
||||
"""Tests for _process_tool_calls input data generation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_expected_args_included(self):
|
||||
"""Test that all expected arguments are included in input_data."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"provided_field": "value",
|
||||
# optional_field not provided
|
||||
})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"provided_field": {"type": "string"},
|
||||
"optional_field": {"type": "string"},
|
||||
},
|
||||
"required": ["provided_field"],
|
||||
},
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {
|
||||
"provided_field": "Provided Field",
|
||||
"optional_field": "Optional Field",
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
assert len(processed) == 1
|
||||
input_data = processed[0].input_data
|
||||
|
||||
# Both fields should be in input_data
|
||||
assert "Provided Field" in input_data
|
||||
assert "Optional Field" in input_data
|
||||
|
||||
# Provided has value, optional is None
|
||||
assert input_data["Provided Field"] == "value"
|
||||
assert input_data["Optional Field"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_args_from_llm_ignored(self):
|
||||
"""Test that extra arguments from LLM not in schema are ignored."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"expected_field": "value",
|
||||
"unexpected_field": "should_be_ignored",
|
||||
})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"expected_field": {"type": "string"},
|
||||
# unexpected_field not in schema
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {"expected_field": "Expected Field"},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
assert len(processed) == 1
|
||||
input_data = processed[0].input_data
|
||||
|
||||
# Only expected field should be in input_data
|
||||
assert "Expected Field" in input_data
|
||||
assert "unexpected_field" not in input_data
|
||||
assert "Unexpected Field" not in input_data
|
||||
|
||||
|
||||
class TestToolCallMatching:
|
||||
"""Tests for tool call matching logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_not_found_skipped(self):
|
||||
"""Test that tool calls for unknown tools are skipped."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "unknown_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "known_tool", # Different name
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
"_sink_node_id": "sink",
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
# Unknown tool is skipped (not processed)
|
||||
assert len(processed) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_tool_fallback(self):
|
||||
"""Test fallback when only one tool exists but name doesn't match."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "wrong_name"
|
||||
mock_tool_call.function.arguments = json.dumps({"field": "value"})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
# Only one tool defined
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "only_tool",
|
||||
"parameters": {
|
||||
"properties": {"field": {"type": "string"}},
|
||||
"required": [],
|
||||
},
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {"field": "Field"},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
# Falls back to the only tool
|
||||
assert len(processed) == 1
|
||||
assert processed[0].input_data["Field"] == "value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_tool_calls_processed(self):
|
||||
"""Test that multiple tool calls are all processed."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call_1 = Mock()
|
||||
mock_tool_call_1.function.name = "tool_a"
|
||||
mock_tool_call_1.function.arguments = json.dumps({"a": "1"})
|
||||
|
||||
mock_tool_call_2 = Mock()
|
||||
mock_tool_call_2.function.name = "tool_b"
|
||||
mock_tool_call_2.function.arguments = json.dumps({"b": "2"})
|
||||
|
||||
mock_response.tool_calls = [mock_tool_call_1, mock_tool_call_2]
|
||||
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tool_a",
|
||||
"parameters": {
|
||||
"properties": {"a": {"type": "string"}},
|
||||
"required": [],
|
||||
},
|
||||
"_sink_node_id": "sink_a",
|
||||
"_field_mapping": {"a": "A"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tool_b",
|
||||
"parameters": {
|
||||
"properties": {"b": {"type": "string"}},
|
||||
"required": [],
|
||||
},
|
||||
"_sink_node_id": "sink_b",
|
||||
"_field_mapping": {"b": "B"},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
assert len(processed) == 2
|
||||
assert processed[0].input_data["A"] == "1"
|
||||
assert processed[1].input_data["B"] == "2"
|
||||
|
||||
|
||||
class TestOutputEmitKeyGeneration:
|
||||
"""Tests for output emit key generation consistency."""
|
||||
|
||||
def test_emit_key_uses_sanitized_field_name(self):
|
||||
"""Test that emit keys use sanitized field names."""
|
||||
cleanup = SmartDecisionMakerBlock.cleanup
|
||||
|
||||
original_field = "Max Keyword Difficulty"
|
||||
sink_node_id = "node-123"
|
||||
|
||||
sanitized = cleanup(original_field)
|
||||
emit_key = f"tools_^_{sink_node_id}_~_{sanitized}"
|
||||
|
||||
assert emit_key == "tools_^_node-123_~_max_keyword_difficulty"
|
||||
|
||||
def test_emit_key_format_consistent(self):
|
||||
"""Test that emit key format is consistent."""
|
||||
test_cases = [
|
||||
("field", "node", "tools_^_node_~_field"),
|
||||
("Field Name", "node-123", "tools_^_node-123_~_field_name"),
|
||||
("CPC ($)", "abc", "tools_^_abc_~_cpc____"),
|
||||
]
|
||||
|
||||
cleanup = SmartDecisionMakerBlock.cleanup
|
||||
|
||||
for original_field, node_id, expected in test_cases:
|
||||
sanitized = cleanup(original_field)
|
||||
emit_key = f"tools_^_{node_id}_~_{sanitized}"
|
||||
assert emit_key == expected, \
|
||||
f"Expected {expected}, got {emit_key}"
|
||||
|
||||
def test_emit_key_sanitization_idempotent(self):
|
||||
"""Test that sanitizing an already sanitized name gives same result."""
|
||||
cleanup = SmartDecisionMakerBlock.cleanup
|
||||
|
||||
original = "Test Field Name"
|
||||
first_clean = cleanup(original)
|
||||
second_clean = cleanup(first_clean)
|
||||
|
||||
assert first_clean == second_clean
|
||||
|
||||
|
||||
class TestToolFunctionMetadata:
|
||||
"""Tests for tool function metadata handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sink_node_id_preserved(self):
|
||||
"""Test that _sink_node_id is preserved in tool function."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "specific-node-id"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
mock_node.block.input_schema.get_field_schema = Mock(
|
||||
return_value={"type": "string", "description": "test"}
|
||||
)
|
||||
|
||||
mock_links = [
|
||||
Mock(sink_name="field", sink_id="specific-node-id", source_id="source"),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
assert signature["function"]["_sink_node_id"] == "specific-node-id"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_field_mapping_preserved(self):
|
||||
"""Test that _field_mapping is preserved in tool function."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
mock_node.block.input_schema.get_field_schema = Mock(
|
||||
return_value={"type": "string", "description": "test"}
|
||||
)
|
||||
|
||||
mock_links = [
|
||||
Mock(sink_name="Original Field Name", sink_id="test-node", source_id="source"),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
field_mapping = signature["function"]["_field_mapping"]
|
||||
assert "original_field_name" in field_mapping
|
||||
assert field_mapping["original_field_name"] == "Original Field Name"
|
||||
|
||||
|
||||
class TestRequiredFieldsHandling:
|
||||
"""Tests for required fields handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_required_fields_use_sanitized_names(self):
|
||||
"""Test that required fields array uses sanitized names."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={
|
||||
"properties": {},
|
||||
"required": ["Required Field", "Another Required"],
|
||||
}
|
||||
)
|
||||
mock_node.block.input_schema.get_field_schema = Mock(
|
||||
return_value={"type": "string", "description": "test"}
|
||||
)
|
||||
|
||||
mock_links = [
|
||||
Mock(sink_name="Required Field", sink_id="test-node", source_id="source"),
|
||||
Mock(sink_name="Another Required", sink_id="test-node", source_id="source"),
|
||||
Mock(sink_name="Optional Field", sink_id="test-node", source_id="source"),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
required = signature["function"]["parameters"]["required"]
|
||||
|
||||
# Should use sanitized names
|
||||
assert "required_field" in required
|
||||
assert "another_required" in required
|
||||
|
||||
# Original names should NOT be in required
|
||||
assert "Required Field" not in required
|
||||
assert "Another Required" not in required
|
||||
|
||||
# Optional field should not be required
|
||||
assert "optional_field" not in required
|
||||
assert "Optional Field" not in required
|
||||
@@ -373,7 +373,7 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
input_data = block.input_schema(
|
||||
prompt="Create a user dictionary",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
agent_mode_max_iterations=0, # Use traditional mode to test output yielding
|
||||
)
|
||||
|
||||
@@ -594,7 +594,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
input_data = block.input_schema(
|
||||
prompt="Test prompt",
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||
model=llm.LlmModel.GPT4O,
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
retry=3, # Allow retries
|
||||
agent_mode_max_iterations=1,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,871 @@
|
||||
"""
|
||||
Tests for SmartDecisionMaker error handling failure modes.
|
||||
|
||||
Covers failure modes:
|
||||
3. JSON Deserialization Without Exception Handling
|
||||
4. Database Transaction Inconsistency
|
||||
5. Missing Null Checks After Database Calls
|
||||
15. Error Message Context Loss
|
||||
17. No Validation of Dynamic Field Paths
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.smart_decision_maker import (
|
||||
SmartDecisionMakerBlock,
|
||||
_convert_raw_response_to_dict,
|
||||
_create_tool_response,
|
||||
)
|
||||
|
||||
|
||||
class TestJSONDeserializationErrors:
|
||||
"""
|
||||
Tests for Failure Mode #3: JSON Deserialization Without Exception Handling
|
||||
|
||||
When LLM returns malformed JSON in tool call arguments, the json.loads()
|
||||
call fails without proper error handling.
|
||||
"""
|
||||
|
||||
def test_malformed_json_single_quotes(self):
|
||||
"""
|
||||
Test that single quotes in JSON cause parsing failure.
|
||||
|
||||
LLMs sometimes return {'key': 'value'} instead of {"key": "value"}
|
||||
"""
|
||||
malformed = "{'key': 'value'}"
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
json.loads(malformed)
|
||||
|
||||
def test_malformed_json_trailing_comma(self):
|
||||
"""
|
||||
Test that trailing commas cause parsing failure.
|
||||
"""
|
||||
malformed = '{"key": "value",}'
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
json.loads(malformed)
|
||||
|
||||
def test_malformed_json_unquoted_keys(self):
|
||||
"""
|
||||
Test that unquoted keys cause parsing failure.
|
||||
"""
|
||||
malformed = '{key: "value"}'
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
json.loads(malformed)
|
||||
|
||||
def test_malformed_json_python_none(self):
|
||||
"""
|
||||
Test that Python None instead of null causes failure.
|
||||
"""
|
||||
malformed = '{"key": None}'
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
json.loads(malformed)
|
||||
|
||||
def test_malformed_json_python_true_false(self):
|
||||
"""
|
||||
Test that Python True/False instead of true/false causes failure.
|
||||
"""
|
||||
malformed_true = '{"key": True}'
|
||||
malformed_false = '{"key": False}'
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
json.loads(malformed_true)
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
json.loads(malformed_false)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_returns_malformed_json_crashes_block(self):
|
||||
"""
|
||||
Test that malformed JSON from LLM causes block to crash.
|
||||
|
||||
BUG: The json.loads() at line 625, 706, 1124 can throw JSONDecodeError
|
||||
which is not caught, causing the entire block to fail.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Create response with malformed JSON
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = "{'malformed': 'json'}" # Single quotes!
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = None
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {"malformed": {"type": "string"}}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm:
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
with patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
# BUG: This should raise JSONDecodeError
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class TestDatabaseTransactionInconsistency:
|
||||
"""
|
||||
Tests for Failure Mode #4: Database Transaction Inconsistency
|
||||
|
||||
When multiple database operations are performed in sequence,
|
||||
a failure partway through leaves the database in an inconsistent state.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_input_insertion_on_failure(self):
|
||||
"""
|
||||
Test that partial failures during multi-input insertion
|
||||
leave database in inconsistent state.
|
||||
"""
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# Track which inputs were inserted
|
||||
inserted_inputs = []
|
||||
call_count = 0
|
||||
|
||||
async def failing_upsert(node_id, graph_exec_id, input_name, input_data):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
# Fail on the third input
|
||||
if call_count == 3:
|
||||
raise Exception("Database connection lost!")
|
||||
|
||||
inserted_inputs.append(input_name)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.node_exec_id = "exec-id"
|
||||
return mock_result, {input_name: input_data}
|
||||
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = "call_1"
|
||||
mock_tool_call.function.name = "multi_input_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"input1": "value1",
|
||||
"input2": "value2",
|
||||
"input3": "value3", # This one will fail
|
||||
"input4": "value4",
|
||||
"input5": "value5",
|
||||
})
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = None
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
}
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multi_input_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {
|
||||
"input1": "input1",
|
||||
"input2": "input2",
|
||||
"input3": "input3",
|
||||
"input4": "input4",
|
||||
"input5": "input5",
|
||||
},
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"input1": {"type": "string"},
|
||||
"input2": {"type": "string"},
|
||||
"input3": {"type": "string"},
|
||||
"input4": {"type": "string"},
|
||||
"input5": {"type": "string"},
|
||||
},
|
||||
"required": ["input1", "input2", "input3", "input4", "input5"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
mock_db_client.upsert_execution_input.side_effect = failing_upsert
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm, \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=1,
|
||||
)
|
||||
|
||||
# The block should fail, but some inputs were already inserted
|
||||
outputs = {}
|
||||
try:
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
except Exception:
|
||||
pass # Expected
|
||||
|
||||
# BUG: Some inputs were inserted before failure
|
||||
# Database is now in inconsistent state
|
||||
assert len(inserted_inputs) == 2, \
|
||||
f"Expected 2 inserted before failure, got {inserted_inputs}"
|
||||
assert "input1" in inserted_inputs
|
||||
assert "input2" in inserted_inputs
|
||||
# input3, input4, input5 were never inserted
|
||||
|
||||
|
||||
class TestMissingNullChecks:
|
||||
"""
|
||||
Tests for Failure Mode #5: Missing Null Checks After Database Calls
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_node_returns_none(self):
|
||||
"""
|
||||
Test handling when get_node returns None.
|
||||
"""
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = "call_1"
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"param": "value"})
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = None
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
}
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "nonexistent-node",
|
||||
"_field_mapping": {"param": "param"},
|
||||
"parameters": {
|
||||
"properties": {"param": {"type": "string"}},
|
||||
"required": ["param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_node.return_value = None # Node doesn't exist!
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", new_callable=AsyncMock) as mock_llm, \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=1,
|
||||
)
|
||||
|
||||
# Should raise ValueError for missing node
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_execution_outputs(self):
|
||||
"""
|
||||
Test handling when get_execution_outputs_by_node_exec_id returns empty.
|
||||
"""
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
if call_count > 1:
|
||||
resp = MagicMock()
|
||||
resp.response = "Done"
|
||||
resp.tool_calls = []
|
||||
resp.prompt_tokens = 10
|
||||
resp.completion_tokens = 5
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": "Done"}
|
||||
return resp
|
||||
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = "call_1"
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [mock_tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.node_exec_id = "exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (mock_exec_result, {})
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {} # Empty!
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
mock_execution_processor.on_node_execution = AsyncMock(return_value=MagicMock(error=None))
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=2,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Empty outputs should be handled gracefully
|
||||
# (uses "Tool executed successfully" as fallback)
|
||||
assert "finished" in outputs or "conversations" in outputs
|
||||
|
||||
|
||||
class TestErrorMessageContextLoss:
|
||||
"""
|
||||
Tests for Failure Mode #15: Error Message Context Loss
|
||||
|
||||
When exceptions are caught and converted to strings, important
|
||||
debugging information is lost.
|
||||
"""
|
||||
|
||||
def test_exception_to_string_loses_traceback(self):
|
||||
"""
|
||||
Test that converting exception to string loses traceback.
|
||||
"""
|
||||
try:
|
||||
def inner():
|
||||
raise ValueError("Inner error")
|
||||
|
||||
def outer():
|
||||
inner()
|
||||
|
||||
outer()
|
||||
except Exception as e:
|
||||
error_string = str(e)
|
||||
error_repr = repr(e)
|
||||
|
||||
# String representation loses call stack
|
||||
assert "inner" not in error_string
|
||||
assert "outer" not in error_string
|
||||
|
||||
# Even repr doesn't have full traceback
|
||||
assert "Traceback" not in error_repr
|
||||
|
||||
def test_tool_response_loses_exception_type(self):
|
||||
"""
|
||||
Test that _create_tool_response loses exception type information.
|
||||
"""
|
||||
original_error = ConnectionError("Database unreachable")
|
||||
tool_response = _create_tool_response(
|
||||
"call_123",
|
||||
f"Tool execution failed: {str(original_error)}"
|
||||
)
|
||||
|
||||
content = tool_response.get("content", "")
|
||||
|
||||
# Original exception type is lost
|
||||
assert "ConnectionError" not in content
|
||||
# Only the message remains
|
||||
assert "Database unreachable" in content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_mode_error_response_lacks_context(self):
|
||||
"""
|
||||
Test that agent mode error responses lack debugging context.
|
||||
"""
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = "call_1"
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({})
|
||||
|
||||
mock_response_1 = MagicMock()
|
||||
mock_response_1.response = None
|
||||
mock_response_1.tool_calls = [mock_tool_call]
|
||||
mock_response_1.prompt_tokens = 50
|
||||
mock_response_1.completion_tokens = 25
|
||||
mock_response_1.reasoning = None
|
||||
mock_response_1.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_use", "id": "call_1"}]
|
||||
}
|
||||
|
||||
mock_response_2 = MagicMock()
|
||||
mock_response_2.response = "Handled the error"
|
||||
mock_response_2.tool_calls = []
|
||||
mock_response_2.prompt_tokens = 30
|
||||
mock_response_2.completion_tokens = 15
|
||||
mock_response_2.reasoning = None
|
||||
mock_response_2.raw_response = {"role": "assistant", "content": "Handled"}
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return mock_response_1
|
||||
return mock_response_2
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {},
|
||||
"parameters": {"properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Create a complex error with nested cause
|
||||
class CustomDatabaseError(Exception):
|
||||
pass
|
||||
|
||||
def create_complex_error():
|
||||
try:
|
||||
raise ConnectionError("Network timeout after 30s")
|
||||
except ConnectionError as e:
|
||||
raise CustomDatabaseError("Failed to connect to database") from e
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
|
||||
# Make upsert raise the complex error
|
||||
try:
|
||||
create_complex_error()
|
||||
except CustomDatabaseError as e:
|
||||
mock_db_client.upsert_execution_input.side_effect = e
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures), \
|
||||
patch("backend.blocks.smart_decision_maker.get_database_manager_async_client", return_value=mock_db_client):
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=2,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Check conversation for error details
|
||||
conversations = outputs.get("conversations", [])
|
||||
error_found = False
|
||||
for msg in conversations:
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") == "tool_result":
|
||||
result_content = item.get("content", "")
|
||||
if "Error" in result_content or "failed" in result_content.lower():
|
||||
error_found = True
|
||||
# BUG: The error content lacks:
|
||||
# - Exception type (CustomDatabaseError)
|
||||
# - Chained cause (ConnectionError)
|
||||
# - Stack trace
|
||||
assert "CustomDatabaseError" not in result_content
|
||||
assert "ConnectionError" not in result_content
|
||||
|
||||
# Note: error_found may be False if the error prevented tool response creation
|
||||
|
||||
|
||||
class TestRawResponseConversion:
|
||||
"""Tests for _convert_raw_response_to_dict edge cases."""
|
||||
|
||||
def test_string_response_converted(self):
|
||||
"""Test that string responses are properly wrapped."""
|
||||
result = _convert_raw_response_to_dict("Hello, world!")
|
||||
assert result == {"role": "assistant", "content": "Hello, world!"}
|
||||
|
||||
def test_dict_response_unchanged(self):
|
||||
"""Test that dict responses are passed through."""
|
||||
original = {"role": "assistant", "content": "test", "extra": "field"}
|
||||
result = _convert_raw_response_to_dict(original)
|
||||
assert result == original
|
||||
|
||||
def test_object_response_converted(self):
|
||||
"""Test that objects are converted using json.to_dict."""
|
||||
mock_obj = MagicMock()
|
||||
|
||||
with patch("backend.blocks.smart_decision_maker.json.to_dict") as mock_to_dict:
|
||||
mock_to_dict.return_value = {"converted": True}
|
||||
result = _convert_raw_response_to_dict(mock_obj)
|
||||
mock_to_dict.assert_called_once_with(mock_obj)
|
||||
assert result == {"converted": True}
|
||||
|
||||
def test_none_response(self):
|
||||
"""Test handling of None response."""
|
||||
with patch("backend.blocks.smart_decision_maker.json.to_dict") as mock_to_dict:
|
||||
mock_to_dict.return_value = None
|
||||
result = _convert_raw_response_to_dict(None)
|
||||
# None is not a string or dict, so it goes through to_dict
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestValidationRetryMechanism:
|
||||
"""Tests for the validation and retry mechanism."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_error_triggers_retry(self):
|
||||
"""
|
||||
Test that validation errors trigger retry with feedback.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
prompt = kwargs.get("prompt", [])
|
||||
|
||||
if call_count == 1:
|
||||
# First call: return tool call with wrong parameter
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"wrong_param": "value"})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [mock_tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": None}
|
||||
return resp
|
||||
else:
|
||||
# Second call: check that error feedback was added
|
||||
has_error_feedback = any(
|
||||
"parameter errors" in str(msg.get("content", "")).lower()
|
||||
for msg in prompt
|
||||
)
|
||||
|
||||
# Return correct tool call
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"correct_param": "value"})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [mock_tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": None}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {"correct_param": "correct_param"},
|
||||
"parameters": {
|
||||
"properties": {"correct_param": {"type": "string"}},
|
||||
"required": ["correct_param"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0, # Traditional mode
|
||||
retry=3,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
outputs = {}
|
||||
async for name, value in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[name] = value
|
||||
|
||||
# Should have made multiple calls due to retry
|
||||
assert call_count >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_retries_exceeded(self):
|
||||
"""
|
||||
Test behavior when max retries are exceeded.
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
async def mock_llm_call(**kwargs):
|
||||
# Always return invalid tool call
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "test_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({"wrong": "param"})
|
||||
|
||||
resp = MagicMock()
|
||||
resp.response = None
|
||||
resp.tool_calls = [mock_tool_call]
|
||||
resp.prompt_tokens = 50
|
||||
resp.completion_tokens = 25
|
||||
resp.reasoning = None
|
||||
resp.raw_response = {"role": "assistant", "content": None}
|
||||
return resp
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"_sink_node_id": "sink",
|
||||
"_field_mapping": {"correct": "correct"},
|
||||
"parameters": {
|
||||
"properties": {"correct": {"type": "string"}},
|
||||
"required": ["correct"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", side_effect=mock_llm_call), \
|
||||
patch.object(block, "_create_tool_node_signatures", return_value=mock_tool_signatures):
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
retry=2, # Only 2 retries
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
# Should raise ValueError after max retries
|
||||
with pytest.raises(ValueError, match="parameter errors"):
|
||||
async for _ in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph",
|
||||
node_id="test-node",
|
||||
graph_exec_id="test-exec",
|
||||
node_exec_id="test-node-exec",
|
||||
user_id="test-user",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
pass
|
||||
@@ -0,0 +1,819 @@
|
||||
"""
|
||||
Comprehensive tests for SmartDecisionMakerBlock pin name sanitization.
|
||||
|
||||
This test file addresses the critical bug where field names with spaces/special characters
|
||||
(e.g., "Max Keyword Difficulty") are not consistently sanitized between frontend and backend,
|
||||
causing tool calls to "go into the void".
|
||||
|
||||
The core issue:
|
||||
- Frontend connects link with original name: tools_^_{node_id}_~_Max Keyword Difficulty
|
||||
- Backend emits with sanitized name: tools_^_{node_id}_~_max_keyword_difficulty
|
||||
- parse_execution_output compares sink_pin_name directly without sanitization
|
||||
- Result: mismatch causes tool calls to fail silently
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.dynamic_fields import (
|
||||
parse_execution_output,
|
||||
sanitize_pin_name,
|
||||
)
|
||||
|
||||
|
||||
class TestCleanupFunction:
|
||||
"""Tests for the SmartDecisionMakerBlock.cleanup() static method."""
|
||||
|
||||
def test_cleanup_spaces_to_underscores(self):
|
||||
"""Spaces should be replaced with underscores."""
|
||||
assert SmartDecisionMakerBlock.cleanup("Max Keyword Difficulty") == "max_keyword_difficulty"
|
||||
|
||||
def test_cleanup_mixed_case_to_lowercase(self):
|
||||
"""Mixed case should be converted to lowercase."""
|
||||
assert SmartDecisionMakerBlock.cleanup("MaxKeywordDifficulty") == "maxkeyworddifficulty"
|
||||
assert SmartDecisionMakerBlock.cleanup("UPPER_CASE") == "upper_case"
|
||||
|
||||
def test_cleanup_special_characters(self):
|
||||
"""Special characters should be replaced with underscores."""
|
||||
assert SmartDecisionMakerBlock.cleanup("field@name!") == "field_name_"
|
||||
assert SmartDecisionMakerBlock.cleanup("value#1") == "value_1"
|
||||
assert SmartDecisionMakerBlock.cleanup("test$value") == "test_value"
|
||||
assert SmartDecisionMakerBlock.cleanup("a%b^c") == "a_b_c"
|
||||
|
||||
def test_cleanup_preserves_valid_characters(self):
|
||||
"""Valid characters (alphanumeric, underscore, hyphen) should be preserved."""
|
||||
assert SmartDecisionMakerBlock.cleanup("valid_name-123") == "valid_name-123"
|
||||
assert SmartDecisionMakerBlock.cleanup("abc123") == "abc123"
|
||||
|
||||
def test_cleanup_empty_string(self):
|
||||
"""Empty string should return empty string."""
|
||||
assert SmartDecisionMakerBlock.cleanup("") == ""
|
||||
|
||||
def test_cleanup_only_special_chars(self):
|
||||
"""String of only special characters should return underscores."""
|
||||
assert SmartDecisionMakerBlock.cleanup("@#$%") == "____"
|
||||
|
||||
def test_cleanup_unicode_characters(self):
|
||||
"""Unicode characters should be replaced with underscores."""
|
||||
assert SmartDecisionMakerBlock.cleanup("café") == "caf_"
|
||||
assert SmartDecisionMakerBlock.cleanup("日本語") == "___"
|
||||
|
||||
def test_cleanup_multiple_consecutive_spaces(self):
|
||||
"""Multiple consecutive spaces should become multiple underscores."""
|
||||
assert SmartDecisionMakerBlock.cleanup("a b") == "a___b"
|
||||
|
||||
def test_cleanup_leading_trailing_spaces(self):
|
||||
"""Leading/trailing spaces should become underscores."""
|
||||
assert SmartDecisionMakerBlock.cleanup(" name ") == "_name_"
|
||||
|
||||
def test_cleanup_realistic_field_names(self):
|
||||
"""Test realistic field names from actual use cases."""
|
||||
# From the reported bug
|
||||
assert SmartDecisionMakerBlock.cleanup("Max Keyword Difficulty") == "max_keyword_difficulty"
|
||||
# Other realistic names
|
||||
assert SmartDecisionMakerBlock.cleanup("Search Query") == "search_query"
|
||||
assert SmartDecisionMakerBlock.cleanup("API Response (JSON)") == "api_response__json_"
|
||||
assert SmartDecisionMakerBlock.cleanup("User's Input") == "user_s_input"
|
||||
|
||||
|
||||
class TestFieldMappingCreation:
|
||||
"""Tests for field mapping creation in function signatures."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_field_mapping_with_spaces_in_names(self):
|
||||
"""Test that field mapping correctly maps clean names back to original names with spaces."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test description"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": ["Max Keyword Difficulty"]}
|
||||
)
|
||||
|
||||
def get_field_schema(field_name):
|
||||
if field_name == "Max Keyword Difficulty":
|
||||
return {"type": "integer", "description": "Maximum keyword difficulty (0-100)"}
|
||||
raise KeyError(f"Field {field_name} not found")
|
||||
|
||||
mock_node.block.input_schema.get_field_schema = get_field_schema
|
||||
|
||||
mock_links = [
|
||||
Mock(
|
||||
source_name="tools_^_test_~_max_keyword_difficulty",
|
||||
sink_name="Max Keyword Difficulty", # Original name with spaces
|
||||
sink_id="test-node-id",
|
||||
source_id="smart_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
# Verify the cleaned name is used in properties
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
assert "max_keyword_difficulty" in properties
|
||||
|
||||
# Verify the field mapping maps back to original
|
||||
field_mapping = signature["function"]["_field_mapping"]
|
||||
assert field_mapping["max_keyword_difficulty"] == "Max Keyword Difficulty"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_field_mapping_with_multiple_special_char_names(self):
|
||||
"""Test field mapping with multiple fields containing special characters."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "SEO Tool"
|
||||
mock_node.block.description = "SEO analysis tool"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
|
||||
def get_field_schema(field_name):
|
||||
schemas = {
|
||||
"Max Keyword Difficulty": {"type": "integer", "description": "Max difficulty"},
|
||||
"Search Volume (Monthly)": {"type": "integer", "description": "Monthly volume"},
|
||||
"CPC ($)": {"type": "number", "description": "Cost per click"},
|
||||
"Target URL": {"type": "string", "description": "URL to analyze"},
|
||||
}
|
||||
if field_name in schemas:
|
||||
return schemas[field_name]
|
||||
raise KeyError(f"Field {field_name} not found")
|
||||
|
||||
mock_node.block.input_schema.get_field_schema = get_field_schema
|
||||
|
||||
mock_links = [
|
||||
Mock(sink_name="Max Keyword Difficulty", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
Mock(sink_name="Search Volume (Monthly)", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
Mock(sink_name="CPC ($)", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
Mock(sink_name="Target URL", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
field_mapping = signature["function"]["_field_mapping"]
|
||||
|
||||
# Verify all cleaned names are in properties
|
||||
assert "max_keyword_difficulty" in properties
|
||||
assert "search_volume__monthly_" in properties
|
||||
assert "cpc____" in properties
|
||||
assert "target_url" in properties
|
||||
|
||||
# Verify field mappings
|
||||
assert field_mapping["max_keyword_difficulty"] == "Max Keyword Difficulty"
|
||||
assert field_mapping["search_volume__monthly_"] == "Search Volume (Monthly)"
|
||||
assert field_mapping["cpc____"] == "CPC ($)"
|
||||
assert field_mapping["target_url"] == "Target URL"
|
||||
|
||||
|
||||
class TestFieldNameCollision:
|
||||
"""Tests for detecting field name collisions after sanitization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collision_detection_same_sanitized_name(self):
|
||||
"""Test behavior when two different names sanitize to the same value."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
# These two different names will sanitize to the same value
|
||||
name1 = "max keyword difficulty" # -> max_keyword_difficulty
|
||||
name2 = "Max Keyword Difficulty" # -> max_keyword_difficulty
|
||||
name3 = "MAX_KEYWORD_DIFFICULTY" # -> max_keyword_difficulty
|
||||
|
||||
assert SmartDecisionMakerBlock.cleanup(name1) == SmartDecisionMakerBlock.cleanup(name2)
|
||||
assert SmartDecisionMakerBlock.cleanup(name2) == SmartDecisionMakerBlock.cleanup(name3)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collision_in_function_signature(self):
|
||||
"""Test that collisions in sanitized names could cause issues."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test description"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": []}
|
||||
)
|
||||
|
||||
def get_field_schema(field_name):
|
||||
return {"type": "string", "description": f"Field: {field_name}"}
|
||||
|
||||
mock_node.block.input_schema.get_field_schema = get_field_schema
|
||||
|
||||
# Two different fields that sanitize to the same name
|
||||
mock_links = [
|
||||
Mock(sink_name="Test Field", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
Mock(sink_name="test field", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
field_mapping = signature["function"]["_field_mapping"]
|
||||
|
||||
# Both sanitize to "test_field" - only one will be in properties
|
||||
assert "test_field" in properties
|
||||
# The field_mapping will have the last one written
|
||||
assert field_mapping["test_field"] in ["Test Field", "test field"]
|
||||
|
||||
|
||||
class TestOutputRouting:
|
||||
"""Tests for output routing with sanitized names."""
|
||||
|
||||
def test_emit_key_format_with_spaces(self):
|
||||
"""Test that emit keys use sanitized field names."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
original_field_name = "Max Keyword Difficulty"
|
||||
sink_node_id = "node-123"
|
||||
|
||||
sanitized_name = block.cleanup(original_field_name)
|
||||
emit_key = f"tools_^_{sink_node_id}_~_{sanitized_name}"
|
||||
|
||||
assert emit_key == "tools_^_node-123_~_max_keyword_difficulty"
|
||||
|
||||
def test_parse_execution_output_exact_match(self):
|
||||
"""Test parse_execution_output with exact matching names."""
|
||||
output_item = ("tools_^_node-123_~_max_keyword_difficulty", 50)
|
||||
|
||||
# When sink_pin_name matches the sanitized name, it should work
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="max_keyword_difficulty",
|
||||
)
|
||||
assert result == 50
|
||||
|
||||
def test_parse_execution_output_mismatch_original_vs_sanitized(self):
|
||||
"""
|
||||
CRITICAL TEST: This reproduces the exact bug reported.
|
||||
|
||||
When frontend creates a link with original name "Max Keyword Difficulty"
|
||||
but backend emits with sanitized name "max_keyword_difficulty",
|
||||
the tool call should still be routed correctly.
|
||||
|
||||
CURRENT BEHAVIOR (BUG): Returns None because names don't match
|
||||
EXPECTED BEHAVIOR: Should return the value (50) after sanitizing both names
|
||||
"""
|
||||
output_item = ("tools_^_node-123_~_max_keyword_difficulty", 50)
|
||||
|
||||
# This is what happens: sink_pin_name comes from frontend link (unsanitized)
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="Max Keyword Difficulty", # Original name with spaces
|
||||
)
|
||||
|
||||
# BUG: This currently returns None because:
|
||||
# - target_input_pin = "max_keyword_difficulty" (from emit key, sanitized)
|
||||
# - sink_pin_name = "Max Keyword Difficulty" (from link, original)
|
||||
# - They don't match, so routing fails
|
||||
#
|
||||
# TODO: When the bug is fixed, change this assertion to:
|
||||
# assert result == 50
|
||||
assert result is None # Current buggy behavior
|
||||
|
||||
def test_parse_execution_output_with_sanitized_sink_pin(self):
|
||||
"""Test that if sink_pin_name is pre-sanitized, routing works."""
|
||||
output_item = ("tools_^_node-123_~_max_keyword_difficulty", 50)
|
||||
|
||||
# If sink_pin_name is already sanitized, routing works
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="max_keyword_difficulty", # Pre-sanitized
|
||||
)
|
||||
assert result == 50
|
||||
|
||||
|
||||
class TestProcessToolCallsMapping:
|
||||
"""Tests for _process_tool_calls method field mapping."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_tool_calls_maps_clean_to_original(self):
|
||||
"""Test that _process_tool_calls correctly maps clean names back to original."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "seo_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"max_keyword_difficulty": 50, # LLM uses clean name
|
||||
"search_query": "test query",
|
||||
})
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seo_tool",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"max_keyword_difficulty": {"type": "integer"},
|
||||
"search_query": {"type": "string"},
|
||||
},
|
||||
"required": ["max_keyword_difficulty", "search_query"],
|
||||
},
|
||||
"_sink_node_id": "test-sink-node",
|
||||
"_field_mapping": {
|
||||
"max_keyword_difficulty": "Max Keyword Difficulty", # Original name
|
||||
"search_query": "Search Query",
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
|
||||
assert len(processed) == 1
|
||||
tool_info = processed[0]
|
||||
|
||||
# Verify input_data uses ORIGINAL field names
|
||||
assert "Max Keyword Difficulty" in tool_info.input_data
|
||||
assert "Search Query" in tool_info.input_data
|
||||
assert tool_info.input_data["Max Keyword Difficulty"] == 50
|
||||
assert tool_info.input_data["Search Query"] == "test query"
|
||||
|
||||
|
||||
class TestToolOutputEmitting:
|
||||
"""Tests for the tool output emitting in traditional mode."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_keys_use_sanitized_names(self):
|
||||
"""Test that emit keys always use sanitized field names."""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "seo_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
"max_keyword_difficulty": 50,
|
||||
})
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = None
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seo_tool",
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
"_field_mapping": {
|
||||
"max_keyword_difficulty": "Max Keyword Difficulty",
|
||||
},
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"max_keyword_difficulty": {"type": "integer"},
|
||||
},
|
||||
"required": ["max_keyword_difficulty"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.blocks.llm.llm_call",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
|
||||
):
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Test prompt",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=0,
|
||||
)
|
||||
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# The emit key should use the sanitized field name
|
||||
# Even though the original was "Max Keyword Difficulty", emit uses sanitized
|
||||
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
|
||||
assert outputs["tools_^_test-sink-node-id_~_max_keyword_difficulty"] == 50
|
||||
|
||||
|
||||
class TestSanitizationConsistency:
|
||||
"""Tests for ensuring sanitization is consistent throughout the pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_round_trip_with_spaces(self):
|
||||
"""
|
||||
Test the full round-trip of a field name with spaces through the system.
|
||||
|
||||
This simulates:
|
||||
1. Frontend creates link with sink_name="Max Keyword Difficulty"
|
||||
2. Backend creates function signature with cleaned property name
|
||||
3. LLM responds with cleaned name
|
||||
4. Backend processes response and maps back to original
|
||||
5. Backend emits with sanitized name
|
||||
6. Routing should match (currently broken)
|
||||
"""
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
original_field_name = "Max Keyword Difficulty"
|
||||
cleaned_field_name = SmartDecisionMakerBlock.cleanup(original_field_name)
|
||||
|
||||
# Step 1: Simulate frontend link creation
|
||||
mock_link = Mock()
|
||||
mock_link.sink_name = original_field_name # Frontend uses original
|
||||
mock_link.sink_id = "test-sink-node-id"
|
||||
mock_link.source_id = "smart-node-id"
|
||||
|
||||
# Step 2: Create function signature
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-sink-node-id"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "SEO Tool"
|
||||
mock_node.block.description = "SEO analysis"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": [original_field_name]}
|
||||
)
|
||||
mock_node.block.input_schema.get_field_schema = Mock(
|
||||
return_value={"type": "integer", "description": "Max difficulty"}
|
||||
)
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, [mock_link])
|
||||
|
||||
# Verify cleaned name is in properties
|
||||
assert cleaned_field_name in signature["function"]["parameters"]["properties"]
|
||||
# Verify field mapping exists
|
||||
assert signature["function"]["_field_mapping"][cleaned_field_name] == original_field_name
|
||||
|
||||
# Step 3: Simulate LLM response using cleaned name
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function.name = "seo_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
cleaned_field_name: 50 # LLM uses cleaned name
|
||||
})
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response = None
|
||||
mock_response.tool_calls = [mock_tool_call]
|
||||
mock_response.prompt_tokens = 50
|
||||
mock_response.completion_tokens = 25
|
||||
mock_response.reasoning = None
|
||||
mock_response.raw_response = {"role": "assistant", "content": None}
|
||||
|
||||
# Prepare tool_functions as they would be in run()
|
||||
tool_functions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seo_tool",
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
"_field_mapping": signature["function"]["_field_mapping"],
|
||||
"parameters": signature["function"]["parameters"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Step 4: Process tool calls
|
||||
processed = block._process_tool_calls(mock_response, tool_functions)
|
||||
assert len(processed) == 1
|
||||
# Input data should have ORIGINAL name
|
||||
assert original_field_name in processed[0].input_data
|
||||
assert processed[0].input_data[original_field_name] == 50
|
||||
|
||||
# Step 5: Emit key generation (from run method logic)
|
||||
field_mapping = processed[0].field_mapping
|
||||
for clean_arg_name in signature["function"]["parameters"]["properties"]:
|
||||
original = field_mapping.get(clean_arg_name, clean_arg_name)
|
||||
sanitized_arg_name = block.cleanup(original)
|
||||
emit_key = f"tools_^_test-sink-node-id_~_{sanitized_arg_name}"
|
||||
|
||||
# Emit key uses sanitized name
|
||||
assert emit_key == f"tools_^_test-sink-node-id_~_{cleaned_field_name}"
|
||||
|
||||
# Step 6: Routing check (this is where the bug manifests)
|
||||
emit_key = f"tools_^_test-sink-node-id_~_{cleaned_field_name}"
|
||||
output_item = (emit_key, 50)
|
||||
|
||||
# Current routing uses original sink_name from link
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="test-sink-node-id",
|
||||
sink_pin_name=original_field_name, # Frontend's original name
|
||||
)
|
||||
|
||||
# BUG: This returns None because sanitized != original
|
||||
# When fixed, this should return 50
|
||||
assert result is None # Current broken behavior
|
||||
|
||||
def test_sanitization_is_idempotent(self):
|
||||
"""Test that sanitizing an already sanitized name gives the same result."""
|
||||
original = "Max Keyword Difficulty"
|
||||
first_clean = SmartDecisionMakerBlock.cleanup(original)
|
||||
second_clean = SmartDecisionMakerBlock.cleanup(first_clean)
|
||||
|
||||
assert first_clean == second_clean
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Tests for edge cases in the sanitization pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_field_name(self):
|
||||
"""Test handling of empty field name."""
|
||||
assert SmartDecisionMakerBlock.cleanup("") == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_very_long_field_name(self):
|
||||
"""Test handling of very long field names."""
|
||||
long_name = "A" * 1000 + " " + "B" * 1000
|
||||
cleaned = SmartDecisionMakerBlock.cleanup(long_name)
|
||||
assert "_" in cleaned # Space was replaced
|
||||
assert len(cleaned) == len(long_name)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_field_name_with_newlines(self):
|
||||
"""Test handling of field names with newlines."""
|
||||
name_with_newline = "First Line\nSecond Line"
|
||||
cleaned = SmartDecisionMakerBlock.cleanup(name_with_newline)
|
||||
assert "\n" not in cleaned
|
||||
assert "_" in cleaned
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_field_name_with_tabs(self):
|
||||
"""Test handling of field names with tabs."""
|
||||
name_with_tab = "First\tSecond"
|
||||
cleaned = SmartDecisionMakerBlock.cleanup(name_with_tab)
|
||||
assert "\t" not in cleaned
|
||||
assert "_" in cleaned
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_numeric_field_name(self):
|
||||
"""Test handling of purely numeric field names."""
|
||||
assert SmartDecisionMakerBlock.cleanup("123") == "123"
|
||||
assert SmartDecisionMakerBlock.cleanup("123 456") == "123_456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hyphenated_field_names(self):
|
||||
"""Test that hyphens are preserved (valid in function names)."""
|
||||
assert SmartDecisionMakerBlock.cleanup("field-name") == "field-name"
|
||||
assert SmartDecisionMakerBlock.cleanup("Field-Name") == "field-name"
|
||||
|
||||
|
||||
class TestDynamicFieldsWithSpaces:
|
||||
"""Tests for dynamic fields with spaces in their names."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_dict_field_with_spaces(self):
|
||||
"""Test dynamic dictionary fields where the key contains spaces."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "CreateDictionary"
|
||||
mock_node.block.description = "Creates a dictionary"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={"properties": {}, "required": ["values"]}
|
||||
)
|
||||
mock_node.block.input_schema.get_field_schema = Mock(
|
||||
side_effect=KeyError("not found")
|
||||
)
|
||||
|
||||
# Dynamic field with a key containing spaces
|
||||
mock_links = [
|
||||
Mock(
|
||||
sink_name="values_#_User Name", # Dict key with space
|
||||
sink_id="test-node-id",
|
||||
source_id="smart_node_id",
|
||||
),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
properties = signature["function"]["parameters"]["properties"]
|
||||
field_mapping = signature["function"]["_field_mapping"]
|
||||
|
||||
# The cleaned name should be in properties
|
||||
expected_clean = SmartDecisionMakerBlock.cleanup("values_#_User Name")
|
||||
assert expected_clean in properties
|
||||
|
||||
# Field mapping should map back to original
|
||||
assert field_mapping[expected_clean] == "values_#_User Name"
|
||||
|
||||
|
||||
class TestAgentModeWithSpaces:
|
||||
"""Tests for agent mode with field names containing spaces."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_mode_tool_execution_with_spaces(self):
|
||||
"""Test that agent mode correctly handles field names with spaces."""
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
|
||||
import backend.blocks.llm as llm_module
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
original_field = "Max Keyword Difficulty"
|
||||
clean_field = SmartDecisionMakerBlock.cleanup(original_field)
|
||||
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = "call_1"
|
||||
mock_tool_call.function.name = "seo_tool"
|
||||
mock_tool_call.function.arguments = json.dumps({
|
||||
clean_field: 50 # LLM uses clean name
|
||||
})
|
||||
|
||||
mock_response_1 = MagicMock()
|
||||
mock_response_1.response = None
|
||||
mock_response_1.tool_calls = [mock_tool_call]
|
||||
mock_response_1.prompt_tokens = 50
|
||||
mock_response_1.completion_tokens = 25
|
||||
mock_response_1.reasoning = None
|
||||
mock_response_1.raw_response = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_1", "type": "function"}],
|
||||
}
|
||||
|
||||
mock_response_2 = MagicMock()
|
||||
mock_response_2.response = "Task completed"
|
||||
mock_response_2.tool_calls = []
|
||||
mock_response_2.prompt_tokens = 30
|
||||
mock_response_2.completion_tokens = 15
|
||||
mock_response_2.reasoning = None
|
||||
mock_response_2.raw_response = {"role": "assistant", "content": "Task completed"}
|
||||
|
||||
llm_call_mock = AsyncMock()
|
||||
llm_call_mock.side_effect = [mock_response_1, mock_response_2]
|
||||
|
||||
mock_tool_signatures = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seo_tool",
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
"_field_mapping": {
|
||||
clean_field: original_field,
|
||||
},
|
||||
"parameters": {
|
||||
"properties": {
|
||||
clean_field: {"type": "integer"},
|
||||
},
|
||||
"required": [clean_field],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_node = MagicMock()
|
||||
mock_node.block_id = "test-block-id"
|
||||
mock_db_client.get_node.return_value = mock_node
|
||||
|
||||
mock_node_exec_result = MagicMock()
|
||||
mock_node_exec_result.node_exec_id = "test-tool-exec-id"
|
||||
|
||||
# The input data should use ORIGINAL field name
|
||||
mock_input_data = {original_field: 50}
|
||||
mock_db_client.upsert_execution_input.return_value = (
|
||||
mock_node_exec_result,
|
||||
mock_input_data,
|
||||
)
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
|
||||
"result": {"status": "success"}
|
||||
}
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
|
||||
), patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = AsyncMock()
|
||||
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_execution_processor.execution_stats = MagicMock()
|
||||
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||
|
||||
mock_node_stats = MagicMock()
|
||||
mock_node_stats.error = None
|
||||
mock_execution_processor.on_node_execution = AsyncMock(
|
||||
return_value=mock_node_stats
|
||||
)
|
||||
|
||||
input_data = SmartDecisionMakerBlock.Input(
|
||||
prompt="Analyze keywords",
|
||||
model=llm_module.DEFAULT_LLM_MODEL,
|
||||
credentials=llm_module.TEST_CREDENTIALS_INPUT,
|
||||
agent_mode_max_iterations=3,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
async for output_name, output_data in block.run(
|
||||
input_data,
|
||||
credentials=llm_module.TEST_CREDENTIALS,
|
||||
graph_id="test-graph-id",
|
||||
node_id="test-node-id",
|
||||
graph_exec_id="test-exec-id",
|
||||
node_exec_id="test-node-exec-id",
|
||||
user_id="test-user-id",
|
||||
graph_version=1,
|
||||
execution_context=mock_execution_context,
|
||||
execution_processor=mock_execution_processor,
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify upsert was called with original field name
|
||||
upsert_calls = mock_db_client.upsert_execution_input.call_args_list
|
||||
assert len(upsert_calls) > 0
|
||||
# Check that the original field name was used
|
||||
for call in upsert_calls:
|
||||
input_name = call.kwargs.get("input_name") or call.args[2]
|
||||
# The input name should be the original (mapped back)
|
||||
assert input_name == original_field
|
||||
|
||||
|
||||
class TestRequiredFieldsWithSpaces:
|
||||
"""Tests for required field handling with spaces in names."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_required_fields_use_clean_names(self):
|
||||
"""Test that required fields array uses clean names for API compatibility."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
mock_node = Mock()
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block = Mock()
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "Test"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.block.input_schema.jsonschema = Mock(
|
||||
return_value={
|
||||
"properties": {},
|
||||
"required": ["Max Keyword Difficulty", "Search Query"],
|
||||
}
|
||||
)
|
||||
|
||||
def get_field_schema(field_name):
|
||||
return {"type": "string", "description": f"Field: {field_name}"}
|
||||
|
||||
mock_node.block.input_schema.get_field_schema = get_field_schema
|
||||
|
||||
mock_links = [
|
||||
Mock(sink_name="Max Keyword Difficulty", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
Mock(sink_name="Search Query", sink_id="test-node-id", source_id="smart_node_id"),
|
||||
]
|
||||
|
||||
signature = await block._create_block_function_signature(mock_node, mock_links)
|
||||
|
||||
required = signature["function"]["parameters"]["required"]
|
||||
|
||||
# Required array should use CLEAN names for API compatibility
|
||||
assert "max_keyword_difficulty" in required
|
||||
assert "search_query" in required
|
||||
# Original names should NOT be in required
|
||||
assert "Max Keyword Difficulty" not in required
|
||||
assert "Search Query" not in required
|
||||
@@ -1,5 +1,5 @@
|
||||
from gravitasml.parser import Parser
|
||||
from gravitasml.token import tokenize
|
||||
from gravitasml.token import Token, tokenize
|
||||
|
||||
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
|
||||
from backend.data.model import SchemaField
|
||||
@@ -25,6 +25,38 @@ class XMLParserBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_tokens(tokens: list[Token]) -> None:
|
||||
"""Ensure the XML has a single root element and no stray text."""
|
||||
if not tokens:
|
||||
raise ValueError("XML input is empty.")
|
||||
|
||||
depth = 0
|
||||
root_seen = False
|
||||
|
||||
for token in tokens:
|
||||
if token.type == "TAG_OPEN":
|
||||
if depth == 0 and root_seen:
|
||||
raise ValueError("XML must have a single root element.")
|
||||
depth += 1
|
||||
if depth == 1:
|
||||
root_seen = True
|
||||
elif token.type == "TAG_CLOSE":
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
raise SyntaxError("Unexpected closing tag in XML input.")
|
||||
elif token.type in {"TEXT", "ESCAPE"}:
|
||||
if depth == 0 and token.value:
|
||||
raise ValueError(
|
||||
"XML contains text outside the root element; "
|
||||
"wrap content in a single root tag."
|
||||
)
|
||||
|
||||
if depth != 0:
|
||||
raise SyntaxError("Unclosed tag detected in XML input.")
|
||||
if not root_seen:
|
||||
raise ValueError("XML must include a root element.")
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Security fix: Add size limits to prevent XML bomb attacks
|
||||
MAX_XML_SIZE = 10 * 1024 * 1024 # 10MB limit for XML input
|
||||
@@ -35,7 +67,9 @@ class XMLParserBlock(Block):
|
||||
)
|
||||
|
||||
try:
|
||||
tokens = tokenize(input_data.input_xml)
|
||||
tokens = list(tokenize(input_data.input_xml))
|
||||
self._validate_tokens(tokens)
|
||||
|
||||
parser = Parser(tokens)
|
||||
parsed_result = parser.parse()
|
||||
yield "parsed_xml", parsed_result
|
||||
|
||||
@@ -111,6 +111,8 @@ class TranscribeYoutubeVideoBlock(Block):
|
||||
return parsed_url.path.split("/")[2]
|
||||
if parsed_url.path[:3] == "/v/":
|
||||
return parsed_url.path.split("/")[2]
|
||||
if parsed_url.path.startswith("/shorts/"):
|
||||
return parsed_url.path.split("/")[2]
|
||||
raise ValueError(f"Invalid YouTube URL: {url}")
|
||||
|
||||
def get_transcript(
|
||||
|
||||
@@ -50,6 +50,8 @@ from .model import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
from .graph import Link
|
||||
|
||||
app_config = Config()
|
||||
@@ -472,6 +474,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||
self.requires_human_review: bool = False
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
@@ -614,7 +617,77 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
block_id=self.id,
|
||||
) from ex
|
||||
|
||||
async def is_block_exec_need_review(
|
||||
self,
|
||||
input_data: BlockInput,
|
||||
*,
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: "ExecutionContext",
|
||||
**kwargs,
|
||||
) -> tuple[bool, BlockInput]:
|
||||
"""
|
||||
Check if this block execution needs human review and handle the review process.
|
||||
|
||||
Returns:
|
||||
Tuple of (should_pause, input_data_to_use)
|
||||
- should_pause: True if execution should be paused for review
|
||||
- input_data_to_use: The input data to use (may be modified by reviewer)
|
||||
"""
|
||||
# Skip review if not required or safe mode is disabled
|
||||
if not self.requires_human_review or not execution_context.safe_mode:
|
||||
return False, input_data
|
||||
|
||||
from backend.blocks.helpers.review import HITLReviewHelper
|
||||
|
||||
# Handle the review request and get decision
|
||||
decision = await HITLReviewHelper.handle_review_decision(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=self.name,
|
||||
editable=True,
|
||||
)
|
||||
|
||||
if decision is None:
|
||||
# We're awaiting review - pause execution
|
||||
return True, input_data
|
||||
|
||||
if not decision.should_proceed:
|
||||
# Review was rejected, raise an error to stop execution
|
||||
raise BlockExecutionError(
|
||||
message=f"Block execution rejected by reviewer: {decision.message}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Review was approved - use the potentially modified data
|
||||
# ReviewResult.data must be a dict for block inputs
|
||||
reviewed_data = decision.review_result.data
|
||||
if not isinstance(reviewed_data, dict):
|
||||
raise BlockExecutionError(
|
||||
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
return False, reviewed_data
|
||||
|
||||
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||
# Check for review requirement and get potentially modified input data
|
||||
should_pause, input_data = await self.is_block_exec_need_review(
|
||||
input_data, **kwargs
|
||||
)
|
||||
if should_pause:
|
||||
return
|
||||
|
||||
# Validate the input data (original or reviewer-modified) once
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise BlockInputError(
|
||||
message=f"Unable to execute block with invalid input data: {error}",
|
||||
@@ -622,6 +695,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
# Use the validated input data
|
||||
async for output_name, output_data in self.run(
|
||||
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
||||
**kwargs,
|
||||
|
||||
@@ -59,12 +59,13 @@ from backend.integrations.credentials_store import (
|
||||
|
||||
MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3: 4,
|
||||
LlmModel.O3_MINI: 2, # $1.10 / $4.40
|
||||
LlmModel.O1: 16, # $15 / $60
|
||||
LlmModel.O3_MINI: 2,
|
||||
LlmModel.O1: 16,
|
||||
LlmModel.O1_MINI: 4,
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5: 2,
|
||||
LlmModel.GPT5_2: 6,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5: 2,
|
||||
LlmModel.GPT5_MINI: 1,
|
||||
LlmModel.GPT5_NANO: 1,
|
||||
LlmModel.GPT5_CHAT: 5,
|
||||
@@ -87,7 +88,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1, # $0.59 / $0.79
|
||||
LlmModel.LLAMA3_3_70B: 1,
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_3: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
||||
|
||||
@@ -341,6 +341,19 @@ class UserCreditBase(ABC):
|
||||
|
||||
if result:
|
||||
# UserBalance is already updated by the CTE
|
||||
|
||||
# Clear insufficient funds notification flags when credits are added
|
||||
# so user can receive alerts again if they run out in the future.
|
||||
if transaction.amount > 0 and transaction.type in [
|
||||
CreditTransactionType.GRANT,
|
||||
CreditTransactionType.TOP_UP,
|
||||
]:
|
||||
from backend.executor.manager import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
return result[0]["balance"]
|
||||
|
||||
async def _add_transaction(
|
||||
@@ -530,6 +543,22 @@ class UserCreditBase(ABC):
|
||||
if result:
|
||||
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
|
||||
# UserBalance is already updated by the CTE
|
||||
|
||||
# Clear insufficient funds notification flags when credits are added
|
||||
# so user can receive alerts again if they run out in the future.
|
||||
if (
|
||||
amount > 0
|
||||
and is_active
|
||||
and transaction_type
|
||||
in [CreditTransactionType.GRANT, CreditTransactionType.TOP_UP]
|
||||
):
|
||||
# Lazy import to avoid circular dependency with executor.manager
|
||||
from backend.executor.manager import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
return new_balance, tx_key
|
||||
|
||||
# If no result, either user doesn't exist or insufficient balance
|
||||
|
||||
@@ -383,6 +383,7 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
self,
|
||||
execution_context: ExecutionContext,
|
||||
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
):
|
||||
return GraphExecutionEntry(
|
||||
user_id=self.user_id,
|
||||
@@ -390,6 +391,7 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
nodes_input_masks=compiled_nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip or set(),
|
||||
execution_context=execution_context,
|
||||
)
|
||||
|
||||
@@ -1145,6 +1147,8 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None
|
||||
nodes_to_skip: set[str] = Field(default_factory=set)
|
||||
"""Node IDs that should be skipped due to optional credentials not being configured."""
|
||||
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)
|
||||
|
||||
|
||||
|
||||
@@ -94,6 +94,15 @@ class Node(BaseDbModel):
|
||||
input_links: list[Link] = []
|
||||
output_links: list[Link] = []
|
||||
|
||||
@property
|
||||
def credentials_optional(self) -> bool:
|
||||
"""
|
||||
Whether credentials are optional for this node.
|
||||
When True and credentials are not configured, the node will be skipped
|
||||
during execution rather than causing a validation error.
|
||||
"""
|
||||
return self.metadata.get("credentials_optional", False)
|
||||
|
||||
@property
|
||||
def block(self) -> AnyBlockSchema | "_UnknownBlockBase":
|
||||
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
|
||||
@@ -235,7 +244,10 @@ class BaseGraph(BaseDbModel):
|
||||
return any(
|
||||
node.block_id
|
||||
for node in self.nodes
|
||||
if node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
|
||||
if (
|
||||
node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
|
||||
or node.block.requires_human_review
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -326,7 +338,35 @@ class Graph(BaseGraph):
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
return self._credentials_input_schema.jsonschema()
|
||||
schema = self._credentials_input_schema.jsonschema()
|
||||
|
||||
# Determine which credential fields are required based on credentials_optional metadata
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
required_fields = []
|
||||
|
||||
# Build a map of node_id -> node for quick lookup
|
||||
all_nodes = {node.id: node for node in self.nodes}
|
||||
for sub_graph in self.sub_graphs:
|
||||
for node in sub_graph.nodes:
|
||||
all_nodes[node.id] = node
|
||||
|
||||
for field_key, (
|
||||
_field_info,
|
||||
node_field_pairs,
|
||||
) in graph_credentials_inputs.items():
|
||||
# A field is required if ANY node using it has credentials_optional=False
|
||||
is_required = False
|
||||
for node_id, _field_name in node_field_pairs:
|
||||
node = all_nodes.get(node_id)
|
||||
if node and not node.credentials_optional:
|
||||
is_required = True
|
||||
break
|
||||
|
||||
if is_required:
|
||||
required_fields.append(field_key)
|
||||
|
||||
schema["required"] = required_fields
|
||||
return schema
|
||||
|
||||
@property
|
||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||
|
||||
@@ -396,3 +396,58 @@ async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
)
|
||||
assert got_graph is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for Optional Credentials Feature
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_node_credentials_optional_default():
|
||||
"""Test that credentials_optional defaults to False when not set in metadata."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={},
|
||||
)
|
||||
assert node.credentials_optional is False
|
||||
|
||||
|
||||
def test_node_credentials_optional_true():
|
||||
"""Test that credentials_optional returns True when explicitly set."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={"credentials_optional": True},
|
||||
)
|
||||
assert node.credentials_optional is True
|
||||
|
||||
|
||||
def test_node_credentials_optional_false():
|
||||
"""Test that credentials_optional returns False when explicitly set to False."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={"credentials_optional": False},
|
||||
)
|
||||
assert node.credentials_optional is False
|
||||
|
||||
|
||||
def test_node_credentials_optional_with_other_metadata():
|
||||
"""Test that credentials_optional works correctly with other metadata present."""
|
||||
node = Node(
|
||||
id="test_node",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
metadata={
|
||||
"position": {"x": 100, "y": 200},
|
||||
"customized_name": "My Custom Node",
|
||||
"credentials_optional": True,
|
||||
},
|
||||
)
|
||||
assert node.credentials_optional is True
|
||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||
assert node.metadata["customized_name"] == "My Custom Node"
|
||||
|
||||
@@ -0,0 +1,513 @@
|
||||
"""
|
||||
Tests for dynamic fields edge cases and failure modes.
|
||||
|
||||
Covers failure modes:
|
||||
8. No Type Validation in Dynamic Field Merging
|
||||
17. No Validation of Dynamic Field Paths
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.dynamic_fields import (
|
||||
DICT_SPLIT,
|
||||
LIST_SPLIT,
|
||||
OBJC_SPLIT,
|
||||
extract_base_field_name,
|
||||
get_dynamic_field_description,
|
||||
is_dynamic_field,
|
||||
is_tool_pin,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
sanitize_pin_name,
|
||||
)
|
||||
|
||||
|
||||
class TestDynamicFieldMergingTypeValidation:
|
||||
"""
|
||||
Tests for Failure Mode #8: No Type Validation in Dynamic Field Merging
|
||||
|
||||
When merging dynamic fields, there's no validation that intermediate
|
||||
structures have the correct type, leading to potential type coercion errors.
|
||||
"""
|
||||
|
||||
def test_merge_dict_field_creates_dict(self):
|
||||
"""Test that dictionary fields create dict structure."""
|
||||
data = {
|
||||
"values_#_name": "Alice",
|
||||
"values_#_age": 30,
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert "values" in result
|
||||
assert isinstance(result["values"], dict)
|
||||
assert result["values"]["name"] == "Alice"
|
||||
assert result["values"]["age"] == 30
|
||||
|
||||
def test_merge_list_field_creates_list(self):
|
||||
"""Test that list fields create list structure."""
|
||||
data = {
|
||||
"items_$_0": "first",
|
||||
"items_$_1": "second",
|
||||
"items_$_2": "third",
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert "items" in result
|
||||
assert isinstance(result["items"], list)
|
||||
assert result["items"] == ["first", "second", "third"]
|
||||
|
||||
def test_merge_with_existing_primitive_type_conflict(self):
|
||||
"""
|
||||
Test behavior when merging into existing primitive value.
|
||||
|
||||
BUG: If the base field already exists as a primitive,
|
||||
merging a dynamic field may fail or corrupt data.
|
||||
"""
|
||||
# Pre-existing primitive value
|
||||
data = {
|
||||
"value": "I am a string", # Primitive
|
||||
"value_#_key": "dict value", # Dynamic dict field
|
||||
}
|
||||
|
||||
# This may raise an error or produce unexpected results
|
||||
# depending on merge order and implementation
|
||||
try:
|
||||
result = merge_execution_input(data)
|
||||
# If it succeeds, check what happened
|
||||
# The primitive may have been overwritten
|
||||
if isinstance(result.get("value"), dict):
|
||||
# Primitive was converted to dict - data loss!
|
||||
assert "key" in result["value"]
|
||||
else:
|
||||
# Or the dynamic field was ignored
|
||||
pass
|
||||
except (TypeError, AttributeError):
|
||||
# Expected error when trying to merge into primitive
|
||||
pass
|
||||
|
||||
def test_merge_list_with_gaps(self):
|
||||
"""Test merging list fields with non-contiguous indices."""
|
||||
data = {
|
||||
"items_$_0": "zero",
|
||||
"items_$_2": "two", # Gap at index 1
|
||||
"items_$_5": "five", # Larger gap
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert "items" in result
|
||||
# Check how gaps are handled
|
||||
items = result["items"]
|
||||
assert items[0] == "zero"
|
||||
# Index 1 may be None or missing
|
||||
assert items[2] == "two"
|
||||
assert items[5] == "five"
|
||||
|
||||
def test_merge_nested_dynamic_fields(self):
|
||||
"""Test merging deeply nested dynamic fields."""
|
||||
data = {
|
||||
"data_#_users_$_0": "user1",
|
||||
"data_#_users_$_1": "user2",
|
||||
"data_#_config_#_enabled": True,
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
# Complex nested structures should be created
|
||||
assert "data" in result
|
||||
|
||||
def test_merge_object_field(self):
|
||||
"""Test merging object attribute fields."""
|
||||
data = {
|
||||
"user_@_name": "Alice",
|
||||
"user_@_email": "alice@example.com",
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert "user" in result
|
||||
# Object fields create dict-like structure
|
||||
assert result["user"]["name"] == "Alice"
|
||||
assert result["user"]["email"] == "alice@example.com"
|
||||
|
||||
def test_merge_mixed_field_types(self):
|
||||
"""Test merging mixed regular and dynamic fields."""
|
||||
data = {
|
||||
"regular": "value",
|
||||
"dict_field_#_key": "dict_value",
|
||||
"list_field_$_0": "list_item",
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert result["regular"] == "value"
|
||||
assert result["dict_field"]["key"] == "dict_value"
|
||||
assert result["list_field"][0] == "list_item"
|
||||
|
||||
|
||||
class TestDynamicFieldPathValidation:
|
||||
"""
|
||||
Tests for Failure Mode #17: No Validation of Dynamic Field Paths
|
||||
|
||||
When traversing dynamic field paths, intermediate None values
|
||||
can cause TypeErrors instead of graceful failures.
|
||||
"""
|
||||
|
||||
def test_parse_output_with_none_intermediate(self):
|
||||
"""
|
||||
Test parse_execution_output with None intermediate value.
|
||||
|
||||
If data contains {"items": None} and we try to access items[0],
|
||||
it should return None gracefully, not raise TypeError.
|
||||
"""
|
||||
# Output with nested path
|
||||
output_item = ("data_$_0", "value")
|
||||
|
||||
# When the base is None, should return None
|
||||
# This tests the path traversal logic
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="data",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
# Should handle gracefully (return the value or None)
|
||||
# Not raise TypeError
|
||||
|
||||
def test_extract_base_field_name_with_multiple_delimiters(self):
|
||||
"""Test extracting base name with multiple delimiters."""
|
||||
# Multiple dict delimiters
|
||||
assert extract_base_field_name("a_#_b_#_c") == "a"
|
||||
|
||||
# Multiple list delimiters
|
||||
assert extract_base_field_name("a_$_0_$_1") == "a"
|
||||
|
||||
# Mixed delimiters
|
||||
assert extract_base_field_name("a_#_b_$_0") == "a"
|
||||
|
||||
def test_is_dynamic_field_edge_cases(self):
|
||||
"""Test is_dynamic_field with edge cases."""
|
||||
# Standard dynamic fields
|
||||
assert is_dynamic_field("values_#_key") is True
|
||||
assert is_dynamic_field("items_$_0") is True
|
||||
assert is_dynamic_field("obj_@_attr") is True
|
||||
|
||||
# Regular fields
|
||||
assert is_dynamic_field("regular") is False
|
||||
assert is_dynamic_field("with_underscore") is False
|
||||
|
||||
# Edge cases
|
||||
assert is_dynamic_field("") is False
|
||||
assert is_dynamic_field("_#_") is True # Just delimiter
|
||||
assert is_dynamic_field("a_#_") is True # Trailing delimiter
|
||||
|
||||
def test_sanitize_pin_name_with_tool_pins(self):
|
||||
"""Test sanitize_pin_name with various tool pin formats."""
|
||||
# Tool pins should return "tools"
|
||||
assert sanitize_pin_name("tools") == "tools"
|
||||
assert sanitize_pin_name("tools_^_node_~_field") == "tools"
|
||||
|
||||
# Dynamic fields should return base name
|
||||
assert sanitize_pin_name("values_#_key") == "values"
|
||||
assert sanitize_pin_name("items_$_0") == "items"
|
||||
|
||||
# Regular fields unchanged
|
||||
assert sanitize_pin_name("regular") == "regular"
|
||||
|
||||
|
||||
class TestDynamicFieldDescriptions:
|
||||
"""Tests for dynamic field description generation."""
|
||||
|
||||
def test_dict_field_description(self):
|
||||
"""Test description for dictionary fields."""
|
||||
desc = get_dynamic_field_description("values_#_user_name")
|
||||
|
||||
assert "Dictionary field" in desc
|
||||
assert "values['user_name']" in desc
|
||||
|
||||
def test_list_field_description(self):
|
||||
"""Test description for list fields."""
|
||||
desc = get_dynamic_field_description("items_$_0")
|
||||
|
||||
assert "List item 0" in desc
|
||||
assert "items[0]" in desc
|
||||
|
||||
def test_object_field_description(self):
|
||||
"""Test description for object fields."""
|
||||
desc = get_dynamic_field_description("user_@_email")
|
||||
|
||||
assert "Object attribute" in desc
|
||||
assert "user.email" in desc
|
||||
|
||||
def test_regular_field_description(self):
|
||||
"""Test description for regular (non-dynamic) fields."""
|
||||
desc = get_dynamic_field_description("regular_field")
|
||||
|
||||
assert desc == "Value for regular_field"
|
||||
|
||||
def test_description_with_numeric_key(self):
|
||||
"""Test description with numeric dictionary key."""
|
||||
desc = get_dynamic_field_description("values_#_123")
|
||||
|
||||
assert "Dictionary field" in desc
|
||||
assert "values['123']" in desc
|
||||
|
||||
|
||||
class TestParseExecutionOutputToolRouting:
|
||||
"""Tests for tool pin routing in parse_execution_output."""
|
||||
|
||||
def test_tool_pin_routing_exact_match(self):
|
||||
"""Test tool pin routing with exact match."""
|
||||
output_item = ("tools_^_node-123_~_field_name", "value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="field_name",
|
||||
)
|
||||
|
||||
assert result == "value"
|
||||
|
||||
def test_tool_pin_routing_node_mismatch(self):
|
||||
"""Test tool pin routing with node ID mismatch."""
|
||||
output_item = ("tools_^_node-123_~_field_name", "value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="different-node",
|
||||
sink_pin_name="field_name",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_tool_pin_routing_field_mismatch(self):
|
||||
"""Test tool pin routing with field name mismatch."""
|
||||
output_item = ("tools_^_node-123_~_field_name", "value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="different_field",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_tool_pin_missing_required_params(self):
|
||||
"""Test that tool pins require node_id and pin_name."""
|
||||
output_item = ("tools_^_node-123_~_field", "value")
|
||||
|
||||
with pytest.raises(ValueError, match="must be provided"):
|
||||
parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id=None,
|
||||
sink_pin_name="field",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="must be provided"):
|
||||
parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
|
||||
class TestParseExecutionOutputDynamicFields:
|
||||
"""Tests for dynamic field routing in parse_execution_output."""
|
||||
|
||||
def test_dict_field_extraction(self):
|
||||
"""Test extraction of dictionary field value."""
|
||||
# The output_item is (field_name, data_structure)
|
||||
data = {"key1": "value1", "key2": "value2"}
|
||||
output_item = ("values", data)
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="values_#_key1",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
assert result == "value1"
|
||||
|
||||
def test_list_field_extraction(self):
|
||||
"""Test extraction of list item value."""
|
||||
data = ["zero", "one", "two"]
|
||||
output_item = ("items", data)
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="items_$_1",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
assert result == "one"
|
||||
|
||||
def test_nested_field_extraction(self):
|
||||
"""Test extraction of nested field value."""
|
||||
data = {
|
||||
"users": [
|
||||
{"name": "Alice", "email": "alice@example.com"},
|
||||
{"name": "Bob", "email": "bob@example.com"},
|
||||
]
|
||||
}
|
||||
output_item = ("data", data)
|
||||
|
||||
# Access nested path
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="data_#_users",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
assert result == data["users"]
|
||||
|
||||
def test_missing_key_returns_none(self):
|
||||
"""Test that missing keys return None."""
|
||||
data = {"existing": "value"}
|
||||
output_item = ("values", data)
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="values_#_nonexistent",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_index_out_of_bounds_returns_none(self):
|
||||
"""Test that out-of-bounds indices return None."""
|
||||
data = ["zero", "one"]
|
||||
output_item = ("items", data)
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="items_$_99",
|
||||
sink_node_id=None,
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestIsToolPin:
|
||||
"""Tests for is_tool_pin function."""
|
||||
|
||||
def test_tools_prefix(self):
|
||||
"""Test that 'tools_^_' prefix is recognized."""
|
||||
assert is_tool_pin("tools_^_node_~_field") is True
|
||||
assert is_tool_pin("tools_^_anything") is True
|
||||
|
||||
def test_tools_exact(self):
|
||||
"""Test that exact 'tools' is recognized."""
|
||||
assert is_tool_pin("tools") is True
|
||||
|
||||
def test_non_tool_pins(self):
|
||||
"""Test that non-tool pins are not recognized."""
|
||||
assert is_tool_pin("input") is False
|
||||
assert is_tool_pin("output") is False
|
||||
assert is_tool_pin("toolsomething") is False
|
||||
assert is_tool_pin("my_tools") is False
|
||||
assert is_tool_pin("") is False
|
||||
|
||||
|
||||
class TestMergeExecutionInputEdgeCases:
|
||||
"""Edge case tests for merge_execution_input."""
|
||||
|
||||
def test_empty_input(self):
|
||||
"""Test merging empty input."""
|
||||
result = merge_execution_input({})
|
||||
assert result == {}
|
||||
|
||||
def test_only_regular_fields(self):
|
||||
"""Test merging only regular fields (no dynamic)."""
|
||||
data = {"a": 1, "b": 2, "c": 3}
|
||||
result = merge_execution_input(data)
|
||||
assert result == data
|
||||
|
||||
def test_overwrite_behavior(self):
|
||||
"""Test behavior when same key is set multiple times."""
|
||||
# This shouldn't happen in practice, but test the behavior
|
||||
data = {
|
||||
"values_#_key": "first",
|
||||
}
|
||||
result = merge_execution_input(data)
|
||||
assert result["values"]["key"] == "first"
|
||||
|
||||
def test_numeric_string_keys(self):
|
||||
"""Test handling of numeric string keys in dict fields."""
|
||||
data = {
|
||||
"values_#_123": "numeric_key",
|
||||
"values_#_456": "another_numeric",
|
||||
}
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert result["values"]["123"] == "numeric_key"
|
||||
assert result["values"]["456"] == "another_numeric"
|
||||
|
||||
def test_special_characters_in_keys(self):
|
||||
"""Test handling of special characters in keys."""
|
||||
data = {
|
||||
"values_#_key-with-dashes": "value1",
|
||||
"values_#_key.with.dots": "value2",
|
||||
}
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert result["values"]["key-with-dashes"] == "value1"
|
||||
assert result["values"]["key.with.dots"] == "value2"
|
||||
|
||||
def test_deeply_nested_list(self):
|
||||
"""Test deeply nested list indices."""
|
||||
data = {
|
||||
"matrix_$_0_$_0": "0,0",
|
||||
"matrix_$_0_$_1": "0,1",
|
||||
"matrix_$_1_$_0": "1,0",
|
||||
"matrix_$_1_$_1": "1,1",
|
||||
}
|
||||
|
||||
# Note: Current implementation may not support this depth
|
||||
# Test documents expected behavior
|
||||
try:
|
||||
result = merge_execution_input(data)
|
||||
# If supported, verify structure
|
||||
except (KeyError, TypeError, IndexError):
|
||||
# Deep nesting may not be supported
|
||||
pass
|
||||
|
||||
def test_none_values(self):
|
||||
"""Test handling of None values in input."""
|
||||
data = {
|
||||
"regular": None,
|
||||
"dict_#_key": None,
|
||||
"list_$_0": None,
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert result["regular"] is None
|
||||
assert result["dict"]["key"] is None
|
||||
assert result["list"][0] is None
|
||||
|
||||
def test_complex_values(self):
|
||||
"""Test handling of complex values (dicts, lists)."""
|
||||
data = {
|
||||
"values_#_nested_dict": {"inner": "value"},
|
||||
"values_#_nested_list": [1, 2, 3],
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert result["values"]["nested_dict"] == {"inner": "value"}
|
||||
assert result["values"]["nested_list"] == [1, 2, 3]
|
||||
@@ -0,0 +1,463 @@
|
||||
"""
|
||||
Tests for dynamic field routing with sanitized names.
|
||||
|
||||
This test file specifically tests the parse_execution_output function
|
||||
which is responsible for routing tool outputs to the correct nodes.
|
||||
The critical bug this addresses is the mismatch between:
|
||||
- emit keys using sanitized names (e.g., "max_keyword_difficulty")
|
||||
- sink_pin_name using original names (e.g., "Max Keyword Difficulty")
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.dynamic_fields import (
|
||||
DICT_SPLIT,
|
||||
LIST_SPLIT,
|
||||
OBJC_SPLIT,
|
||||
extract_base_field_name,
|
||||
get_dynamic_field_description,
|
||||
is_dynamic_field,
|
||||
is_tool_pin,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
sanitize_pin_name,
|
||||
)
|
||||
|
||||
|
||||
def cleanup(s: str) -> str:
|
||||
"""
|
||||
Simulate SmartDecisionMakerBlock.cleanup() for testing.
|
||||
Clean up names for use as tool function names.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
|
||||
|
||||
|
||||
class TestParseExecutionOutputToolRouting:
|
||||
"""Tests for tool pin routing in parse_execution_output."""
|
||||
|
||||
def test_exact_match_routes_correctly(self):
|
||||
"""When emit key field exactly matches sink_pin_name, routing works."""
|
||||
output_item = ("tools_^_node-123_~_query", "test value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="query",
|
||||
)
|
||||
assert result == "test value"
|
||||
|
||||
def test_sanitized_emit_vs_original_sink_fails(self):
|
||||
"""
|
||||
CRITICAL BUG TEST: When emit key uses sanitized name but sink uses original,
|
||||
routing fails.
|
||||
"""
|
||||
# Backend emits with sanitized name
|
||||
sanitized_field = cleanup("Max Keyword Difficulty")
|
||||
output_item = (f"tools_^_node-123_~_{sanitized_field}", 50)
|
||||
|
||||
# Frontend link has original name
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="Max Keyword Difficulty", # Original name
|
||||
)
|
||||
|
||||
# BUG: This returns None because sanitized != original
|
||||
# Once fixed, change this to: assert result == 50
|
||||
assert result is None, "Expected None due to sanitization mismatch bug"
|
||||
|
||||
def test_node_id_mismatch_returns_none(self):
|
||||
"""When node IDs don't match, routing should return None."""
|
||||
output_item = ("tools_^_node-123_~_query", "test value")
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="different-node", # Different node
|
||||
sink_pin_name="query",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_both_node_and_pin_must_match(self):
|
||||
"""Both node_id and pin_name must match for routing to succeed."""
|
||||
output_item = ("tools_^_node-123_~_query", "test value")
|
||||
|
||||
# Wrong node, right pin
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="wrong-node",
|
||||
sink_pin_name="query",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
# Right node, wrong pin
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="wrong_pin",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
# Right node, right pin
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name="query",
|
||||
)
|
||||
assert result == "test value"
|
||||
|
||||
|
||||
class TestToolPinRoutingWithSpecialCharacters:
|
||||
"""Tests for tool pin routing with various special characters in names."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"original_name,sanitized_name",
|
||||
[
|
||||
("Max Keyword Difficulty", "max_keyword_difficulty"),
|
||||
("Search Volume (Monthly)", "search_volume__monthly_"),
|
||||
("CPC ($)", "cpc____"),
|
||||
("User's Input", "user_s_input"),
|
||||
("Query #1", "query__1"),
|
||||
("API.Response", "api_response"),
|
||||
("Field@Name", "field_name"),
|
||||
("Test\tTab", "test_tab"),
|
||||
("Test\nNewline", "test_newline"),
|
||||
],
|
||||
)
|
||||
def test_routing_mismatch_with_special_chars(self, original_name, sanitized_name):
|
||||
"""
|
||||
Test that various special characters cause routing mismatches.
|
||||
|
||||
This test documents the current buggy behavior where sanitized emit keys
|
||||
don't match original sink_pin_names.
|
||||
"""
|
||||
# Verify sanitization
|
||||
assert cleanup(original_name) == sanitized_name
|
||||
|
||||
# Backend emits with sanitized name
|
||||
output_item = (f"tools_^_node-123_~_{sanitized_name}", "value")
|
||||
|
||||
# Frontend link has original name
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name=original_name,
|
||||
)
|
||||
|
||||
# BUG: Returns None due to mismatch
|
||||
assert result is None, f"Routing should fail for '{original_name}' vs '{sanitized_name}'"
|
||||
|
||||
|
||||
class TestToolPinMissingParameters:
|
||||
"""Tests for missing required parameters in parse_execution_output."""
|
||||
|
||||
def test_missing_sink_node_id_raises_error(self):
|
||||
"""Missing sink_node_id should raise ValueError for tool pins."""
|
||||
output_item = ("tools_^_node-123_~_query", "test value")
|
||||
|
||||
with pytest.raises(ValueError, match="sink_node_id and sink_pin_name must be provided"):
|
||||
parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id=None,
|
||||
sink_pin_name="query",
|
||||
)
|
||||
|
||||
def test_missing_sink_pin_name_raises_error(self):
|
||||
"""Missing sink_pin_name should raise ValueError for tool pins."""
|
||||
output_item = ("tools_^_node-123_~_query", "test value")
|
||||
|
||||
with pytest.raises(ValueError, match="sink_node_id and sink_pin_name must be provided"):
|
||||
parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id="node-123",
|
||||
sink_pin_name=None,
|
||||
)
|
||||
|
||||
|
||||
class TestIsToolPin:
|
||||
"""Tests for is_tool_pin function."""
|
||||
|
||||
def test_tools_prefix_is_tool_pin(self):
|
||||
"""Names starting with 'tools_^_' are tool pins."""
|
||||
assert is_tool_pin("tools_^_node_~_field") is True
|
||||
assert is_tool_pin("tools_^_anything") is True
|
||||
|
||||
def test_tools_exact_is_tool_pin(self):
|
||||
"""Exact 'tools' is a tool pin."""
|
||||
assert is_tool_pin("tools") is True
|
||||
|
||||
def test_non_tool_pins(self):
|
||||
"""Non-tool pin names should return False."""
|
||||
assert is_tool_pin("input") is False
|
||||
assert is_tool_pin("output") is False
|
||||
assert is_tool_pin("my_tools") is False
|
||||
assert is_tool_pin("toolsomething") is False
|
||||
|
||||
|
||||
class TestSanitizePinName:
|
||||
"""Tests for sanitize_pin_name function."""
|
||||
|
||||
def test_extracts_base_from_dynamic_field(self):
|
||||
"""Should extract base field name from dynamic fields."""
|
||||
assert sanitize_pin_name("values_#_key") == "values"
|
||||
assert sanitize_pin_name("items_$_0") == "items"
|
||||
assert sanitize_pin_name("obj_@_attr") == "obj"
|
||||
|
||||
def test_returns_tools_for_tool_pins(self):
|
||||
"""Tool pins should be sanitized to 'tools'."""
|
||||
assert sanitize_pin_name("tools_^_node_~_field") == "tools"
|
||||
assert sanitize_pin_name("tools") == "tools"
|
||||
|
||||
def test_regular_field_unchanged(self):
|
||||
"""Regular field names should be unchanged."""
|
||||
assert sanitize_pin_name("query") == "query"
|
||||
assert sanitize_pin_name("max_difficulty") == "max_difficulty"
|
||||
|
||||
|
||||
class TestDynamicFieldDescriptions:
|
||||
"""Tests for dynamic field description generation."""
|
||||
|
||||
def test_dict_field_description_with_spaces_in_key(self):
|
||||
"""Dictionary field keys with spaces should generate correct descriptions."""
|
||||
# After cleanup, "User Name" becomes "user_name" in the field name
|
||||
# But the original key might have had spaces
|
||||
desc = get_dynamic_field_description("values_#_user_name")
|
||||
assert "Dictionary field" in desc
|
||||
assert "values['user_name']" in desc
|
||||
|
||||
def test_list_field_description(self):
|
||||
"""List field descriptions should include index."""
|
||||
desc = get_dynamic_field_description("items_$_0")
|
||||
assert "List item 0" in desc
|
||||
assert "items[0]" in desc
|
||||
|
||||
def test_object_field_description(self):
|
||||
"""Object field descriptions should include attribute."""
|
||||
desc = get_dynamic_field_description("user_@_email")
|
||||
assert "Object attribute" in desc
|
||||
assert "user.email" in desc
|
||||
|
||||
|
||||
class TestMergeExecutionInput:
|
||||
"""Tests for merge_execution_input function."""
|
||||
|
||||
def test_merges_dict_fields(self):
|
||||
"""Dictionary fields should be merged into nested structure."""
|
||||
data = {
|
||||
"values_#_name": "Alice",
|
||||
"values_#_age": 30,
|
||||
"other_field": "unchanged",
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert "values" in result
|
||||
assert result["values"]["name"] == "Alice"
|
||||
assert result["values"]["age"] == 30
|
||||
assert result["other_field"] == "unchanged"
|
||||
|
||||
def test_merges_list_fields(self):
|
||||
"""List fields should be merged into arrays."""
|
||||
data = {
|
||||
"items_$_0": "first",
|
||||
"items_$_1": "second",
|
||||
"items_$_2": "third",
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert "items" in result
|
||||
assert result["items"] == ["first", "second", "third"]
|
||||
|
||||
def test_merges_mixed_fields(self):
|
||||
"""Mixed regular and dynamic fields should all be preserved."""
|
||||
data = {
|
||||
"regular": "value",
|
||||
"dict_#_key": "dict_value",
|
||||
"list_$_0": "list_item",
|
||||
}
|
||||
|
||||
result = merge_execution_input(data)
|
||||
|
||||
assert result["regular"] == "value"
|
||||
assert result["dict"]["key"] == "dict_value"
|
||||
assert result["list"] == ["list_item"]
|
||||
|
||||
|
||||
class TestExtractBaseFieldName:
|
||||
"""Tests for extract_base_field_name function."""
|
||||
|
||||
def test_extracts_from_dict_delimiter(self):
|
||||
"""Should extract base name before _#_ delimiter."""
|
||||
assert extract_base_field_name("values_#_name") == "values"
|
||||
assert extract_base_field_name("user_#_email_#_domain") == "user"
|
||||
|
||||
def test_extracts_from_list_delimiter(self):
|
||||
"""Should extract base name before _$_ delimiter."""
|
||||
assert extract_base_field_name("items_$_0") == "items"
|
||||
assert extract_base_field_name("data_$_1_$_nested") == "data"
|
||||
|
||||
def test_extracts_from_object_delimiter(self):
|
||||
"""Should extract base name before _@_ delimiter."""
|
||||
assert extract_base_field_name("obj_@_attr") == "obj"
|
||||
|
||||
def test_no_delimiter_returns_original(self):
|
||||
"""Names without delimiters should be returned unchanged."""
|
||||
assert extract_base_field_name("regular_field") == "regular_field"
|
||||
assert extract_base_field_name("query") == "query"
|
||||
|
||||
|
||||
class TestIsDynamicField:
|
||||
"""Tests for is_dynamic_field function."""
|
||||
|
||||
def test_dict_delimiter_is_dynamic(self):
|
||||
"""Fields with _#_ are dynamic."""
|
||||
assert is_dynamic_field("values_#_key") is True
|
||||
|
||||
def test_list_delimiter_is_dynamic(self):
|
||||
"""Fields with _$_ are dynamic."""
|
||||
assert is_dynamic_field("items_$_0") is True
|
||||
|
||||
def test_object_delimiter_is_dynamic(self):
|
||||
"""Fields with _@_ are dynamic."""
|
||||
assert is_dynamic_field("obj_@_attr") is True
|
||||
|
||||
def test_regular_fields_not_dynamic(self):
|
||||
"""Regular field names without delimiters are not dynamic."""
|
||||
assert is_dynamic_field("regular_field") is False
|
||||
assert is_dynamic_field("query") is False
|
||||
assert is_dynamic_field("Max Keyword Difficulty") is False
|
||||
|
||||
|
||||
class TestRoutingEndToEnd:
|
||||
"""End-to-end tests for the full routing flow."""
|
||||
|
||||
def test_successful_routing_without_spaces(self):
|
||||
"""Full routing flow works when no spaces in names."""
|
||||
field_name = "query"
|
||||
node_id = "test-node-123"
|
||||
|
||||
# Emit key (as created by SmartDecisionMaker)
|
||||
emit_key = f"tools_^_{node_id}_~_{cleanup(field_name)}"
|
||||
output_item = (emit_key, "search term")
|
||||
|
||||
# Route (as called by executor)
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id=node_id,
|
||||
sink_pin_name=field_name,
|
||||
)
|
||||
|
||||
assert result == "search term"
|
||||
|
||||
def test_failed_routing_with_spaces(self):
|
||||
"""
|
||||
Full routing flow FAILS when names have spaces.
|
||||
|
||||
This test documents the exact bug scenario:
|
||||
1. Frontend creates link with sink_name="Max Keyword Difficulty"
|
||||
2. SmartDecisionMaker emits with sanitized name in key
|
||||
3. Executor calls parse_execution_output with original sink_pin_name
|
||||
4. Routing fails because names don't match
|
||||
"""
|
||||
original_field_name = "Max Keyword Difficulty"
|
||||
sanitized_field_name = cleanup(original_field_name)
|
||||
node_id = "test-node-123"
|
||||
|
||||
# Step 1 & 2: SmartDecisionMaker emits with sanitized name
|
||||
emit_key = f"tools_^_{node_id}_~_{sanitized_field_name}"
|
||||
output_item = (emit_key, 50)
|
||||
|
||||
# Step 3: Executor routes with original name from link
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id=node_id,
|
||||
sink_pin_name=original_field_name, # Original from link!
|
||||
)
|
||||
|
||||
# Step 4: BUG - Returns None instead of 50
|
||||
assert result is None
|
||||
|
||||
# This is what should happen after fix:
|
||||
# assert result == 50
|
||||
|
||||
def test_multiple_fields_with_spaces(self):
|
||||
"""Test routing multiple fields where some have spaces."""
|
||||
node_id = "test-node"
|
||||
|
||||
fields = {
|
||||
"query": "test", # No spaces - should work
|
||||
"Max Difficulty": 100, # Spaces - will fail
|
||||
"min_volume": 1000, # No spaces - should work
|
||||
}
|
||||
|
||||
results = {}
|
||||
for original_name, value in fields.items():
|
||||
sanitized = cleanup(original_name)
|
||||
emit_key = f"tools_^_{node_id}_~_{sanitized}"
|
||||
output_item = (emit_key, value)
|
||||
|
||||
result = parse_execution_output(
|
||||
output_item,
|
||||
link_output_selector="tools",
|
||||
sink_node_id=node_id,
|
||||
sink_pin_name=original_name,
|
||||
)
|
||||
results[original_name] = result
|
||||
|
||||
# Fields without spaces work
|
||||
assert results["query"] == "test"
|
||||
assert results["min_volume"] == 1000
|
||||
|
||||
# Fields with spaces fail
|
||||
assert results["Max Difficulty"] is None # BUG!
|
||||
|
||||
|
||||
class TestProposedFix:
|
||||
"""
|
||||
Tests for the proposed fix.
|
||||
|
||||
The fix should sanitize sink_pin_name before comparison in parse_execution_output.
|
||||
This class contains tests that will pass once the fix is implemented.
|
||||
"""
|
||||
|
||||
def test_routing_should_sanitize_both_sides(self):
|
||||
"""
|
||||
PROPOSED FIX: parse_execution_output should sanitize sink_pin_name
|
||||
before comparing with the field from emit key.
|
||||
|
||||
Current behavior: Direct string comparison
|
||||
Fixed behavior: Compare cleanup(target_input_pin) == cleanup(sink_pin_name)
|
||||
"""
|
||||
original_field = "Max Keyword Difficulty"
|
||||
sanitized_field = cleanup(original_field)
|
||||
node_id = "node-123"
|
||||
|
||||
emit_key = f"tools_^_{node_id}_~_{sanitized_field}"
|
||||
output_item = (emit_key, 50)
|
||||
|
||||
# Extract the comparison being made
|
||||
selector = emit_key[8:] # Remove "tools_^_"
|
||||
target_node_id, target_input_pin = selector.split("_~_", 1)
|
||||
|
||||
# Current comparison (FAILS):
|
||||
current_comparison = (target_input_pin == original_field)
|
||||
assert current_comparison is False, "Current comparison fails"
|
||||
|
||||
# Proposed fixed comparison (PASSES):
|
||||
# Either sanitize sink_pin_name, or sanitize both
|
||||
fixed_comparison = (target_input_pin == cleanup(original_field))
|
||||
assert fixed_comparison is True, "Fixed comparison should pass"
|
||||
@@ -114,6 +114,40 @@ utilization_gauge = Gauge(
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
# Redis key prefix for tracking insufficient funds Discord notifications.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
|
||||
# TTL for the notification flag (30 days) - acts as a fallback cleanup
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
|
||||
|
||||
|
||||
async def clear_insufficient_funds_notifications(user_id: str) -> int:
|
||||
"""
|
||||
Clear all insufficient funds notification flags for a user.
|
||||
|
||||
This should be called when a user tops up their credits, allowing
|
||||
Discord notifications to be sent again if they run out of funds.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to clear notifications for.
|
||||
|
||||
Returns:
|
||||
The number of keys that were deleted.
|
||||
"""
|
||||
try:
|
||||
redis_client = await redis.get_redis_async()
|
||||
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
keys = [key async for key in redis_client.scan_iter(match=pattern)]
|
||||
if keys:
|
||||
return await redis_client.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clear insufficient funds notification flags for user "
|
||||
f"{user_id}: {e}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
# Thread-local storage for ExecutionProcessor instances
|
||||
_tls = threading.local()
|
||||
@@ -144,6 +178,7 @@ async def execute_node(
|
||||
execution_processor: "ExecutionProcessor",
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
@@ -211,6 +246,7 @@ async def execute_node(
|
||||
"user_id": user_id,
|
||||
"execution_context": execution_context,
|
||||
"execution_processor": execution_processor,
|
||||
"nodes_to_skip": nodes_to_skip or set(),
|
||||
}
|
||||
|
||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||
@@ -508,6 +544,7 @@ class ExecutionProcessor:
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
nodes_input_masks: Optional[NodesInputMasks],
|
||||
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
@@ -530,6 +567,7 @@ class ExecutionProcessor:
|
||||
db_client=db_client,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
)
|
||||
if isinstance(status, BaseException):
|
||||
raise status
|
||||
@@ -575,6 +613,7 @@ class ExecutionProcessor:
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
nodes_to_skip: Optional[set[str]] = None,
|
||||
) -> ExecutionStatus:
|
||||
status = ExecutionStatus.RUNNING
|
||||
|
||||
@@ -611,6 +650,7 @@ class ExecutionProcessor:
|
||||
execution_processor=self,
|
||||
execution_stats=stats,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
):
|
||||
await persist_output(output_name, output_data)
|
||||
|
||||
@@ -922,6 +962,21 @@ class ExecutionProcessor:
|
||||
|
||||
queued_node_exec = execution_queue.get()
|
||||
|
||||
# Check if this node should be skipped due to optional credentials
|
||||
if queued_node_exec.node_id in graph_exec.nodes_to_skip:
|
||||
log_metadata.info(
|
||||
f"Skipping node execution {queued_node_exec.node_exec_id} "
|
||||
f"for node {queued_node_exec.node_id} - optional credentials not configured"
|
||||
)
|
||||
# Mark the node as completed without executing
|
||||
# No outputs will be produced, so downstream nodes won't trigger
|
||||
update_node_execution_status(
|
||||
db_client=db_client,
|
||||
exec_id=queued_node_exec.node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
)
|
||||
continue
|
||||
|
||||
log_metadata.debug(
|
||||
f"Dispatching node execution {queued_node_exec.node_exec_id} "
|
||||
f"for node {queued_node_exec.node_id}",
|
||||
@@ -982,6 +1037,7 @@ class ExecutionProcessor:
|
||||
execution_stats,
|
||||
execution_stats_lock,
|
||||
),
|
||||
nodes_to_skip=graph_exec.nodes_to_skip,
|
||||
),
|
||||
self.node_execution_loop,
|
||||
)
|
||||
@@ -1261,12 +1317,40 @@ class ExecutionProcessor:
|
||||
graph_id: str,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
# Check if we've already sent a notification for this user+agent combo.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
try:
|
||||
redis_client = redis.get_redis()
|
||||
# SET NX returns True only if the key was newly set (didn't exist)
|
||||
is_new_notification = redis_client.set(
|
||||
redis_key,
|
||||
"1",
|
||||
nx=True,
|
||||
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
|
||||
)
|
||||
if not is_new_notification:
|
||||
# Already notified for this user+agent, skip all notifications
|
||||
logger.debug(
|
||||
f"Skipping duplicate insufficient funds notification for "
|
||||
f"user={user_id}, graph={graph_id}"
|
||||
)
|
||||
return
|
||||
except Exception as redis_error:
|
||||
# If Redis fails, log and continue to send the notification
|
||||
# (better to occasionally duplicate than to never notify)
|
||||
logger.warning(
|
||||
f"Failed to check/set insufficient funds notification flag in Redis: "
|
||||
f"{redis_error}"
|
||||
)
|
||||
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
|
||||
# Queue user email notification
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
@@ -1280,6 +1364,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
)
|
||||
|
||||
# Send Discord system alert
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
|
||||
|
||||
@@ -0,0 +1,560 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data.notifications import ZeroBalanceData
|
||||
from backend.executor.manager import (
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX,
|
||||
ExecutionProcessor,
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
async def async_iter(items):
|
||||
"""Helper to create an async iterator from a list."""
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_sends_discord_alert_first_time(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that the first insufficient funds notification sends a Discord alert."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72, # $0.72
|
||||
amount=-714, # Attempting to spend $7.14
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to simulate first-time notification (set returns True)
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.return_value = True # Key was newly set
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify notification was queued
|
||||
mock_queue_notif.assert_called_once()
|
||||
notification_call = mock_queue_notif.call_args[0][0]
|
||||
assert notification_call.type == NotificationType.ZERO_BALANCE
|
||||
assert notification_call.user_id == user_id
|
||||
assert isinstance(notification_call.data, ZeroBalanceData)
|
||||
assert notification_call.data.current_balance == 72
|
||||
|
||||
# Verify Redis was checked with correct key pattern
|
||||
expected_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
mock_redis_client.set.assert_called_once()
|
||||
call_args = mock_redis_client.set.call_args
|
||||
assert call_args[0][0] == expected_key
|
||||
assert call_args[1]["nx"] is True
|
||||
|
||||
# Verify Discord alert was sent
|
||||
mock_client.discord_system_alert.assert_called_once()
|
||||
discord_message = mock_client.discord_system_alert.call_args[0][0]
|
||||
assert "Insufficient Funds Alert" in discord_message
|
||||
assert "test@example.com" in discord_message
|
||||
assert "Test Agent" in discord_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_skips_duplicate_notifications(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that duplicate insufficient funds notifications skip both email and Discord."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to simulate duplicate notification (set returns False/None)
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.return_value = None # Key already existed
|
||||
|
||||
# Create mock database client
|
||||
mock_db_client = MagicMock()
|
||||
mock_db_client.get_graph_metadata.return_value = MagicMock(name="Test Agent")
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify email notification was NOT queued (deduplication worked)
|
||||
mock_queue_notif.assert_not_called()
|
||||
|
||||
# Verify Discord alert was NOT sent (deduplication worked)
|
||||
mock_client.discord_system_alert.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that different agents for the same user get separate Discord alerts."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id_1 = "test-graph-111"
|
||||
graph_id_2 = "test-graph-222"
|
||||
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch("backend.executor.manager.queue_notification"), patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
# Both calls return True (first time for each agent)
|
||||
mock_redis_client.set.return_value = True
|
||||
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# First agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_1,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Second agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_2,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify Discord alerts were sent for both agents
|
||||
assert mock_client.discord_system_alert.call_count == 2
|
||||
|
||||
# Verify Redis was called with different keys
|
||||
assert mock_redis_client.set.call_count == 2
|
||||
calls = mock_redis_client.set.call_args_list
|
||||
assert (
|
||||
calls[0][0][0]
|
||||
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_1}"
|
||||
)
|
||||
assert (
|
||||
calls[1][0][0]
|
||||
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_2}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications(server: SpinTestServer):
|
||||
"""Test that clearing notifications removes all keys for a user."""
|
||||
|
||||
user_id = "test-user-123"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
# Mock scan_iter to return some keys as an async iterator
|
||||
mock_keys = [
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1",
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-2",
|
||||
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-3",
|
||||
]
|
||||
mock_redis_client.scan_iter.return_value = async_iter(mock_keys)
|
||||
# delete is awaited, so use AsyncMock
|
||||
mock_redis_client.delete = AsyncMock(return_value=3)
|
||||
|
||||
# Clear notifications
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify correct pattern was used
|
||||
expected_pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
mock_redis_client.scan_iter.assert_called_once_with(match=expected_pattern)
|
||||
|
||||
# Verify delete was called with all keys
|
||||
mock_redis_client.delete.assert_called_once_with(*mock_keys)
|
||||
|
||||
# Verify return value
|
||||
assert result == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications_no_keys(server: SpinTestServer):
|
||||
"""Test clearing notifications when there are no keys to clear."""
|
||||
|
||||
user_id = "test-user-no-notifications"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
# Mock scan_iter to return no keys as an async iterator
|
||||
mock_redis_client.scan_iter.return_value = async_iter([])
|
||||
|
||||
# Clear notifications
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify delete was not called
|
||||
mock_redis_client.delete.assert_not_called()
|
||||
|
||||
# Verify return value
|
||||
assert result == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_clear_insufficient_funds_notifications_handles_redis_error(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that clearing notifications handles Redis errors gracefully."""
|
||||
|
||||
user_id = "test-user-redis-error"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
# Mock get_redis_async to raise an error
|
||||
mock_redis_module.get_redis_async = AsyncMock(
|
||||
side_effect=Exception("Redis connection failed")
|
||||
)
|
||||
|
||||
# Clear notifications should not raise, just return 0
|
||||
result = await clear_insufficient_funds_notifications(user_id)
|
||||
|
||||
# Verify it returned 0 (graceful failure)
|
||||
assert result == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_insufficient_funds_continues_on_redis_error(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that both email and Discord notifications are still sent when Redis fails."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
message="Insufficient balance",
|
||||
user_id=user_id,
|
||||
balance=72,
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_settings.config.frontend_base_url = "https://test.com"
|
||||
|
||||
# Mock Redis to raise an error
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis.return_value = mock_redis_client
|
||||
mock_redis_client.set.side_effect = Exception("Redis connection error")
|
||||
|
||||
mock_db_client = MagicMock()
|
||||
mock_graph_metadata = MagicMock()
|
||||
mock_graph_metadata.name = "Test Agent"
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
e=error,
|
||||
)
|
||||
|
||||
# Verify email notification was still queued despite Redis error
|
||||
mock_queue_notif.assert_called_once()
|
||||
|
||||
# Verify Discord alert was still sent despite Redis error
|
||||
mock_client.discord_system_alert.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_clears_notifications_on_grant(server: SpinTestServer):
|
||||
"""Test that _add_transaction clears notification flags when adding GRANT credits."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-grant-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 1000, "transactionKey": "test-tx-key"}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter(
|
||||
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
|
||||
)
|
||||
mock_redis_client.delete = AsyncMock(return_value=1)
|
||||
|
||||
# Create a concrete instance
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with GRANT type (should clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500, # Positive amount
|
||||
transaction_type=CreditTransactionType.GRANT,
|
||||
is_active=True, # Active transaction
|
||||
)
|
||||
|
||||
# Verify notification clearing was called
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
mock_redis_client.scan_iter.assert_called_once_with(
|
||||
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_clears_notifications_on_top_up(server: SpinTestServer):
|
||||
"""Test that _add_transaction clears notification flags when adding TOP_UP credits."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-topup-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 2000, "transactionKey": "test-tx-key-2"}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter([])
|
||||
mock_redis_client.delete = AsyncMock(return_value=0)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with TOP_UP type (should clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=1000, # Positive amount
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify notification clearing was attempted
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_skips_clearing_for_inactive_transaction(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that _add_transaction does NOT clear notifications for inactive transactions."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-inactive"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 500, "transactionKey": "test-tx-key-3"}]
|
||||
|
||||
# Mock async Redis
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with is_active=False (should NOT clear notifications)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=500,
|
||||
transaction_type=CreditTransactionType.TOP_UP,
|
||||
is_active=False, # Inactive - pending Stripe payment
|
||||
)
|
||||
|
||||
# Verify notification clearing was NOT called
|
||||
mock_redis_module.get_redis_async.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_transaction_skips_clearing_for_usage_transaction(
|
||||
server: SpinTestServer,
|
||||
):
|
||||
"""Test that _add_transaction does NOT clear notifications for USAGE transactions."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-usage"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
mock_query.return_value = [{"balance": 400, "transactionKey": "test-tx-key-4"}]
|
||||
|
||||
# Mock async Redis
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _add_transaction with USAGE type (spending, should NOT clear)
|
||||
await credit_model._add_transaction(
|
||||
user_id=user_id,
|
||||
amount=-100, # Negative - spending credits
|
||||
transaction_type=CreditTransactionType.USAGE,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Verify notification clearing was NOT called
|
||||
mock_redis_module.get_redis_async.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_enable_transaction_clears_notifications(server: SpinTestServer):
|
||||
"""Test that _enable_transaction clears notification flags when enabling a TOP_UP."""
|
||||
from prisma.enums import CreditTransactionType
|
||||
|
||||
from backend.data.credit import UserCredit
|
||||
|
||||
user_id = "test-user-enable"
|
||||
|
||||
with patch("backend.data.credit.CreditTransaction") as mock_credit_tx, patch(
|
||||
"backend.data.credit.query_raw_with_schema"
|
||||
) as mock_query, patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
|
||||
# Mock finding the pending transaction
|
||||
mock_transaction = MagicMock()
|
||||
mock_transaction.amount = 1000
|
||||
mock_transaction.type = CreditTransactionType.TOP_UP
|
||||
mock_credit_tx.prisma.return_value.find_first = AsyncMock(
|
||||
return_value=mock_transaction
|
||||
)
|
||||
|
||||
# Mock the query to return updated balance
|
||||
mock_query.return_value = [{"balance": 1500}]
|
||||
|
||||
# Mock async Redis for notification clearing
|
||||
mock_redis_client = MagicMock()
|
||||
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
|
||||
mock_redis_client.scan_iter.return_value = async_iter(
|
||||
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
|
||||
)
|
||||
mock_redis_client.delete = AsyncMock(return_value=1)
|
||||
|
||||
credit_model = UserCredit()
|
||||
|
||||
# Call _enable_transaction (simulates Stripe checkout completion)
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
await credit_model._enable_transaction(
|
||||
transaction_key="cs_test_123",
|
||||
user_id=user_id,
|
||||
metadata=SafeJson({"payment": "completed"}),
|
||||
)
|
||||
|
||||
# Verify notification clearing was called
|
||||
mock_redis_module.get_redis_async.assert_called_once()
|
||||
mock_redis_client.scan_iter.assert_called_once_with(
|
||||
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
)
|
||||
@@ -239,14 +239,19 @@ async def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
) -> tuple[dict[str, dict[str, str]], set[str]]:
|
||||
"""
|
||||
Checks all credentials for all nodes of the graph and returns structured errors.
|
||||
Checks all credentials for all nodes of the graph and returns structured errors
|
||||
and a set of nodes that should be skipped due to optional missing credentials.
|
||||
|
||||
Returns:
|
||||
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node
|
||||
tuple[
|
||||
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node,
|
||||
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
||||
]
|
||||
"""
|
||||
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
|
||||
nodes_to_skip: set[str] = set()
|
||||
|
||||
for node in graph.nodes:
|
||||
block = node.block
|
||||
@@ -256,27 +261,46 @@ async def _validate_node_input_credentials(
|
||||
if not credentials_fields:
|
||||
continue
|
||||
|
||||
# Track if any credential field is missing for this node
|
||||
has_missing_credentials = False
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
try:
|
||||
# Check nodes_input_masks first, then input_default
|
||||
field_value = None
|
||||
if (
|
||||
nodes_input_masks
|
||||
and (node_input_mask := nodes_input_masks.get(node.id))
|
||||
and field_name in node_input_mask
|
||||
):
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node_input_mask[field_name]
|
||||
)
|
||||
field_value = node_input_mask[field_name]
|
||||
elif field_name in node.input_default:
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
)
|
||||
else:
|
||||
# Missing credentials
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = "These credentials are required"
|
||||
continue
|
||||
# For optional credentials, don't use input_default - treat as missing
|
||||
# This prevents stale credential IDs from failing validation
|
||||
if node.credentials_optional:
|
||||
field_value = None
|
||||
else:
|
||||
field_value = node.input_default[field_name]
|
||||
|
||||
# Check if credentials are missing (None, empty, or not present)
|
||||
if field_value is None or (
|
||||
isinstance(field_value, dict) and not field_value.get("id")
|
||||
):
|
||||
has_missing_credentials = True
|
||||
# If node has credentials_optional flag, mark for skipping instead of error
|
||||
if node.credentials_optional:
|
||||
continue # Don't add error, will be marked for skip after loop
|
||||
else:
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = "These credentials are required"
|
||||
continue
|
||||
|
||||
credentials_meta = credentials_meta_type.model_validate(field_value)
|
||||
|
||||
except ValidationError as e:
|
||||
# Validation error means credentials were provided but invalid
|
||||
# This should always be an error, even if optional
|
||||
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
|
||||
continue
|
||||
|
||||
@@ -287,6 +311,7 @@ async def _validate_node_input_credentials(
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any errors fetching credentials
|
||||
# If credentials were explicitly configured but unavailable, it's an error
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = f"Credentials not available: {e}"
|
||||
@@ -313,7 +338,19 @@ async def _validate_node_input_credentials(
|
||||
] = "Invalid credentials: type/provider mismatch"
|
||||
continue
|
||||
|
||||
return credential_errors
|
||||
# If node has optional credentials and any are missing, mark for skipping
|
||||
# But only if there are no other errors for this node
|
||||
if (
|
||||
has_missing_credentials
|
||||
and node.credentials_optional
|
||||
and node.id not in credential_errors
|
||||
):
|
||||
nodes_to_skip.add(node.id)
|
||||
logger.info(
|
||||
f"Node #{node.id} will be skipped: optional credentials not configured"
|
||||
)
|
||||
|
||||
return credential_errors, nodes_to_skip
|
||||
|
||||
|
||||
def make_node_credentials_input_map(
|
||||
@@ -355,21 +392,25 @@ async def validate_graph_with_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> Mapping[str, Mapping[str, str]]:
|
||||
) -> tuple[Mapping[str, Mapping[str, str]], set[str]]:
|
||||
"""
|
||||
Validate graph including credentials and return structured errors per node.
|
||||
Validate graph including credentials and return structured errors per node,
|
||||
along with a set of nodes that should be skipped due to optional missing credentials.
|
||||
|
||||
Returns:
|
||||
dict[node_id, dict[field_name, error_message]]: Validation errors per node
|
||||
tuple[
|
||||
dict[node_id, dict[field_name, error_message]]: Validation errors per node,
|
||||
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
||||
]
|
||||
"""
|
||||
# Get input validation errors
|
||||
node_input_errors = GraphModel.validate_graph_get_errors(
|
||||
graph, for_run=True, nodes_input_masks=nodes_input_masks
|
||||
)
|
||||
|
||||
# Get credential input/availability/validation errors
|
||||
node_credential_input_errors = await _validate_node_input_credentials(
|
||||
graph, user_id, nodes_input_masks
|
||||
# Get credential input/availability/validation errors and nodes to skip
|
||||
node_credential_input_errors, nodes_to_skip = (
|
||||
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
|
||||
)
|
||||
|
||||
# Merge credential errors with structural errors
|
||||
@@ -378,7 +419,7 @@ async def validate_graph_with_credentials(
|
||||
node_input_errors[node_id] = {}
|
||||
node_input_errors[node_id].update(field_errors)
|
||||
|
||||
return node_input_errors
|
||||
return node_input_errors, nodes_to_skip
|
||||
|
||||
|
||||
async def _construct_starting_node_execution_input(
|
||||
@@ -386,7 +427,7 @@ async def _construct_starting_node_execution_input(
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
This function checks the graph for starting nodes, validates the input data
|
||||
@@ -400,11 +441,14 @@ async def _construct_starting_node_execution_input(
|
||||
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
|
||||
|
||||
Returns:
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
tuple[
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID
|
||||
and the corresponding input data for that node.
|
||||
set[str]: Node IDs that should be skipped (optional credentials not configured)
|
||||
]
|
||||
"""
|
||||
# Use new validation function that includes credentials
|
||||
validation_errors = await validate_graph_with_credentials(
|
||||
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
|
||||
graph, user_id, nodes_input_masks
|
||||
)
|
||||
n_error_nodes = len(validation_errors)
|
||||
@@ -445,7 +489,7 @@ async def _construct_starting_node_execution_input(
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
return nodes_input
|
||||
return nodes_input, nodes_to_skip
|
||||
|
||||
|
||||
async def validate_and_construct_node_execution_input(
|
||||
@@ -456,7 +500,7 @@ async def validate_and_construct_node_execution_input(
|
||||
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
is_sub_graph: bool = False,
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
|
||||
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
|
||||
"""
|
||||
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
||||
This centralizes the logic used by both scheduler validation and actual execution.
|
||||
@@ -473,6 +517,7 @@ async def validate_and_construct_node_execution_input(
|
||||
GraphModel: Full graph object for the given `graph_id`.
|
||||
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
|
||||
dict[str, BlockInput]: Node input masks including all passed-in credentials.
|
||||
set[str]: Node IDs that should be skipped (optional credentials not configured).
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the graph is not found.
|
||||
@@ -514,14 +559,16 @@ async def validate_and_construct_node_execution_input(
|
||||
nodes_input_masks or {},
|
||||
)
|
||||
|
||||
starting_nodes_input = await _construct_starting_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=graph_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
starting_nodes_input, nodes_to_skip = (
|
||||
await _construct_starting_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=graph_inputs,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
)
|
||||
|
||||
return graph, starting_nodes_input, nodes_input_masks
|
||||
return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip
|
||||
|
||||
|
||||
def _merge_nodes_input_masks(
|
||||
@@ -779,6 +826,9 @@ async def add_graph_execution(
|
||||
|
||||
# Use existing execution's compiled input masks
|
||||
compiled_nodes_input_masks = graph_exec.nodes_input_masks or {}
|
||||
# For resumed executions, nodes_to_skip was already determined at creation time
|
||||
# TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes
|
||||
nodes_to_skip: set[str] = set()
|
||||
|
||||
logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}")
|
||||
else:
|
||||
@@ -787,7 +837,7 @@ async def add_graph_execution(
|
||||
)
|
||||
|
||||
# Create new execution
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks = (
|
||||
graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = (
|
||||
await validate_and_construct_node_execution_input(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
@@ -836,6 +886,7 @@ async def add_graph_execution(
|
||||
try:
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry(
|
||||
compiled_nodes_input_masks=compiled_nodes_input_masks,
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
||||
|
||||
@@ -367,10 +367,13 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
)
|
||||
|
||||
# Setup mock returns
|
||||
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
|
||||
nodes_to_skip: set[str] = set()
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
nodes_to_skip,
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
@@ -456,3 +459,212 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
# Both executions should succeed (though they create different objects)
|
||||
assert result1 == mock_graph_exec
|
||||
assert result2 == mock_graph_exec_2
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for Optional Credentials Feature
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_returns_nodes_to_skip(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that _validate_node_input_credentials returns nodes_to_skip set
|
||||
for nodes with credentials_optional=True and missing credentials.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
# Create a mock node with credentials_optional=True
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-optional-creds"
|
||||
mock_node.credentials_optional = True
|
||||
mock_node.input_default = {} # No credentials configured
|
||||
|
||||
# Create a mock block with credentials field
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_credentials_field_type = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
# Create mock graph
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Node should be in nodes_to_skip, not in errors
|
||||
assert mock_node.id in nodes_to_skip
|
||||
assert mock_node.id not in errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_required_missing_creds_error(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that _validate_node_input_credentials returns errors
|
||||
for nodes with credentials_optional=False and missing credentials.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
# Create a mock node with credentials_optional=False (required)
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-required-creds"
|
||||
mock_node.credentials_optional = False
|
||||
mock_node.input_default = {} # No credentials configured
|
||||
|
||||
# Create a mock block with credentials field
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_credentials_field_type = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
# Create mock graph
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Node should be in errors, not in nodes_to_skip
|
||||
assert mock_node.id in errors
|
||||
assert "credentials" in errors[mock_node.id]
|
||||
assert "required" in errors[mock_node.id]["credentials"].lower()
|
||||
assert mock_node.id not in nodes_to_skip
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_with_credentials_returns_nodes_to_skip(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
Test that validate_graph_with_credentials returns nodes_to_skip set
|
||||
from _validate_node_input_credentials.
|
||||
"""
|
||||
from backend.executor.utils import validate_graph_with_credentials
|
||||
|
||||
# Mock _validate_node_input_credentials to return specific values
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils._validate_node_input_credentials"
|
||||
)
|
||||
expected_errors = {"node1": {"field": "error"}}
|
||||
expected_nodes_to_skip = {"node2", "node3"}
|
||||
mock_validate.return_value = (expected_errors, expected_nodes_to_skip)
|
||||
|
||||
# Mock GraphModel with validate_graph_get_errors method
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.validate_graph_get_errors.return_value = {}
|
||||
|
||||
# Call the function
|
||||
errors, nodes_to_skip = await validate_graph_with_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user-id",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Verify nodes_to_skip is passed through
|
||||
assert nodes_to_skip == expected_nodes_to_skip
|
||||
assert "node1" in errors
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
"""
|
||||
Test that add_graph_execution properly passes nodes_to_skip
|
||||
to the graph execution entry.
|
||||
"""
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.executor.utils import add_graph_execution
|
||||
|
||||
# Mock data
|
||||
graph_id = "test-graph-id"
|
||||
user_id = "test-user-id"
|
||||
inputs = {"test_input": "test_value"}
|
||||
graph_version = 1
|
||||
|
||||
# Mock the graph object
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.version = graph_version
|
||||
|
||||
# Starting nodes and masks
|
||||
starting_nodes_input = [("node1", {"input1": "value1"})]
|
||||
compiled_nodes_input_masks = {}
|
||||
nodes_to_skip = {"skipped-node-1", "skipped-node-2"}
|
||||
|
||||
# Mock the graph execution object
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = []
|
||||
|
||||
# Track what's passed to to_graph_execution_entry
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_to_entry(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return mocker.MagicMock()
|
||||
|
||||
mock_graph_exec.to_graph_execution_entry.side_effect = capture_to_entry
|
||||
|
||||
# Setup mocks
|
||||
mock_validate = mocker.patch(
|
||||
"backend.executor.utils.validate_and_construct_node_execution_input"
|
||||
)
|
||||
mock_edb = mocker.patch("backend.executor.utils.execution_db")
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_udb = mocker.patch("backend.executor.utils.user_db")
|
||||
mock_gdb = mocker.patch("backend.executor.utils.graph_db")
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
|
||||
# Setup returns - include nodes_to_skip in the tuple
|
||||
mock_validate.return_value = (
|
||||
mock_graph,
|
||||
starting_nodes_input,
|
||||
compiled_nodes_input_masks,
|
||||
nodes_to_skip, # This should be passed through
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
|
||||
|
||||
mock_user = mocker.MagicMock()
|
||||
mock_user.timezone = "UTC"
|
||||
mock_settings = mocker.MagicMock()
|
||||
mock_settings.human_in_the_loop_safe_mode = True
|
||||
|
||||
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
|
||||
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
|
||||
mock_get_queue.return_value = mocker.AsyncMock()
|
||||
mock_get_event_bus.return_value = mocker.MagicMock(publish=mocker.AsyncMock())
|
||||
|
||||
# Call the function
|
||||
await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
# Verify nodes_to_skip was passed to to_graph_execution_entry
|
||||
assert "nodes_to_skip" in captured_kwargs
|
||||
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
||||
|
||||
@@ -8,6 +8,7 @@ from .discord import DiscordOAuthHandler
|
||||
from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
from .reddit import RedditOAuthHandler
|
||||
from .twitter import TwitterOAuthHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -20,6 +21,7 @@ _ORIGINAL_HANDLERS = [
|
||||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
RedditOAuthHandler,
|
||||
TwitterOAuthHandler,
|
||||
TodoistOAuthHandler,
|
||||
]
|
||||
|
||||
208
autogpt_platform/backend/backend/integrations/oauth/reddit.py
Normal file
208
autogpt_platform/backend/backend/integrations/oauth/reddit.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import Requests
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class RedditOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Reddit OAuth 2.0 handler.
|
||||
|
||||
Based on the documentation at:
|
||||
- https://github.com/reddit-archive/reddit/wiki/OAuth2
|
||||
|
||||
Notes:
|
||||
- Reddit requires `duration=permanent` to get refresh tokens
|
||||
- Access tokens expire after 1 hour (3600 seconds)
|
||||
- Reddit requires HTTP Basic Auth for token requests
|
||||
- Reddit requires a unique User-Agent header
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = ProviderName.REDDIT
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = [
|
||||
"identity", # Get username, verify auth
|
||||
"read", # Access posts and comments
|
||||
"submit", # Submit new posts and comments
|
||||
"edit", # Edit own posts and comments
|
||||
"history", # Access user's post history
|
||||
"privatemessages", # Access inbox and send private messages
|
||||
"flair", # Access and set flair on posts/subreddits
|
||||
]
|
||||
|
||||
AUTHORIZE_URL = "https://www.reddit.com/api/v1/authorize"
|
||||
TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
|
||||
USERNAME_URL = "https://oauth.reddit.com/api/v1/me"
|
||||
REVOKE_URL = "https://www.reddit.com/api/v1/revoke_token"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
"""Generate Reddit OAuth 2.0 authorization URL"""
|
||||
scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": " ".join(scopes),
|
||||
"state": state,
|
||||
"duration": "permanent", # Required for refresh tokens
|
||||
}
|
||||
|
||||
return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
"""Exchange authorization code for access tokens"""
|
||||
scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
|
||||
# Reddit requires HTTP Basic Auth for token requests
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = await Requests().post(
|
||||
self.TOKEN_URL, headers=headers, data=data, auth=auth
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
error_text = response.text()
|
||||
raise ValueError(
|
||||
f"Reddit token exchange failed: {response.status} - {error_text}"
|
||||
)
|
||||
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
|
||||
|
||||
username = await self._get_username(tokens["access_token"])
|
||||
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=None,
|
||||
username=username,
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=tokens.get("refresh_token"),
|
||||
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
|
||||
refresh_token_expires_at=None, # Reddit refresh tokens don't expire
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
async def _get_username(self, access_token: str) -> str:
|
||||
"""Get the username from the access token"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
response = await Requests().get(self.USERNAME_URL, headers=headers)
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"Failed to get Reddit username: {response.status}")
|
||||
|
||||
data = response.json()
|
||||
return data.get("name", "unknown")
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
"""Refresh access tokens using refresh token"""
|
||||
if not credentials.refresh_token:
|
||||
raise ValueError("No refresh token available")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
}
|
||||
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = await Requests().post(
|
||||
self.TOKEN_URL, headers=headers, data=data, auth=auth
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
error_text = response.text()
|
||||
raise ValueError(
|
||||
f"Reddit token refresh failed: {response.status} - {error_text}"
|
||||
)
|
||||
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
|
||||
|
||||
username = await self._get_username(tokens["access_token"])
|
||||
|
||||
# Reddit may or may not return a new refresh token
|
||||
new_refresh_token = tokens.get("refresh_token")
|
||||
if new_refresh_token:
|
||||
refresh_token: SecretStr | None = SecretStr(new_refresh_token)
|
||||
elif credentials.refresh_token:
|
||||
# Keep the existing refresh token
|
||||
refresh_token = credentials.refresh_token
|
||||
else:
|
||||
refresh_token = None
|
||||
|
||||
return OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=credentials.title,
|
||||
username=username,
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=refresh_token,
|
||||
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
"""Revoke the access token"""
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": settings.config.reddit_user_agent,
|
||||
}
|
||||
|
||||
data = {
|
||||
"token": credentials.access_token.get_secret_value(),
|
||||
"token_type_hint": "access_token",
|
||||
}
|
||||
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = await Requests().post(
|
||||
self.REVOKE_URL, headers=headers, data=data, auth=auth
|
||||
)
|
||||
|
||||
# Reddit returns 204 No Content on successful revocation
|
||||
return response.ok
|
||||
@@ -264,7 +264,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
)
|
||||
|
||||
reddit_user_agent: str = Field(
|
||||
default="AutoGPT:1.0 (by /u/autogpt)",
|
||||
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
|
||||
description="The user agent for the Reddit API",
|
||||
)
|
||||
|
||||
|
||||
227
autogpt_platform/backend/gen_prisma_types_stub.py
Normal file
227
autogpt_platform/backend/gen_prisma_types_stub.py
Normal file
@@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate a lightweight stub for prisma/types.py that collapses all exported
|
||||
symbols to Any. This prevents Pyright from spending time/budget on Prisma's
|
||||
query DSL types while keeping runtime behavior unchanged.
|
||||
|
||||
Usage:
|
||||
poetry run gen-prisma-stub
|
||||
|
||||
This script automatically finds the prisma package location and generates
|
||||
the types.pyi stub file in the same directory as types.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Set
|
||||
|
||||
|
||||
def _iter_assigned_names(target: ast.expr) -> Iterable[str]:
|
||||
"""Extract names from assignment targets (handles tuple unpacking)."""
|
||||
if isinstance(target, ast.Name):
|
||||
yield target.id
|
||||
elif isinstance(target, (ast.Tuple, ast.List)):
|
||||
for elt in target.elts:
|
||||
yield from _iter_assigned_names(elt)
|
||||
|
||||
|
||||
def _is_private(name: str) -> bool:
|
||||
"""Check if a name is private (starts with _ but not __)."""
|
||||
return name.startswith("_") and not name.startswith("__")
|
||||
|
||||
|
||||
def _is_safe_type_alias(node: ast.Assign) -> bool:
|
||||
"""Check if an assignment is a safe type alias that shouldn't be stubbed.
|
||||
|
||||
Safe types are:
|
||||
- Literal types (don't cause type budget issues)
|
||||
- Simple type references (SortMode, SortOrder, etc.)
|
||||
- TypeVar definitions
|
||||
"""
|
||||
if not node.value:
|
||||
return False
|
||||
|
||||
# Check if it's a Subscript (like Literal[...], Union[...], TypeVar[...])
|
||||
if isinstance(node.value, ast.Subscript):
|
||||
# Get the base type name
|
||||
if isinstance(node.value.value, ast.Name):
|
||||
base_name = node.value.value.id
|
||||
# Literal types are safe
|
||||
if base_name == "Literal":
|
||||
return True
|
||||
# TypeVar is safe
|
||||
if base_name == "TypeVar":
|
||||
return True
|
||||
elif isinstance(node.value.value, ast.Attribute):
|
||||
# Handle typing_extensions.Literal etc.
|
||||
if node.value.value.attr == "Literal":
|
||||
return True
|
||||
|
||||
# Check if it's a simple Name reference (like SortMode = _types.SortMode)
|
||||
if isinstance(node.value, ast.Attribute):
|
||||
return True
|
||||
|
||||
# Check if it's a Call (like TypeVar(...))
|
||||
if isinstance(node.value, ast.Call):
|
||||
if isinstance(node.value.func, ast.Name):
|
||||
if node.value.func.id == "TypeVar":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def collect_top_level_symbols(
|
||||
tree: ast.Module, source_lines: list[str]
|
||||
) -> tuple[Set[str], Set[str], list[str], Set[str]]:
|
||||
"""Collect all top-level symbols from an AST module.
|
||||
|
||||
Returns:
|
||||
Tuple of (class_names, function_names, safe_variable_sources, unsafe_variable_names)
|
||||
safe_variable_sources contains the actual source code lines for safe variables
|
||||
"""
|
||||
classes: Set[str] = set()
|
||||
functions: Set[str] = set()
|
||||
safe_variable_sources: list[str] = []
|
||||
unsafe_variables: Set[str] = set()
|
||||
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.ClassDef):
|
||||
if not _is_private(node.name):
|
||||
classes.add(node.name)
|
||||
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
if not _is_private(node.name):
|
||||
functions.add(node.name)
|
||||
elif isinstance(node, ast.Assign):
|
||||
is_safe = _is_safe_type_alias(node)
|
||||
names = []
|
||||
for t in node.targets:
|
||||
for n in _iter_assigned_names(t):
|
||||
if not _is_private(n):
|
||||
names.append(n)
|
||||
if names:
|
||||
if is_safe:
|
||||
# Extract the source code for this assignment
|
||||
start_line = node.lineno - 1 # 0-indexed
|
||||
end_line = node.end_lineno if node.end_lineno else node.lineno
|
||||
source = "\n".join(source_lines[start_line:end_line])
|
||||
safe_variable_sources.append(source)
|
||||
else:
|
||||
unsafe_variables.update(names)
|
||||
elif isinstance(node, ast.AnnAssign) and node.target:
|
||||
# Annotated assignments are always stubbed
|
||||
for n in _iter_assigned_names(node.target):
|
||||
if not _is_private(n):
|
||||
unsafe_variables.add(n)
|
||||
|
||||
return classes, functions, safe_variable_sources, unsafe_variables
|
||||
|
||||
|
||||
def find_prisma_types_path() -> Path:
|
||||
"""Find the prisma types.py file in the installed package."""
|
||||
spec = importlib.util.find_spec("prisma")
|
||||
if spec is None or spec.origin is None:
|
||||
raise RuntimeError("Could not find prisma package. Is it installed?")
|
||||
|
||||
prisma_dir = Path(spec.origin).parent
|
||||
types_path = prisma_dir / "types.py"
|
||||
|
||||
if not types_path.exists():
|
||||
raise RuntimeError(f"prisma/types.py not found at {types_path}")
|
||||
|
||||
return types_path
|
||||
|
||||
|
||||
def generate_stub(src_path: Path, stub_path: Path) -> int:
|
||||
"""Generate the .pyi stub file from the source types.py."""
|
||||
code = src_path.read_text(encoding="utf-8", errors="ignore")
|
||||
source_lines = code.splitlines()
|
||||
tree = ast.parse(code, filename=str(src_path))
|
||||
classes, functions, safe_variable_sources, unsafe_variables = (
|
||||
collect_top_level_symbols(tree, source_lines)
|
||||
)
|
||||
|
||||
header = """\
|
||||
# -*- coding: utf-8 -*-
|
||||
# Auto-generated stub file - DO NOT EDIT
|
||||
# Generated by gen_prisma_types_stub.py
|
||||
#
|
||||
# This stub intentionally collapses complex Prisma query DSL types to Any.
|
||||
# Prisma's generated types can explode Pyright's type inference budgets
|
||||
# on large schemas. We collapse them to Any so the rest of the codebase
|
||||
# can remain strongly typed while keeping runtime behavior unchanged.
|
||||
#
|
||||
# Safe types (Literal, TypeVar, simple references) are preserved from the
|
||||
# original types.py to maintain proper type checking where possible.
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
from typing_extensions import Literal
|
||||
|
||||
# Re-export commonly used typing constructs that may be imported from this module
|
||||
from typing import TYPE_CHECKING, TypeVar, Generic, Union, Optional, List, Dict
|
||||
|
||||
# Base type alias for stubbed Prisma types - allows any dict structure
|
||||
_PrismaDict = dict[str, Any]
|
||||
|
||||
"""
|
||||
|
||||
lines = [header]
|
||||
|
||||
# Include safe variable definitions (Literal types, TypeVars, etc.)
|
||||
lines.append("# Safe type definitions preserved from original types.py")
|
||||
for source in safe_variable_sources:
|
||||
lines.append(source)
|
||||
lines.append("")
|
||||
|
||||
# Stub all classes and unsafe variables uniformly as dict[str, Any] aliases
|
||||
# This allows:
|
||||
# 1. Use in type annotations: x: SomeType
|
||||
# 2. Constructor calls: SomeType(...)
|
||||
# 3. Dict literal assignments: x: SomeType = {...}
|
||||
lines.append(
|
||||
"# Stubbed types (collapsed to dict[str, Any] to prevent type budget exhaustion)"
|
||||
)
|
||||
all_stubbed = sorted(classes | unsafe_variables)
|
||||
for name in all_stubbed:
|
||||
lines.append(f"{name} = _PrismaDict")
|
||||
|
||||
lines.append("")
|
||||
|
||||
# Stub functions
|
||||
for name in sorted(functions):
|
||||
lines.append(f"def {name}(*args: Any, **kwargs: Any) -> Any: ...")
|
||||
|
||||
lines.append("")
|
||||
|
||||
stub_path.write_text("\n".join(lines), encoding="utf-8")
|
||||
return (
|
||||
len(classes)
|
||||
+ len(functions)
|
||||
+ len(safe_variable_sources)
|
||||
+ len(unsafe_variables)
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point."""
|
||||
try:
|
||||
types_path = find_prisma_types_path()
|
||||
stub_path = types_path.with_suffix(".pyi")
|
||||
|
||||
print(f"Found prisma types.py at: {types_path}")
|
||||
print(f"Generating stub at: {stub_path}")
|
||||
|
||||
num_symbols = generate_stub(types_path, stub_path)
|
||||
print(f"Generated {stub_path.name} with {num_symbols} Any-typed symbols")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -25,6 +25,9 @@ def run(*command: str) -> None:
|
||||
|
||||
|
||||
def lint():
|
||||
# Generate Prisma types stub before running pyright to prevent type budget exhaustion
|
||||
run("gen-prisma-stub")
|
||||
|
||||
lint_step_args: list[list[str]] = [
|
||||
["ruff", "check", *TARGET_DIRS, "--exit-zero"],
|
||||
["ruff", "format", "--diff", "--check", LIBS_DIR],
|
||||
@@ -49,4 +52,6 @@ def format():
|
||||
run("ruff", "format", LIBS_DIR)
|
||||
run("isort", "--profile", "black", BACKEND_DIR)
|
||||
run("black", BACKEND_DIR)
|
||||
# Generate Prisma types stub before running pyright to prevent type budget exhaustion
|
||||
run("gen-prisma-stub")
|
||||
run("pyright", *TARGET_DIRS)
|
||||
|
||||
24
autogpt_platform/backend/poetry.lock
generated
24
autogpt_platform/backend/poetry.lock
generated
@@ -1906,16 +1906,32 @@ httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
|
||||
pydantic = ">=1.10,<3"
|
||||
pyjwt = ">=2.10.1,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "gravitas-md2gdocs"
|
||||
version = "0.1.0"
|
||||
description = "Convert Markdown to Google Docs API requests"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "gravitas_md2gdocs-0.1.0-py3-none-any.whl", hash = "sha256:0cb0627779fdd65c1604818af4142eea1b25d055060183363de1bae4d9e46508"},
|
||||
{file = "gravitas_md2gdocs-0.1.0.tar.gz", hash = "sha256:bb3122fe9fa35c528f3f00b785d3f1398d350082d5d03f60f56c895bdcc68033"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["google-auth-oauthlib (>=1.0.0)", "pytest (>=7.0.0)", "pytest-cov (>=4.0.0)", "python-dotenv (>=1.0.0)", "ruff (>=0.1.0)"]
|
||||
google = ["google-api-python-client (>=2.0.0)", "google-auth (>=2.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "gravitasml"
|
||||
version = "0.1.3"
|
||||
version = "0.1.4"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "gravitasml-0.1.3-py3-none-any.whl", hash = "sha256:51ff98b4564b7a61f7796f18d5f2558b919d30b3722579296089645b7bc18b85"},
|
||||
{file = "gravitasml-0.1.3.tar.gz", hash = "sha256:04d240b9fa35878252d57a36032130b6516487468847fcdced1022c032a20f57"},
|
||||
{file = "gravitasml-0.1.4-py3-none-any.whl", hash = "sha256:671a18b11d3d8a0e270c6a80c72cd058458b18d5ef7560d00010e962ab1bca74"},
|
||||
{file = "gravitasml-0.1.4.tar.gz", hash = "sha256:35d0d9fec7431817482d53d9c976e375557c3e041d1eb6928e809324a8c866e3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -7279,4 +7295,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "13b191b2a1989d3321ff713c66ff6f5f4f3b82d15df4d407e0e5dbf87d7522c4"
|
||||
content-hash = "a93ba0cea3b465cb6ec3e3f258b383b09f84ea352ccfdbfa112902cde5653fc6"
|
||||
|
||||
@@ -27,7 +27,7 @@ google-api-python-client = "^2.177.0"
|
||||
google-auth-oauthlib = "^1.2.2"
|
||||
google-cloud-storage = "^3.2.0"
|
||||
googlemaps = "^4.10.0"
|
||||
gravitasml = "^0.1.3"
|
||||
gravitasml = "^0.1.4"
|
||||
groq = "^0.30.0"
|
||||
html2text = "^2024.2.26"
|
||||
jinja2 = "^3.1.6"
|
||||
@@ -82,6 +82,7 @@ firecrawl-py = "^4.3.6"
|
||||
exa-py = "^1.14.20"
|
||||
croniter = "^6.0.0"
|
||||
stagehand = "^0.5.1"
|
||||
gravitas-md2gdocs = "^0.1.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
@@ -116,6 +117,7 @@ lint = "linter:lint"
|
||||
test = "run_tests:test"
|
||||
load-store-agents = "test.load_store_agents:run"
|
||||
export-api-schema = "backend.cli.generate_openapi_json:main"
|
||||
gen-prisma-stub = "gen_prisma_types_stub:main"
|
||||
oauth-tool = "backend.cli.oauth_tool:cli"
|
||||
|
||||
[tool.isort]
|
||||
@@ -133,6 +135,9 @@ ignore_patterns = []
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
# Disable syrupy plugin to avoid conflict with pytest-snapshot
|
||||
# Both provide --snapshot-update argument causing ArgumentError
|
||||
addopts = "-p no:syrupy"
|
||||
filterwarnings = [
|
||||
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",
|
||||
"ignore:invalid escape sequence:DeprecationWarning:tweepy.api",
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
"created_at": "2025-09-04T13:37:00",
|
||||
"credentials_input_schema": {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"title": "TestGraphCredentialsInputSchema",
|
||||
"type": "object"
|
||||
},
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
{
|
||||
"credentials_input_schema": {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"title": "TestGraphCredentialsInputSchema",
|
||||
"type": "object"
|
||||
},
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"id": "test-agent-1",
|
||||
"graph_id": "test-agent-1",
|
||||
"graph_version": 1,
|
||||
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"image_url": null,
|
||||
"creator_name": "Test Creator",
|
||||
"creator_image_url": "",
|
||||
@@ -41,6 +42,7 @@
|
||||
"id": "test-agent-2",
|
||||
"graph_id": "test-agent-2",
|
||||
"graph_version": 1,
|
||||
"owner_user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"image_url": null,
|
||||
"creator_name": "Test Creator",
|
||||
"creator_image_url": "",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"submissions": [
|
||||
{
|
||||
"listing_id": "test-listing-id",
|
||||
"agent_id": "test-agent-id",
|
||||
"agent_version": 1,
|
||||
"name": "Test Agent",
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
from backend.blocks.google.docs import GoogleDocsFormatTextBlock
|
||||
|
||||
|
||||
def _make_mock_docs_service() -> Mock:
|
||||
service = Mock()
|
||||
# Ensure chained call exists: service.documents().batchUpdate(...).execute()
|
||||
service.documents.return_value.batchUpdate.return_value.execute.return_value = {}
|
||||
return service
|
||||
|
||||
|
||||
def test_format_text_parses_shorthand_hex_color():
|
||||
block = GoogleDocsFormatTextBlock()
|
||||
service = _make_mock_docs_service()
|
||||
|
||||
result = block._format_text(
|
||||
service,
|
||||
document_id="doc_1",
|
||||
start_index=1,
|
||||
end_index=2,
|
||||
bold=False,
|
||||
italic=False,
|
||||
underline=False,
|
||||
font_size=0,
|
||||
foreground_color="#FFF",
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
# Verify request body contains correct rgbColor for white.
|
||||
_, kwargs = service.documents.return_value.batchUpdate.call_args
|
||||
requests = kwargs["body"]["requests"]
|
||||
rgb = requests[0]["updateTextStyle"]["textStyle"]["foregroundColor"]["color"][
|
||||
"rgbColor"
|
||||
]
|
||||
assert rgb == {"red": 1.0, "green": 1.0, "blue": 1.0}
|
||||
|
||||
|
||||
def test_format_text_parses_full_hex_color():
|
||||
block = GoogleDocsFormatTextBlock()
|
||||
service = _make_mock_docs_service()
|
||||
|
||||
result = block._format_text(
|
||||
service,
|
||||
document_id="doc_1",
|
||||
start_index=1,
|
||||
end_index=2,
|
||||
bold=False,
|
||||
italic=False,
|
||||
underline=False,
|
||||
font_size=0,
|
||||
foreground_color="#FF0000",
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
_, kwargs = service.documents.return_value.batchUpdate.call_args
|
||||
requests = kwargs["body"]["requests"]
|
||||
rgb = requests[0]["updateTextStyle"]["textStyle"]["foregroundColor"]["color"][
|
||||
"rgbColor"
|
||||
]
|
||||
assert rgb == {"red": 1.0, "green": 0.0, "blue": 0.0}
|
||||
|
||||
|
||||
def test_format_text_ignores_invalid_color_when_other_fields_present():
|
||||
block = GoogleDocsFormatTextBlock()
|
||||
service = _make_mock_docs_service()
|
||||
|
||||
result = block._format_text(
|
||||
service,
|
||||
document_id="doc_1",
|
||||
start_index=1,
|
||||
end_index=2,
|
||||
bold=True,
|
||||
italic=False,
|
||||
underline=False,
|
||||
font_size=0,
|
||||
foreground_color="#GGG",
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert "warning" in result
|
||||
|
||||
# Should still apply bold, but should NOT include foregroundColor in textStyle.
|
||||
_, kwargs = service.documents.return_value.batchUpdate.call_args
|
||||
requests = kwargs["body"]["requests"]
|
||||
text_style = requests[0]["updateTextStyle"]["textStyle"]
|
||||
fields = requests[0]["updateTextStyle"]["fields"]
|
||||
|
||||
assert text_style == {"bold": True}
|
||||
assert fields == "bold"
|
||||
|
||||
|
||||
def test_format_text_invalid_color_only_does_not_call_api():
|
||||
block = GoogleDocsFormatTextBlock()
|
||||
service = _make_mock_docs_service()
|
||||
|
||||
result = block._format_text(
|
||||
service,
|
||||
document_id="doc_1",
|
||||
start_index=1,
|
||||
end_index=2,
|
||||
bold=False,
|
||||
italic=False,
|
||||
underline=False,
|
||||
font_size=0,
|
||||
foreground_color="#F",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid foreground_color" in result["message"]
|
||||
service.documents.return_value.batchUpdate.assert_not_called()
|
||||
@@ -37,6 +37,18 @@ class TestTranscribeYoutubeVideoBlock:
|
||||
video_id = self.youtube_block.extract_video_id(url)
|
||||
assert video_id == "dQw4w9WgXcQ"
|
||||
|
||||
def test_extract_video_id_shorts_url(self):
|
||||
"""Test extracting video ID from YouTube Shorts URL."""
|
||||
url = "https://www.youtube.com/shorts/dtUqwMu3e-g"
|
||||
video_id = self.youtube_block.extract_video_id(url)
|
||||
assert video_id == "dtUqwMu3e-g"
|
||||
|
||||
def test_extract_video_id_shorts_url_with_params(self):
|
||||
"""Test extracting video ID from YouTube Shorts URL with query parameters."""
|
||||
url = "https://www.youtube.com/shorts/dtUqwMu3e-g?feature=share"
|
||||
video_id = self.youtube_block.extract_video_id(url)
|
||||
assert video_id == "dtUqwMu3e-g"
|
||||
|
||||
@patch("backend.blocks.youtube.YouTubeTranscriptApi")
|
||||
def test_get_transcript_english_available(self, mock_api_class):
|
||||
"""Test getting transcript when English is available."""
|
||||
|
||||
146
autogpt_platform/cloudflare_worker.js
Normal file
146
autogpt_platform/cloudflare_worker.js
Normal file
@@ -0,0 +1,146 @@
|
||||
/**
|
||||
* Cloudflare Workers Script for docs.agpt.co → agpt.co/docs migration
|
||||
*
|
||||
* Deploy this script to handle all redirects with a single JavaScript file.
|
||||
* No rule limits, easy to maintain, handles all edge cases.
|
||||
*/
|
||||
|
||||
// URL mapping for special cases that don't follow patterns
|
||||
const SPECIAL_MAPPINGS = {
|
||||
// Root page
|
||||
'/': '/docs/platform',
|
||||
|
||||
// Special cases that don't follow standard patterns
|
||||
'/platform/d_id/': '/docs/integrations/block-integrations/d-id',
|
||||
'/platform/blocks/blocks/': '/docs/integrations',
|
||||
'/platform/blocks/decoder_block/': '/docs/integrations/block-integrations/text-decoder',
|
||||
'/platform/blocks/http': '/docs/integrations/block-integrations/send-web-request',
|
||||
'/platform/blocks/llm/': '/docs/integrations/block-integrations/ai-and-llm',
|
||||
'/platform/blocks/time_blocks': '/docs/integrations/block-integrations/time-and-date',
|
||||
'/platform/blocks/text_to_speech_block': '/docs/integrations/block-integrations/text-to-speech',
|
||||
'/platform/blocks/ai_shortform_video_block': '/docs/integrations/block-integrations/ai-shortform-video',
|
||||
'/platform/blocks/replicate_flux_advanced': '/docs/integrations/block-integrations/replicate-flux-advanced',
|
||||
'/platform/blocks/flux_kontext': '/docs/integrations/block-integrations/flux-kontext',
|
||||
'/platform/blocks/ai_condition/': '/docs/integrations/block-integrations/ai-condition',
|
||||
'/platform/blocks/email_block': '/docs/integrations/block-integrations/email',
|
||||
'/platform/blocks/google_maps': '/docs/integrations/block-integrations/google-maps',
|
||||
'/platform/blocks/google/gmail': '/docs/integrations/block-integrations/gmail',
|
||||
'/platform/blocks/github/issues/': '/docs/integrations/block-integrations/github-issues',
|
||||
'/platform/blocks/github/repo/': '/docs/integrations/block-integrations/github-repo',
|
||||
'/platform/blocks/github/pull_requests': '/docs/integrations/block-integrations/github-pull-requests',
|
||||
'/platform/blocks/twitter/twitter': '/docs/integrations/block-integrations/twitter',
|
||||
'/classic/setup/': '/docs/classic/setup/setting-up-autogpt-classic',
|
||||
'/code-of-conduct/': '/docs/classic/help-us-improve-autogpt/code-of-conduct',
|
||||
'/contributing/': '/docs/classic/contributing',
|
||||
'/contribute/': '/docs/contribute',
|
||||
'/forge/components/introduction/': '/docs/classic/forge/introduction'
|
||||
};
|
||||
|
||||
/**
|
||||
* Transform path by replacing underscores with hyphens and removing trailing slashes
|
||||
*/
|
||||
function transformPath(path) {
|
||||
return path.replace(/_/g, '-').replace(/\/$/, '');
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle docs.agpt.co redirects
|
||||
*/
|
||||
function handleDocsRedirect(url) {
|
||||
const pathname = url.pathname;
|
||||
|
||||
// Check special mappings first
|
||||
if (SPECIAL_MAPPINGS[pathname]) {
|
||||
return `https://agpt.co${SPECIAL_MAPPINGS[pathname]}`;
|
||||
}
|
||||
|
||||
// Pattern-based redirects
|
||||
|
||||
// Platform blocks: /platform/blocks/* → /docs/integrations/block-integrations/*
|
||||
if (pathname.startsWith('/platform/blocks/')) {
|
||||
const blockName = pathname.substring('/platform/blocks/'.length);
|
||||
const transformedName = transformPath(blockName);
|
||||
return `https://agpt.co/docs/integrations/block-integrations/${transformedName}`;
|
||||
}
|
||||
|
||||
// Platform contributing: /platform/contributing/* → /docs/platform/contributing/*
|
||||
if (pathname.startsWith('/platform/contributing/')) {
|
||||
const subPath = pathname.substring('/platform/contributing/'.length);
|
||||
return `https://agpt.co/docs/platform/contributing/${subPath}`;
|
||||
}
|
||||
|
||||
// Platform general: /platform/* → /docs/platform/* (with underscore→hyphen)
|
||||
if (pathname.startsWith('/platform/')) {
|
||||
const subPath = pathname.substring('/platform/'.length);
|
||||
const transformedPath = transformPath(subPath);
|
||||
return `https://agpt.co/docs/platform/${transformedPath}`;
|
||||
}
|
||||
|
||||
// Forge components: /forge/components/* → /docs/classic/forge/introduction/*
|
||||
if (pathname.startsWith('/forge/components/')) {
|
||||
const subPath = pathname.substring('/forge/components/'.length);
|
||||
return `https://agpt.co/docs/classic/forge/introduction/${subPath}`;
|
||||
}
|
||||
|
||||
// Forge general: /forge/* → /docs/classic/forge/*
|
||||
if (pathname.startsWith('/forge/')) {
|
||||
const subPath = pathname.substring('/forge/'.length);
|
||||
return `https://agpt.co/docs/classic/forge/${subPath}`;
|
||||
}
|
||||
|
||||
// Classic: /classic/* → /docs/classic/*
|
||||
if (pathname.startsWith('/classic/')) {
|
||||
const subPath = pathname.substring('/classic/'.length);
|
||||
return `https://agpt.co/docs/classic/${subPath}`;
|
||||
}
|
||||
|
||||
// Default fallback
|
||||
return 'https://agpt.co/docs/';
|
||||
}
|
||||
|
||||
/**
|
||||
* Main Worker function
|
||||
*/
|
||||
export default {
|
||||
async fetch(request, env, ctx) {
|
||||
const url = new URL(request.url);
|
||||
|
||||
// Only handle docs.agpt.co requests
|
||||
if (url.hostname === 'docs.agpt.co') {
|
||||
const redirectUrl = handleDocsRedirect(url);
|
||||
|
||||
return new Response(null, {
|
||||
status: 301,
|
||||
headers: {
|
||||
'Location': redirectUrl,
|
||||
'Cache-Control': 'max-age=300' // Cache redirects for 5 minutes
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// For non-docs requests, pass through or return 404
|
||||
return new Response('Not Found', { status: 404 });
|
||||
}
|
||||
};
|
||||
|
||||
// Test function for local development
|
||||
export function testRedirects() {
|
||||
const testCases = [
|
||||
'https://docs.agpt.co/',
|
||||
'https://docs.agpt.co/platform/getting-started/',
|
||||
'https://docs.agpt.co/platform/advanced_setup/',
|
||||
'https://docs.agpt.co/platform/blocks/basic/',
|
||||
'https://docs.agpt.co/platform/blocks/ai_condition/',
|
||||
'https://docs.agpt.co/classic/setup/',
|
||||
'https://docs.agpt.co/forge/components/agents/',
|
||||
'https://docs.agpt.co/contributing/',
|
||||
'https://docs.agpt.co/unknown-page'
|
||||
];
|
||||
|
||||
console.log('Testing redirects:');
|
||||
testCases.forEach(testUrl => {
|
||||
const url = new URL(testUrl);
|
||||
const result = handleDocsRedirect(url);
|
||||
console.log(`${testUrl} → ${result}`);
|
||||
});
|
||||
}
|
||||
@@ -37,7 +37,7 @@ services:
|
||||
context: ../
|
||||
dockerfile: autogpt_platform/backend/Dockerfile
|
||||
target: migrate
|
||||
command: ["sh", "-c", "poetry run prisma generate && poetry run prisma migrate deploy"]
|
||||
command: ["sh", "-c", "poetry run prisma generate && poetry run gen-prisma-stub && poetry run prisma migrate deploy"]
|
||||
develop:
|
||||
watch:
|
||||
- path: ./
|
||||
|
||||
@@ -46,14 +46,15 @@
|
||||
"@radix-ui/react-scroll-area": "1.2.10",
|
||||
"@radix-ui/react-select": "2.2.6",
|
||||
"@radix-ui/react-separator": "1.1.7",
|
||||
"@radix-ui/react-slider": "1.3.6",
|
||||
"@radix-ui/react-slot": "1.2.3",
|
||||
"@radix-ui/react-switch": "1.2.6",
|
||||
"@radix-ui/react-tabs": "1.1.13",
|
||||
"@radix-ui/react-toast": "1.2.15",
|
||||
"@radix-ui/react-tooltip": "1.2.8",
|
||||
"@rjsf/core": "5.24.13",
|
||||
"@rjsf/utils": "5.24.13",
|
||||
"@rjsf/validator-ajv8": "5.24.13",
|
||||
"@rjsf/core": "6.1.2",
|
||||
"@rjsf/utils": "6.1.2",
|
||||
"@rjsf/validator-ajv8": "6.1.2",
|
||||
"@sentry/nextjs": "10.27.0",
|
||||
"@supabase/ssr": "0.7.0",
|
||||
"@supabase/supabase-js": "2.78.0",
|
||||
@@ -69,6 +70,7 @@
|
||||
"cmdk": "1.1.1",
|
||||
"cookie": "1.0.2",
|
||||
"date-fns": "4.1.0",
|
||||
"dexie": "4.2.1",
|
||||
"dotenv": "17.2.3",
|
||||
"elliptic": "6.6.1",
|
||||
"embla-carousel-react": "8.6.0",
|
||||
@@ -90,7 +92,6 @@
|
||||
"react-currency-input-field": "4.0.3",
|
||||
"react-day-picker": "9.11.1",
|
||||
"react-dom": "18.3.1",
|
||||
"react-drag-drop-files": "2.4.0",
|
||||
"react-hook-form": "7.66.0",
|
||||
"react-icons": "5.5.0",
|
||||
"react-markdown": "9.0.3",
|
||||
|
||||
3878
autogpt_platform/frontend/pnpm-lock.yaml
generated
3878
autogpt_platform/frontend/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
BIN
autogpt_platform/frontend/public/integrations/webshare_proxy.png
Normal file
BIN
autogpt_platform/frontend/public/integrations/webshare_proxy.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.6 KiB |
BIN
autogpt_platform/frontend/public/integrations/wordpress.png
Normal file
BIN
autogpt_platform/frontend/public/integrations/wordpress.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
@@ -1,4 +1,4 @@
|
||||
import { OAuthPopupResultMessage } from "@/components/renderers/input-renderer/fields/CredentialField/models/OAuthCredentialModal/useOAuthCredentialModal";
|
||||
import { OAuthPopupResultMessage } from "./types";
|
||||
import { NextResponse } from "next/server";
|
||||
|
||||
// This route is intended to be used as the callback for integration OAuth flows,
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
export type OAuthPopupResultMessage = { message_type: "oauth_popup_result" } & (
|
||||
| {
|
||||
success: true;
|
||||
code: string;
|
||||
state: string;
|
||||
}
|
||||
| {
|
||||
success: false;
|
||||
message: string;
|
||||
}
|
||||
);
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
SheetTitle,
|
||||
SheetTrigger,
|
||||
} from "@/components/__legacy__/ui/sheet";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
@@ -25,7 +26,6 @@ import {
|
||||
import { BookOpenIcon } from "@phosphor-icons/react";
|
||||
import { useMemo } from "react";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { BuilderActionButton } from "../BuilderActionButton";
|
||||
|
||||
export const AgentOutputs = ({ flowID }: { flowID: string | null }) => {
|
||||
const hasOutputs = useGraphStore(useShallow((state) => state.hasOutputs));
|
||||
@@ -76,9 +76,13 @@ export const AgentOutputs = ({ flowID }: { flowID: string | null }) => {
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<SheetTrigger asChild>
|
||||
<BuilderActionButton disabled={!flowID || !hasOutputs()}>
|
||||
<BookOpenIcon className="size-6" />
|
||||
</BuilderActionButton>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
disabled={!flowID || !hasOutputs()}
|
||||
>
|
||||
<BookOpenIcon className="size-4" />
|
||||
</Button>
|
||||
</SheetTrigger>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { ButtonProps } from "@/components/atoms/Button/helpers";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CircleNotchIcon } from "@phosphor-icons/react";
|
||||
|
||||
export const BuilderActionButton = ({
|
||||
children,
|
||||
className,
|
||||
isLoading,
|
||||
...props
|
||||
}: ButtonProps & { isLoading?: boolean }) => {
|
||||
return (
|
||||
<Button
|
||||
variant="icon"
|
||||
size={"small"}
|
||||
className={cn(
|
||||
"relative h-12 w-12 min-w-0 text-lg",
|
||||
"bg-gradient-to-br from-zinc-50 to-zinc-200",
|
||||
"border border-zinc-200",
|
||||
"shadow-[inset_0_3px_0_0_rgba(255,255,255,0.5),0_2px_4px_0_rgba(0,0,0,0.2)]",
|
||||
"dark:shadow-[inset_0_1px_0_0_rgba(255,255,255,0.1),0_2px_4px_0_rgba(0,0,0,0.4)]",
|
||||
"hover:shadow-[inset_0_1px_0_0_rgba(255,255,255,0.5),0_1px_2px_0_rgba(0,0,0,0.2)]",
|
||||
"active:shadow-[inset_0_2px_4px_0_rgba(0,0,0,0.2)]",
|
||||
"transition-all duration-150",
|
||||
"disabled:cursor-not-allowed disabled:opacity-50",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{!isLoading ? (
|
||||
children
|
||||
) : (
|
||||
<CircleNotchIcon className="size-6 animate-spin" />
|
||||
)}
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
@@ -1,12 +1,12 @@
|
||||
import { ShareIcon } from "@phosphor-icons/react";
|
||||
import { BuilderActionButton } from "../BuilderActionButton";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { usePublishToMarketplace } from "./usePublishToMarketplace";
|
||||
import { PublishAgentModal } from "@/components/contextual/PublishAgentModal/PublishAgentModal";
|
||||
import { ShareIcon } from "@phosphor-icons/react";
|
||||
import { usePublishToMarketplace } from "./usePublishToMarketplace";
|
||||
|
||||
export const PublishToMarketplace = ({ flowID }: { flowID: string | null }) => {
|
||||
const { handlePublishToMarketplace, publishState, handleStateChange } =
|
||||
@@ -16,12 +16,14 @@ export const PublishToMarketplace = ({ flowID }: { flowID: string | null }) => {
|
||||
<>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<BuilderActionButton
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={handlePublishToMarketplace}
|
||||
disabled={!flowID}
|
||||
>
|
||||
<ShareIcon className="size-6 drop-shadow-sm" />
|
||||
</BuilderActionButton>
|
||||
<ShareIcon className="size-4" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>Publish to Marketplace</TooltipContent>
|
||||
</Tooltip>
|
||||
@@ -30,6 +32,7 @@ export const PublishToMarketplace = ({ flowID }: { flowID: string | null }) => {
|
||||
targetState={publishState}
|
||||
onStateChange={handleStateChange}
|
||||
preSelectedAgentId={flowID || undefined}
|
||||
showTrigger={false}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import { useRunGraph } from "./useRunGraph";
|
||||
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { PlayIcon, StopIcon } from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { RunInputDialog } from "../RunInputDialog/RunInputDialog";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { BuilderActionButton } from "../BuilderActionButton";
|
||||
import { PlayIcon, StopIcon } from "@phosphor-icons/react";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { RunInputDialog } from "../RunInputDialog/RunInputDialog";
|
||||
import { useRunGraph } from "./useRunGraph";
|
||||
|
||||
export const RunGraph = ({ flowID }: { flowID: string | null }) => {
|
||||
const {
|
||||
@@ -29,21 +28,19 @@ export const RunGraph = ({ flowID }: { flowID: string | null }) => {
|
||||
<>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<BuilderActionButton
|
||||
className={cn(
|
||||
isGraphRunning &&
|
||||
"border-red-500 bg-gradient-to-br from-red-400 to-red-500 shadow-[inset_0_2px_0_0_rgba(255,255,255,0.5),0_2px_4px_0_rgba(0,0,0,0.2)]",
|
||||
)}
|
||||
<Button
|
||||
size="icon"
|
||||
variant={isGraphRunning ? "destructive" : "primary"}
|
||||
onClick={isGraphRunning ? handleStopGraph : handleRunGraph}
|
||||
disabled={!flowID || isExecutingGraph || isTerminatingGraph}
|
||||
isLoading={isExecutingGraph || isTerminatingGraph || isSaving}
|
||||
loading={isExecutingGraph || isTerminatingGraph || isSaving}
|
||||
>
|
||||
{!isGraphRunning ? (
|
||||
<PlayIcon className="size-6 drop-shadow-sm" />
|
||||
<PlayIcon className="size-4" />
|
||||
) : (
|
||||
<StopIcon className="size-6 drop-shadow-sm" />
|
||||
<StopIcon className="size-4" />
|
||||
)}
|
||||
</BuilderActionButton>
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{isGraphRunning ? "Stop agent" : "Run agent"}
|
||||
|
||||
@@ -5,7 +5,7 @@ import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { ClockIcon, PlayIcon } from "@phosphor-icons/react";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { FormRenderer } from "@/components/renderers/input-renderer/FormRenderer";
|
||||
import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer";
|
||||
import { useRunInputDialog } from "./useRunInputDialog";
|
||||
import { CronSchedulerDialog } from "../CronSchedulerDialog/CronSchedulerDialog";
|
||||
|
||||
@@ -66,6 +66,7 @@ export const RunInputDialog = ({
|
||||
formContext={{
|
||||
showHandles: false,
|
||||
size: "large",
|
||||
showOptionalToggle: false,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
import { parseAsInteger, parseAsString, useQueryStates } from "nuqs";
|
||||
import { useMemo, useState } from "react";
|
||||
import { uiSchema } from "../../../FlowEditor/nodes/uiSchema";
|
||||
import { isCredentialFieldSchema } from "@/components/renderers/input-renderer/fields/CredentialField/helpers";
|
||||
import { isCredentialFieldSchema } from "@/components/renderers/InputRenderer/custom/CredentialField/helpers";
|
||||
|
||||
export const useRunInputDialog = ({
|
||||
setIsOpen,
|
||||
@@ -66,7 +66,7 @@ export const useRunInputDialog = ({
|
||||
if (isCredentialFieldSchema(fieldSchema)) {
|
||||
dynamicUiSchema[fieldName] = {
|
||||
...dynamicUiSchema[fieldName],
|
||||
"ui:field": "credentials",
|
||||
"ui:field": "custom/credential_field",
|
||||
};
|
||||
}
|
||||
});
|
||||
@@ -76,12 +76,18 @@ export const useRunInputDialog = ({
|
||||
}, [credentialsSchema]);
|
||||
|
||||
const handleManualRun = async () => {
|
||||
// Filter out incomplete credentials (those without a valid id)
|
||||
// RJSF auto-populates const values (provider, type) but not id field
|
||||
const validCredentials = Object.fromEntries(
|
||||
Object.entries(credentialValues).filter(([_, cred]) => cred && cred.id),
|
||||
);
|
||||
|
||||
await executeGraph({
|
||||
graphId: flowID ?? "",
|
||||
graphVersion: flowVersion || null,
|
||||
data: {
|
||||
inputs: inputValues,
|
||||
credentials_inputs: credentialValues,
|
||||
credentials_inputs: validCredentials,
|
||||
source: "builder",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { ClockIcon } from "@phosphor-icons/react";
|
||||
import { RunInputDialog } from "../RunInputDialog/RunInputDialog";
|
||||
import { useScheduleGraph } from "./useScheduleGraph";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { ClockIcon } from "@phosphor-icons/react";
|
||||
import { CronSchedulerDialog } from "../CronSchedulerDialog/CronSchedulerDialog";
|
||||
import { BuilderActionButton } from "../BuilderActionButton";
|
||||
import { RunInputDialog } from "../RunInputDialog/RunInputDialog";
|
||||
import { useScheduleGraph } from "./useScheduleGraph";
|
||||
|
||||
export const ScheduleGraph = ({ flowID }: { flowID: string | null }) => {
|
||||
const {
|
||||
@@ -23,12 +23,14 @@ export const ScheduleGraph = ({ flowID }: { flowID: string | null }) => {
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<BuilderActionButton
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={handleScheduleGraph}
|
||||
disabled={!flowID}
|
||||
>
|
||||
<ClockIcon className="size-6" />
|
||||
</BuilderActionButton>
|
||||
<ClockIcon className="size-4" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<p>Schedule Graph</p>
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { ClockCounterClockwiseIcon, XIcon } from "@phosphor-icons/react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { formatTimeAgo } from "@/lib/utils/time";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { useDraftRecoveryPopup } from "./useDraftRecoveryPopup";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { DraftDiff } from "@/lib/dexie/draft-utils";
|
||||
|
||||
interface DraftRecoveryPopupProps {
|
||||
isInitialLoadComplete: boolean;
|
||||
}
|
||||
|
||||
function formatDiffSummary(diff: DraftDiff | null): string {
|
||||
if (!diff) return "";
|
||||
|
||||
const parts: string[] = [];
|
||||
|
||||
// Node changes
|
||||
const nodeChanges: string[] = [];
|
||||
if (diff.nodes.added > 0) nodeChanges.push(`+${diff.nodes.added}`);
|
||||
if (diff.nodes.removed > 0) nodeChanges.push(`-${diff.nodes.removed}`);
|
||||
if (diff.nodes.modified > 0) nodeChanges.push(`~${diff.nodes.modified}`);
|
||||
|
||||
if (nodeChanges.length > 0) {
|
||||
parts.push(
|
||||
`${nodeChanges.join("/")} block${diff.nodes.added + diff.nodes.removed + diff.nodes.modified !== 1 ? "s" : ""}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Edge changes
|
||||
const edgeChanges: string[] = [];
|
||||
if (diff.edges.added > 0) edgeChanges.push(`+${diff.edges.added}`);
|
||||
if (diff.edges.removed > 0) edgeChanges.push(`-${diff.edges.removed}`);
|
||||
if (diff.edges.modified > 0) edgeChanges.push(`~${diff.edges.modified}`);
|
||||
|
||||
if (edgeChanges.length > 0) {
|
||||
parts.push(
|
||||
`${edgeChanges.join("/")} connection${diff.edges.added + diff.edges.removed + diff.edges.modified !== 1 ? "s" : ""}`,
|
||||
);
|
||||
}
|
||||
|
||||
return parts.join(", ");
|
||||
}
|
||||
|
||||
export function DraftRecoveryPopup({
|
||||
isInitialLoadComplete,
|
||||
}: DraftRecoveryPopupProps) {
|
||||
const {
|
||||
isOpen,
|
||||
popupRef,
|
||||
nodeCount,
|
||||
edgeCount,
|
||||
diff,
|
||||
savedAt,
|
||||
onLoad,
|
||||
onDiscard,
|
||||
} = useDraftRecoveryPopup(isInitialLoadComplete);
|
||||
|
||||
const diffSummary = formatDiffSummary(diff);
|
||||
|
||||
return (
|
||||
<AnimatePresence>
|
||||
{isOpen && (
|
||||
<motion.div
|
||||
ref={popupRef}
|
||||
className={cn("absolute left-1/2 top-4 z-50")}
|
||||
initial={{
|
||||
opacity: 0,
|
||||
x: "-50%",
|
||||
y: "-150%",
|
||||
scale: 0.5,
|
||||
filter: "blur(20px)",
|
||||
}}
|
||||
animate={{
|
||||
opacity: 1,
|
||||
x: "-50%",
|
||||
y: "0%",
|
||||
scale: 1,
|
||||
filter: "blur(0px)",
|
||||
}}
|
||||
exit={{
|
||||
opacity: 0,
|
||||
y: "-150%",
|
||||
scale: 0.5,
|
||||
filter: "blur(20px)",
|
||||
transition: { duration: 0.4, type: "spring", bounce: 0.2 },
|
||||
}}
|
||||
transition={{ duration: 0.2, type: "spring", bounce: 0.2 }}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center gap-3 rounded-xlarge border border-amber-200 bg-amber-50 px-4 py-3 shadow-lg",
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2 text-amber-700 dark:text-amber-300">
|
||||
<ClockCounterClockwiseIcon className="h-5 w-5" weight="fill" />
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col">
|
||||
<Text
|
||||
variant="small-medium"
|
||||
className="text-amber-900 dark:text-amber-100"
|
||||
>
|
||||
Unsaved changes found
|
||||
</Text>
|
||||
<Text
|
||||
variant="small"
|
||||
className="text-amber-700 dark:text-amber-400"
|
||||
>
|
||||
{diffSummary ||
|
||||
`${nodeCount} block${nodeCount !== 1 ? "s" : ""}, ${edgeCount} connection${edgeCount !== 1 ? "s" : ""}`}{" "}
|
||||
• {formatTimeAgo(new Date(savedAt).toISOString())}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<div className="ml-2 flex items-center gap-2">
|
||||
<Tooltip delayDuration={10}>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={onLoad}
|
||||
className="aspect-square min-w-0 p-1.5"
|
||||
>
|
||||
<ClockCounterClockwiseIcon size={20} weight="fill" />
|
||||
<span className="sr-only">Restore changes</span>
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>Restore changes</TooltipContent>
|
||||
</Tooltip>
|
||||
<Tooltip delayDuration={10}>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="icon"
|
||||
onClick={onDiscard}
|
||||
aria-label="Discard changes"
|
||||
className="aspect-square min-w-0 p-1.5"
|
||||
>
|
||||
<XIcon size={20} />
|
||||
<span className="sr-only">Discard changes</span>
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>Discard changes</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
import { useEffect, useRef } from "react";
|
||||
import { useDraftManager } from "../FlowEditor/Flow/useDraftManager";
|
||||
|
||||
export const useDraftRecoveryPopup = (isInitialLoadComplete: boolean) => {
|
||||
const popupRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const {
|
||||
isRecoveryOpen: isOpen,
|
||||
savedAt,
|
||||
nodeCount,
|
||||
edgeCount,
|
||||
diff,
|
||||
loadDraft: onLoad,
|
||||
discardDraft: onDiscard,
|
||||
} = useDraftManager(isInitialLoadComplete);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isOpen) return;
|
||||
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
if (
|
||||
popupRef.current &&
|
||||
!popupRef.current.contains(event.target as Node)
|
||||
) {
|
||||
onDiscard();
|
||||
}
|
||||
};
|
||||
|
||||
const timeoutId = setTimeout(() => {
|
||||
document.addEventListener("mousedown", handleClickOutside);
|
||||
}, 100);
|
||||
|
||||
return () => {
|
||||
clearTimeout(timeoutId);
|
||||
document.removeEventListener("mousedown", handleClickOutside);
|
||||
};
|
||||
}, [isOpen, onDiscard]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isOpen) return;
|
||||
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (event.key === "Escape") {
|
||||
onDiscard();
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener("keydown", handleKeyDown);
|
||||
return () => {
|
||||
document.removeEventListener("keydown", handleKeyDown);
|
||||
};
|
||||
}, [isOpen, onDiscard]);
|
||||
return {
|
||||
popupRef,
|
||||
isOpen,
|
||||
nodeCount,
|
||||
edgeCount,
|
||||
diff,
|
||||
savedAt,
|
||||
onLoad,
|
||||
onDiscard,
|
||||
};
|
||||
};
|
||||
@@ -1,26 +1,27 @@
|
||||
import { ReactFlow, Background } from "@xyflow/react";
|
||||
import NewControlPanel from "../../NewControlPanel/NewControlPanel";
|
||||
import CustomEdge from "../edges/CustomEdge";
|
||||
import { useFlow } from "./useFlow";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { useMemo, useEffect, useCallback } from "react";
|
||||
import { CustomNode } from "../nodes/CustomNode/CustomNode";
|
||||
import { useCustomEdge } from "../edges/useCustomEdge";
|
||||
import { useFlowRealtime } from "./useFlowRealtime";
|
||||
import { GraphLoadingBox } from "./components/GraphLoadingBox";
|
||||
import { BuilderActions } from "../../BuilderActions/BuilderActions";
|
||||
import { RunningBackground } from "./components/RunningBackground";
|
||||
import { useGraphStore } from "../../../stores/graphStore";
|
||||
import { useCopyPaste } from "./useCopyPaste";
|
||||
import { FloatingReviewsPanel } from "@/components/organisms/FloatingReviewsPanel/FloatingReviewsPanel";
|
||||
import { parseAsString, useQueryStates } from "nuqs";
|
||||
import { CustomControls } from "./components/CustomControl";
|
||||
import { useGetV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { FloatingReviewsPanel } from "@/components/organisms/FloatingReviewsPanel/FloatingReviewsPanel";
|
||||
import { Background, ReactFlow } from "@xyflow/react";
|
||||
import { parseAsString, useQueryStates } from "nuqs";
|
||||
import { useCallback, useMemo } from "react";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { useGraphStore } from "../../../stores/graphStore";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { BuilderActions } from "../../BuilderActions/BuilderActions";
|
||||
import { DraftRecoveryPopup } from "../../DraftRecoveryDialog/DraftRecoveryPopup";
|
||||
import { FloatingSafeModeToggle } from "../../FloatingSafeModeToogle";
|
||||
import NewControlPanel from "../../NewControlPanel/NewControlPanel";
|
||||
import CustomEdge from "../edges/CustomEdge";
|
||||
import { useCustomEdge } from "../edges/useCustomEdge";
|
||||
import { CustomNode } from "../nodes/CustomNode/CustomNode";
|
||||
import { CustomControls } from "./components/CustomControl";
|
||||
import { GraphLoadingBox } from "./components/GraphLoadingBox";
|
||||
import { RunningBackground } from "./components/RunningBackground";
|
||||
import { TriggerAgentBanner } from "./components/TriggerAgentBanner";
|
||||
import { resolveCollisions } from "./helpers/resolve-collision";
|
||||
import { FloatingSafeModeToggle } from "../../FloatingSafeModeToogle";
|
||||
import { useCopyPaste } from "./useCopyPaste";
|
||||
import { useFlow } from "./useFlow";
|
||||
import { useFlowRealtime } from "./useFlowRealtime";
|
||||
|
||||
export const Flow = () => {
|
||||
const [{ flowID, flowExecutionID }] = useQueryStates({
|
||||
@@ -41,14 +42,18 @@ export const Flow = () => {
|
||||
|
||||
const nodes = useNodeStore(useShallow((state) => state.nodes));
|
||||
const setNodes = useNodeStore(useShallow((state) => state.setNodes));
|
||||
|
||||
const onNodesChange = useNodeStore(
|
||||
useShallow((state) => state.onNodesChange),
|
||||
);
|
||||
|
||||
const hasWebhookNodes = useNodeStore(
|
||||
useShallow((state) => state.hasWebhookNodes()),
|
||||
);
|
||||
|
||||
const nodeTypes = useMemo(() => ({ custom: CustomNode }), []);
|
||||
const edgeTypes = useMemo(() => ({ custom: CustomEdge }), []);
|
||||
|
||||
const onNodeDragStop = useCallback(() => {
|
||||
setNodes(
|
||||
resolveCollisions(nodes, {
|
||||
@@ -60,29 +65,26 @@ export const Flow = () => {
|
||||
}, [setNodes, nodes]);
|
||||
const { edges, onConnect, onEdgesChange } = useCustomEdge();
|
||||
|
||||
// We use this hook to load the graph and convert them into custom nodes and edges.
|
||||
const { onDragOver, onDrop, isFlowContentLoading, isLocked, setIsLocked } =
|
||||
useFlow();
|
||||
// for loading purpose
|
||||
const {
|
||||
onDragOver,
|
||||
onDrop,
|
||||
isFlowContentLoading,
|
||||
isInitialLoadComplete,
|
||||
isLocked,
|
||||
setIsLocked,
|
||||
} = useFlow();
|
||||
|
||||
// This hook is used for websocket realtime updates.
|
||||
useFlowRealtime();
|
||||
|
||||
// Copy/paste functionality
|
||||
const handleCopyPaste = useCopyPaste();
|
||||
useCopyPaste();
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
handleCopyPaste(event);
|
||||
};
|
||||
|
||||
window.addEventListener("keydown", handleKeyDown);
|
||||
return () => {
|
||||
window.removeEventListener("keydown", handleKeyDown);
|
||||
};
|
||||
}, [handleCopyPaste]);
|
||||
const isGraphRunning = useGraphStore(
|
||||
useShallow((state) => state.isGraphRunning),
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex h-full w-full dark:bg-slate-900">
|
||||
<div className="relative flex-1">
|
||||
@@ -95,6 +97,9 @@ export const Flow = () => {
|
||||
onConnect={onConnect}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onNodeDragStop={onNodeDragStop}
|
||||
onNodeContextMenu={(event) => {
|
||||
event.preventDefault();
|
||||
}}
|
||||
maxZoom={2}
|
||||
minZoom={0.1}
|
||||
onDragOver={onDragOver}
|
||||
@@ -102,6 +107,7 @@ export const Flow = () => {
|
||||
nodesDraggable={!isLocked}
|
||||
nodesConnectable={!isLocked}
|
||||
elementsSelectable={!isLocked}
|
||||
deleteKeyCode={["Backspace", "Delete"]}
|
||||
>
|
||||
<Background />
|
||||
<CustomControls setIsLocked={setIsLocked} isLocked={isLocked} />
|
||||
@@ -115,6 +121,7 @@ export const Flow = () => {
|
||||
className="right-2 top-32 p-2"
|
||||
/>
|
||||
)}
|
||||
<DraftRecoveryPopup isInitialLoadComplete={isInitialLoadComplete} />
|
||||
</ReactFlow>
|
||||
</div>
|
||||
{/* TODO: Need to update it in future - also do not send executionId as prop - rather use useQueryState inside the component */}
|
||||
|
||||
@@ -48,8 +48,6 @@ export const resolveCollisions: CollisionAlgorithm = (
|
||||
const width = (node.width ?? node.measured?.width ?? 0) + margin * 2;
|
||||
const height = (node.height ?? node.measured?.height ?? 0) + margin * 2;
|
||||
|
||||
console.log("width", width);
|
||||
console.log("height", height);
|
||||
const x = node.position.x - margin;
|
||||
const y = node.position.y - margin;
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useCallback } from "react";
|
||||
import { useCallback, useEffect } from "react";
|
||||
import { useReactFlow } from "@xyflow/react";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
@@ -151,5 +151,16 @@ export function useCopyPaste() {
|
||||
[getViewport, toast],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
handleCopyPaste(event);
|
||||
};
|
||||
|
||||
window.addEventListener("keydown", handleKeyDown);
|
||||
return () => {
|
||||
window.removeEventListener("keydown", handleKeyDown);
|
||||
};
|
||||
}, [handleCopyPaste]);
|
||||
|
||||
return handleCopyPaste;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,319 @@
|
||||
import { useState, useCallback, useEffect, useRef } from "react";
|
||||
import { parseAsString, parseAsInteger, useQueryStates } from "nuqs";
|
||||
import {
|
||||
draftService,
|
||||
getTempFlowId,
|
||||
getOrCreateTempFlowId,
|
||||
DraftData,
|
||||
} from "@/services/builder-draft/draft-service";
|
||||
import { BuilderDraft } from "@/lib/dexie/db";
|
||||
import {
|
||||
cleanNodes,
|
||||
cleanEdges,
|
||||
calculateDraftDiff,
|
||||
DraftDiff,
|
||||
} from "@/lib/dexie/draft-utils";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { useEdgeStore } from "../../../stores/edgeStore";
|
||||
import { useGraphStore } from "../../../stores/graphStore";
|
||||
import { useHistoryStore } from "../../../stores/historyStore";
|
||||
import isEqual from "lodash/isEqual";
|
||||
|
||||
const AUTO_SAVE_INTERVAL_MS = 15000; // 15 seconds
|
||||
|
||||
interface DraftRecoveryState {
|
||||
isOpen: boolean;
|
||||
draft: BuilderDraft | null;
|
||||
diff: DraftDiff | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Consolidated hook for draft persistence and recovery
|
||||
* - Auto-saves builder state every 15 seconds
|
||||
* - Saves on beforeunload event
|
||||
* - Checks for and manages unsaved drafts on load
|
||||
*/
|
||||
export function useDraftManager(isInitialLoadComplete: boolean) {
|
||||
const [state, setState] = useState<DraftRecoveryState>({
|
||||
isOpen: false,
|
||||
draft: null,
|
||||
diff: null,
|
||||
});
|
||||
|
||||
const [{ flowID, flowVersion }] = useQueryStates({
|
||||
flowID: parseAsString,
|
||||
flowVersion: parseAsInteger,
|
||||
});
|
||||
|
||||
const lastSavedStateRef = useRef<DraftData | null>(null);
|
||||
const saveTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const isDirtyRef = useRef(false);
|
||||
const hasCheckedForDraft = useRef(false);
|
||||
|
||||
const getEffectiveFlowId = useCallback((): string => {
|
||||
return flowID || getOrCreateTempFlowId();
|
||||
}, [flowID]);
|
||||
|
||||
const getCurrentState = useCallback((): DraftData => {
|
||||
const nodes = useNodeStore.getState().nodes;
|
||||
const edges = useEdgeStore.getState().edges;
|
||||
const nodeCounter = useNodeStore.getState().nodeCounter;
|
||||
const graphStore = useGraphStore.getState();
|
||||
|
||||
return {
|
||||
nodes,
|
||||
edges,
|
||||
graphSchemas: {
|
||||
input: graphStore.inputSchema,
|
||||
credentials: graphStore.credentialsInputSchema,
|
||||
output: graphStore.outputSchema,
|
||||
},
|
||||
nodeCounter,
|
||||
flowVersion: flowVersion ?? undefined,
|
||||
};
|
||||
}, [flowVersion]);
|
||||
|
||||
const cleanStateForComparison = useCallback((stateData: DraftData) => {
|
||||
return {
|
||||
nodes: cleanNodes(stateData.nodes),
|
||||
edges: cleanEdges(stateData.edges),
|
||||
};
|
||||
}, []);
|
||||
|
||||
const hasChanges = useCallback((): boolean => {
|
||||
const currentState = getCurrentState();
|
||||
|
||||
if (!lastSavedStateRef.current) {
|
||||
return currentState.nodes.length > 0;
|
||||
}
|
||||
|
||||
const currentClean = cleanStateForComparison(currentState);
|
||||
const lastClean = cleanStateForComparison(lastSavedStateRef.current);
|
||||
|
||||
return !isEqual(currentClean, lastClean);
|
||||
}, [getCurrentState, cleanStateForComparison]);
|
||||
|
||||
const saveDraft = useCallback(async () => {
|
||||
const effectiveFlowId = getEffectiveFlowId();
|
||||
const currentState = getCurrentState();
|
||||
|
||||
if (currentState.nodes.length === 0 && currentState.edges.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!hasChanges()) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await draftService.saveDraft(effectiveFlowId, currentState);
|
||||
lastSavedStateRef.current = currentState;
|
||||
isDirtyRef.current = false;
|
||||
} catch (error) {
|
||||
console.error("[DraftPersistence] Failed to save draft:", error);
|
||||
}
|
||||
}, [getEffectiveFlowId, getCurrentState, hasChanges]);
|
||||
|
||||
const scheduleSave = useCallback(() => {
|
||||
isDirtyRef.current = true;
|
||||
|
||||
if (saveTimeoutRef.current) {
|
||||
clearTimeout(saveTimeoutRef.current);
|
||||
}
|
||||
|
||||
saveTimeoutRef.current = setTimeout(() => {
|
||||
saveDraft();
|
||||
}, AUTO_SAVE_INTERVAL_MS);
|
||||
}, [saveDraft]);
|
||||
|
||||
useEffect(() => {
|
||||
const unsubscribeNodes = useNodeStore.subscribe((storeState, prevState) => {
|
||||
if (storeState.nodes !== prevState.nodes) {
|
||||
scheduleSave();
|
||||
}
|
||||
});
|
||||
|
||||
const unsubscribeEdges = useEdgeStore.subscribe((storeState, prevState) => {
|
||||
if (storeState.edges !== prevState.edges) {
|
||||
scheduleSave();
|
||||
}
|
||||
});
|
||||
|
||||
return () => {
|
||||
unsubscribeNodes();
|
||||
unsubscribeEdges();
|
||||
};
|
||||
}, [scheduleSave]);
|
||||
|
||||
useEffect(() => {
|
||||
const handleBeforeUnload = () => {
|
||||
if (isDirtyRef.current) {
|
||||
const effectiveFlowId = getEffectiveFlowId();
|
||||
const currentState = getCurrentState();
|
||||
|
||||
if (
|
||||
currentState.nodes.length === 0 &&
|
||||
currentState.edges.length === 0
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
draftService.saveDraft(effectiveFlowId, currentState).catch(() => {
|
||||
// Ignore errors on unload
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener("beforeunload", handleBeforeUnload);
|
||||
return () => {
|
||||
window.removeEventListener("beforeunload", handleBeforeUnload);
|
||||
};
|
||||
}, [getEffectiveFlowId, getCurrentState]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (saveTimeoutRef.current) {
|
||||
clearTimeout(saveTimeoutRef.current);
|
||||
}
|
||||
if (isDirtyRef.current) {
|
||||
saveDraft();
|
||||
}
|
||||
};
|
||||
}, [saveDraft]);
|
||||
|
||||
useEffect(() => {
|
||||
draftService.cleanupExpired().catch((error) => {
|
||||
console.error(
|
||||
"[DraftPersistence] Failed to cleanup expired drafts:",
|
||||
error,
|
||||
);
|
||||
});
|
||||
}, []);
|
||||
|
||||
const checkForDraft = useCallback(async () => {
|
||||
const effectiveFlowId = flowID || getTempFlowId();
|
||||
|
||||
if (!effectiveFlowId) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const draft = await draftService.loadDraft(effectiveFlowId);
|
||||
|
||||
if (!draft) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentNodes = useNodeStore.getState().nodes;
|
||||
const currentEdges = useEdgeStore.getState().edges;
|
||||
|
||||
const isDifferent = draftService.isDraftDifferent(
|
||||
draft,
|
||||
currentNodes,
|
||||
currentEdges,
|
||||
);
|
||||
|
||||
if (isDifferent && (draft.nodes.length > 0 || draft.edges.length > 0)) {
|
||||
const diff = calculateDraftDiff(
|
||||
draft.nodes,
|
||||
draft.edges,
|
||||
currentNodes,
|
||||
currentEdges,
|
||||
);
|
||||
setState({
|
||||
isOpen: true,
|
||||
draft,
|
||||
diff,
|
||||
});
|
||||
} else {
|
||||
await draftService.deleteDraft(effectiveFlowId);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("[DraftRecovery] Failed to check for draft:", error);
|
||||
}
|
||||
}, [flowID]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isInitialLoadComplete && !hasCheckedForDraft.current) {
|
||||
hasCheckedForDraft.current = true;
|
||||
checkForDraft();
|
||||
}
|
||||
}, [isInitialLoadComplete, checkForDraft]);
|
||||
|
||||
useEffect(() => {
|
||||
hasCheckedForDraft.current = false;
|
||||
setState({
|
||||
isOpen: false,
|
||||
draft: null,
|
||||
diff: null,
|
||||
});
|
||||
}, [flowID]);
|
||||
|
||||
const loadDraft = useCallback(async () => {
|
||||
if (!state.draft) return;
|
||||
|
||||
const { draft } = state;
|
||||
|
||||
try {
|
||||
useNodeStore.getState().setNodes(draft.nodes);
|
||||
useEdgeStore.getState().setEdges(draft.edges);
|
||||
draft.nodes.forEach((node) => {
|
||||
useNodeStore.getState().syncHardcodedValuesWithHandleIds(node.id);
|
||||
});
|
||||
|
||||
if (draft.nodeCounter !== undefined) {
|
||||
useNodeStore.setState({ nodeCounter: draft.nodeCounter });
|
||||
}
|
||||
|
||||
if (draft.graphSchemas) {
|
||||
useGraphStore
|
||||
.getState()
|
||||
.setGraphSchemas(
|
||||
draft.graphSchemas.input as Record<string, unknown> | null,
|
||||
draft.graphSchemas.credentials as Record<string, unknown> | null,
|
||||
draft.graphSchemas.output as Record<string, unknown> | null,
|
||||
);
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
useHistoryStore.getState().initializeHistory();
|
||||
}, 100);
|
||||
|
||||
await draftService.deleteDraft(draft.id);
|
||||
|
||||
setState({
|
||||
isOpen: false,
|
||||
draft: null,
|
||||
diff: null,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("[DraftRecovery] Failed to load draft:", error);
|
||||
}
|
||||
}, [state.draft]);
|
||||
|
||||
const discardDraft = useCallback(async () => {
|
||||
if (!state.draft) {
|
||||
setState({ isOpen: false, draft: null, diff: null });
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await draftService.deleteDraft(state.draft.id);
|
||||
} catch (error) {
|
||||
console.error("[DraftRecovery] Failed to discard draft:", error);
|
||||
}
|
||||
|
||||
setState({ isOpen: false, draft: null, diff: null });
|
||||
}, [state.draft]);
|
||||
|
||||
return {
|
||||
// Recovery popup props
|
||||
isRecoveryOpen: state.isOpen,
|
||||
savedAt: state.draft?.savedAt ?? 0,
|
||||
nodeCount: state.draft?.nodes.length ?? 0,
|
||||
edgeCount: state.draft?.edges.length ?? 0,
|
||||
diff: state.diff,
|
||||
loadDraft,
|
||||
discardDraft,
|
||||
};
|
||||
}
|
||||
@@ -21,6 +21,7 @@ import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecut
|
||||
export const useFlow = () => {
|
||||
const [isLocked, setIsLocked] = useState(false);
|
||||
const [hasAutoFramed, setHasAutoFramed] = useState(false);
|
||||
const [isInitialLoadComplete, setIsInitialLoadComplete] = useState(false);
|
||||
const addNodes = useNodeStore(useShallow((state) => state.addNodes));
|
||||
const addLinks = useEdgeStore(useShallow((state) => state.addLinks));
|
||||
const updateNodeStatus = useNodeStore(
|
||||
@@ -120,6 +121,14 @@ export const useFlow = () => {
|
||||
if (customNodes.length > 0) {
|
||||
useNodeStore.getState().setNodes([]);
|
||||
addNodes(customNodes);
|
||||
|
||||
// Sync hardcoded values with handle IDs.
|
||||
// If a key–value field has a key without a value, the backend omits it from hardcoded values.
|
||||
// But if a handleId exists for that key, it causes inconsistency.
|
||||
// This ensures hardcoded values stay in sync with handle IDs.
|
||||
customNodes.forEach((node) => {
|
||||
useNodeStore.getState().syncHardcodedValuesWithHandleIds(node.id);
|
||||
});
|
||||
}
|
||||
}, [customNodes, addNodes]);
|
||||
|
||||
@@ -174,11 +183,23 @@ export const useFlow = () => {
|
||||
if (customNodes.length > 0 && graph?.links) {
|
||||
const timer = setTimeout(() => {
|
||||
useHistoryStore.getState().initializeHistory();
|
||||
// Mark initial load as complete after history is initialized
|
||||
setIsInitialLoadComplete(true);
|
||||
}, 100);
|
||||
return () => clearTimeout(timer);
|
||||
}
|
||||
}, [customNodes, graph?.links]);
|
||||
|
||||
// Also mark as complete for new flows (no flowID) after a short delay
|
||||
useEffect(() => {
|
||||
if (!flowID && !isGraphLoading && !isBlocksLoading) {
|
||||
const timer = setTimeout(() => {
|
||||
setIsInitialLoadComplete(true);
|
||||
}, 200);
|
||||
return () => clearTimeout(timer);
|
||||
}
|
||||
}, [flowID, isGraphLoading, isBlocksLoading]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
useNodeStore.getState().setNodes([]);
|
||||
@@ -217,6 +238,7 @@ export const useFlow = () => {
|
||||
|
||||
useEffect(() => {
|
||||
setHasAutoFramed(false);
|
||||
setIsInitialLoadComplete(false);
|
||||
}, [flowID, flowVersion]);
|
||||
|
||||
// Drag and drop block from block menu
|
||||
@@ -253,6 +275,7 @@ export const useFlow = () => {
|
||||
|
||||
return {
|
||||
isFlowContentLoading: isGraphLoading || isBlocksLoading,
|
||||
isInitialLoadComplete,
|
||||
onDragOver,
|
||||
onDrop,
|
||||
isLocked,
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import { Connection as RFConnection, EdgeChange } from "@xyflow/react";
|
||||
import {
|
||||
Connection as RFConnection,
|
||||
EdgeChange,
|
||||
applyEdgeChanges,
|
||||
} from "@xyflow/react";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { useCallback } from "react";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { CustomEdge } from "./CustomEdge";
|
||||
|
||||
export const useCustomEdge = () => {
|
||||
const edges = useEdgeStore((s) => s.edges);
|
||||
const addEdge = useEdgeStore((s) => s.addEdge);
|
||||
const removeEdge = useEdgeStore((s) => s.removeEdge);
|
||||
const setEdges = useEdgeStore((s) => s.setEdges);
|
||||
|
||||
const onConnect = useCallback(
|
||||
(conn: RFConnection) => {
|
||||
@@ -45,14 +50,10 @@ export const useCustomEdge = () => {
|
||||
);
|
||||
|
||||
const onEdgesChange = useCallback(
|
||||
(changes: EdgeChange[]) => {
|
||||
changes.forEach((change) => {
|
||||
if (change.type === "remove") {
|
||||
removeEdge(change.id);
|
||||
}
|
||||
});
|
||||
(changes: EdgeChange<CustomEdge>[]) => {
|
||||
setEdges(applyEdgeChanges(changes, edges));
|
||||
},
|
||||
[removeEdge],
|
||||
[edges, setEdges],
|
||||
);
|
||||
|
||||
return { edges, onConnect, onEdgesChange };
|
||||
|
||||
@@ -1,26 +1,32 @@
|
||||
import { CircleIcon } from "@phosphor-icons/react";
|
||||
import { Handle, Position } from "@xyflow/react";
|
||||
import { useEdgeStore } from "../../../stores/edgeStore";
|
||||
import { cleanUpHandleId } from "@/components/renderers/InputRenderer/helpers";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const NodeHandle = ({
|
||||
const InputNodeHandle = ({
|
||||
handleId,
|
||||
isConnected,
|
||||
side,
|
||||
nodeId,
|
||||
}: {
|
||||
handleId: string;
|
||||
isConnected: boolean;
|
||||
side: "left" | "right";
|
||||
nodeId: string;
|
||||
}) => {
|
||||
const cleanedHandleId = cleanUpHandleId(handleId);
|
||||
const isInputConnected = useEdgeStore((state) =>
|
||||
state.isInputConnected(nodeId ?? "", cleanedHandleId),
|
||||
);
|
||||
|
||||
return (
|
||||
<Handle
|
||||
type={side === "left" ? "target" : "source"}
|
||||
position={side === "left" ? Position.Left : Position.Right}
|
||||
id={handleId}
|
||||
className={side === "left" ? "-ml-4 mr-2" : "-mr-2 ml-2"}
|
||||
type={"target"}
|
||||
position={Position.Left}
|
||||
id={cleanedHandleId}
|
||||
className={"-ml-6 mr-2"}
|
||||
>
|
||||
<div className="pointer-events-none">
|
||||
<CircleIcon
|
||||
size={16}
|
||||
weight={isConnected ? "fill" : "duotone"}
|
||||
weight={isInputConnected ? "fill" : "duotone"}
|
||||
className={"text-gray-400 opacity-100"}
|
||||
/>
|
||||
</div>
|
||||
@@ -28,4 +34,35 @@ const NodeHandle = ({
|
||||
);
|
||||
};
|
||||
|
||||
export default NodeHandle;
|
||||
const OutputNodeHandle = ({
|
||||
field_name,
|
||||
nodeId,
|
||||
hexColor,
|
||||
}: {
|
||||
field_name: string;
|
||||
nodeId: string;
|
||||
hexColor: string;
|
||||
}) => {
|
||||
const isOutputConnected = useEdgeStore((state) =>
|
||||
state.isOutputConnected(nodeId, field_name),
|
||||
);
|
||||
return (
|
||||
<Handle
|
||||
type={"source"}
|
||||
position={Position.Right}
|
||||
id={field_name}
|
||||
className={"-mr-2 ml-2"}
|
||||
>
|
||||
<div className="pointer-events-none">
|
||||
<CircleIcon
|
||||
size={16}
|
||||
weight={"duotone"}
|
||||
color={isOutputConnected ? hexColor : "gray"}
|
||||
className={cn("text-gray-400 opacity-100")}
|
||||
/>
|
||||
</div>
|
||||
</Handle>
|
||||
);
|
||||
};
|
||||
|
||||
export { InputNodeHandle, OutputNodeHandle };
|
||||
|
||||
@@ -1,31 +1,4 @@
|
||||
/**
|
||||
* Handle ID Types for different input structures
|
||||
*
|
||||
* Examples:
|
||||
* SIMPLE: "message"
|
||||
* NESTED: "config.api_key"
|
||||
* ARRAY: "items_$_0", "items_$_1"
|
||||
* KEY_VALUE: "headers_#_Authorization", "params_#_limit"
|
||||
*
|
||||
* Note: All handle IDs are sanitized to remove spaces and special characters.
|
||||
* Spaces become underscores, and special characters are removed.
|
||||
* Example: "user name" becomes "user_name", "email@domain.com" becomes "emaildomaincom"
|
||||
*/
|
||||
export enum HandleIdType {
|
||||
SIMPLE = "SIMPLE",
|
||||
NESTED = "NESTED",
|
||||
ARRAY = "ARRAY",
|
||||
KEY_VALUE = "KEY_VALUE",
|
||||
}
|
||||
|
||||
const fromRjsfId = (id: string): string => {
|
||||
if (!id) return "";
|
||||
const parts = id.split("_");
|
||||
const filtered = parts.filter(
|
||||
(p) => p !== "root" && p !== "properties" && p.length > 0,
|
||||
);
|
||||
return filtered.join("_") || "";
|
||||
};
|
||||
// Here we are handling single level of nesting, if need more in future then i will update it
|
||||
|
||||
const sanitizeForHandleId = (str: string): string => {
|
||||
if (!str) return "";
|
||||
@@ -38,51 +11,53 @@ const sanitizeForHandleId = (str: string): string => {
|
||||
.replace(/^_|_$/g, ""); // Remove leading/trailing underscores
|
||||
};
|
||||
|
||||
export const generateHandleId = (
|
||||
const cleanTitleId = (id: string): string => {
|
||||
if (!id) return "";
|
||||
|
||||
if (id.endsWith("_title")) {
|
||||
id = id.slice(0, -6);
|
||||
}
|
||||
const parts = id.split("_");
|
||||
const filtered = parts.filter(
|
||||
(p) => p !== "root" && p !== "properties" && p.length > 0,
|
||||
);
|
||||
const filtered_id = filtered.join("_") || "";
|
||||
return filtered_id;
|
||||
};
|
||||
|
||||
export const generateHandleIdFromTitleId = (
|
||||
fieldKey: string,
|
||||
nestedValues: string[] = [],
|
||||
type: HandleIdType = HandleIdType.SIMPLE,
|
||||
{
|
||||
isObjectProperty,
|
||||
isAdditionalProperty,
|
||||
isArrayItem,
|
||||
}: {
|
||||
isArrayItem?: boolean;
|
||||
isObjectProperty?: boolean;
|
||||
isAdditionalProperty?: boolean;
|
||||
} = {
|
||||
isArrayItem: false,
|
||||
isObjectProperty: false,
|
||||
isAdditionalProperty: false,
|
||||
},
|
||||
): string => {
|
||||
if (!fieldKey) return "";
|
||||
|
||||
fieldKey = fromRjsfId(fieldKey);
|
||||
fieldKey = sanitizeForHandleId(fieldKey);
|
||||
const filteredKey = cleanTitleId(fieldKey);
|
||||
if (isAdditionalProperty || isArrayItem) {
|
||||
return filteredKey;
|
||||
}
|
||||
const cleanedKey = sanitizeForHandleId(filteredKey);
|
||||
|
||||
if (type === HandleIdType.SIMPLE || nestedValues.length === 0) {
|
||||
return fieldKey;
|
||||
if (isObjectProperty) {
|
||||
// "config_api_key" -> "config.api_key"
|
||||
const parts = cleanedKey.split("_");
|
||||
if (parts.length >= 2) {
|
||||
const baseName = parts[0];
|
||||
const propertyName = parts.slice(1).join("_");
|
||||
return `${baseName}.${propertyName}`;
|
||||
}
|
||||
}
|
||||
|
||||
const sanitizedNestedValues = nestedValues.map((value) =>
|
||||
sanitizeForHandleId(value),
|
||||
);
|
||||
|
||||
switch (type) {
|
||||
case HandleIdType.NESTED:
|
||||
return [fieldKey, ...sanitizedNestedValues].join(".");
|
||||
|
||||
case HandleIdType.ARRAY:
|
||||
return [fieldKey, ...sanitizedNestedValues].join("_$_");
|
||||
|
||||
case HandleIdType.KEY_VALUE:
|
||||
return [fieldKey, ...sanitizedNestedValues].join("_#_");
|
||||
|
||||
default:
|
||||
return fieldKey;
|
||||
}
|
||||
};
|
||||
|
||||
export const parseKeyValueHandleId = (
|
||||
handleId: string,
|
||||
type: HandleIdType,
|
||||
): string => {
|
||||
if (type === HandleIdType.KEY_VALUE) {
|
||||
return handleId.split("_#_")[1];
|
||||
} else if (type === HandleIdType.ARRAY) {
|
||||
return handleId.split("_$_")[1];
|
||||
} else if (type === HandleIdType.NESTED) {
|
||||
return handleId.split(".")[1];
|
||||
} else if (type === HandleIdType.SIMPLE) {
|
||||
return handleId.split("_")[1];
|
||||
}
|
||||
return "";
|
||||
return cleanedKey;
|
||||
};
|
||||
|
||||
@@ -1,24 +1,25 @@
|
||||
import React from "react";
|
||||
import { Node as XYNode, NodeProps } from "@xyflow/react";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
import { BlockUIType } from "../../../types";
|
||||
import { StickyNoteBlock } from "./components/StickyNoteBlock";
|
||||
import { BlockInfoCategoriesItem } from "@/app/api/__generated__/models/blockInfoCategoriesItem";
|
||||
import { BlockCost } from "@/app/api/__generated__/models/blockCost";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { BlockCost } from "@/app/api/__generated__/models/blockCost";
|
||||
import { BlockInfoCategoriesItem } from "@/app/api/__generated__/models/blockInfoCategoriesItem";
|
||||
import { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult";
|
||||
import { NodeContainer } from "./components/NodeContainer";
|
||||
import { NodeHeader } from "./components/NodeHeader";
|
||||
import { FormCreator } from "../FormCreator";
|
||||
import { preprocessInputSchema } from "@/components/renderers/input-renderer/utils/input-schema-pre-processor";
|
||||
import { OutputHandler } from "../OutputHandler";
|
||||
import { NodeAdvancedToggle } from "./components/NodeAdvancedToggle";
|
||||
import { NodeDataRenderer } from "./components/NodeOutput/NodeOutput";
|
||||
import { NodeExecutionBadge } from "./components/NodeExecutionBadge";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { WebhookDisclaimer } from "./components/WebhookDisclaimer";
|
||||
import { AyrshareConnectButton } from "./components/AyrshareConnectButton";
|
||||
import { NodeModelMetadata } from "@/app/api/__generated__/models/nodeModelMetadata";
|
||||
import { preprocessInputSchema } from "@/components/renderers/InputRenderer/utils/input-schema-pre-processor";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
import { NodeProps, Node as XYNode } from "@xyflow/react";
|
||||
import React from "react";
|
||||
import { BlockUIType } from "../../../types";
|
||||
import { FormCreator } from "../FormCreator";
|
||||
import { OutputHandler } from "../OutputHandler";
|
||||
import { AyrshareConnectButton } from "./components/AyrshareConnectButton";
|
||||
import { NodeAdvancedToggle } from "./components/NodeAdvancedToggle";
|
||||
import { NodeContainer } from "./components/NodeContainer";
|
||||
import { NodeExecutionBadge } from "./components/NodeExecutionBadge";
|
||||
import { NodeHeader } from "./components/NodeHeader";
|
||||
import { NodeDataRenderer } from "./components/NodeOutput/NodeOutput";
|
||||
import { NodeRightClickMenu } from "./components/NodeRightClickMenu";
|
||||
import { StickyNoteBlock } from "./components/StickyNoteBlock";
|
||||
import { WebhookDisclaimer } from "./components/WebhookDisclaimer";
|
||||
|
||||
export type CustomNodeData = {
|
||||
hardcodedValues: {
|
||||
@@ -88,7 +89,7 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
|
||||
// Currently all blockTypes design are similar - that's why i am using the same component for all of them
|
||||
// If in future - if we need some drastic change in some blockTypes design - we can create separate components for them
|
||||
return (
|
||||
const node = (
|
||||
<NodeContainer selected={selected} nodeId={nodeId} hasErrors={hasErrors}>
|
||||
<div className="rounded-xlarge bg-white">
|
||||
<NodeHeader data={data} nodeId={nodeId} />
|
||||
@@ -99,7 +100,7 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
nodeId={nodeId}
|
||||
uiType={data.uiType}
|
||||
className={cn(
|
||||
"bg-white pr-6",
|
||||
"bg-white px-4",
|
||||
isWebhook && "pointer-events-none opacity-50",
|
||||
)}
|
||||
showHandles={showHandles}
|
||||
@@ -117,6 +118,15 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||
<NodeExecutionBadge nodeId={nodeId} />
|
||||
</NodeContainer>
|
||||
);
|
||||
|
||||
return (
|
||||
<NodeRightClickMenu
|
||||
nodeId={nodeId}
|
||||
subGraphID={data.hardcodedValues?.graph_id}
|
||||
>
|
||||
{node}
|
||||
</NodeRightClickMenu>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ export const NodeAdvancedToggle = ({ nodeId }: { nodeId: string }) => {
|
||||
);
|
||||
const setShowAdvanced = useNodeStore((state) => state.setShowAdvanced);
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-2 rounded-b-xlarge border-t border-slate-200/50 bg-white px-5 py-3.5">
|
||||
<div className="flex items-center justify-between gap-2 rounded-b-xlarge border-t border-zinc-200 bg-white px-5 py-3.5">
|
||||
<Text variant="body" className="font-medium text-slate-700">
|
||||
Advanced
|
||||
</Text>
|
||||
|
||||
@@ -22,7 +22,7 @@ export const NodeContainer = ({
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"z-12 max-w-[370px] rounded-xlarge ring-1 ring-slate-200/60",
|
||||
"z-12 w-[350px] rounded-xlarge ring-1 ring-slate-200/60",
|
||||
selected && "shadow-lg ring-2 ring-slate-200",
|
||||
status && nodeStyleBasedOnStatus[status],
|
||||
hasErrors ? nodeStyleBasedOnStatus[AgentExecutionStatus.FAILED] : "",
|
||||
|
||||
@@ -1,26 +1,31 @@
|
||||
import { Separator } from "@/components/__legacy__/ui/separator";
|
||||
import { useCopyPasteStore } from "@/app/(platform)/build/stores/copyPasteStore";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/molecules/DropdownMenu/DropdownMenu";
|
||||
import { DotsThreeOutlineVerticalIcon } from "@phosphor-icons/react";
|
||||
import { Copy, Trash2, ExternalLink } from "lucide-react";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { useCopyPasteStore } from "@/app/(platform)/build/stores/copyPasteStore";
|
||||
import {
|
||||
SecondaryDropdownMenuContent,
|
||||
SecondaryDropdownMenuItem,
|
||||
SecondaryDropdownMenuSeparator,
|
||||
} from "@/components/molecules/SecondaryMenu/SecondaryMenu";
|
||||
import {
|
||||
ArrowSquareOutIcon,
|
||||
CopyIcon,
|
||||
DotsThreeOutlineVerticalIcon,
|
||||
TrashIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { useReactFlow } from "@xyflow/react";
|
||||
|
||||
export const NodeContextMenu = ({
|
||||
nodeId,
|
||||
subGraphID,
|
||||
}: {
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
subGraphID?: string;
|
||||
}) => {
|
||||
};
|
||||
|
||||
export const NodeContextMenu = ({ nodeId, subGraphID }: Props) => {
|
||||
const { deleteElements } = useReactFlow();
|
||||
|
||||
const handleCopy = () => {
|
||||
function handleCopy() {
|
||||
useNodeStore.setState((state) => ({
|
||||
nodes: state.nodes.map((node) => ({
|
||||
...node,
|
||||
@@ -30,47 +35,47 @@ export const NodeContextMenu = ({
|
||||
|
||||
useCopyPasteStore.getState().copySelectedNodes();
|
||||
useCopyPasteStore.getState().pasteNodes();
|
||||
};
|
||||
}
|
||||
|
||||
const handleDelete = () => {
|
||||
function handleDelete() {
|
||||
deleteElements({ nodes: [{ id: nodeId }] });
|
||||
};
|
||||
}
|
||||
|
||||
return (
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger className="py-2">
|
||||
<DotsThreeOutlineVerticalIcon size={16} weight="fill" />
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent
|
||||
side="right"
|
||||
align="start"
|
||||
className="rounded-xlarge"
|
||||
>
|
||||
<DropdownMenuItem onClick={handleCopy} className="hover:rounded-xlarge">
|
||||
<Copy className="mr-2 h-4 w-4" />
|
||||
Copy Node
|
||||
</DropdownMenuItem>
|
||||
<SecondaryDropdownMenuContent side="right" align="start">
|
||||
<SecondaryDropdownMenuItem onClick={handleCopy}>
|
||||
<CopyIcon size={20} className="mr-2 dark:text-gray-100" />
|
||||
<span className="dark:text-gray-100">Copy</span>
|
||||
</SecondaryDropdownMenuItem>
|
||||
<SecondaryDropdownMenuSeparator />
|
||||
|
||||
{subGraphID && (
|
||||
<DropdownMenuItem
|
||||
onClick={() => window.open(`/build?flowID=${subGraphID}`)}
|
||||
className="hover:rounded-xlarge"
|
||||
>
|
||||
<ExternalLink className="mr-2 h-4 w-4" />
|
||||
Open Agent
|
||||
</DropdownMenuItem>
|
||||
<>
|
||||
<SecondaryDropdownMenuItem
|
||||
onClick={() => window.open(`/build?flowID=${subGraphID}`)}
|
||||
>
|
||||
<ArrowSquareOutIcon
|
||||
size={20}
|
||||
className="mr-2 dark:text-gray-100"
|
||||
/>
|
||||
<span className="dark:text-gray-100">Open agent</span>
|
||||
</SecondaryDropdownMenuItem>
|
||||
<SecondaryDropdownMenuSeparator />
|
||||
</>
|
||||
)}
|
||||
|
||||
<Separator className="my-2" />
|
||||
|
||||
<DropdownMenuItem
|
||||
onClick={handleDelete}
|
||||
className="text-red-600 hover:rounded-xlarge"
|
||||
>
|
||||
<Trash2 className="mr-2 h-4 w-4" />
|
||||
Delete
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
<SecondaryDropdownMenuItem variant="destructive" onClick={handleDelete}>
|
||||
<TrashIcon
|
||||
size={20}
|
||||
className="mr-2 text-red-500 dark:text-red-400"
|
||||
/>
|
||||
<span className="dark:text-red-400">Delete</span>
|
||||
</SecondaryDropdownMenuItem>
|
||||
</SecondaryDropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
);
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user