mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-09 06:15:41 -05:00
Compare commits
6 Commits
feat/copit
...
refactor/a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
562cf04ab6 | ||
|
|
90b3b5ba16 | ||
|
|
f4f81bc4fc | ||
|
|
c5abc01f25 | ||
|
|
8b7053c1de | ||
|
|
e00c1202ad |
2
.github/workflows/classic-frontend-ci.yml
vendored
2
.github/workflows/classic-frontend-ci.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
|
||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||
if: github.event_name == 'push'
|
||||
uses: peter-evans/create-pull-request@v8
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
add-paths: classic/frontend/build/web
|
||||
base: ${{ github.ref_name }}
|
||||
|
||||
@@ -42,7 +42,7 @@ jobs:
|
||||
|
||||
- name: Get CI failure details
|
||||
id: failure_details
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const run = await github.rest.actions.getWorkflowRun({
|
||||
|
||||
9
.github/workflows/claude-dependabot.yml
vendored
9
.github/workflows/claude-dependabot.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -78,7 +78,7 @@ jobs:
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
@@ -91,7 +91,7 @@ jobs:
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
@@ -124,7 +124,7 @@ jobs:
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
@@ -309,7 +309,6 @@ jobs:
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
allowed_bots: "dependabot[bot]"
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||
prompt: |
|
||||
|
||||
8
.github/workflows/claude.yml
vendored
8
.github/workflows/claude.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -94,7 +94,7 @@ jobs:
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
@@ -107,7 +107,7 @@ jobs:
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
@@ -140,7 +140,7 @@ jobs:
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
|
||||
8
.github/workflows/copilot-setup-steps.yml
vendored
8
.github/workflows/copilot-setup-steps.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -76,7 +76,7 @@ jobs:
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
@@ -89,7 +89,7 @@ jobs:
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
@@ -132,7 +132,7 @@ jobs:
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
|
||||
2
.github/workflows/docs-block-sync.yml
vendored
2
.github/workflows/docs-block-sync.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
2
.github/workflows/docs-claude-review.yml
vendored
2
.github/workflows/docs-claude-review.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
2
.github/workflows/docs-enhance.yml
vendored
2
.github/workflows/docs-enhance.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -88,7 +88,7 @@ jobs:
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: Check comment permissions and deployment status
|
||||
id: check_status
|
||||
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const commentBody = context.payload.comment.body.trim();
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
|
||||
- name: Post permission denied comment
|
||||
if: steps.check_status.outputs.permission_denied == 'true'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
- name: Get PR details for deployment
|
||||
id: pr_details
|
||||
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const pr = await github.rest.pulls.get({
|
||||
@@ -98,7 +98,7 @@ jobs:
|
||||
|
||||
- name: Post deploy success comment
|
||||
if: steps.check_status.outputs.should_deploy == 'true'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
@@ -126,7 +126,7 @@ jobs:
|
||||
|
||||
- name: Post undeploy success comment
|
||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
@@ -139,7 +139,7 @@ jobs:
|
||||
- name: Check deployment status on PR close
|
||||
id: check_pr_close
|
||||
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const comments = await github.rest.issues.listComments({
|
||||
@@ -187,7 +187,7 @@ jobs:
|
||||
github.event_name == 'pull_request' &&
|
||||
github.event.action == 'closed' &&
|
||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
|
||||
22
.github/workflows/platform-frontend-ci.yml
vendored
22
.github/workflows/platform-frontend-ci.yml
vendored
@@ -42,7 +42,7 @@ jobs:
|
||||
- 'autogpt_platform/frontend/src/components/**'
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -54,7 +54,7 @@ jobs:
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
@@ -74,7 +74,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -82,7 +82,7 @@ jobs:
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
@@ -112,7 +112,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
@@ -153,7 +153,7 @@ jobs:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -176,7 +176,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Cache Docker layers
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /tmp/.buildx-cache
|
||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -231,7 +231,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
@@ -282,7 +282,7 @@ jobs:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -290,7 +290,7 @@ jobs:
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
|
||||
8
.github/workflows/platform-fullstack-ci.yml
vendored
8
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
|
||||
1652
autogpt_platform/autogpt_libs/poetry.lock
generated
1652
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^46.0"
|
||||
cryptography = "^45.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.128.0"
|
||||
google-cloud-logging = "^3.13.0"
|
||||
launchdarkly-server-sdk = "^9.14.1"
|
||||
pydantic = "^2.12.5"
|
||||
pydantic-settings = "^2.12.0"
|
||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.27.2"
|
||||
uvicorn = "^0.40.0"
|
||||
supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pyright = "^1.1.408"
|
||||
pyright = "^1.1.404"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.3.0"
|
||||
pytest-mock = "^3.15.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pytest-cov = "^6.2.1"
|
||||
ruff = "^0.15.0"
|
||||
ruff = "^0.12.11"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -27,20 +27,12 @@ class ChatConfig(BaseSettings):
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# Streaming Configuration
|
||||
# Note: When using Claude Agent SDK, context management is handled automatically
|
||||
# via the SDK's built-in compaction. This is mainly used for the fallback path.
|
||||
max_context_messages: int = Field(
|
||||
default=100,
|
||||
ge=1,
|
||||
le=500,
|
||||
description="Max context messages (SDK handles compaction automatically)",
|
||||
default=50, ge=1, le=200, description="Maximum context messages"
|
||||
)
|
||||
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Max retries for fallback path (SDK handles retries internally)",
|
||||
)
|
||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||
max_agent_schedules: int = Field(
|
||||
default=30, description="Maximum number of agent schedules"
|
||||
@@ -101,12 +93,6 @@ class ChatConfig(BaseSettings):
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
description="Use Claude Agent SDK for chat completions",
|
||||
)
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def get_api_key(cls, v):
|
||||
@@ -146,17 +132,6 @@ class ChatConfig(BaseSettings):
|
||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||
return v
|
||||
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_agent_sdk(cls, v):
|
||||
"""Get use_claude_agent_sdk from environment if not provided."""
|
||||
# Check environment variable - default to True if not set
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
# Default to True (SDK enabled by default)
|
||||
return True if v is None else v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
@@ -273,8 +273,9 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
|
||||
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
@@ -316,9 +317,11 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.debug(
|
||||
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
|
||||
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
@@ -369,9 +372,10 @@ async def _save_session_to_db(
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
|
||||
f"roles={[m['role'] for m in messages_data]}"
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
@@ -411,7 +415,7 @@ async def get_chat_session(
|
||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||
|
||||
# Fall back to database
|
||||
logger.debug(f"Session {session_id} not in cache, checking database")
|
||||
logger.info(f"Session {session_id} not in cache, checking database")
|
||||
session = await _get_session_from_db(session_id)
|
||||
|
||||
if session is None:
|
||||
@@ -428,6 +432,7 @@ async def get_chat_session(
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await _cache_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
|
||||
@@ -598,19 +603,13 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
|
||||
# Update title in cache if it exists (instead of invalidating).
|
||||
# This prevents race conditions where cache invalidation causes
|
||||
# the frontend to see stale DB data while streaming is still in progress.
|
||||
# Invalidate cache so next fetch gets updated title
|
||||
try:
|
||||
cached = await _get_session_from_cache(session_id)
|
||||
if cached:
|
||||
cached.title = title
|
||||
await _cache_session(cached)
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
# Not critical - title will be correct on next full cache refresh
|
||||
logger.warning(
|
||||
f"Failed to update title in cache for session {session_id}: {e}"
|
||||
)
|
||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid as uuid_module
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -17,17 +16,8 @@ from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
create_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
||||
from .sdk import service as sdk_service
|
||||
from .tracking import track_user_message
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -219,10 +209,6 @@ async def get_session(
|
||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_task:
|
||||
# Filter out the in-progress assistant message from the session response.
|
||||
# The client will receive the complete assistant response through the SSE
|
||||
@@ -279,29 +265,9 @@ async def stream_chat_post(
|
||||
containing the task_id for reconnection.
|
||||
|
||||
"""
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
import asyncio
|
||||
|
||||
# Add user message to session BEFORE creating task to avoid race condition
|
||||
# where GET_SESSION sees the task as "running" but the message isn't saved yet
|
||||
if request.message:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
)
|
||||
if request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
logger.info(
|
||||
f"[STREAM] Saving user message to session {session_id}, "
|
||||
f"msg_count={len(session.messages)}"
|
||||
)
|
||||
session = await upsert_chat_session(session)
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
task_id = str(uuid_module.uuid4())
|
||||
@@ -317,38 +283,24 @@ async def stream_chat_post(
|
||||
|
||||
# Background task that runs the AI generation independently of SSE connection
|
||||
async def run_ai_generation():
|
||||
chunk_count = 0
|
||||
try:
|
||||
# Emit a start event with task_id for reconnection
|
||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||
|
||||
# Choose service based on configuration
|
||||
use_sdk = config.use_claude_agent_sdk
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else chat_service.stream_chat_completion
|
||||
)
|
||||
# Pass message=None since we already added it to the session above
|
||||
async for chunk in stream_fn(
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
None, # Message already in session
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass session with message already added
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
chunk_count += 1
|
||||
# Write to Redis (subscribers will receive via XREAD)
|
||||
await stream_registry.publish_chunk(task_id, chunk)
|
||||
|
||||
logger.info(
|
||||
f"[BG_TASK] AI generation completed for session {session_id}: {chunk_count} chunks, marking task {task_id} as completed"
|
||||
)
|
||||
# Mark task as completed (also publishes StreamFinish)
|
||||
completed = await stream_registry.mark_task_completed(task_id, "completed")
|
||||
logger.info(f"[BG_TASK] mark_task_completed returned: {completed}")
|
||||
# Mark task as completed
|
||||
await stream_registry.mark_task_completed(task_id, "completed")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in background AI generation for session {session_id}: {e}"
|
||||
@@ -363,7 +315,7 @@ async def stream_chat_post(
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
subscriber_queue = None
|
||||
try:
|
||||
# Subscribe to the task stream (replays + live updates)
|
||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
@@ -371,7 +323,6 @@ async def stream_chat_post(
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
logger.warning(f"Failed to subscribe to task {task_id}")
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
@@ -390,11 +341,11 @@ async def stream_chat_post(
|
||||
yield StreamHeartbeat().to_sse()
|
||||
|
||||
except GeneratorExit:
|
||||
pass # Client disconnected - normal behavior
|
||||
pass # Client disconnected - background task continues
|
||||
except Exception as e:
|
||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
@@ -449,21 +400,35 @@ async def stream_chat_get(
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
# Choose service based on configuration
|
||||
use_sdk = config.use_claude_agent_sdk
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else chat_service.stream_chat_completion
|
||||
)
|
||||
async for chunk in stream_fn(
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
@@ -585,6 +550,8 @@ async def stream_task(
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import asyncio
|
||||
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
"""Claude Agent SDK integration for CoPilot.
|
||||
|
||||
This module provides the integration layer between the Claude Agent SDK
|
||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||
"""
|
||||
|
||||
from .service import stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server
|
||||
|
||||
__all__ = [
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
@@ -1,348 +0,0 @@
|
||||
"""Anthropic SDK fallback implementation.
|
||||
|
||||
This module provides the fallback streaming implementation using the Anthropic SDK
|
||||
directly when the Claude Agent SDK is not available.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
|
||||
from ..model import ChatMessage, ChatSession
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from .tool_adapter import get_tool_definitions, get_tool_handlers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def stream_with_anthropic(
|
||||
session: ChatSession,
|
||||
system_prompt: str,
|
||||
text_block_id: str,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream using Anthropic SDK directly with tool calling support.
|
||||
|
||||
This function accumulates messages into the session for persistence.
|
||||
The caller should NOT yield an additional StreamFinish - this function handles it.
|
||||
"""
|
||||
import anthropic
|
||||
|
||||
# Only use ANTHROPIC_API_KEY - don't fall back to OpenRouter keys
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
yield StreamError(
|
||||
errorText="ANTHROPIC_API_KEY not configured for fallback",
|
||||
code="config_error",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||
tool_definitions = get_tool_definitions()
|
||||
tool_handlers = get_tool_handlers()
|
||||
|
||||
anthropic_tools = [
|
||||
{
|
||||
"name": t["name"],
|
||||
"description": t["description"],
|
||||
"input_schema": t["inputSchema"],
|
||||
}
|
||||
for t in tool_definitions
|
||||
]
|
||||
|
||||
anthropic_messages = _convert_session_to_anthropic(session)
|
||||
|
||||
if not anthropic_messages or anthropic_messages[-1]["role"] != "user":
|
||||
anthropic_messages.append(
|
||||
{"role": "user", "content": "Continue with the task."}
|
||||
)
|
||||
|
||||
has_started_text = False
|
||||
max_iterations = 10
|
||||
accumulated_text = ""
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
for _ in range(max_iterations):
|
||||
try:
|
||||
async with client.messages.stream(
|
||||
model="claude-sonnet-4-20250514",
|
||||
max_tokens=4096,
|
||||
system=system_prompt,
|
||||
messages=cast(Any, anthropic_messages),
|
||||
tools=cast(Any, anthropic_tools) if anthropic_tools else [],
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
if event.type == "content_block_start":
|
||||
block = event.content_block
|
||||
if hasattr(block, "type"):
|
||||
if block.type == "text" and not has_started_text:
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
has_started_text = True
|
||||
elif block.type == "tool_use":
|
||||
yield StreamToolInputStart(
|
||||
toolCallId=block.id, toolName=block.name
|
||||
)
|
||||
|
||||
elif event.type == "content_block_delta":
|
||||
delta = event.delta
|
||||
if hasattr(delta, "type") and delta.type == "text_delta":
|
||||
accumulated_text += delta.text
|
||||
yield StreamTextDelta(id=text_block_id, delta=delta.text)
|
||||
|
||||
final_message = await stream.get_final_message()
|
||||
|
||||
if final_message.stop_reason == "tool_use":
|
||||
if has_started_text:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
has_started_text = False
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
tool_results = []
|
||||
assistant_content: list[dict[str, Any]] = []
|
||||
|
||||
for block in final_message.content:
|
||||
if block.type == "text":
|
||||
assistant_content.append(
|
||||
{"type": "text", "text": block.text}
|
||||
)
|
||||
elif block.type == "tool_use":
|
||||
assistant_content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"input": block.input,
|
||||
}
|
||||
)
|
||||
|
||||
# Track tool call for session persistence
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": block.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.name,
|
||||
"arguments": json.dumps(
|
||||
block.input
|
||||
if isinstance(block.input, dict)
|
||||
else {}
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=block.id,
|
||||
toolName=block.name,
|
||||
input=(
|
||||
block.input if isinstance(block.input, dict) else {}
|
||||
),
|
||||
)
|
||||
|
||||
output, is_error = await _execute_tool(
|
||||
block.name, block.input, tool_handlers
|
||||
)
|
||||
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=block.id,
|
||||
toolName=block.name,
|
||||
output=output,
|
||||
success=not is_error,
|
||||
)
|
||||
|
||||
# Save tool result to session
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=output,
|
||||
tool_call_id=block.id,
|
||||
)
|
||||
)
|
||||
|
||||
tool_results.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": block.id,
|
||||
"content": output,
|
||||
"is_error": is_error,
|
||||
}
|
||||
)
|
||||
|
||||
# Save assistant message with tool calls to session
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=accumulated_text or None,
|
||||
tool_calls=(
|
||||
accumulated_tool_calls
|
||||
if accumulated_tool_calls
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
# Reset for next iteration
|
||||
accumulated_text = ""
|
||||
accumulated_tool_calls = []
|
||||
|
||||
anthropic_messages.append(
|
||||
{"role": "assistant", "content": assistant_content}
|
||||
)
|
||||
anthropic_messages.append({"role": "user", "content": tool_results})
|
||||
continue
|
||||
|
||||
else:
|
||||
if has_started_text:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
|
||||
# Save final assistant response to session
|
||||
if accumulated_text:
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content=accumulated_text)
|
||||
)
|
||||
|
||||
yield StreamUsage(
|
||||
promptTokens=final_message.usage.input_tokens,
|
||||
completionTokens=final_message.usage.output_tokens,
|
||||
totalTokens=final_message.usage.input_tokens
|
||||
+ final_message.usage.output_tokens,
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Anthropic Fallback] Error: {e}", exc_info=True)
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="anthropic_error",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
yield StreamError(errorText="Max tool iterations reached", code="max_iterations")
|
||||
yield StreamFinish()
|
||||
|
||||
|
||||
def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]:
|
||||
"""Convert session messages to Anthropic format.
|
||||
|
||||
Handles merging consecutive same-role messages (Anthropic requires alternating roles).
|
||||
"""
|
||||
messages: list[dict[str, Any]] = []
|
||||
|
||||
for msg in session.messages:
|
||||
if msg.role == "user":
|
||||
new_msg = {"role": "user", "content": msg.content or ""}
|
||||
elif msg.role == "assistant":
|
||||
content: list[dict[str, Any]] = []
|
||||
if msg.content:
|
||||
content.append({"type": "text", "text": msg.content})
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc.get("id", str(uuid.uuid4())),
|
||||
"name": func.get("name", ""),
|
||||
"input": args,
|
||||
}
|
||||
)
|
||||
if content:
|
||||
new_msg = {"role": "assistant", "content": content}
|
||||
else:
|
||||
continue # Skip empty assistant messages
|
||||
elif msg.role == "tool":
|
||||
new_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.tool_call_id or "",
|
||||
"content": msg.content or "",
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
continue
|
||||
|
||||
messages.append(new_msg)
|
||||
|
||||
# Merge consecutive same-role messages (Anthropic requires alternating roles)
|
||||
return _merge_consecutive_roles(messages)
|
||||
|
||||
|
||||
def _merge_consecutive_roles(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Merge consecutive messages with the same role.
|
||||
|
||||
Anthropic API requires alternating user/assistant roles.
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
merged: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
if merged and merged[-1]["role"] == msg["role"]:
|
||||
# Merge with previous message
|
||||
prev_content = merged[-1]["content"]
|
||||
new_content = msg["content"]
|
||||
|
||||
# Normalize both to list-of-blocks form
|
||||
if isinstance(prev_content, str):
|
||||
prev_content = [{"type": "text", "text": prev_content}]
|
||||
if isinstance(new_content, str):
|
||||
new_content = [{"type": "text", "text": new_content}]
|
||||
|
||||
# Ensure both are lists
|
||||
if not isinstance(prev_content, list):
|
||||
prev_content = [prev_content]
|
||||
if not isinstance(new_content, list):
|
||||
new_content = [new_content]
|
||||
|
||||
merged[-1]["content"] = prev_content + new_content
|
||||
else:
|
||||
merged.append(msg)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
async def _execute_tool(
|
||||
tool_name: str, tool_input: Any, handlers: dict[str, Any]
|
||||
) -> tuple[str, bool]:
|
||||
"""Execute a tool and return (output, is_error)."""
|
||||
handler = handlers.get(tool_name)
|
||||
if not handler:
|
||||
return f"Unknown tool: {tool_name}", True
|
||||
|
||||
try:
|
||||
result = await handler(tool_input)
|
||||
# Safely extract output - handle empty or missing content
|
||||
content = result.get("content") or []
|
||||
if content and isinstance(content, list) and len(content) > 0:
|
||||
first_item = content[0]
|
||||
output = first_item.get("text", "") if isinstance(first_item, dict) else ""
|
||||
else:
|
||||
output = ""
|
||||
is_error = result.get("isError", False)
|
||||
return output, is_error
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}", True
|
||||
@@ -1,320 +0,0 @@
|
||||
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||
|
||||
This module provides the adapter layer that converts streaming messages from
|
||||
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
|
||||
the frontend expects.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from backend.api.features.chat.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SDKResponseAdapter:
|
||||
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||
|
||||
This class maintains state during a streaming session to properly track
|
||||
text blocks, tool calls, and message lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self, message_id: str | None = None):
|
||||
"""Initialize the adapter.
|
||||
|
||||
Args:
|
||||
message_id: Optional message ID. If not provided, one will be generated.
|
||||
"""
|
||||
self.message_id = message_id or str(uuid.uuid4())
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_started_text = False
|
||||
self.has_ended_text = False
|
||||
self.current_tool_calls: dict[str, dict[str, Any]] = {}
|
||||
self.task_id: str | None = None
|
||||
|
||||
def set_task_id(self, task_id: str) -> None:
|
||||
"""Set the task ID for reconnection support."""
|
||||
self.task_id = task_id
|
||||
|
||||
def convert_message(self, sdk_message: Any) -> list[StreamBaseResponse]:
|
||||
"""Convert a single SDK message to Vercel AI SDK format.
|
||||
|
||||
Args:
|
||||
sdk_message: A message from the Claude Agent SDK.
|
||||
|
||||
Returns:
|
||||
List of StreamBaseResponse objects (may be empty or multiple).
|
||||
"""
|
||||
responses: list[StreamBaseResponse] = []
|
||||
|
||||
# Handle different SDK message types - use class name since SDK uses dataclasses
|
||||
class_name = type(sdk_message).__name__
|
||||
msg_subtype = getattr(sdk_message, "subtype", None)
|
||||
|
||||
if class_name == "SystemMessage":
|
||||
if msg_subtype == "init":
|
||||
# Session initialization - emit start
|
||||
responses.append(
|
||||
StreamStart(
|
||||
messageId=self.message_id,
|
||||
taskId=self.task_id,
|
||||
)
|
||||
)
|
||||
|
||||
elif class_name == "AssistantMessage":
|
||||
# Assistant message with content blocks
|
||||
content = getattr(sdk_message, "content", [])
|
||||
for block in content:
|
||||
# Check block type by class name (SDK uses dataclasses) or dict type
|
||||
block_class = type(block).__name__
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
|
||||
if block_class == "TextBlock" or block_type == "text":
|
||||
# Text content
|
||||
text = getattr(block, "text", None) or (
|
||||
block.get("text") if isinstance(block, dict) else ""
|
||||
)
|
||||
|
||||
if text:
|
||||
# Start text block if needed (or restart after tool calls)
|
||||
if not self.has_started_text or self.has_ended_text:
|
||||
# Generate new text block ID for text after tools
|
||||
if self.has_ended_text:
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_ended_text = False
|
||||
responses.append(StreamTextStart(id=self.text_block_id))
|
||||
self.has_started_text = True
|
||||
|
||||
# Emit text delta
|
||||
responses.append(
|
||||
StreamTextDelta(
|
||||
id=self.text_block_id,
|
||||
delta=text,
|
||||
)
|
||||
)
|
||||
|
||||
elif block_class == "ToolUseBlock" or block_type == "tool_use":
|
||||
# Tool call
|
||||
tool_id_raw = getattr(block, "id", None) or (
|
||||
block.get("id") if isinstance(block, dict) else None
|
||||
)
|
||||
tool_id: str = (
|
||||
str(tool_id_raw) if tool_id_raw else str(uuid.uuid4())
|
||||
)
|
||||
|
||||
tool_name_raw = getattr(block, "name", None) or (
|
||||
block.get("name") if isinstance(block, dict) else None
|
||||
)
|
||||
tool_name: str = str(tool_name_raw) if tool_name_raw else "unknown"
|
||||
|
||||
tool_input = getattr(block, "input", None) or (
|
||||
block.get("input") if isinstance(block, dict) else {}
|
||||
)
|
||||
|
||||
# End text block if we were streaming text
|
||||
if self.has_started_text and not self.has_ended_text:
|
||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||
self.has_ended_text = True
|
||||
|
||||
# Emit tool input start
|
||||
responses.append(
|
||||
StreamToolInputStart(
|
||||
toolCallId=tool_id,
|
||||
toolName=tool_name,
|
||||
)
|
||||
)
|
||||
|
||||
# Emit tool input available with full input
|
||||
responses.append(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=tool_id,
|
||||
toolName=tool_name,
|
||||
input=tool_input if isinstance(tool_input, dict) else {},
|
||||
)
|
||||
)
|
||||
|
||||
# Track the tool call
|
||||
self.current_tool_calls[tool_id] = {
|
||||
"name": tool_name,
|
||||
"input": tool_input,
|
||||
}
|
||||
|
||||
elif class_name in ("ToolResultMessage", "UserMessage"):
|
||||
# Tool result - check for tool_result content
|
||||
content = getattr(sdk_message, "content", [])
|
||||
|
||||
for block in content:
|
||||
block_class = type(block).__name__
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
|
||||
if block_class == "ToolResultBlock" or block_type == "tool_result":
|
||||
tool_use_id = getattr(block, "tool_use_id", None) or (
|
||||
block.get("tool_use_id") if isinstance(block, dict) else None
|
||||
)
|
||||
result_content = getattr(block, "content", None) or (
|
||||
block.get("content") if isinstance(block, dict) else ""
|
||||
)
|
||||
is_error = getattr(block, "is_error", False) or (
|
||||
block.get("is_error", False)
|
||||
if isinstance(block, dict)
|
||||
else False
|
||||
)
|
||||
|
||||
if tool_use_id:
|
||||
tool_info = self.current_tool_calls.get(tool_use_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
# Format the output
|
||||
if isinstance(result_content, list):
|
||||
# Extract text from content blocks
|
||||
output_text = ""
|
||||
for item in result_content:
|
||||
if (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") == "text"
|
||||
):
|
||||
output_text += item.get("text", "")
|
||||
elif hasattr(item, "text"):
|
||||
output_text += getattr(item, "text", "")
|
||||
if output_text:
|
||||
output = output_text
|
||||
else:
|
||||
try:
|
||||
output = json.dumps(result_content)
|
||||
except (TypeError, ValueError):
|
||||
output = str(result_content)
|
||||
elif isinstance(result_content, str):
|
||||
output = result_content
|
||||
else:
|
||||
try:
|
||||
output = json.dumps(result_content)
|
||||
except (TypeError, ValueError):
|
||||
output = str(result_content)
|
||||
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_use_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=not is_error,
|
||||
)
|
||||
)
|
||||
|
||||
elif class_name == "ResultMessage":
|
||||
# Final result
|
||||
if msg_subtype == "success":
|
||||
# End text block if still open
|
||||
if self.has_started_text and not self.has_ended_text:
|
||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||
self.has_ended_text = True
|
||||
|
||||
# Emit finish
|
||||
responses.append(StreamFinish())
|
||||
|
||||
elif msg_subtype in ("error", "error_during_execution"):
|
||||
error_msg = getattr(sdk_message, "error", "Unknown error")
|
||||
responses.append(
|
||||
StreamError(
|
||||
errorText=str(error_msg),
|
||||
code="sdk_error",
|
||||
)
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
|
||||
elif class_name == "ErrorMessage":
|
||||
# Error message
|
||||
error_msg = getattr(sdk_message, "message", None) or getattr(
|
||||
sdk_message, "error", "Unknown error"
|
||||
)
|
||||
responses.append(
|
||||
StreamError(
|
||||
errorText=str(error_msg),
|
||||
code="sdk_error",
|
||||
)
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
|
||||
else:
|
||||
logger.debug(f"Unhandled SDK message type: {class_name}")
|
||||
|
||||
return responses
|
||||
|
||||
def create_heartbeat(self, tool_call_id: str | None = None) -> StreamHeartbeat:
|
||||
"""Create a heartbeat response."""
|
||||
return StreamHeartbeat(toolCallId=tool_call_id)
|
||||
|
||||
def create_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
) -> StreamUsage:
|
||||
"""Create a usage statistics response."""
|
||||
return StreamUsage(
|
||||
promptTokens=prompt_tokens,
|
||||
completionTokens=completion_tokens,
|
||||
totalTokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
async def adapt_sdk_stream(
|
||||
sdk_stream: AsyncGenerator[Any, None],
|
||||
message_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Adapt a Claude Agent SDK stream to Vercel AI SDK format.
|
||||
|
||||
Args:
|
||||
sdk_stream: The async generator from the Claude Agent SDK.
|
||||
message_id: Optional message ID for the response.
|
||||
task_id: Optional task ID for reconnection support.
|
||||
|
||||
Yields:
|
||||
StreamBaseResponse objects in Vercel AI SDK format.
|
||||
"""
|
||||
adapter = SDKResponseAdapter(message_id=message_id)
|
||||
if task_id:
|
||||
adapter.set_task_id(task_id)
|
||||
|
||||
# Emit start immediately
|
||||
yield StreamStart(messageId=adapter.message_id, taskId=task_id)
|
||||
|
||||
finished = False
|
||||
try:
|
||||
async for sdk_message in sdk_stream:
|
||||
responses = adapter.convert_message(sdk_message)
|
||||
for response in responses:
|
||||
# Skip duplicate start messages
|
||||
if isinstance(response, StreamStart):
|
||||
continue
|
||||
if isinstance(response, StreamFinish):
|
||||
finished = True
|
||||
yield response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in SDK stream: {e}", exc_info=True)
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="stream_error",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Ensure terminal StreamFinish if SDK stream ended without one
|
||||
if not finished:
|
||||
yield StreamFinish()
|
||||
@@ -1,281 +0,0 @@
|
||||
"""Security hooks for Claude Agent SDK integration.
|
||||
|
||||
This module provides security hooks that validate tool calls before execution,
|
||||
ensuring multi-user isolation and preventing unauthorized operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools that are blocked entirely (CLI/system access)
|
||||
BLOCKED_TOOLS = {
|
||||
"Bash",
|
||||
"bash",
|
||||
"shell",
|
||||
"exec",
|
||||
"terminal",
|
||||
"command",
|
||||
"Read", # Block raw file read - use workspace tools instead
|
||||
"Write", # Block raw file write - use workspace tools instead
|
||||
"Edit", # Block raw file edit - use workspace tools instead
|
||||
"Glob", # Block raw file glob - use workspace tools instead
|
||||
"Grep", # Block raw file grep - use workspace tools instead
|
||||
}
|
||||
|
||||
# Dangerous patterns in tool inputs
|
||||
DANGEROUS_PATTERNS = [
|
||||
r"sudo",
|
||||
r"rm\s+-rf",
|
||||
r"dd\s+if=",
|
||||
r"/etc/passwd",
|
||||
r"/etc/shadow",
|
||||
r"chmod\s+777",
|
||||
r"curl\s+.*\|.*sh",
|
||||
r"wget\s+.*\|.*sh",
|
||||
r"eval\s*\(",
|
||||
r"exec\s*\(",
|
||||
r"__import__",
|
||||
r"os\.system",
|
||||
r"subprocess",
|
||||
]
|
||||
|
||||
|
||||
def _validate_tool_access(tool_name: str, tool_input: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate that a tool call is allowed.
|
||||
|
||||
Returns:
|
||||
Empty dict to allow, or dict with hookSpecificOutput to deny
|
||||
"""
|
||||
# Block forbidden tools
|
||||
if tool_name in BLOCKED_TOOLS:
|
||||
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": (
|
||||
f"Tool '{tool_name}' is not available. "
|
||||
"Use the CoPilot-specific tools instead."
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
# Check for dangerous patterns in tool input
|
||||
input_str = str(tool_input)
|
||||
|
||||
for pattern in DANGEROUS_PATTERNS:
|
||||
if re.search(pattern, input_str, re.IGNORECASE):
|
||||
logger.warning(
|
||||
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
|
||||
)
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": "Input contains blocked pattern",
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _validate_user_isolation(
|
||||
tool_name: str, tool_input: dict[str, Any], user_id: str | None
|
||||
) -> dict[str, Any]:
|
||||
"""Validate that tool calls respect user isolation."""
|
||||
# For workspace file tools, ensure path doesn't escape
|
||||
if "workspace" in tool_name.lower():
|
||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||
if path:
|
||||
# Check for path traversal
|
||||
if ".." in path or path.startswith("/"):
|
||||
logger.warning(
|
||||
f"Blocked path traversal attempt: {path} by user {user_id}"
|
||||
)
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": "Path traversal not allowed",
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def create_security_hooks(user_id: str | None) -> dict[str, Any]:
|
||||
"""Create the security hooks configuration for Claude Agent SDK.
|
||||
|
||||
Includes security validation and observability hooks:
|
||||
- PreToolUse: Security validation before tool execution
|
||||
- PostToolUse: Log successful tool executions
|
||||
- PostToolUseFailure: Log and handle failed tool executions
|
||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||
|
||||
Args:
|
||||
user_id: Current user ID for isolation validation
|
||||
|
||||
Returns:
|
||||
Hooks configuration dict for ClaudeAgentOptions
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk import HookMatcher
|
||||
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||
|
||||
async def pre_tool_use_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Combined pre-tool-use validation hook."""
|
||||
_ = context # unused but required by signature
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||
|
||||
# Strip MCP prefix for consistent validation
|
||||
clean_name = tool_name.removeprefix("mcp__copilot__")
|
||||
|
||||
# Validate basic tool access
|
||||
result = _validate_tool_access(clean_name, tool_input)
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
# Validate user isolation
|
||||
result = _validate_user_isolation(clean_name, tool_input, user_id)
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def post_tool_use_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log successful tool executions for observability."""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def post_tool_failure_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log failed tool executions for debugging."""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
error = input_data.get("error", "Unknown error")
|
||||
logger.warning(
|
||||
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def pre_compact_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log when SDK triggers context compaction.
|
||||
|
||||
The SDK automatically compacts conversation history when it grows too large.
|
||||
This hook provides visibility into when compaction happens.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
logger.info(
|
||||
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
return {
|
||||
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
||||
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
||||
"PostToolUseFailure": [
|
||||
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
||||
],
|
||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||
}
|
||||
except ImportError:
|
||||
# Fallback for when SDK isn't available - return empty hooks
|
||||
return {}
|
||||
|
||||
|
||||
def create_strict_security_hooks(
|
||||
user_id: str | None,
|
||||
allowed_tools: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create strict security hooks that only allow specific tools.
|
||||
|
||||
Args:
|
||||
user_id: Current user ID
|
||||
allowed_tools: List of allowed tool names (defaults to CoPilot tools)
|
||||
|
||||
Returns:
|
||||
Hooks configuration dict
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk import HookMatcher
|
||||
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||
|
||||
from .tool_adapter import RAW_TOOL_NAMES
|
||||
|
||||
tools_list = allowed_tools if allowed_tools is not None else RAW_TOOL_NAMES
|
||||
allowed_set = set(tools_list)
|
||||
|
||||
async def strict_pre_tool_use(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Strict validation that only allows whitelisted tools."""
|
||||
_ = context # unused but required by signature
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||
|
||||
# Remove MCP prefix if present
|
||||
clean_name = tool_name.removeprefix("mcp__copilot__")
|
||||
|
||||
if clean_name not in allowed_set:
|
||||
logger.warning(f"Blocked non-whitelisted tool: {tool_name}")
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
{
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": (
|
||||
f"Tool '{tool_name}' is not in the allowed list"
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Run standard validations using clean_name for consistent checks
|
||||
result = _validate_tool_access(clean_name, tool_input)
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
result = _validate_user_isolation(clean_name, tool_input, user_id)
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
logger.debug(
|
||||
f"[SDK Audit] Tool call: tool={tool_name}, "
|
||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
return {
|
||||
"PreToolUse": [
|
||||
HookMatcher(matcher="*", hooks=[strict_pre_tool_use]),
|
||||
],
|
||||
}
|
||||
except ImportError:
|
||||
return {}
|
||||
@@ -1,475 +0,0 @@
|
||||
"""Claude Agent SDK service layer for CoPilot chat completions."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
|
||||
from backend.data.understanding import (
|
||||
format_understanding_for_prompt,
|
||||
get_business_understanding,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from ..config import ChatConfig
|
||||
from ..model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from ..tracking import track_user_message
|
||||
from .anthropic_fallback import stream_with_anthropic
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .security_hooks import create_security_hooks
|
||||
from .tool_adapter import (
|
||||
COPILOT_TOOL_NAMES,
|
||||
create_copilot_mcp_server,
|
||||
set_execution_context,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Set to hold background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
|
||||
|
||||
Here is everything you know about the current user from previous interactions:
|
||||
|
||||
<users_information>
|
||||
{users_information}
|
||||
</users_information>
|
||||
|
||||
## YOUR CORE MANDATE
|
||||
|
||||
You are action-oriented. Your success is measured by:
|
||||
- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"?
|
||||
- **Demonstrable Proof**: Show working automations, not descriptions of what's possible
|
||||
- **Time Saved**: Focus on tangible efficiency gains
|
||||
- **Quality Output**: Deliver results that meet or exceed expectations
|
||||
|
||||
## YOUR WORKFLOW
|
||||
|
||||
Adapt flexibly to the conversation context. Not every interaction requires all stages:
|
||||
|
||||
1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations.
|
||||
|
||||
2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task.
|
||||
|
||||
3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.).
|
||||
|
||||
4. **Discover or Create Agents**:
|
||||
- **Always check the user's library first** with `find_library_agent` (these may be customized to their needs)
|
||||
- Search the marketplace with `find_agent` for pre-built automations
|
||||
- Find reusable components with `find_block`
|
||||
- Create custom solutions with `create_agent` if nothing suitable exists
|
||||
- Modify existing library agents with `edit_agent`
|
||||
|
||||
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
|
||||
|
||||
6. **Show Results**: Display outputs using `agent_output`.
|
||||
|
||||
## BEHAVIORAL GUIDELINES
|
||||
|
||||
**Be Concise:**
|
||||
- Target 2-5 short lines maximum
|
||||
- Make every word count—no repetition or filler
|
||||
- Use lightweight structure for scannability (bullets, numbered lists, short prompts)
|
||||
- Avoid jargon (blocks, slugs, cron) unless the user asks
|
||||
|
||||
**Be Proactive:**
|
||||
- Suggest next steps before being asked
|
||||
- Anticipate needs based on conversation context and user information
|
||||
- Look for opportunities to expand scope when relevant
|
||||
- Reveal capabilities through action, not explanation
|
||||
|
||||
**Use Tools Effectively:**
|
||||
- Select the right tool for each task
|
||||
- **Always check `find_library_agent` before searching the marketplace**
|
||||
- Use `add_understanding` to capture valuable business context
|
||||
- When tool calls fail, try alternative approaches
|
||||
|
||||
## CRITICAL REMINDER
|
||||
|
||||
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
|
||||
|
||||
|
||||
async def _build_system_prompt(
|
||||
user_id: str | None, has_conversation_history: bool = False
|
||||
) -> tuple[str, Any]:
|
||||
"""Build the system prompt with user's business understanding context.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to fetch understanding for.
|
||||
has_conversation_history: Whether there's existing conversation history.
|
||||
If True, we don't tell the model to greet/introduce (since they're
|
||||
already in a conversation).
|
||||
"""
|
||||
understanding = None
|
||||
if user_id:
|
||||
try:
|
||||
understanding = await get_business_understanding(user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch business understanding: {e}")
|
||||
|
||||
if understanding:
|
||||
context = format_understanding_for_prompt(understanding)
|
||||
elif has_conversation_history:
|
||||
# Don't tell model to greet if there's conversation history
|
||||
context = "No prior understanding saved yet. Continue the existing conversation naturally."
|
||||
else:
|
||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||
|
||||
return DEFAULT_SYSTEM_PROMPT.replace("{users_information}", context), understanding
|
||||
|
||||
|
||||
def _format_conversation_history(session: ChatSession) -> str:
|
||||
"""Format conversation history as a prompt context.
|
||||
|
||||
The SDK handles context compaction automatically, but we apply
|
||||
max_context_messages as a safety guard to limit initial prompt size.
|
||||
"""
|
||||
if not session.messages:
|
||||
return ""
|
||||
|
||||
# Get all messages except the last user message (which will be the prompt)
|
||||
messages = session.messages[:-1] if session.messages else []
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
# Apply max_context_messages limit as a safety guard
|
||||
# (SDK handles compaction, but this prevents excessively large initial prompts)
|
||||
max_messages = config.max_context_messages
|
||||
if len(messages) > max_messages:
|
||||
messages = messages[-max_messages:]
|
||||
|
||||
history_parts = ["<conversation_history>"]
|
||||
|
||||
for msg in messages:
|
||||
if msg.role == "user":
|
||||
history_parts.append(f"User: {msg.content or ''}")
|
||||
elif msg.role == "assistant":
|
||||
# Pass full content - SDK handles compaction automatically
|
||||
history_parts.append(f"Assistant: {msg.content or ''}")
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
func = tc.get("function", {})
|
||||
history_parts.append(
|
||||
f" [Called tool: {func.get('name', 'unknown')}]"
|
||||
)
|
||||
elif msg.role == "tool":
|
||||
# Truncate large tool results to avoid blowing context window
|
||||
tool_content = msg.content or ""
|
||||
if len(tool_content) > 500:
|
||||
tool_content = tool_content[:500] + "... (truncated)"
|
||||
history_parts.append(f" [Tool result: {tool_content}]")
|
||||
|
||||
history_parts.append("</conversation_history>")
|
||||
history_parts.append("")
|
||||
history_parts.append(
|
||||
"Continue this conversation. Respond to the user's latest message:"
|
||||
)
|
||||
history_parts.append("")
|
||||
|
||||
return "\n".join(history_parts)
|
||||
|
||||
|
||||
async def _generate_session_title(
|
||||
message: str,
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""Generate a concise title for a chat session."""
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
try:
|
||||
# Build extra_body for OpenRouter tracing
|
||||
extra_body: dict[str, Any] = {
|
||||
"posthogProperties": {"environment": settings.config.app_env.value},
|
||||
}
|
||||
if user_id:
|
||||
extra_body["user"] = user_id[:128]
|
||||
extra_body["posthogDistinctId"] = user_id
|
||||
if session_id:
|
||||
extra_body["session_id"] = session_id[:128]
|
||||
|
||||
client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
response = await client.chat.completions.create(
|
||||
model=config.title_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Generate a very short title (3-6 words) for a chat conversation based on the user's first message. Return ONLY the title, no quotes or punctuation.",
|
||||
},
|
||||
{"role": "user", "content": message[:500]},
|
||||
],
|
||||
max_tokens=20,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
title = response.choices[0].message.content
|
||||
if title:
|
||||
title = title.strip().strip("\"'")
|
||||
return title[:47] + "..." if len(title) > 50 else title
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate session title: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def stream_chat_completion_sdk(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
tool_call_response: str | None = None, # noqa: ARG001
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
retry_count: int = 0, # noqa: ARG001
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream chat completion using Claude Agent SDK.
|
||||
|
||||
Drop-in replacement for stream_chat_completion with improved reliability.
|
||||
"""
|
||||
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
if message:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="user" if is_user_message else "assistant", content=message
|
||||
)
|
||||
)
|
||||
if is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||
)
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions (first user message)
|
||||
if is_user_message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
if len(user_messages) == 1:
|
||||
first_message = user_messages[0].content or message or ""
|
||||
if first_message:
|
||||
task = asyncio.create_task(
|
||||
_update_title_async(session_id, first_message, user_id)
|
||||
)
|
||||
# Store reference to prevent garbage collection
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# Check if there's conversation history (more than just the current message)
|
||||
has_history = len(session.messages) > 1
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=has_history
|
||||
)
|
||||
set_execution_context(user_id, session, None)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
text_block_id = str(uuid.uuid4())
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
yield StreamStart(messageId=message_id, taskId=task_id)
|
||||
|
||||
# Track whether the stream completed normally via ResultMessage
|
||||
stream_completed = False
|
||||
|
||||
try:
|
||||
try:
|
||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||
|
||||
# Create MCP server with CoPilot tools
|
||||
mcp_server = create_copilot_mcp_server()
|
||||
|
||||
options = ClaudeAgentOptions(
|
||||
system_prompt=system_prompt,
|
||||
mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type]
|
||||
allowed_tools=COPILOT_TOOL_NAMES,
|
||||
hooks=create_security_hooks(user_id), # type: ignore[arg-type]
|
||||
continue_conversation=True, # Enable conversation continuation
|
||||
)
|
||||
|
||||
adapter = SDKResponseAdapter(message_id=message_id)
|
||||
adapter.set_task_id(task_id)
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
# Build prompt with conversation history for context
|
||||
# The SDK doesn't support replaying full conversation history,
|
||||
# so we include it as context in the prompt
|
||||
current_message = message or ""
|
||||
if not current_message and session.messages:
|
||||
last_user = [m for m in session.messages if m.role == "user"]
|
||||
if last_user:
|
||||
current_message = last_user[-1].content or ""
|
||||
|
||||
# Include conversation history if there are prior messages
|
||||
if len(session.messages) > 1:
|
||||
history_context = _format_conversation_history(session)
|
||||
prompt = f"{history_context}{current_message}"
|
||||
else:
|
||||
prompt = current_message
|
||||
|
||||
# Guard against empty prompts
|
||||
if not prompt.strip():
|
||||
yield StreamError(
|
||||
errorText="Message cannot be empty.",
|
||||
code="empty_prompt",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
await client.query(prompt, session_id=session_id)
|
||||
|
||||
# Track assistant response to save to session
|
||||
# We may need multiple assistant messages if text comes after tool results
|
||||
assistant_response = ChatMessage(role="assistant", content="")
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False # Track if we've received tool results
|
||||
|
||||
# Receive messages from the SDK
|
||||
async for sdk_msg in client.receive_messages():
|
||||
for response in adapter.convert_message(sdk_msg):
|
||||
if isinstance(response, StreamStart):
|
||||
continue
|
||||
yield response
|
||||
|
||||
# Accumulate text deltas into assistant response
|
||||
if isinstance(response, StreamTextDelta):
|
||||
delta = response.delta or ""
|
||||
# After tool results, create new assistant message for post-tool text
|
||||
if has_tool_results and has_appended_assistant:
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant", content=delta
|
||||
)
|
||||
accumulated_tool_calls = [] # Reset for new message
|
||||
session.messages.append(assistant_response)
|
||||
has_tool_results = False
|
||||
else:
|
||||
assistant_response.content = (
|
||||
assistant_response.content or ""
|
||||
) + delta
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
# Track tool calls on the assistant message
|
||||
elif isinstance(response, StreamToolInputAvailable):
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": response.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": response.toolName,
|
||||
"arguments": json.dumps(response.input or {}),
|
||||
},
|
||||
}
|
||||
)
|
||||
# Update assistant message with tool calls
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
# Append assistant message if not already (tool-only response)
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolOutputAvailable):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=(
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else str(response.output)
|
||||
),
|
||||
tool_call_id=response.toolCallId,
|
||||
)
|
||||
)
|
||||
has_tool_results = True
|
||||
|
||||
elif isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
|
||||
# Break out of the message loop if we received finish signal
|
||||
if stream_completed:
|
||||
break
|
||||
|
||||
# Ensure assistant response is saved even if no text deltas
|
||||
# (e.g., only tool calls were made)
|
||||
if (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"[SDK] claude-agent-sdk not available, using Anthropic fallback"
|
||||
)
|
||||
async for response in stream_with_anthropic(
|
||||
session, system_prompt, text_block_id
|
||||
):
|
||||
if isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
yield response
|
||||
|
||||
# Save the session with accumulated messages
|
||||
await upsert_chat_session(session)
|
||||
logger.debug(
|
||||
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
||||
)
|
||||
# Yield StreamFinish to signal completion to the caller (routes.py)
|
||||
# Only if one hasn't already been yielded by the stream
|
||||
if not stream_completed:
|
||||
yield StreamFinish()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
||||
# Save session even on error to preserve any partial response
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception as save_err:
|
||||
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
||||
# Sanitize error message to avoid exposing internal details
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="sdk_error",
|
||||
)
|
||||
yield StreamFinish()
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None = None
|
||||
) -> None:
|
||||
"""Background task to update session title."""
|
||||
try:
|
||||
title = await _generate_session_title(
|
||||
message, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if title:
|
||||
await update_session_title(session_id, title)
|
||||
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[SDK] Failed to update session title: {e}")
|
||||
@@ -1,217 +0,0 @@
|
||||
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
|
||||
|
||||
This module provides the adapter layer that converts existing BaseTool implementations
|
||||
into in-process MCP tools that can be used with the Claude Agent SDK.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import TOOL_REGISTRY
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Context variables to pass user/session info to tool execution
|
||||
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
)
|
||||
_current_tool_call_id: ContextVar[str | None] = ContextVar(
|
||||
"current_tool_call_id", default=None
|
||||
)
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
tool_call_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set the execution context for tool calls.
|
||||
|
||||
This must be called before streaming begins to ensure tools have access
|
||||
to user_id and session information.
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_current_tool_call_id.set(tool_call_id)
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None, str | None]:
|
||||
"""Get the current execution context."""
|
||||
return (
|
||||
_current_user_id.get(),
|
||||
_current_session.get(),
|
||||
_current_tool_call_id.get(),
|
||||
)
|
||||
|
||||
|
||||
def create_tool_handler(base_tool: BaseTool):
|
||||
"""Create an async handler function for a BaseTool.
|
||||
|
||||
This wraps the existing BaseTool._execute method to be compatible
|
||||
with the Claude Agent SDK MCP tool format.
|
||||
"""
|
||||
|
||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the wrapped tool and return MCP-formatted response."""
|
||||
user_id, session, tool_call_id = get_execution_context()
|
||||
|
||||
if session is None:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(
|
||||
{
|
||||
"error": "No session context available",
|
||||
"type": "error",
|
||||
}
|
||||
),
|
||||
}
|
||||
],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
try:
|
||||
# Call the existing tool's execute method
|
||||
# Generate unique tool_call_id per invocation for proper correlation
|
||||
effective_id = tool_call_id or f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
result = await base_tool.execute(
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=effective_id,
|
||||
**args,
|
||||
)
|
||||
|
||||
# The result is a StreamToolOutputAvailable, extract the output
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
result.output
|
||||
if isinstance(result.output, str)
|
||||
else json.dumps(result.output)
|
||||
),
|
||||
}
|
||||
],
|
||||
"isError": not result.success,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(
|
||||
{
|
||||
"error": str(e),
|
||||
"type": "error",
|
||||
"message": f"Failed to execute {base_tool.name}",
|
||||
}
|
||||
),
|
||||
}
|
||||
],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
return tool_handler
|
||||
|
||||
|
||||
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
||||
"""Build a JSON Schema input schema for a tool."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": base_tool.parameters.get("properties", {}),
|
||||
"required": base_tool.parameters.get("required", []),
|
||||
}
|
||||
|
||||
|
||||
def get_tool_definitions() -> list[dict[str, Any]]:
|
||||
"""Get all tool definitions in MCP format.
|
||||
|
||||
Returns a list of tool definitions that can be used with
|
||||
create_sdk_mcp_server or as raw tool definitions.
|
||||
"""
|
||||
tool_definitions = []
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
tool_def = {
|
||||
"name": tool_name,
|
||||
"description": base_tool.description,
|
||||
"inputSchema": _build_input_schema(base_tool),
|
||||
}
|
||||
tool_definitions.append(tool_def)
|
||||
|
||||
return tool_definitions
|
||||
|
||||
|
||||
def get_tool_handlers() -> dict[str, Any]:
|
||||
"""Get all tool handlers mapped by name.
|
||||
|
||||
Returns a dictionary mapping tool names to their handler functions.
|
||||
"""
|
||||
handlers = {}
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handlers[tool_name] = create_tool_handler(base_tool)
|
||||
|
||||
return handlers
|
||||
|
||||
|
||||
# Create the MCP server configuration
|
||||
def create_copilot_mcp_server():
|
||||
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||
|
||||
This can be passed to ClaudeAgentOptions.mcp_servers.
|
||||
|
||||
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
|
||||
package being available. This function returns the configuration that
|
||||
can be used with the SDK.
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
# Create decorated tool functions
|
||||
sdk_tools = []
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
# Get the handler
|
||||
handler = create_tool_handler(base_tool)
|
||||
|
||||
# Create the decorated tool
|
||||
# The @tool decorator expects (name, description, schema)
|
||||
# Pass full JSON schema with type, properties, and required
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
_build_input_schema(base_tool),
|
||||
)(handler)
|
||||
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# Create the MCP server
|
||||
server = create_sdk_mcp_server(
|
||||
name="copilot",
|
||||
version="1.0.0",
|
||||
tools=sdk_tools,
|
||||
)
|
||||
|
||||
return server
|
||||
|
||||
except ImportError:
|
||||
# Let ImportError propagate so service.py handles the fallback
|
||||
raise
|
||||
|
||||
|
||||
# List of tool names for allowed_tools configuration
|
||||
COPILOT_TOOL_NAMES = [f"mcp__copilot__{name}" for name in TOOL_REGISTRY.keys()]
|
||||
|
||||
# Also export the raw tool names for flexibility
|
||||
RAW_TOOL_NAMES = list(TOOL_REGISTRY.keys())
|
||||
@@ -555,10 +555,6 @@ async def get_active_task_for_session(
|
||||
if task_user_id and user_id != task_user_id:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
||||
)
|
||||
|
||||
# Get the last message ID from Redis Stream
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
last_id = "0-0"
|
||||
|
||||
@@ -13,32 +13,10 @@ from backend.api.features.chat.tools.models import (
|
||||
NoResultsResponse,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.data.block import BlockType, get_block
|
||||
from backend.data.block import get_block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TARGET_RESULTS = 10
|
||||
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
|
||||
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
|
||||
_OVERFETCH_PAGE_SIZE = 40
|
||||
|
||||
# Block types that only work within graphs and cannot run standalone in CoPilot.
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES = {
|
||||
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
|
||||
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
|
||||
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
|
||||
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
|
||||
BlockType.NOTE, # Visual annotation only - no runtime behavior
|
||||
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
|
||||
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
|
||||
}
|
||||
|
||||
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
||||
COPILOT_EXCLUDED_BLOCK_IDS = {
|
||||
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
|
||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
}
|
||||
|
||||
|
||||
class FindBlockTool(BaseTool):
|
||||
"""Tool for searching available blocks."""
|
||||
@@ -110,7 +88,7 @@ class FindBlockTool(BaseTool):
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=_OVERFETCH_PAGE_SIZE,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
if not results:
|
||||
@@ -130,90 +108,60 @@ class FindBlockTool(BaseTool):
|
||||
block = get_block(block_id)
|
||||
|
||||
# Skip disabled blocks
|
||||
if not block or block.disabled:
|
||||
continue
|
||||
if block and not block.disabled:
|
||||
# Get input/output schemas
|
||||
input_schema = {}
|
||||
output_schema = {}
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Skip blocks excluded from CoPilot (graph-only blocks)
|
||||
if (
|
||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
):
|
||||
continue
|
||||
# Get categories from block instance
|
||||
categories = []
|
||||
if hasattr(block, "categories") and block.categories:
|
||||
categories = [cat.value for cat in block.categories]
|
||||
|
||||
# Get input/output schemas
|
||||
input_schema = {}
|
||||
output_schema = {}
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to generate input schema for block %s: %s",
|
||||
block_id,
|
||||
e,
|
||||
)
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to generate output schema for block %s: %s",
|
||||
block_id,
|
||||
e,
|
||||
)
|
||||
|
||||
# Get categories from block instance
|
||||
categories = []
|
||||
if hasattr(block, "categories") and block.categories:
|
||||
categories = [cat.value for cat in block.categories]
|
||||
|
||||
# Extract required inputs for easier use
|
||||
required_inputs: list[BlockInputFieldInfo] = []
|
||||
if input_schema:
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = set(input_schema.get("required", []))
|
||||
# Get credential field names to exclude from required inputs
|
||||
credentials_fields = set(
|
||||
block.input_schema.get_credentials_fields().keys()
|
||||
)
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields - they're handled separately
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
required_inputs.append(
|
||||
BlockInputFieldInfo(
|
||||
name=field_name,
|
||||
type=field_schema.get("type", "string"),
|
||||
description=field_schema.get("description", ""),
|
||||
required=field_name in required_fields,
|
||||
default=field_schema.get("default"),
|
||||
)
|
||||
# Extract required inputs for easier use
|
||||
required_inputs: list[BlockInputFieldInfo] = []
|
||||
if input_schema:
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = set(input_schema.get("required", []))
|
||||
# Get credential field names to exclude from required inputs
|
||||
credentials_fields = set(
|
||||
block.input_schema.get_credentials_fields().keys()
|
||||
)
|
||||
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.description or "",
|
||||
categories=categories,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
required_inputs=required_inputs,
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields - they're handled separately
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
required_inputs.append(
|
||||
BlockInputFieldInfo(
|
||||
name=field_name,
|
||||
type=field_schema.get("type", "string"),
|
||||
description=field_schema.get("description", ""),
|
||||
required=field_name in required_fields,
|
||||
default=field_schema.get("default"),
|
||||
)
|
||||
)
|
||||
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.description or "",
|
||||
categories=categories,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
required_inputs=required_inputs,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if len(blocks) >= _TARGET_RESULTS:
|
||||
break
|
||||
|
||||
if blocks and len(blocks) < _TARGET_RESULTS:
|
||||
logger.debug(
|
||||
"find_block returned %d/%d results for query '%s' "
|
||||
"(filtered %d excluded/disabled blocks)",
|
||||
len(blocks),
|
||||
_TARGET_RESULTS,
|
||||
query,
|
||||
len(results) - len(blocks),
|
||||
)
|
||||
|
||||
if not blocks:
|
||||
return NoResultsResponse(
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
"""Tests for block filtering in FindBlockTool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
FindBlockTool,
|
||||
)
|
||||
from backend.api.features.chat.tools.models import BlockListResponse
|
||||
from backend.data.block import BlockType
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-find-block"
|
||||
|
||||
|
||||
def make_mock_block(
|
||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||
):
|
||||
"""Create a mock block for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.description = f"{name} description"
|
||||
mock.block_type = block_type
|
||||
mock.disabled = disabled
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||
mock.input_schema.get_credentials_fields.return_value = {}
|
||||
mock.output_schema = MagicMock()
|
||||
mock.output_schema.jsonschema.return_value = {}
|
||||
mock.categories = []
|
||||
return mock
|
||||
|
||||
|
||||
class TestFindBlockFiltering:
|
||||
"""Tests for block filtering in FindBlockTool."""
|
||||
|
||||
def test_excluded_block_types_contains_expected_types(self):
|
||||
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
|
||||
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
|
||||
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
||||
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
||||
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_type_filtered_from_results(self):
|
||||
"""Verify blocks with excluded BlockTypes are filtered from search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
|
||||
search_results = [
|
||||
{"content_id": "input-block-id", "score": 0.9},
|
||||
{"content_id": "standard-block-id", "score": 0.8},
|
||||
]
|
||||
|
||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||
standard_block = make_mock_block(
|
||||
"standard-block-id", "HTTP Request", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
"input-block-id": input_block,
|
||||
"standard-block-id": standard_block,
|
||||
}.get(block_id)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query="test"
|
||||
)
|
||||
|
||||
# Should only return the standard block, not the INPUT block
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "standard-block-id"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_id_filtered_from_results(self):
|
||||
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
search_results = [
|
||||
{"content_id": smart_decision_id, "score": 0.9},
|
||||
{"content_id": "normal-block-id", "score": 0.8},
|
||||
]
|
||||
|
||||
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
||||
smart_block = make_mock_block(
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
normal_block = make_mock_block(
|
||||
"normal-block-id", "Normal Block", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
smart_decision_id: smart_block,
|
||||
"normal-block-id": normal_block,
|
||||
}.get(block_id)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query="decision"
|
||||
)
|
||||
|
||||
# Should only return normal block, not SmartDecisionMakerBlock
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "normal-block-id"
|
||||
@@ -8,10 +8,6 @@ from typing import Any
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
)
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
@@ -216,19 +212,6 @@ class RunBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if block is excluded from CoPilot (graph-only blocks)
|
||||
if (
|
||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
):
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' cannot be run directly in CoPilot. "
|
||||
"This block is designed for use within graphs only."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
"""Tests for block execution guards in RunBlockTool."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.models import ErrorResponse
|
||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||
from backend.data.block import BlockType
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-run-block"
|
||||
|
||||
|
||||
def make_mock_block(
|
||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||
):
|
||||
"""Create a mock block for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.block_type = block_type
|
||||
mock.disabled = disabled
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||
mock.input_schema.get_credentials_fields_info.return_value = []
|
||||
return mock
|
||||
|
||||
|
||||
class TestRunBlockFiltering:
|
||||
"""Tests for block execution guards in RunBlockTool."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_type_returns_error(self):
|
||||
"""Attempting to execute a block with excluded BlockType returns error."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=input_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="input-block-id",
|
||||
input_data={},
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "cannot be run directly in CoPilot" in response.message
|
||||
assert "designed for use within graphs only" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_id_returns_error(self):
|
||||
"""Attempting to execute SmartDecisionMakerBlock returns error."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
smart_block = make_mock_block(
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=smart_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=smart_decision_id,
|
||||
input_data={},
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "cannot be run directly in CoPilot" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_excluded_block_passes_guard(self):
|
||||
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
standard_block = make_mock_block(
|
||||
"standard-id", "HTTP Request", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="standard-id",
|
||||
input_data={},
|
||||
)
|
||||
|
||||
# Should NOT be an ErrorResponse about CoPilot exclusion
|
||||
# (may be other errors like missing credentials, but not the exclusion guard)
|
||||
if isinstance(response, ErrorResponse):
|
||||
assert "cannot be run directly in CoPilot" not in response.message
|
||||
@@ -117,7 +117,7 @@ def build_missing_credentials_from_graph(
|
||||
preserving all supported credential types for each field.
|
||||
"""
|
||||
matched_keys = set(matched_credentials.keys()) if matched_credentials else set()
|
||||
aggregated_fields = graph.aggregate_credentials_inputs()
|
||||
aggregated_fields = graph.regular_credentials_inputs
|
||||
|
||||
return {
|
||||
field_key: _serialize_missing_credential(field_key, field_info)
|
||||
@@ -244,7 +244,7 @@ async def match_user_credentials_to_graph(
|
||||
missing_creds: list[str] = []
|
||||
|
||||
# Get aggregated credentials requirements from the graph
|
||||
aggregated_creds = graph.aggregate_credentials_inputs()
|
||||
aggregated_creds = graph.regular_credentials_inputs
|
||||
logger.debug(
|
||||
f"Matching credentials for graph {graph.id}: {len(aggregated_creds)} required"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Tests for chat tools utility functions."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.model import CredentialsFieldInfo
|
||||
|
||||
|
||||
def _make_regular_field() -> CredentialsFieldInfo:
|
||||
return CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["github"],
|
||||
"credentials_types": ["api_key"],
|
||||
"is_auto_credential": False,
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
|
||||
|
||||
def test_build_missing_credentials_excludes_auto_creds():
|
||||
"""
|
||||
build_missing_credentials_from_graph() should use regular_credentials_inputs
|
||||
and thus exclude auto_credentials from the "missing" set.
|
||||
"""
|
||||
from backend.api.features.chat.tools.utils import (
|
||||
build_missing_credentials_from_graph,
|
||||
)
|
||||
|
||||
regular_field = _make_regular_field()
|
||||
|
||||
mock_graph = MagicMock()
|
||||
# regular_credentials_inputs should only return the non-auto field
|
||||
mock_graph.regular_credentials_inputs = {
|
||||
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
|
||||
}
|
||||
|
||||
result = build_missing_credentials_from_graph(mock_graph, matched_credentials=None)
|
||||
|
||||
# Should include the regular credential
|
||||
assert "github_api_key" in result
|
||||
# Should NOT include the auto_credential (not in regular_credentials_inputs)
|
||||
assert "google_oauth2" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_user_credentials_excludes_auto_creds():
|
||||
"""
|
||||
match_user_credentials_to_graph() should use regular_credentials_inputs
|
||||
and thus exclude auto_credentials from matching.
|
||||
"""
|
||||
from backend.api.features.chat.tools.utils import match_user_credentials_to_graph
|
||||
|
||||
regular_field = _make_regular_field()
|
||||
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.id = "test-graph"
|
||||
# regular_credentials_inputs returns only non-auto fields
|
||||
mock_graph.regular_credentials_inputs = {
|
||||
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
|
||||
}
|
||||
|
||||
# Mock the credentials manager to return no credentials
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.utils.IntegrationCredentialsManager"
|
||||
) as MockCredsMgr:
|
||||
mock_store = AsyncMock()
|
||||
mock_store.get_all_creds.return_value = []
|
||||
MockCredsMgr.return_value.store = mock_store
|
||||
|
||||
matched, missing = await match_user_credentials_to_graph(
|
||||
user_id="test-user", graph=mock_graph
|
||||
)
|
||||
|
||||
# No credentials available, so github should be missing
|
||||
assert len(matched) == 0
|
||||
assert len(missing) == 1
|
||||
assert "github_api_key" in missing[0]
|
||||
@@ -1103,7 +1103,7 @@ async def create_preset_from_graph_execution(
|
||||
raise NotFoundError(
|
||||
f"Graph #{graph_execution.graph_id} not found or accessible"
|
||||
)
|
||||
elif len(graph.aggregate_credentials_inputs()) > 0:
|
||||
elif len(graph.regular_credentials_inputs) > 0:
|
||||
raise ValueError(
|
||||
f"Graph execution #{graph_exec_id} can't be turned into a preset "
|
||||
"because it was run before this feature existed "
|
||||
|
||||
@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
webset = await aexa.websets.get(id=input_data.external_id)
|
||||
webset = aexa.websets.get(id=input_data.external_id)
|
||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
|
||||
yield "webset", webset_result
|
||||
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
||||
count=input_data.search_count,
|
||||
)
|
||||
|
||||
webset = await aexa.websets.create(
|
||||
webset = aexa.websets.create(
|
||||
params=CreateWebsetParameters(
|
||||
search=search_params,
|
||||
external_id=input_data.external_id,
|
||||
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
|
||||
if input_data.metadata is not None:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||
sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||
|
||||
status_str = (
|
||||
sdk_webset.status.value
|
||||
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
response = await aexa.websets.list(
|
||||
response = aexa.websets.list(
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
sdk_webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
status_str = (
|
||||
sdk_webset.status.value
|
||||
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
||||
deleted_webset = aexa.websets.delete(id=input_data.webset_id)
|
||||
|
||||
status_str = (
|
||||
deleted_webset.status.value
|
||||
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
||||
canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
|
||||
|
||||
status_str = (
|
||||
canceled_webset.status.value
|
||||
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
|
||||
entity["description"] = input_data.entity_description
|
||||
payload["entity"] = entity
|
||||
|
||||
sdk_preview = await aexa.websets.preview(params=payload)
|
||||
sdk_preview = aexa.websets.preview(params=payload)
|
||||
|
||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||
|
||||
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
status = (
|
||||
webset.status.value
|
||||
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
# Extract basic info
|
||||
webset_id = webset.id
|
||||
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
|
||||
total_items = 0
|
||||
|
||||
if input_data.include_sample_items and input_data.sample_size > 0:
|
||||
items_response = await aexa.websets.items.list(
|
||||
items_response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||
)
|
||||
sample_items_data = [
|
||||
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Get webset details
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
status = (
|
||||
webset.status.value
|
||||
|
||||
@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_enrichment = await aexa.websets.enrichments.create(
|
||||
sdk_enrichment = aexa.websets.enrichments.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
items_enriched = 0
|
||||
|
||||
while time.time() - poll_start < input_data.polling_timeout:
|
||||
current_enrich = await aexa.websets.enrichments.get(
|
||||
current_enrich = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=enrichment_id
|
||||
)
|
||||
current_status = (
|
||||
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
|
||||
|
||||
if current_status in ["completed", "failed", "cancelled"]:
|
||||
# Estimate items from webset searches
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_enrichment = await aexa.websets.enrichments.get(
|
||||
sdk_enrichment = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_enrichment = await aexa.websets.enrichments.delete(
|
||||
deleted_enrichment = aexa.websets.enrichments.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
||||
canceled_enrichment = aexa.websets.enrichments.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
# Try to estimate how many items were enriched before cancellation
|
||||
items_enriched = 0
|
||||
items_response = await aexa.websets.items.list(
|
||||
items_response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=100
|
||||
)
|
||||
|
||||
|
||||
@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
|
||||
def _create_test_mock():
|
||||
"""Create test mocks for the AsyncExa SDK."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create mock SDK import object
|
||||
mock_import = MagicMock()
|
||||
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
imports=MagicMock(create=AsyncMock(return_value=mock_import))
|
||||
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
||||
)
|
||||
)
|
||||
}
|
||||
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_import = await aexa.websets.imports.create(
|
||||
sdk_import = aexa.websets.imports.create(
|
||||
params=payload, csv_data=input_data.csv_data
|
||||
)
|
||||
|
||||
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
||||
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
||||
|
||||
import_obj = ImportModel.from_sdk(sdk_import)
|
||||
|
||||
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
response = await aexa.websets.imports.list(
|
||||
response = aexa.websets.imports.list(
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
@@ -474,9 +474,7 @@ class ExaDeleteImportBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_import = await aexa.websets.imports.delete(
|
||||
import_id=input_data.import_id
|
||||
)
|
||||
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
||||
|
||||
yield "import_id", deleted_import.id
|
||||
yield "success", "true"
|
||||
@@ -575,14 +573,14 @@ class ExaExportWebsetBlock(Block):
|
||||
}
|
||||
)
|
||||
|
||||
# Create async iterator for list_all
|
||||
async def async_item_iterator(*args, **kwargs):
|
||||
for item in [mock_item1, mock_item2]:
|
||||
yield item
|
||||
# Create mock iterator
|
||||
mock_items = [mock_item1, mock_item2]
|
||||
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(items=MagicMock(list_all=async_item_iterator))
|
||||
websets=MagicMock(
|
||||
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@@ -604,7 +602,7 @@ class ExaExportWebsetBlock(Block):
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
|
||||
async for sdk_item in item_iterator:
|
||||
for sdk_item in item_iterator:
|
||||
if len(all_items) >= input_data.max_items:
|
||||
break
|
||||
|
||||
|
||||
@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_item = await aexa.websets.items.get(
|
||||
sdk_item = aexa.websets.items.get(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
|
||||
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
|
||||
response = None
|
||||
|
||||
while time.time() - start_time < input_data.wait_timeout:
|
||||
response = await aexa.websets.items.list(
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
|
||||
interval = min(interval * 1.2, 10)
|
||||
|
||||
if not response:
|
||||
response = await aexa.websets.items.list(
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
else:
|
||||
response = await aexa.websets.items.list(
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_item = await aexa.websets.items.delete(
|
||||
deleted_item = aexa.websets.items.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
|
||||
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
|
||||
async for sdk_item in item_iterator:
|
||||
for sdk_item in item_iterator:
|
||||
if len(all_items) >= input_data.max_items:
|
||||
break
|
||||
|
||||
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
entity_type = "unknown"
|
||||
if webset.searches:
|
||||
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
||||
# Get sample items if requested
|
||||
sample_items: List[WebsetItemModel] = []
|
||||
if input_data.sample_size > 0:
|
||||
items_response = await aexa.websets.items.list(
|
||||
items_response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||
)
|
||||
# Convert to our stable models
|
||||
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Get items starting from cursor
|
||||
response = await aexa.websets.items.list(
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.since_cursor,
|
||||
limit=input_data.max_items,
|
||||
|
||||
@@ -233,7 +233,7 @@ class ExaCreateMonitorBlock(Block):
|
||||
def _create_test_mock():
|
||||
"""Create test mocks for the AsyncExa SDK."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create mock SDK monitor object
|
||||
mock_monitor = MagicMock()
|
||||
@@ -263,7 +263,7 @@ class ExaCreateMonitorBlock(Block):
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
monitors=MagicMock(create=AsyncMock(return_value=mock_monitor))
|
||||
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
|
||||
)
|
||||
)
|
||||
}
|
||||
@@ -320,7 +320,7 @@ class ExaCreateMonitorBlock(Block):
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_monitor = await aexa.websets.monitors.create(params=payload)
|
||||
sdk_monitor = aexa.websets.monitors.create(params=payload)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
@@ -384,7 +384,7 @@ class ExaGetMonitorBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
@@ -476,7 +476,7 @@ class ExaUpdateMonitorBlock(Block):
|
||||
if input_data.metadata is not None:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_monitor = await aexa.websets.monitors.update(
|
||||
sdk_monitor = aexa.websets.monitors.update(
|
||||
monitor_id=input_data.monitor_id, params=payload
|
||||
)
|
||||
|
||||
@@ -522,9 +522,7 @@ class ExaDeleteMonitorBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_monitor = await aexa.websets.monitors.delete(
|
||||
monitor_id=input_data.monitor_id
|
||||
)
|
||||
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
|
||||
|
||||
yield "monitor_id", deleted_monitor.id
|
||||
yield "success", "true"
|
||||
@@ -581,7 +579,7 @@ class ExaListMonitorsBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
response = await aexa.websets.monitors.list(
|
||||
response = aexa.websets.monitors.list(
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
webset_id=input_data.webset_id,
|
||||
|
||||
@@ -121,7 +121,7 @@ class ExaWaitForWebsetBlock(Block):
|
||||
WebsetTargetStatus.IDLE,
|
||||
WebsetTargetStatus.ANY_COMPLETE,
|
||||
]:
|
||||
final_webset = await aexa.websets.wait_until_idle(
|
||||
final_webset = aexa.websets.wait_until_idle(
|
||||
id=input_data.webset_id,
|
||||
timeout=input_data.timeout,
|
||||
poll_interval=input_data.check_interval,
|
||||
@@ -164,7 +164,7 @@ class ExaWaitForWebsetBlock(Block):
|
||||
interval = input_data.check_interval
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current webset status
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
current_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
@@ -209,7 +209,7 @@ class ExaWaitForWebsetBlock(Block):
|
||||
|
||||
# Timeout reached
|
||||
elapsed = time.time() - start_time
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
final_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
@@ -345,7 +345,7 @@ class ExaWaitForSearchBlock(Block):
|
||||
try:
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current search status using SDK
|
||||
search = await aexa.websets.searches.get(
|
||||
search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
@@ -401,7 +401,7 @@ class ExaWaitForSearchBlock(Block):
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Get last known status
|
||||
search = await aexa.websets.searches.get(
|
||||
search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
final_status = (
|
||||
@@ -503,7 +503,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
try:
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current enrichment status using SDK
|
||||
enrichment = await aexa.websets.enrichments.get(
|
||||
enrichment = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
@@ -548,7 +548,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Get last known status
|
||||
enrichment = await aexa.websets.enrichments.get(
|
||||
enrichment = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
final_status = (
|
||||
@@ -575,7 +575,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||
"""Get sample enriched data and count."""
|
||||
# Get a few items to see enrichment results using SDK
|
||||
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||
|
||||
sample_data: list[SampleEnrichmentModel] = []
|
||||
enriched_count = 0
|
||||
|
||||
@@ -317,7 +317,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
||||
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_search = await aexa.websets.searches.create(
|
||||
sdk_search = aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
@@ -350,7 +350,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
||||
poll_start = time.time()
|
||||
|
||||
while time.time() - poll_start < input_data.polling_timeout:
|
||||
current_search = await aexa.websets.searches.get(
|
||||
current_search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=search_id
|
||||
)
|
||||
current_status = (
|
||||
@@ -442,7 +442,7 @@ class ExaGetWebsetSearchBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_search = await aexa.websets.searches.get(
|
||||
sdk_search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
@@ -523,7 +523,7 @@ class ExaCancelWebsetSearchBlock(Block):
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_search = await aexa.websets.searches.cancel(
|
||||
canceled_search = aexa.websets.searches.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
@@ -604,7 +604,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Get webset to check existing searches
|
||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
# Look for existing search with same query
|
||||
existing_search = None
|
||||
@@ -636,7 +636,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
||||
if input_data.entity_type != SearchEntityType.AUTO:
|
||||
payload["entity"] = {"type": input_data.entity_type.value}
|
||||
|
||||
sdk_search = await aexa.websets.searches.create(
|
||||
sdk_search = aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
|
||||
@@ -531,12 +531,12 @@ class LLMResponse(BaseModel):
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
openai_tools: list[dict] | None = None,
|
||||
) -> Iterable[ToolParam] | anthropic.Omit:
|
||||
) -> Iterable[ToolParam] | anthropic.NotGiven:
|
||||
"""
|
||||
Convert OpenAI tool format to Anthropic tool format.
|
||||
"""
|
||||
if not openai_tools or len(openai_tools) == 0:
|
||||
return anthropic.omit
|
||||
return anthropic.NOT_GIVEN
|
||||
|
||||
anthropic_tools = []
|
||||
for tool in openai_tools:
|
||||
@@ -596,10 +596,10 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
|
||||
def get_parallel_tool_calls_param(
|
||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||
) -> bool | openai.Omit:
|
||||
):
|
||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||
return openai.omit
|
||||
return openai.NOT_GIVEN
|
||||
return parallel_tool_calls
|
||||
|
||||
|
||||
|
||||
@@ -319,6 +319,8 @@ class BlockSchema(BaseModel):
|
||||
"credentials_provider": [config.get("provider", "google")],
|
||||
"credentials_types": [config.get("type", "oauth2")],
|
||||
"credentials_scopes": config.get("scopes"),
|
||||
"is_auto_credential": True,
|
||||
"input_field_name": info["field_name"],
|
||||
}
|
||||
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
||||
auto_schema, by_alias=True
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
import queue
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from multiprocessing import Manager
|
||||
from queue import Empty
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
@@ -1199,16 +1200,12 @@ class NodeExecutionEntry(BaseModel):
|
||||
|
||||
class ExecutionQueue(Generic[T]):
|
||||
"""
|
||||
Thread-safe queue for managing node execution within a single graph execution.
|
||||
|
||||
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
|
||||
threads within the same process. If migrating back to ProcessPoolExecutor,
|
||||
replace with multiprocessing.Manager().Queue() for cross-process safety.
|
||||
Queue for managing the execution of agents.
|
||||
This will be shared between different processes
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Thread-safe queue (not multiprocessing) — see class docstring
|
||||
self.queue: queue.Queue[T] = queue.Queue()
|
||||
self.queue = Manager().Queue()
|
||||
|
||||
def add(self, execution: T) -> T:
|
||||
self.queue.put(execution)
|
||||
@@ -1223,7 +1220,7 @@ class ExecutionQueue(Generic[T]):
|
||||
def get_or_none(self) -> T | None:
|
||||
try:
|
||||
return self.queue.get_nowait()
|
||||
except queue.Empty:
|
||||
except Empty:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
"""Tests for ExecutionQueue thread-safety."""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
|
||||
from backend.data.execution import ExecutionQueue
|
||||
|
||||
|
||||
def test_execution_queue_uses_stdlib_queue():
|
||||
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
|
||||
q = ExecutionQueue()
|
||||
assert isinstance(q.queue, queue.Queue)
|
||||
|
||||
|
||||
def test_basic_operations():
|
||||
"""Test add, get, empty, and get_or_none."""
|
||||
q = ExecutionQueue()
|
||||
|
||||
assert q.empty() is True
|
||||
assert q.get_or_none() is None
|
||||
|
||||
result = q.add("item1")
|
||||
assert result == "item1"
|
||||
assert q.empty() is False
|
||||
|
||||
item = q.get()
|
||||
assert item == "item1"
|
||||
assert q.empty() is True
|
||||
|
||||
|
||||
def test_thread_safety():
|
||||
"""Test concurrent access from multiple threads."""
|
||||
q = ExecutionQueue()
|
||||
results = []
|
||||
num_items = 100
|
||||
|
||||
def producer():
|
||||
for i in range(num_items):
|
||||
q.add(f"item_{i}")
|
||||
|
||||
def consumer():
|
||||
count = 0
|
||||
while count < num_items:
|
||||
item = q.get_or_none()
|
||||
if item is not None:
|
||||
results.append(item)
|
||||
count += 1
|
||||
|
||||
producer_thread = threading.Thread(target=producer)
|
||||
consumer_thread = threading.Thread(target=consumer)
|
||||
|
||||
producer_thread.start()
|
||||
consumer_thread.start()
|
||||
|
||||
producer_thread.join(timeout=5)
|
||||
consumer_thread.join(timeout=5)
|
||||
|
||||
assert len(results) == num_items
|
||||
@@ -447,8 +447,7 @@ class GraphModel(Graph, GraphMeta):
|
||||
@computed_field
|
||||
@property
|
||||
def credentials_input_schema(self) -> dict[str, Any]:
|
||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||
|
||||
graph_credentials_inputs = self.regular_credentials_inputs
|
||||
logger.debug(
|
||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||
f"{graph_credentials_inputs}"
|
||||
@@ -604,6 +603,28 @@ class GraphModel(Graph, GraphMeta):
|
||||
for key, (field_info, node_field_pairs) in combined.items()
|
||||
}
|
||||
|
||||
@property
|
||||
def regular_credentials_inputs(
|
||||
self,
|
||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
||||
"""Credentials that need explicit user mapping (CredentialsMetaInput fields)."""
|
||||
return {
|
||||
k: v
|
||||
for k, v in self.aggregate_credentials_inputs().items()
|
||||
if not v[0].is_auto_credential
|
||||
}
|
||||
|
||||
@property
|
||||
def auto_credentials_inputs(
|
||||
self,
|
||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
||||
"""Credentials embedded in file fields (_credentials_id), resolved at execution time."""
|
||||
return {
|
||||
k: v
|
||||
for k, v in self.aggregate_credentials_inputs().items()
|
||||
if v[0].is_auto_credential
|
||||
}
|
||||
|
||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||
"""
|
||||
Reassigns all IDs in the graph to new UUIDs.
|
||||
@@ -654,6 +675,16 @@ class GraphModel(Graph, GraphMeta):
|
||||
) and graph_id in graph_id_map:
|
||||
node.input_default["graph_id"] = graph_id_map[graph_id]
|
||||
|
||||
# Clear auto-credentials references (e.g., _credentials_id in
|
||||
# GoogleDriveFile fields) so the new user must re-authenticate
|
||||
# with their own account
|
||||
for node in graph.nodes:
|
||||
if not node.input_default:
|
||||
continue
|
||||
for key, value in node.input_default.items():
|
||||
if isinstance(value, dict) and "_credentials_id" in value:
|
||||
del value["_credentials_id"]
|
||||
|
||||
def validate_graph(
|
||||
self,
|
||||
for_run: bool = False,
|
||||
|
||||
@@ -463,3 +463,329 @@ def test_node_credentials_optional_with_other_metadata():
|
||||
assert node.credentials_optional is True
|
||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||
assert node.metadata["customized_name"] == "My Custom Node"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for CredentialsFieldInfo.combine() field propagation
|
||||
def test_combine_preserves_is_auto_credential_flag():
|
||||
"""
|
||||
CredentialsFieldInfo.combine() must propagate is_auto_credential and
|
||||
input_field_name to the combined result. Regression test for reviewer
|
||||
finding that combine() dropped these fields.
|
||||
"""
|
||||
from backend.data.model import CredentialsFieldInfo
|
||||
|
||||
auto_field = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["google"],
|
||||
"credentials_types": ["oauth2"],
|
||||
"credentials_scopes": ["drive.readonly"],
|
||||
"is_auto_credential": True,
|
||||
"input_field_name": "spreadsheet",
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
|
||||
# combine() takes *args of (field_info, key) tuples
|
||||
combined = CredentialsFieldInfo.combine(
|
||||
(auto_field, ("node-1", "credentials")),
|
||||
(auto_field, ("node-2", "credentials")),
|
||||
)
|
||||
|
||||
assert len(combined) == 1
|
||||
group_key = next(iter(combined))
|
||||
combined_info, combined_keys = combined[group_key]
|
||||
|
||||
assert combined_info.is_auto_credential is True
|
||||
assert combined_info.input_field_name == "spreadsheet"
|
||||
assert combined_keys == {("node-1", "credentials"), ("node-2", "credentials")}
|
||||
|
||||
|
||||
def test_combine_preserves_regular_credential_defaults():
|
||||
"""Regular credentials should have is_auto_credential=False after combine()."""
|
||||
from backend.data.model import CredentialsFieldInfo
|
||||
|
||||
regular_field = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["github"],
|
||||
"credentials_types": ["api_key"],
|
||||
"is_auto_credential": False,
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
|
||||
combined = CredentialsFieldInfo.combine(
|
||||
(regular_field, ("node-1", "credentials")),
|
||||
)
|
||||
|
||||
group_key = next(iter(combined))
|
||||
combined_info, _ = combined[group_key]
|
||||
|
||||
assert combined_info.is_auto_credential is False
|
||||
assert combined_info.input_field_name is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for _reassign_ids credential clearing (Fix 3: SECRT-1772)
|
||||
|
||||
|
||||
def test_reassign_ids_clears_credentials_id():
|
||||
"""
|
||||
[SECRT-1772] _reassign_ids should clear _credentials_id from
|
||||
GoogleDriveFile-style input_default fields so forked agents
|
||||
don't retain the original creator's credential references.
|
||||
"""
|
||||
from backend.data.graph import GraphModel
|
||||
|
||||
node = Node(
|
||||
id="node-1",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={
|
||||
"spreadsheet": {
|
||||
"_credentials_id": "original-cred-id",
|
||||
"id": "file-123",
|
||||
"name": "test.xlsx",
|
||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||
"url": "https://docs.google.com/spreadsheets/d/file-123",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
graph = Graph(
|
||||
id="test-graph",
|
||||
name="Test",
|
||||
description="Test",
|
||||
nodes=[node],
|
||||
links=[],
|
||||
)
|
||||
|
||||
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||
|
||||
# _credentials_id key should be removed (not set to None) so that
|
||||
# _acquire_auto_credentials correctly errors instead of treating it as chained data
|
||||
assert "_credentials_id" not in graph.nodes[0].input_default["spreadsheet"]
|
||||
|
||||
|
||||
def test_reassign_ids_preserves_non_credential_fields():
|
||||
"""
|
||||
Regression guard: _reassign_ids should NOT modify non-credential fields
|
||||
like name, mimeType, id, url.
|
||||
"""
|
||||
from backend.data.graph import GraphModel
|
||||
|
||||
node = Node(
|
||||
id="node-1",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={
|
||||
"spreadsheet": {
|
||||
"_credentials_id": "cred-abc",
|
||||
"id": "file-123",
|
||||
"name": "test.xlsx",
|
||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||
"url": "https://docs.google.com/spreadsheets/d/file-123",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
graph = Graph(
|
||||
id="test-graph",
|
||||
name="Test",
|
||||
description="Test",
|
||||
nodes=[node],
|
||||
links=[],
|
||||
)
|
||||
|
||||
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||
|
||||
field = graph.nodes[0].input_default["spreadsheet"]
|
||||
assert field["id"] == "file-123"
|
||||
assert field["name"] == "test.xlsx"
|
||||
assert field["mimeType"] == "application/vnd.google-apps.spreadsheet"
|
||||
assert field["url"] == "https://docs.google.com/spreadsheets/d/file-123"
|
||||
|
||||
|
||||
def test_reassign_ids_handles_no_credentials():
|
||||
"""
|
||||
Regression guard: _reassign_ids should not error when input_default
|
||||
has no dict fields with _credentials_id.
|
||||
"""
|
||||
from backend.data.graph import GraphModel
|
||||
|
||||
node = Node(
|
||||
id="node-1",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={
|
||||
"input": "some value",
|
||||
"another_input": 42,
|
||||
},
|
||||
)
|
||||
|
||||
graph = Graph(
|
||||
id="test-graph",
|
||||
name="Test",
|
||||
description="Test",
|
||||
nodes=[node],
|
||||
links=[],
|
||||
)
|
||||
|
||||
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||
|
||||
# Should not error, fields unchanged
|
||||
assert graph.nodes[0].input_default["input"] == "some value"
|
||||
assert graph.nodes[0].input_default["another_input"] == 42
|
||||
|
||||
|
||||
def test_reassign_ids_handles_multiple_credential_fields():
|
||||
"""
|
||||
[SECRT-1772] When a node has multiple dict fields with _credentials_id,
|
||||
ALL of them should be cleared.
|
||||
"""
|
||||
from backend.data.graph import GraphModel
|
||||
|
||||
node = Node(
|
||||
id="node-1",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={
|
||||
"spreadsheet": {
|
||||
"_credentials_id": "cred-1",
|
||||
"id": "file-1",
|
||||
"name": "file1.xlsx",
|
||||
},
|
||||
"doc_file": {
|
||||
"_credentials_id": "cred-2",
|
||||
"id": "file-2",
|
||||
"name": "file2.docx",
|
||||
},
|
||||
"plain_input": "not a dict",
|
||||
},
|
||||
)
|
||||
|
||||
graph = Graph(
|
||||
id="test-graph",
|
||||
name="Test",
|
||||
description="Test",
|
||||
nodes=[node],
|
||||
links=[],
|
||||
)
|
||||
|
||||
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
||||
|
||||
assert "_credentials_id" not in graph.nodes[0].input_default["spreadsheet"]
|
||||
assert "_credentials_id" not in graph.nodes[0].input_default["doc_file"]
|
||||
assert graph.nodes[0].input_default["plain_input"] == "not a dict"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for discriminate() field propagation
|
||||
def test_discriminate_preserves_is_auto_credential_flag():
|
||||
"""
|
||||
CredentialsFieldInfo.discriminate() must propagate is_auto_credential and
|
||||
input_field_name to the discriminated result. Regression test for
|
||||
discriminate() dropping these fields (same class of bug as combine()).
|
||||
"""
|
||||
from backend.data.model import CredentialsFieldInfo
|
||||
|
||||
auto_field = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["google", "openai"],
|
||||
"credentials_types": ["oauth2"],
|
||||
"credentials_scopes": ["drive.readonly"],
|
||||
"is_auto_credential": True,
|
||||
"input_field_name": "spreadsheet",
|
||||
"discriminator": "model",
|
||||
"discriminator_mapping": {"gpt-4": "openai", "gemini": "google"},
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
|
||||
discriminated = auto_field.discriminate("gemini")
|
||||
|
||||
assert discriminated.is_auto_credential is True
|
||||
assert discriminated.input_field_name == "spreadsheet"
|
||||
assert discriminated.provider == frozenset(["google"])
|
||||
|
||||
|
||||
def test_discriminate_preserves_regular_credential_defaults():
|
||||
"""Regular credentials should have is_auto_credential=False after discriminate()."""
|
||||
from backend.data.model import CredentialsFieldInfo
|
||||
|
||||
regular_field = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["google", "openai"],
|
||||
"credentials_types": ["api_key"],
|
||||
"is_auto_credential": False,
|
||||
"discriminator": "model",
|
||||
"discriminator_mapping": {"gpt-4": "openai", "gemini": "google"},
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
|
||||
discriminated = regular_field.discriminate("gpt-4")
|
||||
|
||||
assert discriminated.is_auto_credential is False
|
||||
assert discriminated.input_field_name is None
|
||||
assert discriminated.provider == frozenset(["openai"])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for credentials_input_schema excluding auto_credentials
|
||||
def test_credentials_input_schema_excludes_auto_creds():
|
||||
"""
|
||||
GraphModel.credentials_input_schema should exclude auto_credentials
|
||||
(is_auto_credential=True) from the schema. Auto_credentials are
|
||||
transparently resolved at execution time via file picker data.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
from backend.data.graph import GraphModel, NodeModel
|
||||
from backend.data.model import CredentialsFieldInfo
|
||||
|
||||
regular_field_info = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["github"],
|
||||
"credentials_types": ["api_key"],
|
||||
"is_auto_credential": False,
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
|
||||
graph = GraphModel(
|
||||
id="test-graph",
|
||||
version=1,
|
||||
name="Test",
|
||||
description="Test",
|
||||
user_id="test-user",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
nodes=[
|
||||
NodeModel(
|
||||
id="node-1",
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={},
|
||||
graph_id="test-graph",
|
||||
graph_version=1,
|
||||
),
|
||||
],
|
||||
links=[],
|
||||
)
|
||||
|
||||
# Mock regular_credentials_inputs to return only the non-auto field (3-tuple)
|
||||
regular_only = {
|
||||
"github_credentials": (
|
||||
regular_field_info,
|
||||
{("node-1", "credentials")},
|
||||
True,
|
||||
),
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
type(graph),
|
||||
"regular_credentials_inputs",
|
||||
new_callable=PropertyMock,
|
||||
return_value=regular_only,
|
||||
):
|
||||
schema = graph.credentials_input_schema
|
||||
field_names = set(schema.get("properties", {}).keys())
|
||||
# Should include regular credential but NOT auto_credential
|
||||
assert "github_credentials" in field_names
|
||||
assert "google_credentials" not in field_names
|
||||
|
||||
@@ -571,6 +571,8 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
discriminator: Optional[str] = None
|
||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||
discriminator_values: set[Any] = Field(default_factory=set)
|
||||
is_auto_credential: bool = False
|
||||
input_field_name: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def combine(
|
||||
@@ -651,6 +653,9 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
+ "_credentials"
|
||||
)
|
||||
|
||||
# Propagate is_auto_credential from the combined field.
|
||||
# All fields in a group should share the same is_auto_credential
|
||||
# value since auto and regular credentials serve different purposes.
|
||||
result[group_key] = (
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
@@ -659,6 +664,8 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
discriminator_values=set(all_discriminator_values),
|
||||
is_auto_credential=combined.is_auto_credential,
|
||||
input_field_name=combined.input_field_name,
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
@@ -684,6 +691,8 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
discriminator=self.discriminator,
|
||||
discriminator_mapping=self.discriminator_mapping,
|
||||
discriminator_values=self.discriminator_values,
|
||||
is_auto_credential=self.is_auto_credential,
|
||||
input_field_name=self.input_field_name,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -92,6 +92,7 @@ from .utils import (
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
parse_auto_credential_field,
|
||||
validate_exec,
|
||||
)
|
||||
|
||||
@@ -172,6 +173,61 @@ def execute_graph(
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def _acquire_auto_credentials(
|
||||
input_model: type[BlockSchema],
|
||||
input_data: dict[str, Any],
|
||||
creds_manager: "IntegrationCredentialsManager",
|
||||
user_id: str,
|
||||
) -> tuple[dict[str, Any], list[AsyncRedisLock]]:
|
||||
"""
|
||||
Resolve auto_credentials from GoogleDriveFileField-style inputs.
|
||||
|
||||
Returns:
|
||||
(extra_exec_kwargs, locks): kwargs to inject into block execution, and
|
||||
credential locks to release after execution completes.
|
||||
"""
|
||||
extra_exec_kwargs: dict[str, Any] = {}
|
||||
locks: list[AsyncRedisLock] = []
|
||||
|
||||
# NOTE: If a block ever has multiple auto-credential fields, a ValueError
|
||||
# on a later field will strand locks acquired for earlier fields. They'll
|
||||
# auto-expire via Redis TTL, but add a try/except to release partial locks
|
||||
# if that becomes a real scenario.
|
||||
for kwarg_name, info in input_model.get_auto_credentials_fields().items():
|
||||
field_name = info["field_name"]
|
||||
field_data = input_data.get(field_name)
|
||||
|
||||
# Use shared helper to parse the field
|
||||
parsed = parse_auto_credential_field(
|
||||
field_name=field_name,
|
||||
info=info,
|
||||
field_data=field_data,
|
||||
field_present_in_input=field_name in input_data,
|
||||
)
|
||||
|
||||
if parsed.error:
|
||||
raise ValueError(parsed.error)
|
||||
|
||||
if parsed.cred_id:
|
||||
# Credential ID provided - acquire credentials
|
||||
try:
|
||||
credentials, lock = await creds_manager.acquire(user_id, parsed.cred_id)
|
||||
locks.append(lock)
|
||||
extra_exec_kwargs[kwarg_name] = credentials
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"{parsed.provider.capitalize()} credentials for "
|
||||
f"'{parsed.file_name}' in field '{parsed.field_name}' are not "
|
||||
f"available in your account. "
|
||||
f"This can happen if the agent was created by another "
|
||||
f"user or the credentials were deleted. "
|
||||
f"Please open the agent in the builder and re-select "
|
||||
f"the file to authenticate with your own account."
|
||||
)
|
||||
|
||||
return extra_exec_kwargs, locks
|
||||
|
||||
|
||||
async def execute_node(
|
||||
node: Node,
|
||||
data: NodeExecutionEntry,
|
||||
@@ -271,41 +327,14 @@ async def execute_node(
|
||||
extra_exec_kwargs[field_name] = credentials
|
||||
|
||||
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
|
||||
for kwarg_name, info in input_model.get_auto_credentials_fields().items():
|
||||
field_name = info["field_name"]
|
||||
field_data = input_data.get(field_name)
|
||||
if field_data and isinstance(field_data, dict):
|
||||
# Check if _credentials_id key exists in the field data
|
||||
if "_credentials_id" in field_data:
|
||||
cred_id = field_data["_credentials_id"]
|
||||
if cred_id:
|
||||
# Credential ID provided - acquire credentials
|
||||
provider = info.get("config", {}).get(
|
||||
"provider", "external service"
|
||||
)
|
||||
file_name = field_data.get("name", "selected file")
|
||||
try:
|
||||
credentials, lock = await creds_manager.acquire(
|
||||
user_id, cred_id
|
||||
)
|
||||
creds_locks.append(lock)
|
||||
extra_exec_kwargs[kwarg_name] = credentials
|
||||
except ValueError:
|
||||
# Credential was deleted or doesn't exist
|
||||
raise ValueError(
|
||||
f"Authentication expired for '{file_name}' in field '{field_name}'. "
|
||||
f"The saved {provider.capitalize()} credentials no longer exist. "
|
||||
f"Please re-select the file to re-authenticate."
|
||||
)
|
||||
# else: _credentials_id is explicitly None, skip credentials (for chained data)
|
||||
else:
|
||||
# _credentials_id key missing entirely - this is an error
|
||||
provider = info.get("config", {}).get("provider", "external service")
|
||||
file_name = field_data.get("name", "selected file")
|
||||
raise ValueError(
|
||||
f"Authentication missing for '{file_name}' in field '{field_name}'. "
|
||||
f"Please re-select the file to authenticate with {provider.capitalize()}."
|
||||
)
|
||||
auto_extra_kwargs, auto_locks = await _acquire_auto_credentials(
|
||||
input_model=input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=creds_manager,
|
||||
user_id=user_id,
|
||||
)
|
||||
extra_exec_kwargs.update(auto_extra_kwargs)
|
||||
creds_locks.extend(auto_locks)
|
||||
|
||||
output_size = 0
|
||||
|
||||
|
||||
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Tests for auto_credentials handling in execute_node().
|
||||
|
||||
These test the _acquire_auto_credentials() helper function extracted from
|
||||
execute_node() (manager.py lines 273-308).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def google_drive_file_data():
|
||||
return {
|
||||
"valid": {
|
||||
"_credentials_id": "cred-id-123",
|
||||
"id": "file-123",
|
||||
"name": "test.xlsx",
|
||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||
},
|
||||
"chained": {
|
||||
"_credentials_id": None,
|
||||
"id": "file-456",
|
||||
"name": "chained.xlsx",
|
||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||
},
|
||||
"missing_key": {
|
||||
"id": "file-789",
|
||||
"name": "bad.xlsx",
|
||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_input_model(mocker: MockerFixture):
|
||||
"""Create a mock input model with get_auto_credentials_fields() returning one field."""
|
||||
input_model = mocker.MagicMock()
|
||||
input_model.get_auto_credentials_fields.return_value = {
|
||||
"credentials": {
|
||||
"field_name": "spreadsheet",
|
||||
"config": {
|
||||
"provider": "google",
|
||||
"type": "oauth2",
|
||||
"scopes": ["https://www.googleapis.com/auth/drive.readonly"],
|
||||
},
|
||||
}
|
||||
}
|
||||
return input_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_creds_manager(mocker: MockerFixture):
|
||||
manager = mocker.AsyncMock()
|
||||
mock_lock = mocker.AsyncMock()
|
||||
mock_creds = mocker.MagicMock()
|
||||
mock_creds.id = "cred-id-123"
|
||||
mock_creds.provider = "google"
|
||||
manager.acquire.return_value = (mock_creds, mock_lock)
|
||||
return manager, mock_creds, mock_lock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_happy_path(
|
||||
mocker: MockerFixture,
|
||||
google_drive_file_data,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""When field_data has a valid _credentials_id, credentials should be acquired."""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, mock_creds, mock_lock = mock_creds_manager
|
||||
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||
|
||||
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
manager.acquire.assert_called_once_with("user-1", "cred-id-123")
|
||||
assert extra_kwargs["credentials"] == mock_creds
|
||||
assert mock_lock in locks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_field_none_static_raises(
|
||||
mocker: MockerFixture,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""
|
||||
[THE BUG FIX TEST — OPEN-2895]
|
||||
When field_data is None and the key IS in input_data (user didn't select a file),
|
||||
should raise ValueError instead of silently skipping.
|
||||
"""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, _, _ = mock_creds_manager
|
||||
# Key is present but value is None = user didn't select a file
|
||||
input_data = {"spreadsheet": None}
|
||||
|
||||
with pytest.raises(ValueError, match="No file selected"):
|
||||
await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_field_absent_skips(
|
||||
mocker: MockerFixture,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""
|
||||
When the field key is NOT in input_data at all (upstream connection),
|
||||
should skip without error.
|
||||
"""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, _, _ = mock_creds_manager
|
||||
# Key not present = connected from upstream block
|
||||
input_data = {}
|
||||
|
||||
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
manager.acquire.assert_not_called()
|
||||
assert "credentials" not in extra_kwargs
|
||||
assert locks == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_chained_cred_id_none(
|
||||
mocker: MockerFixture,
|
||||
google_drive_file_data,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""
|
||||
When _credentials_id is explicitly None (chained data from upstream),
|
||||
should skip credential acquisition.
|
||||
"""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, _, _ = mock_creds_manager
|
||||
input_data = {"spreadsheet": google_drive_file_data["chained"]}
|
||||
|
||||
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
manager.acquire.assert_not_called()
|
||||
assert "credentials" not in extra_kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_missing_cred_id_key_raises(
|
||||
mocker: MockerFixture,
|
||||
google_drive_file_data,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""
|
||||
When _credentials_id key is missing entirely from field_data dict,
|
||||
should raise ValueError.
|
||||
"""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, _, _ = mock_creds_manager
|
||||
input_data = {"spreadsheet": google_drive_file_data["missing_key"]}
|
||||
|
||||
with pytest.raises(ValueError, match="Authentication missing"):
|
||||
await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_ownership_mismatch_error(
|
||||
mocker: MockerFixture,
|
||||
google_drive_file_data,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""
|
||||
[SECRT-1772] When acquire() raises ValueError (credential belongs to another user),
|
||||
the error message should mention 'not available' (not 'expired').
|
||||
"""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, _, _ = mock_creds_manager
|
||||
manager.acquire.side_effect = ValueError(
|
||||
"Credentials #cred-id-123 for user #user-2 not found"
|
||||
)
|
||||
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||
|
||||
with pytest.raises(ValueError, match="not available in your account"):
|
||||
await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-2",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_deleted_credential_error(
|
||||
mocker: MockerFixture,
|
||||
google_drive_file_data,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""
|
||||
[SECRT-1772] When acquire() raises ValueError (credential was deleted),
|
||||
the error message should mention 'not available' (not 'expired').
|
||||
"""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, _, _ = mock_creds_manager
|
||||
manager.acquire.side_effect = ValueError(
|
||||
"Credentials #cred-id-123 for user #user-1 not found"
|
||||
)
|
||||
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||
|
||||
with pytest.raises(ValueError, match="not available in your account"):
|
||||
await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_lock_appended(
|
||||
mocker: MockerFixture,
|
||||
google_drive_file_data,
|
||||
mock_input_model,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""Lock from acquire() should be included in returned locks list."""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, _, mock_lock = mock_creds_manager
|
||||
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
||||
|
||||
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||
input_model=mock_input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert len(locks) == 1
|
||||
assert locks[0] is mock_lock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_credentials_multiple_fields(
|
||||
mocker: MockerFixture,
|
||||
mock_creds_manager,
|
||||
):
|
||||
"""When there are multiple auto_credentials fields, only valid ones should acquire."""
|
||||
from backend.executor.manager import _acquire_auto_credentials
|
||||
|
||||
manager, mock_creds, mock_lock = mock_creds_manager
|
||||
|
||||
input_model = mocker.MagicMock()
|
||||
input_model.get_auto_credentials_fields.return_value = {
|
||||
"credentials": {
|
||||
"field_name": "spreadsheet",
|
||||
"config": {"provider": "google", "type": "oauth2"},
|
||||
},
|
||||
"credentials2": {
|
||||
"field_name": "doc_file",
|
||||
"config": {"provider": "google", "type": "oauth2"},
|
||||
},
|
||||
}
|
||||
|
||||
input_data = {
|
||||
"spreadsheet": {
|
||||
"_credentials_id": "cred-id-123",
|
||||
"id": "file-1",
|
||||
"name": "file1.xlsx",
|
||||
},
|
||||
"doc_file": {
|
||||
"_credentials_id": None,
|
||||
"id": "file-2",
|
||||
"name": "chained.doc",
|
||||
},
|
||||
}
|
||||
|
||||
extra_kwargs, locks = await _acquire_auto_credentials(
|
||||
input_model=input_model,
|
||||
input_data=input_data,
|
||||
creds_manager=manager,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
# Only the first field should have acquired credentials
|
||||
manager.acquire.assert_called_once_with("user-1", "cred-id-123")
|
||||
assert "credentials" in extra_kwargs
|
||||
assert "credentials2" not in extra_kwargs
|
||||
assert len(locks) == 1
|
||||
@@ -4,7 +4,7 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import Mapping, Optional, cast
|
||||
from typing import Any, Mapping, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
@@ -55,6 +55,87 @@ from backend.util.type import convert
|
||||
config = Config()
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
|
||||
|
||||
# ============ Auto-Credentials Helpers ============ #
|
||||
|
||||
|
||||
class AutoCredentialFieldInfo(BaseModel):
|
||||
"""Parsed info from an auto-credential field (e.g., GoogleDriveFileField)."""
|
||||
|
||||
cred_id: str | None
|
||||
"""The credential ID to use, or None if not provided."""
|
||||
provider: str
|
||||
"""The provider name (e.g., 'google')."""
|
||||
file_name: str
|
||||
"""The display name for error messages."""
|
||||
field_name: str
|
||||
"""The original field name in the schema."""
|
||||
error: str | None = None
|
||||
"""Validation error message, if any."""
|
||||
|
||||
|
||||
def parse_auto_credential_field(
|
||||
field_name: str,
|
||||
info: dict,
|
||||
field_data: Any,
|
||||
*,
|
||||
field_present_in_input: bool = True,
|
||||
) -> AutoCredentialFieldInfo:
|
||||
"""
|
||||
Parse auto-credential field data and extract credential info.
|
||||
|
||||
This is shared logic used by both credential acquisition (manager.py)
|
||||
and credential validation (utils.py).
|
||||
|
||||
Args:
|
||||
field_name: The name of the field in the schema
|
||||
info: The auto_credentials field info from get_auto_credentials_fields()
|
||||
field_data: The actual field data from input
|
||||
field_present_in_input: Whether the field key exists in input_data
|
||||
|
||||
Returns:
|
||||
AutoCredentialFieldInfo with parsed data and any validation errors
|
||||
"""
|
||||
provider = info.get("config", {}).get("provider", "external service")
|
||||
file_name = (
|
||||
field_data.get("name", "selected file")
|
||||
if isinstance(field_data, dict)
|
||||
else "selected file"
|
||||
)
|
||||
|
||||
result = AutoCredentialFieldInfo(
|
||||
cred_id=None,
|
||||
provider=provider,
|
||||
file_name=file_name,
|
||||
field_name=field_name,
|
||||
)
|
||||
|
||||
if field_data and isinstance(field_data, dict):
|
||||
if "_credentials_id" not in field_data:
|
||||
# Key removed (e.g., on fork) — needs re-auth
|
||||
result.error = (
|
||||
f"Authentication missing for '{file_name}' in field "
|
||||
f"'{field_name}'. Please re-select the file to authenticate "
|
||||
f"with {provider.capitalize()}."
|
||||
)
|
||||
else:
|
||||
cred_id = field_data.get("_credentials_id")
|
||||
if cred_id:
|
||||
result.cred_id = cred_id
|
||||
# else: _credentials_id is explicitly None, skip (chained data)
|
||||
elif field_data is None and not field_present_in_input:
|
||||
# Field not in input_data at all = connected from upstream block, skip
|
||||
pass
|
||||
elif field_present_in_input:
|
||||
# field_data is None/empty but key IS in input_data = user didn't select
|
||||
result.error = (
|
||||
f"No file selected for '{field_name}'. "
|
||||
f"Please select a file to provide "
|
||||
f"{provider.capitalize()} authentication."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============ Resource Helpers ============ #
|
||||
|
||||
|
||||
@@ -259,7 +340,8 @@ async def _validate_node_input_credentials(
|
||||
|
||||
# Find any fields of type CredentialsMetaInput
|
||||
credentials_fields = block.input_schema.get_credentials_fields()
|
||||
if not credentials_fields:
|
||||
auto_credentials_fields = block.input_schema.get_auto_credentials_fields()
|
||||
if not credentials_fields and not auto_credentials_fields:
|
||||
continue
|
||||
|
||||
# Track if any credential field is missing for this node
|
||||
@@ -339,6 +421,52 @@ async def _validate_node_input_credentials(
|
||||
] = "Invalid credentials: type/provider mismatch"
|
||||
continue
|
||||
|
||||
# Validate auto-credentials (GoogleDriveFileField-based)
|
||||
# These have _credentials_id embedded in the file field data
|
||||
if auto_credentials_fields:
|
||||
for _kwarg_name, info in auto_credentials_fields.items():
|
||||
field_name = info["field_name"]
|
||||
# Check input_default and nodes_input_masks for the field value
|
||||
field_value = node.input_default.get(field_name)
|
||||
if nodes_input_masks and node.id in nodes_input_masks:
|
||||
field_value = nodes_input_masks[node.id].get(
|
||||
field_name, field_value
|
||||
)
|
||||
|
||||
# Use shared helper to parse the field
|
||||
parsed = parse_auto_credential_field(
|
||||
field_name=field_name,
|
||||
info=info,
|
||||
field_data=field_value,
|
||||
field_present_in_input=True, # For validation, assume present
|
||||
)
|
||||
|
||||
if parsed.error:
|
||||
has_missing_credentials = True
|
||||
credential_errors[node.id][field_name] = parsed.error
|
||||
continue
|
||||
|
||||
if parsed.cred_id:
|
||||
# Validate that credentials exist and are accessible
|
||||
try:
|
||||
creds_store = get_integration_credentials_store()
|
||||
creds = await creds_store.get_creds_by_id(
|
||||
user_id, parsed.cred_id
|
||||
)
|
||||
except Exception as e:
|
||||
has_missing_credentials = True
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = f"Credentials not available: {e}"
|
||||
continue
|
||||
if not creds:
|
||||
has_missing_credentials = True
|
||||
credential_errors[node.id][field_name] = (
|
||||
"The saved credentials are not available "
|
||||
"for your account. Please re-select the file to "
|
||||
"authenticate with your own account."
|
||||
)
|
||||
|
||||
# If node has optional credentials and any are missing, mark for skipping
|
||||
# But only if there are no other errors for this node
|
||||
if (
|
||||
@@ -370,8 +498,9 @@ def make_node_credentials_input_map(
|
||||
"""
|
||||
result: dict[str, dict[str, JsonValue]] = {}
|
||||
|
||||
# Get aggregated credentials fields for the graph
|
||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||
# Only map regular credentials (not auto_credentials, which are resolved
|
||||
# at execution time from _credentials_id in file field data)
|
||||
graph_cred_inputs = graph.regular_credentials_inputs
|
||||
|
||||
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
||||
# Best-effort map: skip missing items
|
||||
|
||||
@@ -907,3 +907,335 @@ async def test_stop_graph_execution_cascades_to_child_with_reviews(
|
||||
|
||||
# Verify both parent and child status updates
|
||||
assert mock_execution_db.update_graph_execution_stats.call_count >= 1
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for auto_credentials validation in _validate_node_input_credentials
|
||||
# (Fix 3: SECRT-1772 + Fix 4: Path 4)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_auto_creds_valid(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
[SECRT-1772] When a node has auto_credentials with a valid _credentials_id
|
||||
that exists in the store, validation should pass without errors.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-auto-creds"
|
||||
mock_node.credentials_optional = False
|
||||
mock_node.input_default = {
|
||||
"spreadsheet": {
|
||||
"_credentials_id": "valid-cred-id",
|
||||
"id": "file-123",
|
||||
"name": "test.xlsx",
|
||||
}
|
||||
}
|
||||
|
||||
mock_block = mocker.MagicMock()
|
||||
# No regular credentials fields
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {}
|
||||
# Has auto_credentials fields
|
||||
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||
"credentials": {
|
||||
"field_name": "spreadsheet",
|
||||
"config": {"provider": "google", "type": "oauth2"},
|
||||
}
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Mock the credentials store to return valid credentials
|
||||
mock_store = mocker.MagicMock()
|
||||
mock_creds = mocker.MagicMock()
|
||||
mock_creds.id = "valid-cred-id"
|
||||
mock_store.get_creds_by_id = mocker.AsyncMock(return_value=mock_creds)
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_integration_credentials_store",
|
||||
return_value=mock_store,
|
||||
)
|
||||
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
assert mock_node.id not in errors
|
||||
assert mock_node.id not in nodes_to_skip
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_auto_creds_missing(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
[SECRT-1772] When a node has auto_credentials with a _credentials_id
|
||||
that doesn't exist for the current user, validation should report an error.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-bad-auto-creds"
|
||||
mock_node.credentials_optional = False
|
||||
mock_node.input_default = {
|
||||
"spreadsheet": {
|
||||
"_credentials_id": "other-users-cred-id",
|
||||
"id": "file-123",
|
||||
"name": "test.xlsx",
|
||||
}
|
||||
}
|
||||
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {}
|
||||
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||
"credentials": {
|
||||
"field_name": "spreadsheet",
|
||||
"config": {"provider": "google", "type": "oauth2"},
|
||||
}
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Mock the credentials store to return None (cred not found for this user)
|
||||
mock_store = mocker.MagicMock()
|
||||
mock_store.get_creds_by_id = mocker.AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_integration_credentials_store",
|
||||
return_value=mock_store,
|
||||
)
|
||||
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="different-user",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
assert mock_node.id in errors
|
||||
assert "spreadsheet" in errors[mock_node.id]
|
||||
assert "not available" in errors[mock_node.id]["spreadsheet"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_both_regular_and_auto(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
[SECRT-1772] A node that has BOTH regular credentials AND auto_credentials
|
||||
should have both validated.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-both-creds"
|
||||
mock_node.credentials_optional = False
|
||||
mock_node.input_default = {
|
||||
"credentials": {
|
||||
"id": "regular-cred-id",
|
||||
"provider": "github",
|
||||
"type": "api_key",
|
||||
},
|
||||
"spreadsheet": {
|
||||
"_credentials_id": "auto-cred-id",
|
||||
"id": "file-123",
|
||||
"name": "test.xlsx",
|
||||
},
|
||||
}
|
||||
|
||||
mock_credentials_field_type = mocker.MagicMock()
|
||||
mock_credentials_meta = mocker.MagicMock()
|
||||
mock_credentials_meta.id = "regular-cred-id"
|
||||
mock_credentials_meta.provider = "github"
|
||||
mock_credentials_meta.type = "api_key"
|
||||
mock_credentials_field_type.model_validate.return_value = mock_credentials_meta
|
||||
|
||||
mock_block = mocker.MagicMock()
|
||||
# Regular credentials field
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||
"credentials": mock_credentials_field_type,
|
||||
}
|
||||
# Auto-credentials field
|
||||
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||
"auto_credentials": {
|
||||
"field_name": "spreadsheet",
|
||||
"config": {"provider": "google", "type": "oauth2"},
|
||||
}
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
# Mock the credentials store to return valid credentials for both
|
||||
mock_store = mocker.MagicMock()
|
||||
mock_regular_creds = mocker.MagicMock()
|
||||
mock_regular_creds.id = "regular-cred-id"
|
||||
mock_regular_creds.provider = "github"
|
||||
mock_regular_creds.type = "api_key"
|
||||
|
||||
mock_auto_creds = mocker.MagicMock()
|
||||
mock_auto_creds.id = "auto-cred-id"
|
||||
|
||||
def get_creds_side_effect(user_id, cred_id):
|
||||
if cred_id == "regular-cred-id":
|
||||
return mock_regular_creds
|
||||
elif cred_id == "auto-cred-id":
|
||||
return mock_auto_creds
|
||||
return None
|
||||
|
||||
mock_store.get_creds_by_id = mocker.AsyncMock(side_effect=get_creds_side_effect)
|
||||
mocker.patch(
|
||||
"backend.executor.utils.get_integration_credentials_store",
|
||||
return_value=mock_store,
|
||||
)
|
||||
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# Both should validate successfully - no errors
|
||||
assert mock_node.id not in errors
|
||||
assert mock_node.id not in nodes_to_skip
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_input_credentials_auto_creds_skipped_when_none(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
When a node has auto_credentials but the field value has _credentials_id=None
|
||||
(e.g., from upstream connection), validation should skip it without error.
|
||||
"""
|
||||
from backend.executor.utils import _validate_node_input_credentials
|
||||
|
||||
mock_node = mocker.MagicMock()
|
||||
mock_node.id = "node-with-chained-auto-creds"
|
||||
mock_node.credentials_optional = False
|
||||
mock_node.input_default = {
|
||||
"spreadsheet": {
|
||||
"_credentials_id": None,
|
||||
"id": "file-123",
|
||||
"name": "test.xlsx",
|
||||
}
|
||||
}
|
||||
|
||||
mock_block = mocker.MagicMock()
|
||||
mock_block.input_schema.get_credentials_fields.return_value = {}
|
||||
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
||||
"credentials": {
|
||||
"field_name": "spreadsheet",
|
||||
"config": {"provider": "google", "type": "oauth2"},
|
||||
}
|
||||
}
|
||||
mock_node.block = mock_block
|
||||
|
||||
mock_graph = mocker.MagicMock()
|
||||
mock_graph.nodes = [mock_node]
|
||||
|
||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
||||
graph=mock_graph,
|
||||
user_id="test-user",
|
||||
nodes_input_masks=None,
|
||||
)
|
||||
|
||||
# No error - chained data with None cred_id is valid
|
||||
assert mock_node.id not in errors
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for CredentialsFieldInfo auto_credential tag (Fix 4: Path 4)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_credentials_field_info_auto_credential_tag():
|
||||
"""
|
||||
[Path 4] CredentialsFieldInfo should support is_auto_credential and
|
||||
input_field_name fields for distinguishing auto from regular credentials.
|
||||
"""
|
||||
from backend.data.model import CredentialsFieldInfo
|
||||
|
||||
# Regular credential should have is_auto_credential=False by default
|
||||
regular = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["github"],
|
||||
"credentials_types": ["api_key"],
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
assert regular.is_auto_credential is False
|
||||
assert regular.input_field_name is None
|
||||
|
||||
# Auto credential should have is_auto_credential=True
|
||||
auto = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["google"],
|
||||
"credentials_types": ["oauth2"],
|
||||
"is_auto_credential": True,
|
||||
"input_field_name": "spreadsheet",
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
assert auto.is_auto_credential is True
|
||||
assert auto.input_field_name == "spreadsheet"
|
||||
|
||||
|
||||
def test_make_node_credentials_input_map_excludes_auto_creds(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""
|
||||
[Path 4] make_node_credentials_input_map should only include regular credentials,
|
||||
not auto_credentials (which are resolved at execution time).
|
||||
"""
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.executor.utils import make_node_credentials_input_map
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# Create a mock graph with aggregate_credentials_inputs that returns
|
||||
# both regular and auto credentials
|
||||
mock_graph = mocker.MagicMock()
|
||||
|
||||
regular_field_info = CredentialsFieldInfo.model_validate(
|
||||
{
|
||||
"credentials_provider": ["github"],
|
||||
"credentials_types": ["api_key"],
|
||||
"is_auto_credential": False,
|
||||
},
|
||||
by_alias=True,
|
||||
)
|
||||
|
||||
# Mock regular_credentials_inputs property (auto_credentials are excluded)
|
||||
mock_graph.regular_credentials_inputs = {
|
||||
"github_creds": (regular_field_info, {("node-1", "credentials")}, True),
|
||||
}
|
||||
|
||||
graph_credentials_input = {
|
||||
"github_creds": CredentialsMetaInput(
|
||||
id="cred-123",
|
||||
provider=ProviderName("github"),
|
||||
type="api_key",
|
||||
),
|
||||
}
|
||||
|
||||
result = make_node_credentials_input_map(mock_graph, graph_credentials_input)
|
||||
|
||||
# Regular credentials should be mapped
|
||||
assert "node-1" in result
|
||||
assert "credentials" in result["node-1"]
|
||||
|
||||
# Auto credentials should NOT appear in the result
|
||||
# (they would have been mapped to the kwarg_name "credentials" not "spreadsheet")
|
||||
for node_id, fields in result.items():
|
||||
for field_name, value in fields.items():
|
||||
# Verify no auto-credential phantom entries
|
||||
if isinstance(value, dict):
|
||||
assert "_credentials_id" not in value
|
||||
|
||||
@@ -342,14 +342,6 @@ async def store_media_file(
|
||||
if not target_path.is_file():
|
||||
raise ValueError(f"Local file does not exist: {target_path}")
|
||||
|
||||
# Virus scan the local file before any further processing
|
||||
local_content = target_path.read_bytes()
|
||||
if len(local_content) > MAX_FILE_SIZE_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large: {len(local_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||
)
|
||||
await scan_content_safe(local_content, filename=sanitized_file)
|
||||
|
||||
# Return based on requested format
|
||||
if return_format == "for_local_processing":
|
||||
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
||||
|
||||
@@ -247,100 +247,3 @@ class TestFileCloudIntegration:
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_media_file_local_path_scanned(self):
|
||||
"""Test that local file paths are scanned for viruses."""
|
||||
graph_exec_id = "test-exec-123"
|
||||
local_file = "test_video.mp4"
|
||||
file_content = b"fake video content"
|
||||
|
||||
with patch(
|
||||
"backend.util.file.get_cloud_storage_handler"
|
||||
) as mock_handler_getter, patch(
|
||||
"backend.util.file.scan_content_safe"
|
||||
) as mock_scan, patch(
|
||||
"backend.util.file.Path"
|
||||
) as mock_path_class:
|
||||
|
||||
# Mock cloud storage handler - not a cloud path
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.is_cloud_path.return_value = False
|
||||
mock_handler_getter.return_value = mock_handler
|
||||
|
||||
# Mock virus scanner
|
||||
mock_scan.return_value = None
|
||||
|
||||
# Mock file system operations
|
||||
mock_base_path = MagicMock()
|
||||
mock_target_path = MagicMock()
|
||||
mock_resolved_path = MagicMock()
|
||||
|
||||
mock_path_class.return_value = mock_base_path
|
||||
mock_base_path.mkdir = MagicMock()
|
||||
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
||||
mock_target_path.resolve.return_value = mock_resolved_path
|
||||
mock_resolved_path.is_relative_to.return_value = True
|
||||
mock_resolved_path.is_file.return_value = True
|
||||
mock_resolved_path.read_bytes.return_value = file_content
|
||||
mock_resolved_path.relative_to.return_value = Path(local_file)
|
||||
mock_resolved_path.name = local_file
|
||||
|
||||
result = await store_media_file(
|
||||
file=MediaFileType(local_file),
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
# Verify virus scan was called for local file
|
||||
mock_scan.assert_called_once_with(file_content, filename=local_file)
|
||||
|
||||
# Result should be the relative path
|
||||
assert str(result) == local_file
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_media_file_local_path_virus_detected(self):
|
||||
"""Test that infected local files raise VirusDetectedError."""
|
||||
from backend.api.features.store.exceptions import VirusDetectedError
|
||||
|
||||
graph_exec_id = "test-exec-123"
|
||||
local_file = "infected.exe"
|
||||
file_content = b"malicious content"
|
||||
|
||||
with patch(
|
||||
"backend.util.file.get_cloud_storage_handler"
|
||||
) as mock_handler_getter, patch(
|
||||
"backend.util.file.scan_content_safe"
|
||||
) as mock_scan, patch(
|
||||
"backend.util.file.Path"
|
||||
) as mock_path_class:
|
||||
|
||||
# Mock cloud storage handler - not a cloud path
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.is_cloud_path.return_value = False
|
||||
mock_handler_getter.return_value = mock_handler
|
||||
|
||||
# Mock virus scanner to detect virus
|
||||
mock_scan.side_effect = VirusDetectedError(
|
||||
"EICAR-Test-File", "File rejected due to virus detection"
|
||||
)
|
||||
|
||||
# Mock file system operations
|
||||
mock_base_path = MagicMock()
|
||||
mock_target_path = MagicMock()
|
||||
mock_resolved_path = MagicMock()
|
||||
|
||||
mock_path_class.return_value = mock_base_path
|
||||
mock_base_path.mkdir = MagicMock()
|
||||
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
||||
mock_target_path.resolve.return_value = mock_resolved_path
|
||||
mock_resolved_path.is_relative_to.return_value = True
|
||||
mock_resolved_path.is_file.return_value = True
|
||||
mock_resolved_path.read_bytes.return_value = file_content
|
||||
|
||||
with pytest.raises(VirusDetectedError):
|
||||
await store_media_file(
|
||||
file=MediaFileType(local_file),
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
7101
autogpt_platform/backend/poetry.lock
generated
7101
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -12,17 +12,16 @@ python = ">=3.10,<3.14"
|
||||
aio-pika = "^9.5.5"
|
||||
aiohttp = "^3.10.0"
|
||||
aiodns = "^3.5.0"
|
||||
anthropic = "^0.79.0"
|
||||
anthropic = "^0.59.0"
|
||||
apscheduler = "^3.11.1"
|
||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||
claude-agent-sdk = "^0.1.0"
|
||||
click = "^8.2.0"
|
||||
cryptography = "^46.0"
|
||||
cryptography = "^45.0"
|
||||
discord-py = "^2.5.2"
|
||||
e2b-code-interpreter = "^1.5.2"
|
||||
elevenlabs = "^1.50.0"
|
||||
fastapi = "^0.128.5"
|
||||
fastapi = "^0.116.1"
|
||||
feedparser = "^6.0.11"
|
||||
flake8 = "^7.3.0"
|
||||
google-api-python-client = "^2.177.0"
|
||||
@@ -36,10 +35,10 @@ jinja2 = "^3.1.6"
|
||||
jsonref = "^1.1.0"
|
||||
jsonschema = "^4.25.0"
|
||||
langfuse = "^3.11.0"
|
||||
launchdarkly-server-sdk = "^9.14.1"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
mem0ai = "^0.1.115"
|
||||
moviepy = "^2.1.2"
|
||||
ollama = "^0.6.1"
|
||||
ollama = "^0.5.1"
|
||||
openai = "^1.97.1"
|
||||
orjson = "^3.10.0"
|
||||
pika = "^1.3.2"
|
||||
@@ -49,16 +48,16 @@ postmarker = "^1.0"
|
||||
praw = "~7.8.1"
|
||||
prisma = "^0.15.0"
|
||||
rank-bm25 = "^0.2.2"
|
||||
prometheus-client = "^0.24.1"
|
||||
prometheus-client = "^0.22.1"
|
||||
prometheus-fastapi-instrumentator = "^7.0.0"
|
||||
psutil = "^7.0.0"
|
||||
psycopg2-binary = "^2.9.10"
|
||||
pydantic = { extras = ["email"], version = "^2.12.5" }
|
||||
pydantic-settings = "^2.12.0"
|
||||
pydantic = { extras = ["email"], version = "^2.11.7" }
|
||||
pydantic-settings = "^2.10.1"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
python-dotenv = "^1.1.1"
|
||||
python-multipart = "^0.0.22"
|
||||
python-multipart = "^0.0.20"
|
||||
redis = "^6.2.0"
|
||||
regex = "^2025.9.18"
|
||||
replicate = "^1.0.6"
|
||||
@@ -66,11 +65,11 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal
|
||||
sqlalchemy = "^2.0.40"
|
||||
strenum = "^0.4.9"
|
||||
stripe = "^11.5.0"
|
||||
supabase = "2.27.3"
|
||||
tenacity = "^9.1.4"
|
||||
supabase = "2.17.0"
|
||||
tenacity = "^9.1.2"
|
||||
todoist-api-python = "^2.1.7"
|
||||
tweepy = "^4.16.0"
|
||||
uvicorn = { extras = ["standard"], version = "^0.40.0" }
|
||||
uvicorn = { extras = ["standard"], version = "^0.35.0" }
|
||||
websockets = "^15.0"
|
||||
youtube-transcript-api = "^1.2.1"
|
||||
yt-dlp = "2025.12.08"
|
||||
@@ -78,7 +77,7 @@ zerobouncesdk = "^1.1.2"
|
||||
# NOTE: please insert new dependencies in their alphabetical location
|
||||
pytest-snapshot = "^0.9.0"
|
||||
aiofiles = "^24.1.0"
|
||||
tiktoken = "^0.12.0"
|
||||
tiktoken = "^0.9.0"
|
||||
aioclamd = "^1.0.0"
|
||||
setuptools = "^80.9.0"
|
||||
gcloud-aio-storage = "^9.5.0"
|
||||
@@ -96,13 +95,13 @@ black = "^24.10.0"
|
||||
faker = "^38.2.0"
|
||||
httpx = "^0.28.1"
|
||||
isort = "^5.13.2"
|
||||
poethepoet = "^0.41.0"
|
||||
poethepoet = "^0.37.0"
|
||||
pre-commit = "^4.4.0"
|
||||
pyright = "^1.1.407"
|
||||
pytest-mock = "^3.15.1"
|
||||
pytest-watcher = "^0.6.3"
|
||||
pytest-watcher = "^0.4.2"
|
||||
requests = "^2.32.5"
|
||||
ruff = "^0.15.0"
|
||||
ruff = "^0.14.5"
|
||||
# NOTE: please insert new dependencies in their alphabetical location
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -102,7 +102,7 @@
|
||||
"react-markdown": "9.0.3",
|
||||
"react-modal": "3.16.3",
|
||||
"react-shepherd": "6.1.9",
|
||||
"react-window": "2.2.0",
|
||||
"react-window": "1.8.11",
|
||||
"recharts": "3.3.0",
|
||||
"rehype-autolink-headings": "7.1.0",
|
||||
"rehype-highlight": "7.0.2",
|
||||
@@ -140,7 +140,7 @@
|
||||
"@types/react": "18.3.17",
|
||||
"@types/react-dom": "18.3.5",
|
||||
"@types/react-modal": "3.16.3",
|
||||
"@types/react-window": "2.0.0",
|
||||
"@types/react-window": "1.8.8",
|
||||
"@vitejs/plugin-react": "5.1.2",
|
||||
"axe-playwright": "2.2.2",
|
||||
"chromatic": "13.3.3",
|
||||
|
||||
38
autogpt_platform/frontend/pnpm-lock.yaml
generated
38
autogpt_platform/frontend/pnpm-lock.yaml
generated
@@ -228,8 +228,8 @@ importers:
|
||||
specifier: 6.1.9
|
||||
version: 6.1.9(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(typescript@5.9.3)
|
||||
react-window:
|
||||
specifier: 2.2.0
|
||||
version: 2.2.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
specifier: 1.8.11
|
||||
version: 1.8.11(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
recharts:
|
||||
specifier: 3.3.0
|
||||
version: 3.3.0(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react-is@18.3.1)(react@18.3.1)(redux@5.0.1)
|
||||
@@ -337,8 +337,8 @@ importers:
|
||||
specifier: 3.16.3
|
||||
version: 3.16.3
|
||||
'@types/react-window':
|
||||
specifier: 2.0.0
|
||||
version: 2.0.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
specifier: 1.8.8
|
||||
version: 1.8.8
|
||||
'@vitejs/plugin-react':
|
||||
specifier: 5.1.2
|
||||
version: 5.1.2(vite@7.3.1(@types/node@24.10.0)(jiti@2.6.1)(terser@5.44.1)(yaml@2.8.2))
|
||||
@@ -3469,9 +3469,8 @@ packages:
|
||||
'@types/react-modal@3.16.3':
|
||||
resolution: {integrity: sha512-xXuGavyEGaFQDgBv4UVm8/ZsG+qxeQ7f77yNrW3n+1J6XAstUy5rYHeIHPh1KzsGc6IkCIdu6lQ2xWzu1jBTLg==}
|
||||
|
||||
'@types/react-window@2.0.0':
|
||||
resolution: {integrity: sha512-E8hMDtImEpMk1SjswSvqoSmYvk7GEtyVaTa/GJV++FdDNuMVVEzpAClyJ0nqeKYBrMkGiyH6M1+rPLM0Nu1exQ==}
|
||||
deprecated: This is a stub types definition. react-window provides its own type definitions, so you do not need this installed.
|
||||
'@types/react-window@1.8.8':
|
||||
resolution: {integrity: sha512-8Ls660bHR1AUA2kuRvVG9D/4XpRC6wjAaPT9dil7Ckc76eP9TKWZwwmgfq8Q1LANX3QNDnoU4Zp48A3w+zK69Q==}
|
||||
|
||||
'@types/react@18.3.17':
|
||||
resolution: {integrity: sha512-opAQ5no6LqJNo9TqnxBKsgnkIYHozW9KSTlFVoSUJYh1Fl/sswkEoqIugRSm7tbh6pABtYjGAjW+GOS23j8qbw==}
|
||||
@@ -5977,6 +5976,9 @@ packages:
|
||||
resolution: {integrity: sha512-UERzLsxzllchadvbPs5aolHh65ISpKpM+ccLbOJ8/vvpBKmAWf+la7dXFy7Mr0ySHbdHrFv5kGFCUHHe6GFEmw==}
|
||||
engines: {node: '>= 4.0.0'}
|
||||
|
||||
memoize-one@5.2.1:
|
||||
resolution: {integrity: sha512-zYiwtZUcYyXKo/np96AGZAckk+FWWsUdJ3cHGGmld7+AhvcWmQyGCYUh1hc4Q/pkOhb65dQR/pqCyK0cOaHz4Q==}
|
||||
|
||||
merge-stream@2.0.0:
|
||||
resolution: {integrity: sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==}
|
||||
|
||||
@@ -6889,11 +6891,12 @@ packages:
|
||||
'@types/react':
|
||||
optional: true
|
||||
|
||||
react-window@2.2.0:
|
||||
resolution: {integrity: sha512-Y2L7yonHq6K1pQA2P98wT5QdIsEcjBTB7T8o6Mub12hH9eYppXoYu6vgClmcjlh3zfNcW2UrXiJJJqDxUY7GVw==}
|
||||
react-window@1.8.11:
|
||||
resolution: {integrity: sha512-+SRbUVT2scadgFSWx+R1P754xHPEqvcfSfVX10QYg6POOz+WNgkN48pS+BtZNIMGiL1HYrSEiCkwsMS15QogEQ==}
|
||||
engines: {node: '>8.0.0'}
|
||||
peerDependencies:
|
||||
react: ^18.0.0 || ^19.0.0
|
||||
react-dom: ^18.0.0 || ^19.0.0
|
||||
react: ^15.0.0 || ^16.0.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
||||
react-dom: ^15.0.0 || ^16.0.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
||||
|
||||
react@18.3.1:
|
||||
resolution: {integrity: sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==}
|
||||
@@ -11600,12 +11603,9 @@ snapshots:
|
||||
dependencies:
|
||||
'@types/react': 18.3.17
|
||||
|
||||
'@types/react-window@2.0.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
|
||||
'@types/react-window@1.8.8':
|
||||
dependencies:
|
||||
react-window: 2.2.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||
transitivePeerDependencies:
|
||||
- react
|
||||
- react-dom
|
||||
'@types/react': 18.3.17
|
||||
|
||||
'@types/react@18.3.17':
|
||||
dependencies:
|
||||
@@ -14545,6 +14545,8 @@ snapshots:
|
||||
dependencies:
|
||||
fs-monkey: 1.1.0
|
||||
|
||||
memoize-one@5.2.1: {}
|
||||
|
||||
merge-stream@2.0.0: {}
|
||||
|
||||
merge2@1.4.1: {}
|
||||
@@ -15590,8 +15592,10 @@ snapshots:
|
||||
optionalDependencies:
|
||||
'@types/react': 18.3.17
|
||||
|
||||
react-window@2.2.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||
react-window@1.8.11(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||
dependencies:
|
||||
'@babel/runtime': 7.28.4
|
||||
memoize-one: 5.2.1
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
|
||||
|
||||
@@ -12307,9 +12307,7 @@
|
||||
"title": "Location"
|
||||
},
|
||||
"msg": { "type": "string", "title": "Message" },
|
||||
"type": { "type": "string", "title": "Error Type" },
|
||||
"input": { "title": "Input" },
|
||||
"ctx": { "type": "object", "title": "Context" }
|
||||
"type": { "type": "string", "title": "Error Type" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["loc", "msg", "type"],
|
||||
|
||||
@@ -4,7 +4,9 @@ import { loadScript } from "@/services/scripts/scripts";
|
||||
export async function loadGoogleAPIPicker(): Promise<void> {
|
||||
validateWindow();
|
||||
|
||||
await loadScript("https://apis.google.com/js/api.js");
|
||||
await loadScript("https://apis.google.com/js/api.js", {
|
||||
referrerPolicy: "no-referrer-when-downgrade",
|
||||
});
|
||||
|
||||
const googleAPI = window.gapi;
|
||||
if (!googleAPI) {
|
||||
@@ -27,7 +29,9 @@ export async function loadGoogleIdentityServices(): Promise<void> {
|
||||
throw new Error("Google Identity Services cannot load on server");
|
||||
}
|
||||
|
||||
await loadScript("https://accounts.google.com/gsi/client");
|
||||
await loadScript("https://accounts.google.com/gsi/client", {
|
||||
referrerPolicy: "no-referrer-when-downgrade",
|
||||
});
|
||||
|
||||
const google = window.google;
|
||||
if (!google?.accounts?.oauth2) {
|
||||
|
||||
@@ -4,7 +4,7 @@ import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Bell, MagnifyingGlass, X } from "@phosphor-icons/react";
|
||||
import { List, type RowComponentProps } from "react-window";
|
||||
import { FixedSizeList as List } from "react-window";
|
||||
import { AgentExecutionWithInfo } from "../../helpers";
|
||||
import { ActivityItem } from "../ActivityItem";
|
||||
import styles from "./styles.module.css";
|
||||
@@ -19,16 +19,14 @@ interface Props {
|
||||
recentFailures: AgentExecutionWithInfo[];
|
||||
}
|
||||
|
||||
interface ActivityRowProps {
|
||||
executions: AgentExecutionWithInfo[];
|
||||
interface VirtualizedItemProps {
|
||||
index: number;
|
||||
style: React.CSSProperties;
|
||||
data: AgentExecutionWithInfo[];
|
||||
}
|
||||
|
||||
function VirtualizedActivityItem({
|
||||
index,
|
||||
style,
|
||||
executions,
|
||||
}: RowComponentProps<ActivityRowProps>) {
|
||||
const execution = executions[index];
|
||||
function VirtualizedActivityItem({ index, style, data }: VirtualizedItemProps) {
|
||||
const execution = data[index];
|
||||
return (
|
||||
<div style={style}>
|
||||
<ActivityItem execution={execution} />
|
||||
@@ -131,13 +129,14 @@ export function ActivityDropdown({
|
||||
>
|
||||
{filteredExecutions.length > 0 ? (
|
||||
<List
|
||||
defaultHeight={listHeight}
|
||||
rowCount={filteredExecutions.length}
|
||||
rowHeight={itemHeight}
|
||||
rowProps={{ executions: filteredExecutions }}
|
||||
rowComponent={VirtualizedActivityItem}
|
||||
style={{ width: 320, height: listHeight }}
|
||||
/>
|
||||
height={listHeight}
|
||||
width={320} // Match dropdown width (w-80 = 20rem = 320px)
|
||||
itemCount={filteredExecutions.length}
|
||||
itemSize={itemHeight}
|
||||
itemData={filteredExecutions}
|
||||
>
|
||||
{VirtualizedActivityItem}
|
||||
</List>
|
||||
) : (
|
||||
<div className="flex h-full flex-col items-center justify-center gap-5 pb-8 pt-6">
|
||||
<div className="mx-auto inline-flex flex-col items-center justify-center rounded-full bg-bgLightGrey p-6">
|
||||
|
||||
Reference in New Issue
Block a user