Compare commits

..

6 Commits

Author SHA1 Message Date
Otto
562cf04ab6 refactor(backend): extract shared auto-credential parsing to utils.py
Addresses review feedback on #12004:
- Added AutoCredentialFieldInfo dataclass and parse_auto_credential_field()
  helper to executor/utils.py
- Updated _acquire_auto_credentials in manager.py to use shared helper
- Updated _validate_node_input_credentials in utils.py to use shared helper

This consolidates the duplicate logic for parsing GoogleDriveFileField-style
auto-credential fields, making manager.py less cluttered while ensuring
consistent validation/acquisition behavior.
2026-02-09 07:57:36 +00:00
Nicholas Tindle
90b3b5ba16 fix(backend): Fix misplaced section header in graph_test.py
Move the _reassign_ids section comment to above the actual _reassign_ids
tests, and label the combine() tests correctly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 16:11:47 -06:00
Nicholas Tindle
f4f81bc4fc fix(backend): Remove _credentials_id key on fork instead of setting to None
Setting _credentials_id to None on fork was ambiguous — both "forked,
needs re-auth" and "chained data from upstream" were represented as None.
This caused _acquire_auto_credentials to silently skip credential
acquisition for forked agents, leading to confusing TypeErrors at runtime.

Now the key is deleted entirely, making the three states unambiguous:
- Present with value: user-selected credentials
- Present as None: chained data from upstream block
- Absent: forked/needs re-authentication

Also adds pre-run validation for the missing key case and makes error
messages provider-agnostic.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 17:34:16 -06:00
Nicholas Tindle
c5abc01f25 fix(backend): Add error handling for auto-credentials store lookup
Wrap get_creds_by_id call in try/except in the auto-credentials
validation path to match the error handling pattern used for regular
credentials.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 16:53:29 -06:00
Nicholas Tindle
8b7053c1de merge: Resolve conflicts with dev (PR #11986 graph model refactor)
Adapt auto-credentials filtering to dev's refactored graph model:
- aggregate_credentials_inputs() now returns 3-tuples (field_info, node_pairs, is_required)
- credentials_input_schema moved to GraphModel, builds JSON schema directly
- Update regular/auto_credentials_inputs properties for 3-tuple format
- Update test mocks and assertions for new tuple format and class hierarchy

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 16:39:57 -06:00
Nicholas Tindle
e00c1202ad fix(platform): Fix Google Drive auto-credentials handling across the platform
- Tag auto-credentials with `is_auto_credential` and `input_field_name` on `CredentialsFieldInfo` to distinguish them from regular user-provided credentials
- Add `regular_credentials_inputs` and `auto_credentials_inputs` properties to `Graph` so UI schemas, CoPilot, and library presets only surface regular credentials
- Extract `_acquire_auto_credentials()` helper in executor to resolve embedded `_credentials_id` at execution time with proper lock management
- Validate auto-credentials ownership in `_validate_node_input_credentials()` to catch stale/missing credentials before execution
- Clear `_credentials_id` in `_reassign_ids()` on graph fork so cloned agents require re-authentication
- Propagate `is_auto_credential` through `combine()` and `discriminate()` on `CredentialsFieldInfo`
- Add `referrerPolicy: "no-referrer-when-downgrade"` to Google API script loading to fix Firefox API key validation
- Comprehensive test coverage for all new behavior

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 16:08:53 -06:00
200 changed files with 13795 additions and 17904 deletions

View File

@@ -49,7 +49,7 @@ jobs:
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
if: github.event_name == 'push'
uses: peter-evans/create-pull-request@v8
uses: peter-evans/create-pull-request@v7
with:
add-paths: classic/frontend/build/web
base: ${{ github.ref_name }}

View File

@@ -42,7 +42,7 @@ jobs:
- name: Get CI failure details
id: failure_details
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const run = await github.rest.actions.getWorkflowRun({

View File

@@ -41,7 +41,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -78,7 +78,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22"
@@ -91,7 +91,7 @@ jobs:
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
@@ -124,7 +124,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes
@@ -309,7 +309,6 @@ jobs:
uses: anthropics/claude-code-action@v1
with:
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
allowed_bots: "dependabot[bot]"
claude_args: |
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
prompt: |

View File

@@ -57,7 +57,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -94,7 +94,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22"
@@ -107,7 +107,7 @@ jobs:
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
@@ -140,7 +140,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes

View File

@@ -39,7 +39,7 @@ jobs:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
@@ -76,7 +76,7 @@ jobs:
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22"
@@ -89,7 +89,7 @@ jobs:
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
@@ -132,7 +132,7 @@ jobs:
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes

View File

@@ -33,7 +33,7 @@ jobs:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -33,7 +33,7 @@ jobs:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -38,7 +38,7 @@ jobs:
python-version: "3.11"
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -88,7 +88,7 @@ jobs:
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
- name: Set up Python dependency cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}

View File

@@ -17,7 +17,7 @@ jobs:
- name: Check comment permissions and deployment status
id: check_status
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const commentBody = context.payload.comment.body.trim();
@@ -55,7 +55,7 @@ jobs:
- name: Post permission denied comment
if: steps.check_status.outputs.permission_denied == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({
@@ -68,7 +68,7 @@ jobs:
- name: Get PR details for deployment
id: pr_details
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const pr = await github.rest.pulls.get({
@@ -98,7 +98,7 @@ jobs:
- name: Post deploy success comment
if: steps.check_status.outputs.should_deploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({
@@ -126,7 +126,7 @@ jobs:
- name: Post undeploy success comment
if: steps.check_status.outputs.should_undeploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({
@@ -139,7 +139,7 @@ jobs:
- name: Check deployment status on PR close
id: check_pr_close
if: github.event_name == 'pull_request' && github.event.action == 'closed'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
const comments = await github.rest.issues.listComments({
@@ -187,7 +187,7 @@ jobs:
github.event_name == 'pull_request' &&
github.event.action == 'closed' &&
steps.check_pr_close.outputs.should_undeploy == 'true'
uses: actions/github-script@v8
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.createComment({

View File

@@ -42,7 +42,7 @@ jobs:
- 'autogpt_platform/frontend/src/components/**'
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -54,7 +54,7 @@ jobs:
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
@@ -74,7 +74,7 @@ jobs:
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -82,7 +82,7 @@ jobs:
run: corepack enable
- name: Restore dependencies cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
@@ -112,7 +112,7 @@ jobs:
fetch-depth: 0
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -120,7 +120,7 @@ jobs:
run: corepack enable
- name: Restore dependencies cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
@@ -153,7 +153,7 @@ jobs:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -176,7 +176,7 @@ jobs:
uses: docker/setup-buildx-action@v3
- name: Cache Docker layers
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
@@ -231,7 +231,7 @@ jobs:
fi
- name: Restore dependencies cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
@@ -282,7 +282,7 @@ jobs:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -290,7 +290,7 @@ jobs:
run: corepack enable
- name: Restore dependencies cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}

View File

@@ -32,7 +32,7 @@ jobs:
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -44,7 +44,7 @@ jobs:
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
@@ -68,7 +68,7 @@ jobs:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: "22.18.0"
@@ -88,7 +88,7 @@ jobs:
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
- name: Restore dependencies cache
uses: actions/cache@v5
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}

File diff suppressed because it is too large Load Diff

View File

@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
[tool.poetry.dependencies]
python = ">=3.10,<4.0"
colorama = "^0.4.6"
cryptography = "^46.0"
cryptography = "^45.0"
expiringdict = "^1.2.2"
fastapi = "^0.128.0"
google-cloud-logging = "^3.13.0"
launchdarkly-server-sdk = "^9.14.1"
pydantic = "^2.12.5"
pydantic-settings = "^2.12.0"
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
fastapi = "^0.116.1"
google-cloud-logging = "^3.12.1"
launchdarkly-server-sdk = "^9.12.0"
pydantic = "^2.11.7"
pydantic-settings = "^2.10.1"
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
redis = "^6.2.0"
supabase = "^2.27.2"
uvicorn = "^0.40.0"
supabase = "^2.16.0"
uvicorn = "^0.35.0"
[tool.poetry.group.dev.dependencies]
pyright = "^1.1.408"
pyright = "^1.1.404"
pytest = "^8.4.1"
pytest-asyncio = "^1.3.0"
pytest-mock = "^3.15.1"
pytest-asyncio = "^1.1.0"
pytest-mock = "^3.14.1"
pytest-cov = "^6.2.1"
ruff = "^0.15.0"
ruff = "^0.12.11"
[build-system]
requires = ["poetry-core"]

View File

@@ -45,7 +45,10 @@ async def create_chat_session(
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
)
return await PrismaChatSession.prisma().create(data=data)
return await PrismaChatSession.prisma().create(
data=data,
include={"Messages": True},
)
async def update_chat_session(

View File

@@ -18,10 +18,6 @@ class ResponseType(str, Enum):
START = "start"
FINISH = "finish"
# Step lifecycle (one LLM API call within a message)
START_STEP = "start-step"
FINISH_STEP = "finish-step"
# Text streaming
TEXT_START = "text-start"
TEXT_DELTA = "text-delta"
@@ -61,16 +57,6 @@ class StreamStart(StreamBaseResponse):
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-protocol fields like taskId."""
import json
data: dict[str, Any] = {
"type": self.type.value,
"messageId": self.messageId,
}
return f"data: {json.dumps(data)}\n\n"
class StreamFinish(StreamBaseResponse):
"""End of message/stream."""
@@ -78,26 +64,6 @@ class StreamFinish(StreamBaseResponse):
type: ResponseType = ResponseType.FINISH
class StreamStartStep(StreamBaseResponse):
"""Start of a step (one LLM API call within a message).
The AI SDK uses this to add a step-start boundary to message.parts,
enabling visual separation between multiple LLM calls in a single message.
"""
type: ResponseType = ResponseType.START_STEP
class StreamFinishStep(StreamBaseResponse):
"""End of a step (one LLM API call within a message).
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
so the next LLM call in a tool-call continuation starts with clean state.
"""
type: ResponseType = ResponseType.FINISH_STEP
# ========== Text Streaming ==========
@@ -151,7 +117,7 @@ class StreamToolOutputAvailable(StreamBaseResponse):
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
toolCallId: str = Field(..., description="Tool call ID this responds to")
output: str | dict[str, Any] = Field(..., description="Tool execution output")
# Keep these for internal backend use
# Additional fields for internal use (not part of AI SDK spec but useful)
toolName: str | None = Field(
default=None, description="Name of the tool that was executed"
)
@@ -159,17 +125,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
default=True, description="Whether the tool execution succeeded"
)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
import json
data = {
"type": self.type.value,
"toolCallId": self.toolCallId,
"output": self.output,
}
return f"data: {json.dumps(data)}\n\n"
# ========== Other ==========

View File

@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
@@ -17,29 +17,7 @@ from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat
from .tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
AgentPreviewResponse,
AgentSavedResponse,
AgentsFoundResponse,
BlockListResponse,
BlockOutputResponse,
ClarificationNeededResponse,
DocPageResponse,
DocSearchResultsResponse,
ErrorResponse,
ExecutionStartedResponse,
InputValidationErrorResponse,
NeedLoginResponse,
NoResultsResponse,
OperationInProgressResponse,
OperationPendingResponse,
OperationStartedResponse,
SetupRequirementsResponse,
UnderstandingUpdatedResponse,
)
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
config = ChatConfig()
@@ -306,6 +284,10 @@ async def stream_chat_post(
# Background task that runs the AI generation independently of SSE connection
async def run_ai_generation():
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
@@ -313,7 +295,6 @@ async def stream_chat_post(
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
):
# Write to Redis (subscribers will receive via XREAD)
await stream_registry.publish_chunk(task_id, chunk)
@@ -393,69 +374,63 @@ async def stream_chat_post(
@router.get(
"/sessions/{session_id}/stream",
)
async def resume_session_stream(
async def stream_chat_get(
session_id: str,
message: Annotated[str, Query(min_length=1, max_length=10000)],
user_id: str | None = Depends(auth.get_user_id),
is_user_message: bool = Query(default=True),
):
"""
Resume an active stream for a session.
Stream chat responses for a session (GET - legacy endpoint).
Called by the AI SDK's ``useChat(resume: true)`` on page load.
Checks for an active (in-progress) task on the session and either replays
the full SSE stream or returns 204 No Content if nothing is running.
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
- Text fragments as they are generated
- Tool call UI elements (if invoked)
- Tool execution results
Args:
session_id: The chat session identifier.
session_id: The chat session identifier to associate with the streamed messages.
message: The user's new message to process.
user_id: Optional authenticated user ID.
is_user_message: Whether the message is a user message.
Returns:
StreamingResponse (SSE) when an active stream exists,
or 204 No Content when there is nothing to resume.
StreamingResponse: SSE-formatted response chunks.
"""
import asyncio
active_task, _last_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if not active_task:
return Response(status_code=204)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=active_task.task_id,
user_id=user_id,
last_message_id="0-0", # Full replay so useChat rebuilds the message
)
if subscriber_queue is None:
return Response(status_code=204)
session = await _validate_and_get_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
try:
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
yield chunk.to_sse()
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
except GeneratorExit:
pass
except Exception as e:
logger.error(f"Error in resume stream for session {session_id}: {e}")
finally:
try:
await stream_registry.unsubscribe_from_task(
active_task.task_id, subscriber_queue
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion(
session_id,
message,
is_user_message=is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
exc_info=True,
)
yield "data: [DONE]\n\n"
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -463,8 +438,8 @@ async def resume_session_stream(
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
"X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
},
)
@@ -776,42 +751,3 @@ async def health_check() -> dict:
"service": "chat",
"version": "0.1.0",
}
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
ToolResponseUnion = (
AgentsFoundResponse
| NoResultsResponse
| AgentDetailsResponse
| SetupRequirementsResponse
| ExecutionStartedResponse
| NeedLoginResponse
| ErrorResponse
| InputValidationErrorResponse
| AgentOutputResponse
| UnderstandingUpdatedResponse
| AgentPreviewResponse
| AgentSavedResponse
| ClarificationNeededResponse
| BlockListResponse
| BlockOutputResponse
| DocSearchResultsResponse
| DocPageResponse
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
)
@router.get(
"/schema/tool-responses",
response_model=ToolResponseUnion,
include_in_schema=True,
summary="[Dummy] Tool response type export for codegen",
description="This endpoint is not meant to be called. It exists solely to "
"expose tool response models in the OpenAPI schema for frontend codegen.",
)
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
"""Never called at runtime. Exists only so Orval generates TS types."""
raise HTTPException(status_code=501, detail="Schema-only endpoint")

View File

@@ -52,10 +52,8 @@ from .response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
@@ -353,10 +351,6 @@ async def stream_chat_completion(
retry_count: int = 0,
session: ChatSession | None = None,
context: dict[str, str] | None = None, # {url: str, content: str}
_continuation_message_id: (
str | None
) = None, # Internal: reuse message ID for tool call continuations
_task_id: str | None = None, # Internal: task ID for SSE reconnection support
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Main entry point for streaming chat completions with database handling.
@@ -377,45 +371,21 @@ async def stream_chat_completion(
ValueError: If max_context_messages is exceeded
"""
completion_start = time.monotonic()
# Build log metadata for structured logging
log_meta = {"component": "ChatService", "session_id": session_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
extra={
"json_fields": {
**log_meta,
"message_len": len(message) if message else 0,
"is_user_message": is_user_message,
}
},
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
)
# Only fetch from Redis if session not provided (initial call)
if session is None:
fetch_start = time.monotonic()
session = await get_chat_session(session_id, user_id)
fetch_time = (time.monotonic() - fetch_start) * 1000
logger.info(
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
f"n_messages={len(session.messages) if session else 0}",
extra={
"json_fields": {
**log_meta,
"duration_ms": fetch_time,
"n_messages": len(session.messages) if session else 0,
}
},
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
f"message_count={len(session.messages) if session else 0}"
)
else:
logger.info(
f"[TIMING] Using provided session, messages={len(session.messages)}",
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
f"Using provided session object: {session.session_id}, "
f"message_count={len(session.messages)}"
)
if not session:
@@ -436,25 +406,17 @@ async def stream_chat_completion(
# Track user message in PostHog
if is_user_message:
posthog_start = time.monotonic()
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(message),
)
posthog_time = (time.monotonic() - posthog_start) * 1000
logger.info(
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
)
upsert_start = time.monotonic()
session = await upsert_chat_session(session)
upsert_time = (time.monotonic() - upsert_start) * 1000
logger.info(
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
f"Upserting session: {session.session_id} with user id {session.user_id}, "
f"message_count={len(session.messages)}"
)
session = await upsert_chat_session(session)
assert session, "Session not found"
# Generate title for new sessions on first user message (non-blocking)
@@ -492,13 +454,7 @@ async def stream_chat_completion(
asyncio.create_task(_update_title())
# Build system prompt with business understanding
prompt_start = time.monotonic()
system_prompt, understanding = await _build_system_prompt(user_id)
prompt_time = (time.monotonic() - prompt_start) * 1000
logger.info(
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
)
# Initialize variables for streaming
assistant_response = ChatMessage(
@@ -523,27 +479,13 @@ async def stream_chat_completion(
# Generate unique IDs for AI SDK protocol
import uuid as uuid_module
is_continuation = _continuation_message_id is not None
message_id = _continuation_message_id or str(uuid_module.uuid4())
message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4())
# Only yield message start for the initial call, not for continuations.
setup_time = (time.monotonic() - completion_start) * 1000
logger.info(
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
if not is_continuation:
yield StreamStart(messageId=message_id, taskId=_task_id)
# Emit start-step before each LLM call (AI SDK uses this to add step boundaries)
yield StreamStartStep()
# Yield message start
yield StreamStart(messageId=message_id)
try:
logger.info(
"[TIMING] Calling _stream_chat_chunks",
extra={"json_fields": log_meta},
)
async for chunk in _stream_chat_chunks(
session=session,
tools=tools,
@@ -643,10 +585,6 @@ async def stream_chat_completion(
)
yield chunk
elif isinstance(chunk, StreamFinish):
if has_done_tool_call:
# Tool calls happened — close the step but don't send message-level finish.
# The continuation will open a new step, and finish will come at the end.
yield StreamFinishStep()
if not has_done_tool_call:
# Emit text-end before finish if we received text but haven't closed it
if has_received_text and not text_streaming_ended:
@@ -678,8 +616,6 @@ async def stream_chat_completion(
has_saved_assistant_message = True
has_yielded_end = True
# Emit finish-step before finish (resets AI SDK text/reasoning state)
yield StreamFinishStep()
yield chunk
elif isinstance(chunk, StreamError):
has_yielded_error = True
@@ -764,7 +700,6 @@ async def stream_chat_completion(
error_response = StreamError(errorText=error_message)
yield error_response
if not has_yielded_end:
yield StreamFinishStep()
yield StreamFinish()
return
@@ -779,8 +714,6 @@ async def stream_chat_completion(
retry_count=retry_count + 1,
session=session,
context=context,
_continuation_message_id=message_id, # Reuse message ID since start was already sent
_task_id=_task_id,
):
yield chunk
return # Exit after retry to avoid double-saving in finally block
@@ -850,8 +783,6 @@ async def stream_chat_completion(
session=session, # Pass session object to avoid Redis refetch
context=context,
tool_call_response=str(tool_response_messages),
_continuation_message_id=message_id, # Reuse message ID to avoid duplicates
_task_id=_task_id,
):
yield chunk
@@ -962,21 +893,9 @@ async def _stream_chat_chunks(
SSE formatted JSON response objects
"""
import time as time_module
stream_chunks_start = time_module.perf_counter()
model = config.model
# Build log metadata for structured logging
log_meta = {"component": "ChatService", "session_id": session.session_id}
if session.user_id:
log_meta["user_id"] = session.user_id
logger.info(
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
f"user={session.user_id}, n_messages={len(session.messages)}",
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
)
logger.info("Starting pure chat stream")
messages = session.to_openai_messages()
if system_prompt:
@@ -987,18 +906,12 @@ async def _stream_chat_chunks(
messages = [system_message] + messages
# Apply context window management
context_start = time_module.perf_counter()
context_result = await _manage_context_window(
messages=messages,
model=model,
api_key=config.api_key,
base_url=config.base_url,
)
context_time = (time_module.perf_counter() - context_start) * 1000
logger.info(
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
)
if context_result.error:
if "System prompt dropped" in context_result.error:
@@ -1033,19 +946,9 @@ async def _stream_chat_chunks(
while retry_count <= MAX_RETRIES:
try:
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
retry_info = (
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
)
logger.info(
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"retry_count": retry_count,
}
},
f"Creating OpenAI chat completion stream..."
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
)
# Build extra_body for OpenRouter tracing and PostHog analytics
@@ -1062,7 +965,6 @@ async def _stream_chat_chunks(
:128
] # OpenRouter limit
api_call_start = time_module.perf_counter()
stream = await client.chat.completions.create(
model=model,
messages=cast(list[ChatCompletionMessageParam], messages),
@@ -1072,11 +974,6 @@ async def _stream_chat_chunks(
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
extra_body=extra_body,
)
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
logger.info(
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
)
# Variables to accumulate tool calls
tool_calls: list[dict[str, Any]] = []
@@ -1087,13 +984,10 @@ async def _stream_chat_chunks(
# Track if we've started the text block
text_started = False
first_content_chunk = True
chunk_count = 0
# Process the stream
chunk: ChatCompletionChunk
async for chunk in stream:
chunk_count += 1
if chunk.usage:
yield StreamUsage(
promptTokens=chunk.usage.prompt_tokens,
@@ -1116,23 +1010,6 @@ async def _stream_chat_chunks(
if not text_started and text_block_id:
yield StreamTextStart(id=text_block_id)
text_started = True
# Log timing for first content chunk
if first_content_chunk:
first_content_chunk = False
ttfc = (
time_module.perf_counter() - api_call_start
) * 1000
logger.info(
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
f"(since API call), n_chunks={chunk_count}",
extra={
"json_fields": {
**log_meta,
"time_to_first_chunk_ms": ttfc,
"n_chunks": chunk_count,
}
},
)
# Stream the text delta
text_response = StreamTextDelta(
id=text_block_id or "",
@@ -1189,21 +1066,7 @@ async def _stream_chat_chunks(
toolName=tool_calls[idx]["function"]["name"],
)
emitted_start_for_idx.add(idx)
stream_duration = time_module.perf_counter() - api_call_start
logger.info(
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
f"duration={stream_duration:.2f}s, "
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
extra={
"json_fields": {
**log_meta,
"stream_duration_ms": stream_duration * 1000,
"finish_reason": finish_reason,
"n_chunks": chunk_count,
"n_tool_calls": len(tool_calls),
}
},
)
logger.info(f"Stream complete. Finish reason: {finish_reason}")
# Yield all accumulated tool calls after the stream is complete
# This ensures all tool call arguments have been fully received
@@ -1223,12 +1086,6 @@ async def _stream_chat_chunks(
# Re-raise to trigger retry logic in the parent function
raise
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
logger.info(
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
f"session={session.session_id}, user={session.user_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
yield StreamFinish()
return
except Exception as e:
@@ -1708,7 +1565,6 @@ async def _execute_long_running_tool_with_streaming(
task_id,
StreamError(errorText=str(e)),
)
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish())
await _update_pending_operation(
@@ -1966,7 +1822,6 @@ async def _generate_llm_continuation_with_streaming(
# Publish start event
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
await stream_registry.publish_chunk(task_id, StreamStartStep())
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
# Stream the response
@@ -1990,7 +1845,6 @@ async def _generate_llm_continuation_with_streaming(
# Publish end events
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
await stream_registry.publish_chunk(task_id, StreamFinishStep())
if assistant_content:
# Reload session from DB to avoid race condition with user messages
@@ -2032,5 +1886,4 @@ async def _generate_llm_continuation_with_streaming(
task_id,
StreamError(errorText=f"Failed to generate response: {e}"),
)
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish())

View File

@@ -104,24 +104,6 @@ async def create_task(
Returns:
The created ActiveTask instance (metadata only)
"""
import time
start_time = time.perf_counter()
# Build log metadata for structured logging
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"session_id": session_id,
}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
task = ActiveTask(
task_id=task_id,
session_id=session_id,
@@ -132,18 +114,10 @@ async def create_task(
)
# Store metadata in Redis
redis_start = time.perf_counter()
redis = await get_redis_async()
redis_time = (time.perf_counter() - redis_start) * 1000
logger.info(
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
)
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
hset_start = time.perf_counter()
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
@@ -157,22 +131,12 @@ async def create_task(
"created_at": task.created_at.isoformat(),
},
)
hset_time = (time.perf_counter() - hset_start) * 1000
logger.info(
f"[TIMING] redis.hset took {hset_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
)
await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)
logger.debug(f"Created task {task_id} for session {session_id}")
return task
@@ -192,60 +156,26 @@ async def publish_chunk(
Returns:
The Redis Stream message ID
"""
import time
start_time = time.perf_counter()
chunk_type = type(chunk).__name__
chunk_json = chunk.model_dump_json()
message_id = "0-0"
# Build log metadata
log_meta = {
"component": "StreamRegistry",
"task_id": task_id,
"chunk_type": chunk_type,
}
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Write to Redis Stream for persistence and real-time delivery
xadd_start = time.perf_counter()
raw_id = await redis.xadd(
stream_key,
{"data": chunk_json},
maxlen=config.stream_max_length,
)
xadd_time = (time.perf_counter() - xadd_start) * 1000
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Set TTL on stream to match task metadata TTL
await redis.expire(stream_key, config.stream_ttl)
total_time = (time.perf_counter() - start_time) * 1000
# Only log timing for significant chunks or slow operations
if (
chunk_type
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
or total_time > 50
):
logger.info(
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"xadd_time_ms": xadd_time,
"message_id": message_id,
}
},
)
except Exception as e:
elapsed = (time.perf_counter() - start_time) * 1000
logger.error(
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}",
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
f"Failed to publish chunk for task {task_id}: {e}",
exc_info=True,
)
@@ -270,61 +200,24 @@ async def subscribe_to_task(
An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access
"""
import time
start_time = time.perf_counter()
# Build log metadata
log_meta = {"component": "StreamRegistry", "task_id": task_id}
if user_id:
log_meta["user_id"] = user_id
logger.info(
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
)
redis_start = time.perf_counter()
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
hgetall_time = (time.perf_counter() - redis_start) * 1000
logger.info(
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
)
if not meta:
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"reason": "task_not_found",
}
},
)
logger.debug(f"Task {task_id} not found in Redis")
return None
# Note: Redis client uses decode_responses=True, so keys are strings
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
log_meta["session_id"] = meta.get("session_id", "")
# Validate ownership - if task has an owner, requester must match
if task_user_id:
if user_id != task_user_id:
logger.warning(
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
extra={
"json_fields": {
**log_meta,
"task_owner": task_user_id,
"reason": "access_denied",
}
},
f"User {user_id} denied access to task {task_id} "
f"owned by {task_user_id}"
)
return None
@@ -332,19 +225,7 @@ async def subscribe_to_task(
stream_key = _get_task_stream_key(task_id)
# Step 1: Replay messages from Redis Stream
xread_start = time.perf_counter()
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
xread_time = (time.perf_counter() - xread_start) * 1000
logger.info(
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
extra={
"json_fields": {
**log_meta,
"duration_ms": xread_time,
"task_status": task_status,
}
},
)
replayed_count = 0
replay_last_id = last_message_id
@@ -363,48 +244,19 @@ async def subscribe_to_task(
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
logger.info(
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
extra={
"json_fields": {
**log_meta,
"n_messages_replayed": replayed_count,
"replay_last_id": replay_last_id,
}
},
)
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
# Step 2: If task is still running, start stream listener for live updates
if task_status == "running":
logger.info(
"[TIMING] Task still running, starting _stream_listener",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
listener_task = asyncio.create_task(
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
_stream_listener(task_id, subscriber_queue, replay_last_id)
)
# Track listener task for cleanup on unsubscribe
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
else:
# Task is completed/failed - add finish marker
logger.info(
f"[TIMING] Task already {task_status}, adding StreamFinish",
extra={"json_fields": {**log_meta, "task_status": task_status}},
)
await subscriber_queue.put(StreamFinish())
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
f"n_messages_replayed={replayed_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"n_messages_replayed": replayed_count,
}
},
)
return subscriber_queue
@@ -412,7 +264,6 @@ async def _stream_listener(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str,
log_meta: dict | None = None,
) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD.
@@ -423,27 +274,10 @@ async def _stream_listener(
task_id: Task ID to listen for
subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here)
log_meta: Structured logging metadata
"""
import time
start_time = time.perf_counter()
# Use provided log_meta or build minimal one
if log_meta is None:
log_meta = {"component": "StreamRegistry", "task_id": task_id}
logger.info(
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
)
queue_id = id(subscriber_queue)
# Track the last successfully delivered message ID for recovery hints
last_delivered_id = last_replayed_id
messages_delivered = 0
first_message_time = None
xread_count = 0
try:
redis = await get_redis_async()
@@ -453,39 +287,9 @@ async def _stream_listener(
while True:
# Block for up to 30 seconds waiting for new messages
# This allows periodic checking if task is still running
xread_start = time.perf_counter()
xread_count += 1
messages = await redis.xread(
{stream_key: current_id}, block=30000, count=100
)
xread_time = (time.perf_counter() - xread_start) * 1000
if messages:
msg_count = sum(len(msgs) for _, msgs in messages)
logger.info(
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"xread_count": xread_count,
"n_messages": msg_count,
"duration_ms": xread_time,
}
},
)
elif xread_time > 1000:
# Only log timeouts (30s blocking)
logger.info(
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
extra={
"json_fields": {
**log_meta,
"xread_count": xread_count,
"duration_ms": xread_time,
"reason": "timeout",
}
},
)
if not messages:
# Timeout - check if task is still running
@@ -522,30 +326,10 @@ async def _stream_listener(
)
# Update last delivered ID on successful delivery
last_delivered_id = current_id
messages_delivered += 1
if first_message_time is None:
first_message_time = time.perf_counter()
elapsed = (first_message_time - start_time) * 1000
logger.info(
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"chunk_type": type(chunk).__name__,
}
},
)
except asyncio.TimeoutError:
logger.warning(
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
extra={
"json_fields": {
**log_meta,
"timeout_s": QUEUE_PUT_TIMEOUT,
"reason": "queue_full",
}
},
f"Subscriber queue full for task {task_id}, "
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
)
# Send overflow error with recovery info
try:
@@ -567,44 +351,15 @@ async def _stream_listener(
# Stop listening on finish
if isinstance(chunk, StreamFinish):
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
}
},
)
return
except Exception as e:
logger.warning(
f"Error processing stream message: {e}",
extra={"json_fields": {**log_meta, "error": str(e)}},
)
logger.warning(f"Error processing stream message: {e}")
except asyncio.CancelledError:
elapsed = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed,
"messages_delivered": messages_delivered,
"reason": "cancelled",
}
},
)
logger.debug(f"Stream listener cancelled for task {task_id}")
raise # Re-raise to propagate cancellation
except Exception as e:
elapsed = (time.perf_counter() - start_time) * 1000
logger.error(
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
)
logger.error(f"Stream listener error for task {task_id}: {e}")
# On error, send finish to unblock subscriber
try:
await asyncio.wait_for(
@@ -613,24 +368,10 @@ async def _stream_listener(
)
except (asyncio.TimeoutError, asyncio.QueueFull):
logger.warning(
"Could not deliver finish event after error",
extra={"json_fields": log_meta},
f"Could not deliver finish event for task {task_id} after error"
)
finally:
# Clean up listener task mapping on exit
total_time = (time.perf_counter() - start_time) * 1000
logger.info(
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
f"delivered={messages_delivered}, xread_count={xread_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"messages_delivered": messages_delivered,
"xread_count": xread_count,
}
},
)
_listener_tasks.pop(queue_id, None)
@@ -857,10 +598,8 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
ResponseType,
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
@@ -874,8 +613,6 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
type_to_class: dict[str, type[StreamBaseResponse]] = {
ResponseType.START.value: StreamStart,
ResponseType.FINISH.value: StreamFinish,
ResponseType.START_STEP.value: StreamStartStep,
ResponseType.FINISH_STEP.value: StreamFinishStep,
ResponseType.TEXT_START.value: StreamTextStart,
ResponseType.TEXT_DELTA.value: StreamTextDelta,
ResponseType.TEXT_END.value: StreamTextEnd,

View File

@@ -13,32 +13,10 @@ from backend.api.features.chat.tools.models import (
NoResultsResponse,
)
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.data.block import BlockType, get_block
from backend.data.block import get_block
logger = logging.getLogger(__name__)
_TARGET_RESULTS = 10
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
_OVERFETCH_PAGE_SIZE = 40
# Block types that only work within graphs and cannot run standalone in CoPilot.
COPILOT_EXCLUDED_BLOCK_TYPES = {
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
BlockType.NOTE, # Visual annotation only - no runtime behavior
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
}
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
COPILOT_EXCLUDED_BLOCK_IDS = {
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
"3b191d9f-356f-482d-8238-ba04b6d18381",
}
class FindBlockTool(BaseTool):
"""Tool for searching available blocks."""
@@ -110,7 +88,7 @@ class FindBlockTool(BaseTool):
query=query,
content_types=[ContentType.BLOCK],
page=1,
page_size=_OVERFETCH_PAGE_SIZE,
page_size=10,
)
if not results:
@@ -130,90 +108,60 @@ class FindBlockTool(BaseTool):
block = get_block(block_id)
# Skip disabled blocks
if not block or block.disabled:
continue
if block and not block.disabled:
# Get input/output schemas
input_schema = {}
output_schema = {}
try:
input_schema = block.input_schema.jsonschema()
except Exception:
pass
try:
output_schema = block.output_schema.jsonschema()
except Exception:
pass
# Skip blocks excluded from CoPilot (graph-only blocks)
if (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
continue
# Get categories from block instance
categories = []
if hasattr(block, "categories") and block.categories:
categories = [cat.value for cat in block.categories]
# Get input/output schemas
input_schema = {}
output_schema = {}
try:
input_schema = block.input_schema.jsonschema()
except Exception as e:
logger.debug(
"Failed to generate input schema for block %s: %s",
block_id,
e,
)
try:
output_schema = block.output_schema.jsonschema()
except Exception as e:
logger.debug(
"Failed to generate output schema for block %s: %s",
block_id,
e,
)
# Get categories from block instance
categories = []
if hasattr(block, "categories") and block.categories:
categories = [cat.value for cat in block.categories]
# Extract required inputs for easier use
required_inputs: list[BlockInputFieldInfo] = []
if input_schema:
properties = input_schema.get("properties", {})
required_fields = set(input_schema.get("required", []))
# Get credential field names to exclude from required inputs
credentials_fields = set(
block.input_schema.get_credentials_fields().keys()
)
for field_name, field_schema in properties.items():
# Skip credential fields - they're handled separately
if field_name in credentials_fields:
continue
required_inputs.append(
BlockInputFieldInfo(
name=field_name,
type=field_schema.get("type", "string"),
description=field_schema.get("description", ""),
required=field_name in required_fields,
default=field_schema.get("default"),
)
# Extract required inputs for easier use
required_inputs: list[BlockInputFieldInfo] = []
if input_schema:
properties = input_schema.get("properties", {})
required_fields = set(input_schema.get("required", []))
# Get credential field names to exclude from required inputs
credentials_fields = set(
block.input_schema.get_credentials_fields().keys()
)
blocks.append(
BlockInfoSummary(
id=block_id,
name=block.name,
description=block.description or "",
categories=categories,
input_schema=input_schema,
output_schema=output_schema,
required_inputs=required_inputs,
for field_name, field_schema in properties.items():
# Skip credential fields - they're handled separately
if field_name in credentials_fields:
continue
required_inputs.append(
BlockInputFieldInfo(
name=field_name,
type=field_schema.get("type", "string"),
description=field_schema.get("description", ""),
required=field_name in required_fields,
default=field_schema.get("default"),
)
)
blocks.append(
BlockInfoSummary(
id=block_id,
name=block.name,
description=block.description or "",
categories=categories,
input_schema=input_schema,
output_schema=output_schema,
required_inputs=required_inputs,
)
)
)
if len(blocks) >= _TARGET_RESULTS:
break
if blocks and len(blocks) < _TARGET_RESULTS:
logger.debug(
"find_block returned %d/%d results for query '%s' "
"(filtered %d excluded/disabled blocks)",
len(blocks),
_TARGET_RESULTS,
query,
len(results) - len(blocks),
)
if not blocks:
return NoResultsResponse(

View File

@@ -1,139 +0,0 @@
"""Tests for block filtering in FindBlockTool."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
FindBlockTool,
)
from backend.api.features.chat.tools.models import BlockListResponse
from backend.data.block import BlockType
from ._test_data import make_session
_TEST_USER_ID = "test-user-find-block"
def make_mock_block(
block_id: str, name: str, block_type: BlockType, disabled: bool = False
):
"""Create a mock block for testing."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.description = f"{name} description"
mock.block_type = block_type
mock.disabled = disabled
mock.input_schema = MagicMock()
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
mock.input_schema.get_credentials_fields.return_value = {}
mock.output_schema = MagicMock()
mock.output_schema.jsonschema.return_value = {}
mock.categories = []
return mock
class TestFindBlockFiltering:
"""Tests for block filtering in FindBlockTool."""
def test_excluded_block_types_contains_expected_types(self):
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
def test_excluded_block_ids_contains_smart_decision_maker(self):
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_type_filtered_from_results(self):
"""Verify blocks with excluded BlockTypes are filtered from search results."""
session = make_session(user_id=_TEST_USER_ID)
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
search_results = [
{"content_id": "input-block-id", "score": 0.9},
{"content_id": "standard-block-id", "score": 0.8},
]
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
standard_block = make_mock_block(
"standard-block-id", "HTTP Request", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
"input-block-id": input_block,
"standard-block-id": standard_block,
}.get(block_id)
with patch(
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, 2),
):
with patch(
"backend.api.features.chat.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID, session=session, query="test"
)
# Should only return the standard block, not the INPUT block
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 1
assert response.blocks[0].id == "standard-block-id"
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_id_filtered_from_results(self):
"""Verify SmartDecisionMakerBlock is filtered from search results."""
session = make_session(user_id=_TEST_USER_ID)
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
search_results = [
{"content_id": smart_decision_id, "score": 0.9},
{"content_id": "normal-block-id", "score": 0.8},
]
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
smart_block = make_mock_block(
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
)
normal_block = make_mock_block(
"normal-block-id", "Normal Block", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
smart_decision_id: smart_block,
"normal-block-id": normal_block,
}.get(block_id)
with patch(
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, 2),
):
with patch(
"backend.api.features.chat.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID, session=session, query="decision"
)
# Should only return normal block, not SmartDecisionMakerBlock
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 1
assert response.blocks[0].id == "normal-block-id"

View File

@@ -1,29 +0,0 @@
"""Shared helpers for chat tools."""
from typing import Any
def get_inputs_from_schema(
input_schema: dict[str, Any],
exclude_fields: set[str] | None = None,
) -> list[dict[str, Any]]:
"""Extract input field info from JSON schema."""
if not isinstance(input_schema, dict):
return []
exclude = exclude_fields or set()
properties = input_schema.get("properties", {})
required = set(input_schema.get("required", []))
return [
{
"name": name,
"title": schema.get("title", name),
"type": schema.get("type", "string"),
"description": schema.get("description", ""),
"required": name in required,
"default": schema.get("default"),
}
for name, schema in properties.items()
if name not in exclude
]

View File

@@ -24,7 +24,6 @@ from backend.util.timezone_utils import (
)
from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import (
AgentDetails,
AgentDetailsResponse,
@@ -262,7 +261,7 @@ class RunAgentTool(BaseTool):
),
requirements={
"credentials": requirements_creds_list,
"inputs": get_inputs_from_schema(graph.input_schema),
"inputs": self._get_inputs_list(graph.input_schema),
"execution_modes": self._get_execution_modes(graph),
},
),
@@ -370,6 +369,22 @@ class RunAgentTool(BaseTool):
session_id=session_id,
)
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract inputs list from schema."""
inputs_list = []
if isinstance(input_schema, dict) and "properties" in input_schema:
for field_name, field_schema in input_schema["properties"].items():
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in input_schema.get("required", []),
}
)
return inputs_list
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
"""Get available execution modes for the graph."""
trigger_info = graph.trigger_setup_info
@@ -383,7 +398,7 @@ class RunAgentTool(BaseTool):
suffix: str,
) -> str:
"""Build a message describing available inputs for an agent."""
inputs_list = get_inputs_from_schema(graph.input_schema)
inputs_list = self._get_inputs_list(graph.input_schema)
required_names = [i["name"] for i in inputs_list if i["required"]]
optional_names = [i["name"] for i in inputs_list if not i["required"]]

View File

@@ -8,19 +8,14 @@ from typing import Any
from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
)
from backend.data.block import AnyBlockSchema, get_block
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.data.model import CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import (
BlockOutputResponse,
ErrorResponse,
@@ -29,10 +24,7 @@ from .models import (
ToolResponseBase,
UserReadiness,
)
from .utils import (
build_missing_credentials_from_field_info,
match_credentials_to_requirements,
)
from .utils import build_missing_credentials_from_field_info
logger = logging.getLogger(__name__)
@@ -81,6 +73,91 @@ class RunBlockTool(BaseTool):
def requires_auth(self) -> bool:
return True
async def _check_block_credentials(
self,
user_id: str,
block: Any,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Args:
user_id: User ID
block: Block to check credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
input_data = input_data or {}
# Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return matched_credentials, missing_credentials
# Get user's available credentials
creds_manager = IntegrationCredentialsManager()
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
# Get discriminator from input, falling back to schema default
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in effective_field_info.provider
and cred.type in effective_field_info.supported_types
),
None,
)
if matching_cred:
matched_credentials[field_name] = CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
else:
# Create a placeholder for the missing credential
provider = next(iter(effective_field_info.provider), "unknown")
cred_type = next(iter(effective_field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched_credentials, missing_credentials
async def _execute(
self,
user_id: str | None,
@@ -135,24 +212,11 @@ class RunBlockTool(BaseTool):
session_id=session_id,
)
# Check if block is excluded from CoPilot (graph-only blocks)
if (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
return ErrorResponse(
message=(
f"Block '{block.name}' cannot be run directly in CoPilot. "
"This block is designed for use within graphs only."
),
session_id=session_id,
)
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = (
await self._resolve_block_credentials(user_id, block, input_data)
matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block, input_data
)
if missing_credentials:
@@ -281,75 +345,29 @@ class RunBlockTool(BaseTool):
session_id=session_id,
)
async def _resolve_block_credentials(
self,
user_id: str,
block: AnyBlockSchema,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Resolve credentials for a block by matching user's available credentials.
Args:
user_id: User ID
block: Block to resolve credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple of (matched_credentials, missing_credentials) - matched credentials
are used for block execution, missing ones indicate setup requirements.
"""
input_data = input_data or {}
requirements = self._resolve_discriminated_credentials(block, input_data)
if not requirements:
return {}, []
return await match_credentials_to_requirements(user_id, requirements)
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
"""Extract non-credential inputs from block schema."""
inputs_list = []
schema = block.input_schema.jsonschema()
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
# Get credential field names to exclude
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
def _resolve_discriminated_credentials(
self,
block: AnyBlockSchema,
input_data: dict[str, Any],
) -> dict[str, CredentialsFieldInfo]:
"""Resolve credential requirements, applying discriminator logic where needed."""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return {}
for field_name, field_schema in properties.items():
# Skip credential fields
if field_name in credentials_fields:
continue
resolved: dict[str, CredentialsFieldInfo] = {}
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in required_fields,
}
)
for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
# For host-scoped credentials, add the discriminator value
# (e.g., URL) so _credential_is_for_host can match it
effective_field_info.discriminator_values.add(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
resolved[field_name] = effective_field_info
return resolved
return inputs_list

View File

@@ -1,106 +0,0 @@
"""Tests for block execution guards in RunBlockTool."""
from unittest.mock import MagicMock, patch
import pytest
from backend.api.features.chat.tools.models import ErrorResponse
from backend.api.features.chat.tools.run_block import RunBlockTool
from backend.data.block import BlockType
from ._test_data import make_session
_TEST_USER_ID = "test-user-run-block"
def make_mock_block(
block_id: str, name: str, block_type: BlockType, disabled: bool = False
):
"""Create a mock block for testing."""
mock = MagicMock()
mock.id = block_id
mock.name = name
mock.block_type = block_type
mock.disabled = disabled
mock.input_schema = MagicMock()
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
mock.input_schema.get_credentials_fields_info.return_value = []
return mock
class TestRunBlockFiltering:
"""Tests for block execution guards in RunBlockTool."""
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_type_returns_error(self):
"""Attempting to execute a block with excluded BlockType returns error."""
session = make_session(user_id=_TEST_USER_ID)
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
return_value=input_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="input-block-id",
input_data={},
)
assert isinstance(response, ErrorResponse)
assert "cannot be run directly in CoPilot" in response.message
assert "designed for use within graphs only" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_excluded_block_id_returns_error(self):
"""Attempting to execute SmartDecisionMakerBlock returns error."""
session = make_session(user_id=_TEST_USER_ID)
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
smart_block = make_mock_block(
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
return_value=smart_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id=smart_decision_id,
input_data={},
)
assert isinstance(response, ErrorResponse)
assert "cannot be run directly in CoPilot" in response.message
@pytest.mark.asyncio(loop_scope="session")
async def test_non_excluded_block_passes_guard(self):
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
session = make_session(user_id=_TEST_USER_ID)
standard_block = make_mock_block(
"standard-id", "HTTP Request", BlockType.STANDARD
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
return_value=standard_block,
):
tool = RunBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
block_id="standard-id",
input_data={},
)
# Should NOT be an ErrorResponse about CoPilot exclusion
# (may be other errors like missing credentials, but not the exclusion guard)
if isinstance(response, ErrorResponse):
assert "cannot be run directly in CoPilot" not in response.message

View File

@@ -8,7 +8,6 @@ from backend.api.features.library import model as library_model
from backend.api.features.store import db as store_db
from backend.data.graph import GraphModel
from backend.data.model import (
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
HostScopedCredentials,
@@ -118,7 +117,7 @@ def build_missing_credentials_from_graph(
preserving all supported credential types for each field.
"""
matched_keys = set(matched_credentials.keys()) if matched_credentials else set()
aggregated_fields = graph.aggregate_credentials_inputs()
aggregated_fields = graph.regular_credentials_inputs
return {
field_key: _serialize_missing_credential(field_key, field_info)
@@ -224,99 +223,6 @@ async def get_or_create_library_agent(
return library_agents[0]
async def match_credentials_to_requirements(
user_id: str,
requirements: dict[str, CredentialsFieldInfo],
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Match user's credentials against a dictionary of credential requirements.
This is the core matching logic shared by both graph and block credential matching.
"""
matched: dict[str, CredentialsMetaInput] = {}
missing: list[CredentialsMetaInput] = []
if not requirements:
return matched, missing
available_creds = await get_user_credentials(user_id)
for field_name, field_info in requirements.items():
matching_cred = find_matching_credential(available_creds, field_info)
if matching_cred:
try:
matched[field_name] = create_credential_meta_from_match(matching_cred)
except Exception as e:
logger.error(
f"Failed to create CredentialsMetaInput for field '{field_name}': "
f"provider={matching_cred.provider}, type={matching_cred.type}, "
f"credential_id={matching_cred.id}",
exc_info=True,
)
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=f"{field_name} (validation failed: {e})",
)
)
else:
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched, missing
async def get_user_credentials(user_id: str) -> list[Credentials]:
"""Get all available credentials for a user."""
creds_manager = IntegrationCredentialsManager()
return await creds_manager.store.get_all_creds(user_id)
def find_matching_credential(
available_creds: list[Credentials],
field_info: CredentialsFieldInfo,
) -> Credentials | None:
"""Find a credential that matches the required provider, type, scopes, and host."""
for cred in available_creds:
if cred.provider not in field_info.provider:
continue
if cred.type not in field_info.supported_types:
continue
if cred.type == "oauth2" and not _credential_has_required_scopes(
cred, field_info
):
continue
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
continue
return cred
return None
def create_credential_meta_from_match(
matching_cred: Credentials,
) -> CredentialsMetaInput:
"""Create a CredentialsMetaInput from a matched credential."""
return CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
async def match_user_credentials_to_graph(
user_id: str,
graph: GraphModel,
@@ -338,7 +244,7 @@ async def match_user_credentials_to_graph(
missing_creds: list[str] = []
# Get aggregated credentials requirements from the graph
aggregated_creds = graph.aggregate_credentials_inputs()
aggregated_creds = graph.regular_credentials_inputs
logger.debug(
f"Matching credentials for graph {graph.id}: {len(aggregated_creds)} required"
)
@@ -425,6 +331,8 @@ def _credential_has_required_scopes(
# If no scopes are required, any credential matches
if not requirements.required_scopes:
return True
# Check that credential scopes are a superset of required scopes
return set(credential.scopes).issuperset(requirements.required_scopes)

View File

@@ -0,0 +1,78 @@
"""Tests for chat tools utility functions."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data.model import CredentialsFieldInfo
def _make_regular_field() -> CredentialsFieldInfo:
return CredentialsFieldInfo.model_validate(
{
"credentials_provider": ["github"],
"credentials_types": ["api_key"],
"is_auto_credential": False,
},
by_alias=True,
)
def test_build_missing_credentials_excludes_auto_creds():
"""
build_missing_credentials_from_graph() should use regular_credentials_inputs
and thus exclude auto_credentials from the "missing" set.
"""
from backend.api.features.chat.tools.utils import (
build_missing_credentials_from_graph,
)
regular_field = _make_regular_field()
mock_graph = MagicMock()
# regular_credentials_inputs should only return the non-auto field
mock_graph.regular_credentials_inputs = {
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
}
result = build_missing_credentials_from_graph(mock_graph, matched_credentials=None)
# Should include the regular credential
assert "github_api_key" in result
# Should NOT include the auto_credential (not in regular_credentials_inputs)
assert "google_oauth2" not in result
@pytest.mark.asyncio
async def test_match_user_credentials_excludes_auto_creds():
"""
match_user_credentials_to_graph() should use regular_credentials_inputs
and thus exclude auto_credentials from matching.
"""
from backend.api.features.chat.tools.utils import match_user_credentials_to_graph
regular_field = _make_regular_field()
mock_graph = MagicMock()
mock_graph.id = "test-graph"
# regular_credentials_inputs returns only non-auto fields
mock_graph.regular_credentials_inputs = {
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
}
# Mock the credentials manager to return no credentials
with patch(
"backend.api.features.chat.tools.utils.IntegrationCredentialsManager"
) as MockCredsMgr:
mock_store = AsyncMock()
mock_store.get_all_creds.return_value = []
MockCredsMgr.return_value.store = mock_store
matched, missing = await match_user_credentials_to_graph(
user_id="test-user", graph=mock_graph
)
# No credentials available, so github should be missing
assert len(matched) == 0
assert len(missing) == 1
assert "github_api_key" in missing[0]

View File

@@ -1103,7 +1103,7 @@ async def create_preset_from_graph_execution(
raise NotFoundError(
f"Graph #{graph_execution.graph_id} not found or accessible"
)
elif len(graph.aggregate_credentials_inputs()) > 0:
elif len(graph.regular_credentials_inputs) > 0:
raise ValueError(
f"Graph execution #{graph_exec_id} can't be turned into a preset "
"because it was run before this feature existed "

View File

@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
try:
webset = await aexa.websets.get(id=input_data.external_id)
webset = aexa.websets.get(id=input_data.external_id)
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
yield "webset", webset_result
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
count=input_data.search_count,
)
webset = await aexa.websets.create(
webset = aexa.websets.create(
params=CreateWebsetParameters(
search=search_params,
external_id=input_data.external_id,
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
if input_data.metadata is not None:
payload["metadata"] = input_data.metadata
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
status_str = (
sdk_webset.status.value
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
response = await aexa.websets.list(
response = aexa.websets.list(
cursor=input_data.cursor,
limit=input_data.limit,
)
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
sdk_webset = aexa.websets.get(id=input_data.webset_id)
status_str = (
sdk_webset.status.value
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
deleted_webset = aexa.websets.delete(id=input_data.webset_id)
status_str = (
deleted_webset.status.value
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
status_str = (
canceled_webset.status.value
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
entity["description"] = input_data.entity_description
payload["entity"] = entity
sdk_preview = await aexa.websets.preview(params=payload)
sdk_preview = aexa.websets.preview(params=payload)
preview = PreviewWebsetModel.from_sdk(sdk_preview)
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
status = (
webset.status.value
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
# Extract basic info
webset_id = webset.id
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
total_items = 0
if input_data.include_sample_items and input_data.sample_size > 0:
items_response = await aexa.websets.items.list(
items_response = aexa.websets.items.list(
webset_id=input_data.webset_id, limit=input_data.sample_size
)
sample_items_data = [
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
# Get webset details
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
status = (
webset.status.value

View File

@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_enrichment = await aexa.websets.enrichments.create(
sdk_enrichment = aexa.websets.enrichments.create(
webset_id=input_data.webset_id, params=payload
)
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
items_enriched = 0
while time.time() - poll_start < input_data.polling_timeout:
current_enrich = await aexa.websets.enrichments.get(
current_enrich = aexa.websets.enrichments.get(
webset_id=input_data.webset_id, id=enrichment_id
)
current_status = (
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
if current_status in ["completed", "failed", "cancelled"]:
# Estimate items from webset searches
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
if webset.searches:
for search in webset.searches:
if search.progress:
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_enrichment = await aexa.websets.enrichments.get(
sdk_enrichment = aexa.websets.enrichments.get(
webset_id=input_data.webset_id, id=input_data.enrichment_id
)
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
deleted_enrichment = await aexa.websets.enrichments.delete(
deleted_enrichment = aexa.websets.enrichments.delete(
webset_id=input_data.webset_id, id=input_data.enrichment_id
)
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
canceled_enrichment = await aexa.websets.enrichments.cancel(
canceled_enrichment = aexa.websets.enrichments.cancel(
webset_id=input_data.webset_id, id=input_data.enrichment_id
)
# Try to estimate how many items were enriched before cancellation
items_enriched = 0
items_response = await aexa.websets.items.list(
items_response = aexa.websets.items.list(
webset_id=input_data.webset_id, limit=100
)

View File

@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
def _create_test_mock():
"""Create test mocks for the AsyncExa SDK."""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import MagicMock
# Create mock SDK import object
mock_import = MagicMock()
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
return {
"_get_client": lambda *args, **kwargs: MagicMock(
websets=MagicMock(
imports=MagicMock(create=AsyncMock(return_value=mock_import))
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
)
)
}
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
if input_data.metadata:
payload["metadata"] = input_data.metadata
sdk_import = await aexa.websets.imports.create(
sdk_import = aexa.websets.imports.create(
params=payload, csv_data=input_data.csv_data
)
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
import_obj = ImportModel.from_sdk(sdk_import)
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
response = await aexa.websets.imports.list(
response = aexa.websets.imports.list(
cursor=input_data.cursor,
limit=input_data.limit,
)
@@ -474,9 +474,7 @@ class ExaDeleteImportBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
deleted_import = await aexa.websets.imports.delete(
import_id=input_data.import_id
)
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
yield "import_id", deleted_import.id
yield "success", "true"
@@ -575,14 +573,14 @@ class ExaExportWebsetBlock(Block):
}
)
# Create async iterator for list_all
async def async_item_iterator(*args, **kwargs):
for item in [mock_item1, mock_item2]:
yield item
# Create mock iterator
mock_items = [mock_item1, mock_item2]
return {
"_get_client": lambda *args, **kwargs: MagicMock(
websets=MagicMock(items=MagicMock(list_all=async_item_iterator))
websets=MagicMock(
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
)
)
}
@@ -604,7 +602,7 @@ class ExaExportWebsetBlock(Block):
webset_id=input_data.webset_id, limit=input_data.max_items
)
async for sdk_item in item_iterator:
for sdk_item in item_iterator:
if len(all_items) >= input_data.max_items:
break

View File

@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_item = await aexa.websets.items.get(
sdk_item = aexa.websets.items.get(
webset_id=input_data.webset_id, id=input_data.item_id
)
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
response = None
while time.time() - start_time < input_data.wait_timeout:
response = await aexa.websets.items.list(
response = aexa.websets.items.list(
webset_id=input_data.webset_id,
cursor=input_data.cursor,
limit=input_data.limit,
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
interval = min(interval * 1.2, 10)
if not response:
response = await aexa.websets.items.list(
response = aexa.websets.items.list(
webset_id=input_data.webset_id,
cursor=input_data.cursor,
limit=input_data.limit,
)
else:
response = await aexa.websets.items.list(
response = aexa.websets.items.list(
webset_id=input_data.webset_id,
cursor=input_data.cursor,
limit=input_data.limit,
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
) -> BlockOutput:
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
deleted_item = await aexa.websets.items.delete(
deleted_item = aexa.websets.items.delete(
webset_id=input_data.webset_id, id=input_data.item_id
)
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
webset_id=input_data.webset_id, limit=input_data.max_items
)
async for sdk_item in item_iterator:
for sdk_item in item_iterator:
if len(all_items) >= input_data.max_items:
break
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
entity_type = "unknown"
if webset.searches:
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
# Get sample items if requested
sample_items: List[WebsetItemModel] = []
if input_data.sample_size > 0:
items_response = await aexa.websets.items.list(
items_response = aexa.websets.items.list(
webset_id=input_data.webset_id, limit=input_data.sample_size
)
# Convert to our stable models
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
# Get items starting from cursor
response = await aexa.websets.items.list(
response = aexa.websets.items.list(
webset_id=input_data.webset_id,
cursor=input_data.since_cursor,
limit=input_data.max_items,

View File

@@ -233,7 +233,7 @@ class ExaCreateMonitorBlock(Block):
def _create_test_mock():
"""Create test mocks for the AsyncExa SDK."""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import MagicMock
# Create mock SDK monitor object
mock_monitor = MagicMock()
@@ -263,7 +263,7 @@ class ExaCreateMonitorBlock(Block):
return {
"_get_client": lambda *args, **kwargs: MagicMock(
websets=MagicMock(
monitors=MagicMock(create=AsyncMock(return_value=mock_monitor))
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
)
)
}
@@ -320,7 +320,7 @@ class ExaCreateMonitorBlock(Block):
if input_data.metadata:
payload["metadata"] = input_data.metadata
sdk_monitor = await aexa.websets.monitors.create(params=payload)
sdk_monitor = aexa.websets.monitors.create(params=payload)
monitor = MonitorModel.from_sdk(sdk_monitor)
@@ -384,7 +384,7 @@ class ExaGetMonitorBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
monitor = MonitorModel.from_sdk(sdk_monitor)
@@ -476,7 +476,7 @@ class ExaUpdateMonitorBlock(Block):
if input_data.metadata is not None:
payload["metadata"] = input_data.metadata
sdk_monitor = await aexa.websets.monitors.update(
sdk_monitor = aexa.websets.monitors.update(
monitor_id=input_data.monitor_id, params=payload
)
@@ -522,9 +522,7 @@ class ExaDeleteMonitorBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
deleted_monitor = await aexa.websets.monitors.delete(
monitor_id=input_data.monitor_id
)
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
yield "monitor_id", deleted_monitor.id
yield "success", "true"
@@ -581,7 +579,7 @@ class ExaListMonitorsBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
response = await aexa.websets.monitors.list(
response = aexa.websets.monitors.list(
cursor=input_data.cursor,
limit=input_data.limit,
webset_id=input_data.webset_id,

View File

@@ -121,7 +121,7 @@ class ExaWaitForWebsetBlock(Block):
WebsetTargetStatus.IDLE,
WebsetTargetStatus.ANY_COMPLETE,
]:
final_webset = await aexa.websets.wait_until_idle(
final_webset = aexa.websets.wait_until_idle(
id=input_data.webset_id,
timeout=input_data.timeout,
poll_interval=input_data.check_interval,
@@ -164,7 +164,7 @@ class ExaWaitForWebsetBlock(Block):
interval = input_data.check_interval
while time.time() - start_time < input_data.timeout:
# Get current webset status
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
current_status = (
webset.status.value
if hasattr(webset.status, "value")
@@ -209,7 +209,7 @@ class ExaWaitForWebsetBlock(Block):
# Timeout reached
elapsed = time.time() - start_time
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
final_status = (
webset.status.value
if hasattr(webset.status, "value")
@@ -345,7 +345,7 @@ class ExaWaitForSearchBlock(Block):
try:
while time.time() - start_time < input_data.timeout:
# Get current search status using SDK
search = await aexa.websets.searches.get(
search = aexa.websets.searches.get(
webset_id=input_data.webset_id, id=input_data.search_id
)
@@ -401,7 +401,7 @@ class ExaWaitForSearchBlock(Block):
elapsed = time.time() - start_time
# Get last known status
search = await aexa.websets.searches.get(
search = aexa.websets.searches.get(
webset_id=input_data.webset_id, id=input_data.search_id
)
final_status = (
@@ -503,7 +503,7 @@ class ExaWaitForEnrichmentBlock(Block):
try:
while time.time() - start_time < input_data.timeout:
# Get current enrichment status using SDK
enrichment = await aexa.websets.enrichments.get(
enrichment = aexa.websets.enrichments.get(
webset_id=input_data.webset_id, id=input_data.enrichment_id
)
@@ -548,7 +548,7 @@ class ExaWaitForEnrichmentBlock(Block):
elapsed = time.time() - start_time
# Get last known status
enrichment = await aexa.websets.enrichments.get(
enrichment = aexa.websets.enrichments.get(
webset_id=input_data.webset_id, id=input_data.enrichment_id
)
final_status = (
@@ -575,7 +575,7 @@ class ExaWaitForEnrichmentBlock(Block):
) -> tuple[list[SampleEnrichmentModel], int]:
"""Get sample enriched data and count."""
# Get a few items to see enrichment results using SDK
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
sample_data: list[SampleEnrichmentModel] = []
enriched_count = 0

View File

@@ -317,7 +317,7 @@ class ExaCreateWebsetSearchBlock(Block):
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_search = await aexa.websets.searches.create(
sdk_search = aexa.websets.searches.create(
webset_id=input_data.webset_id, params=payload
)
@@ -350,7 +350,7 @@ class ExaCreateWebsetSearchBlock(Block):
poll_start = time.time()
while time.time() - poll_start < input_data.polling_timeout:
current_search = await aexa.websets.searches.get(
current_search = aexa.websets.searches.get(
webset_id=input_data.webset_id, id=search_id
)
current_status = (
@@ -442,7 +442,7 @@ class ExaGetWebsetSearchBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
sdk_search = await aexa.websets.searches.get(
sdk_search = aexa.websets.searches.get(
webset_id=input_data.webset_id, id=input_data.search_id
)
@@ -523,7 +523,7 @@ class ExaCancelWebsetSearchBlock(Block):
# Use AsyncExa SDK
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
canceled_search = await aexa.websets.searches.cancel(
canceled_search = aexa.websets.searches.cancel(
webset_id=input_data.webset_id, id=input_data.search_id
)
@@ -604,7 +604,7 @@ class ExaFindOrCreateSearchBlock(Block):
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
# Get webset to check existing searches
webset = await aexa.websets.get(id=input_data.webset_id)
webset = aexa.websets.get(id=input_data.webset_id)
# Look for existing search with same query
existing_search = None
@@ -636,7 +636,7 @@ class ExaFindOrCreateSearchBlock(Block):
if input_data.entity_type != SearchEntityType.AUTO:
payload["entity"] = {"type": input_data.entity_type.value}
sdk_search = await aexa.websets.searches.create(
sdk_search = aexa.websets.searches.create(
webset_id=input_data.webset_id, params=payload
)

View File

@@ -531,12 +531,12 @@ class LLMResponse(BaseModel):
def convert_openai_tool_fmt_to_anthropic(
openai_tools: list[dict] | None = None,
) -> Iterable[ToolParam] | anthropic.Omit:
) -> Iterable[ToolParam] | anthropic.NotGiven:
"""
Convert OpenAI tool format to Anthropic tool format.
"""
if not openai_tools or len(openai_tools) == 0:
return anthropic.omit
return anthropic.NOT_GIVEN
anthropic_tools = []
for tool in openai_tools:
@@ -596,10 +596,10 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
def get_parallel_tool_calls_param(
llm_model: LlmModel, parallel_tool_calls: bool | None
) -> bool | openai.Omit:
):
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
if llm_model.startswith("o") or parallel_tool_calls is None:
return openai.omit
return openai.NOT_GIVEN
return parallel_tool_calls

View File

@@ -319,6 +319,8 @@ class BlockSchema(BaseModel):
"credentials_provider": [config.get("provider", "google")],
"credentials_types": [config.get("type", "oauth2")],
"credentials_scopes": config.get("scopes"),
"is_auto_credential": True,
"input_field_name": info["field_name"],
}
result[kwarg_name] = CredentialsFieldInfo.model_validate(
auto_schema, by_alias=True

View File

@@ -1,8 +1,9 @@
import logging
import queue
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from enum import Enum
from multiprocessing import Manager
from queue import Empty
from typing import (
TYPE_CHECKING,
Annotated,
@@ -1199,16 +1200,12 @@ class NodeExecutionEntry(BaseModel):
class ExecutionQueue(Generic[T]):
"""
Thread-safe queue for managing node execution within a single graph execution.
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
threads within the same process. If migrating back to ProcessPoolExecutor,
replace with multiprocessing.Manager().Queue() for cross-process safety.
Queue for managing the execution of agents.
This will be shared between different processes
"""
def __init__(self):
# Thread-safe queue (not multiprocessing) — see class docstring
self.queue: queue.Queue[T] = queue.Queue()
self.queue = Manager().Queue()
def add(self, execution: T) -> T:
self.queue.put(execution)
@@ -1223,7 +1220,7 @@ class ExecutionQueue(Generic[T]):
def get_or_none(self) -> T | None:
try:
return self.queue.get_nowait()
except queue.Empty:
except Empty:
return None

View File

@@ -1,58 +0,0 @@
"""Tests for ExecutionQueue thread-safety."""
import queue
import threading
from backend.data.execution import ExecutionQueue
def test_execution_queue_uses_stdlib_queue():
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
q = ExecutionQueue()
assert isinstance(q.queue, queue.Queue)
def test_basic_operations():
"""Test add, get, empty, and get_or_none."""
q = ExecutionQueue()
assert q.empty() is True
assert q.get_or_none() is None
result = q.add("item1")
assert result == "item1"
assert q.empty() is False
item = q.get()
assert item == "item1"
assert q.empty() is True
def test_thread_safety():
"""Test concurrent access from multiple threads."""
q = ExecutionQueue()
results = []
num_items = 100
def producer():
for i in range(num_items):
q.add(f"item_{i}")
def consumer():
count = 0
while count < num_items:
item = q.get_or_none()
if item is not None:
results.append(item)
count += 1
producer_thread = threading.Thread(target=producer)
consumer_thread = threading.Thread(target=consumer)
producer_thread.start()
consumer_thread.start()
producer_thread.join(timeout=5)
consumer_thread.join(timeout=5)
assert len(results) == num_items

View File

@@ -447,8 +447,7 @@ class GraphModel(Graph, GraphMeta):
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
graph_credentials_inputs = self.aggregate_credentials_inputs()
graph_credentials_inputs = self.regular_credentials_inputs
logger.debug(
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
f"{graph_credentials_inputs}"
@@ -604,6 +603,28 @@ class GraphModel(Graph, GraphMeta):
for key, (field_info, node_field_pairs) in combined.items()
}
@property
def regular_credentials_inputs(
self,
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
"""Credentials that need explicit user mapping (CredentialsMetaInput fields)."""
return {
k: v
for k, v in self.aggregate_credentials_inputs().items()
if not v[0].is_auto_credential
}
@property
def auto_credentials_inputs(
self,
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
"""Credentials embedded in file fields (_credentials_id), resolved at execution time."""
return {
k: v
for k, v in self.aggregate_credentials_inputs().items()
if v[0].is_auto_credential
}
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
"""
Reassigns all IDs in the graph to new UUIDs.
@@ -654,6 +675,16 @@ class GraphModel(Graph, GraphMeta):
) and graph_id in graph_id_map:
node.input_default["graph_id"] = graph_id_map[graph_id]
# Clear auto-credentials references (e.g., _credentials_id in
# GoogleDriveFile fields) so the new user must re-authenticate
# with their own account
for node in graph.nodes:
if not node.input_default:
continue
for key, value in node.input_default.items():
if isinstance(value, dict) and "_credentials_id" in value:
del value["_credentials_id"]
def validate_graph(
self,
for_run: bool = False,

View File

@@ -463,3 +463,329 @@ def test_node_credentials_optional_with_other_metadata():
assert node.credentials_optional is True
assert node.metadata["position"] == {"x": 100, "y": 200}
assert node.metadata["customized_name"] == "My Custom Node"
# ============================================================================
# Tests for CredentialsFieldInfo.combine() field propagation
def test_combine_preserves_is_auto_credential_flag():
"""
CredentialsFieldInfo.combine() must propagate is_auto_credential and
input_field_name to the combined result. Regression test for reviewer
finding that combine() dropped these fields.
"""
from backend.data.model import CredentialsFieldInfo
auto_field = CredentialsFieldInfo.model_validate(
{
"credentials_provider": ["google"],
"credentials_types": ["oauth2"],
"credentials_scopes": ["drive.readonly"],
"is_auto_credential": True,
"input_field_name": "spreadsheet",
},
by_alias=True,
)
# combine() takes *args of (field_info, key) tuples
combined = CredentialsFieldInfo.combine(
(auto_field, ("node-1", "credentials")),
(auto_field, ("node-2", "credentials")),
)
assert len(combined) == 1
group_key = next(iter(combined))
combined_info, combined_keys = combined[group_key]
assert combined_info.is_auto_credential is True
assert combined_info.input_field_name == "spreadsheet"
assert combined_keys == {("node-1", "credentials"), ("node-2", "credentials")}
def test_combine_preserves_regular_credential_defaults():
"""Regular credentials should have is_auto_credential=False after combine()."""
from backend.data.model import CredentialsFieldInfo
regular_field = CredentialsFieldInfo.model_validate(
{
"credentials_provider": ["github"],
"credentials_types": ["api_key"],
"is_auto_credential": False,
},
by_alias=True,
)
combined = CredentialsFieldInfo.combine(
(regular_field, ("node-1", "credentials")),
)
group_key = next(iter(combined))
combined_info, _ = combined[group_key]
assert combined_info.is_auto_credential is False
assert combined_info.input_field_name is None
# ============================================================================
# Tests for _reassign_ids credential clearing (Fix 3: SECRT-1772)
def test_reassign_ids_clears_credentials_id():
"""
[SECRT-1772] _reassign_ids should clear _credentials_id from
GoogleDriveFile-style input_default fields so forked agents
don't retain the original creator's credential references.
"""
from backend.data.graph import GraphModel
node = Node(
id="node-1",
block_id=StoreValueBlock().id,
input_default={
"spreadsheet": {
"_credentials_id": "original-cred-id",
"id": "file-123",
"name": "test.xlsx",
"mimeType": "application/vnd.google-apps.spreadsheet",
"url": "https://docs.google.com/spreadsheets/d/file-123",
},
},
)
graph = Graph(
id="test-graph",
name="Test",
description="Test",
nodes=[node],
links=[],
)
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
# _credentials_id key should be removed (not set to None) so that
# _acquire_auto_credentials correctly errors instead of treating it as chained data
assert "_credentials_id" not in graph.nodes[0].input_default["spreadsheet"]
def test_reassign_ids_preserves_non_credential_fields():
"""
Regression guard: _reassign_ids should NOT modify non-credential fields
like name, mimeType, id, url.
"""
from backend.data.graph import GraphModel
node = Node(
id="node-1",
block_id=StoreValueBlock().id,
input_default={
"spreadsheet": {
"_credentials_id": "cred-abc",
"id": "file-123",
"name": "test.xlsx",
"mimeType": "application/vnd.google-apps.spreadsheet",
"url": "https://docs.google.com/spreadsheets/d/file-123",
},
},
)
graph = Graph(
id="test-graph",
name="Test",
description="Test",
nodes=[node],
links=[],
)
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
field = graph.nodes[0].input_default["spreadsheet"]
assert field["id"] == "file-123"
assert field["name"] == "test.xlsx"
assert field["mimeType"] == "application/vnd.google-apps.spreadsheet"
assert field["url"] == "https://docs.google.com/spreadsheets/d/file-123"
def test_reassign_ids_handles_no_credentials():
"""
Regression guard: _reassign_ids should not error when input_default
has no dict fields with _credentials_id.
"""
from backend.data.graph import GraphModel
node = Node(
id="node-1",
block_id=StoreValueBlock().id,
input_default={
"input": "some value",
"another_input": 42,
},
)
graph = Graph(
id="test-graph",
name="Test",
description="Test",
nodes=[node],
links=[],
)
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
# Should not error, fields unchanged
assert graph.nodes[0].input_default["input"] == "some value"
assert graph.nodes[0].input_default["another_input"] == 42
def test_reassign_ids_handles_multiple_credential_fields():
"""
[SECRT-1772] When a node has multiple dict fields with _credentials_id,
ALL of them should be cleared.
"""
from backend.data.graph import GraphModel
node = Node(
id="node-1",
block_id=StoreValueBlock().id,
input_default={
"spreadsheet": {
"_credentials_id": "cred-1",
"id": "file-1",
"name": "file1.xlsx",
},
"doc_file": {
"_credentials_id": "cred-2",
"id": "file-2",
"name": "file2.docx",
},
"plain_input": "not a dict",
},
)
graph = Graph(
id="test-graph",
name="Test",
description="Test",
nodes=[node],
links=[],
)
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
assert "_credentials_id" not in graph.nodes[0].input_default["spreadsheet"]
assert "_credentials_id" not in graph.nodes[0].input_default["doc_file"]
assert graph.nodes[0].input_default["plain_input"] == "not a dict"
# ============================================================================
# Tests for discriminate() field propagation
def test_discriminate_preserves_is_auto_credential_flag():
"""
CredentialsFieldInfo.discriminate() must propagate is_auto_credential and
input_field_name to the discriminated result. Regression test for
discriminate() dropping these fields (same class of bug as combine()).
"""
from backend.data.model import CredentialsFieldInfo
auto_field = CredentialsFieldInfo.model_validate(
{
"credentials_provider": ["google", "openai"],
"credentials_types": ["oauth2"],
"credentials_scopes": ["drive.readonly"],
"is_auto_credential": True,
"input_field_name": "spreadsheet",
"discriminator": "model",
"discriminator_mapping": {"gpt-4": "openai", "gemini": "google"},
},
by_alias=True,
)
discriminated = auto_field.discriminate("gemini")
assert discriminated.is_auto_credential is True
assert discriminated.input_field_name == "spreadsheet"
assert discriminated.provider == frozenset(["google"])
def test_discriminate_preserves_regular_credential_defaults():
"""Regular credentials should have is_auto_credential=False after discriminate()."""
from backend.data.model import CredentialsFieldInfo
regular_field = CredentialsFieldInfo.model_validate(
{
"credentials_provider": ["google", "openai"],
"credentials_types": ["api_key"],
"is_auto_credential": False,
"discriminator": "model",
"discriminator_mapping": {"gpt-4": "openai", "gemini": "google"},
},
by_alias=True,
)
discriminated = regular_field.discriminate("gpt-4")
assert discriminated.is_auto_credential is False
assert discriminated.input_field_name is None
assert discriminated.provider == frozenset(["openai"])
# ============================================================================
# Tests for credentials_input_schema excluding auto_credentials
def test_credentials_input_schema_excludes_auto_creds():
"""
GraphModel.credentials_input_schema should exclude auto_credentials
(is_auto_credential=True) from the schema. Auto_credentials are
transparently resolved at execution time via file picker data.
"""
from datetime import datetime, timezone
from unittest.mock import PropertyMock, patch
from backend.data.graph import GraphModel, NodeModel
from backend.data.model import CredentialsFieldInfo
regular_field_info = CredentialsFieldInfo.model_validate(
{
"credentials_provider": ["github"],
"credentials_types": ["api_key"],
"is_auto_credential": False,
},
by_alias=True,
)
graph = GraphModel(
id="test-graph",
version=1,
name="Test",
description="Test",
user_id="test-user",
created_at=datetime.now(timezone.utc),
nodes=[
NodeModel(
id="node-1",
block_id=StoreValueBlock().id,
input_default={},
graph_id="test-graph",
graph_version=1,
),
],
links=[],
)
# Mock regular_credentials_inputs to return only the non-auto field (3-tuple)
regular_only = {
"github_credentials": (
regular_field_info,
{("node-1", "credentials")},
True,
),
}
with patch.object(
type(graph),
"regular_credentials_inputs",
new_callable=PropertyMock,
return_value=regular_only,
):
schema = graph.credentials_input_schema
field_names = set(schema.get("properties", {}).keys())
# Should include regular credential but NOT auto_credential
assert "github_credentials" in field_names
assert "google_credentials" not in field_names

View File

@@ -571,6 +571,8 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
discriminator: Optional[str] = None
discriminator_mapping: Optional[dict[str, CP]] = None
discriminator_values: set[Any] = Field(default_factory=set)
is_auto_credential: bool = False
input_field_name: Optional[str] = None
@classmethod
def combine(
@@ -651,6 +653,9 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
+ "_credentials"
)
# Propagate is_auto_credential from the combined field.
# All fields in a group should share the same is_auto_credential
# value since auto and regular credentials serve different purposes.
result[group_key] = (
CredentialsFieldInfo[CP, CT](
credentials_provider=combined.provider,
@@ -659,6 +664,8 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
discriminator=combined.discriminator,
discriminator_mapping=combined.discriminator_mapping,
discriminator_values=set(all_discriminator_values),
is_auto_credential=combined.is_auto_credential,
input_field_name=combined.input_field_name,
),
combined_keys,
)
@@ -684,6 +691,8 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
discriminator=self.discriminator,
discriminator_mapping=self.discriminator_mapping,
discriminator_values=self.discriminator_values,
is_auto_credential=self.is_auto_credential,
input_field_name=self.input_field_name,
)

View File

@@ -1,4 +1,3 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from enum import Enum
@@ -226,10 +225,6 @@ class SyncRabbitMQ(RabbitMQBase):
class AsyncRabbitMQ(RabbitMQBase):
"""Asynchronous RabbitMQ client"""
def __init__(self, config: RabbitMQConfig):
super().__init__(config)
self._reconnect_lock: asyncio.Lock | None = None
@property
def is_connected(self) -> bool:
return bool(self._connection and not self._connection.is_closed)
@@ -240,17 +235,7 @@ class AsyncRabbitMQ(RabbitMQBase):
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
async def connect(self):
if self.is_connected and self._channel and not self._channel.is_closed:
return
if (
self.is_connected
and self._connection
and (self._channel is None or self._channel.is_closed)
):
self._channel = await self._connection.channel()
await self._channel.set_qos(prefetch_count=1)
await self.declare_infrastructure()
if self.is_connected:
return
self._connection = await aio_pika.connect_robust(
@@ -306,46 +291,24 @@ class AsyncRabbitMQ(RabbitMQBase):
exchange, routing_key=queue.routing_key or queue.name
)
@property
def _lock(self) -> asyncio.Lock:
if self._reconnect_lock is None:
self._reconnect_lock = asyncio.Lock()
return self._reconnect_lock
async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel:
"""Get a valid channel, reconnecting if the current one is stale.
Uses a lock to prevent concurrent reconnection attempts from racing.
"""
if self.is_ready:
return self._channel # type: ignore # is_ready guarantees non-None
async with self._lock:
# Double-check after acquiring lock
if self.is_ready:
return self._channel # type: ignore
self._channel = None
await self.connect()
if self._channel is None:
raise RuntimeError("Channel should be established after connect")
return self._channel
async def _publish_once(
@func_retry
async def publish_message(
self,
routing_key: str,
message: str,
exchange: Optional[Exchange] = None,
persistent: bool = True,
) -> None:
channel = await self._ensure_channel()
if not self.is_ready:
await self.connect()
if self._channel is None:
raise RuntimeError("Channel should be established after connect")
if exchange:
exchange_obj = await channel.get_exchange(exchange.name)
exchange_obj = await self._channel.get_exchange(exchange.name)
else:
exchange_obj = channel.default_exchange
exchange_obj = self._channel.default_exchange
await exchange_obj.publish(
aio_pika.Message(
@@ -359,23 +322,9 @@ class AsyncRabbitMQ(RabbitMQBase):
routing_key=routing_key,
)
@func_retry
async def publish_message(
self,
routing_key: str,
message: str,
exchange: Optional[Exchange] = None,
persistent: bool = True,
) -> None:
try:
await self._publish_once(routing_key, message, exchange, persistent)
except aio_pika.exceptions.ChannelInvalidStateError:
logger.warning(
"RabbitMQ channel invalid, forcing reconnect and retrying publish"
)
async with self._lock:
self._channel = None
await self._publish_once(routing_key, message, exchange, persistent)
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
return await self._ensure_channel()
if not self.is_ready:
await self.connect()
if self._channel is None:
raise RuntimeError("Channel should be established after connect")
return self._channel

View File

@@ -92,6 +92,7 @@ from .utils import (
block_usage_cost,
create_execution_queue_config,
execution_usage_cost,
parse_auto_credential_field,
validate_exec,
)
@@ -172,6 +173,61 @@ def execute_graph(
T = TypeVar("T")
async def _acquire_auto_credentials(
input_model: type[BlockSchema],
input_data: dict[str, Any],
creds_manager: "IntegrationCredentialsManager",
user_id: str,
) -> tuple[dict[str, Any], list[AsyncRedisLock]]:
"""
Resolve auto_credentials from GoogleDriveFileField-style inputs.
Returns:
(extra_exec_kwargs, locks): kwargs to inject into block execution, and
credential locks to release after execution completes.
"""
extra_exec_kwargs: dict[str, Any] = {}
locks: list[AsyncRedisLock] = []
# NOTE: If a block ever has multiple auto-credential fields, a ValueError
# on a later field will strand locks acquired for earlier fields. They'll
# auto-expire via Redis TTL, but add a try/except to release partial locks
# if that becomes a real scenario.
for kwarg_name, info in input_model.get_auto_credentials_fields().items():
field_name = info["field_name"]
field_data = input_data.get(field_name)
# Use shared helper to parse the field
parsed = parse_auto_credential_field(
field_name=field_name,
info=info,
field_data=field_data,
field_present_in_input=field_name in input_data,
)
if parsed.error:
raise ValueError(parsed.error)
if parsed.cred_id:
# Credential ID provided - acquire credentials
try:
credentials, lock = await creds_manager.acquire(user_id, parsed.cred_id)
locks.append(lock)
extra_exec_kwargs[kwarg_name] = credentials
except ValueError:
raise ValueError(
f"{parsed.provider.capitalize()} credentials for "
f"'{parsed.file_name}' in field '{parsed.field_name}' are not "
f"available in your account. "
f"This can happen if the agent was created by another "
f"user or the credentials were deleted. "
f"Please open the agent in the builder and re-select "
f"the file to authenticate with your own account."
)
return extra_exec_kwargs, locks
async def execute_node(
node: Node,
data: NodeExecutionEntry,
@@ -271,41 +327,14 @@ async def execute_node(
extra_exec_kwargs[field_name] = credentials
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
for kwarg_name, info in input_model.get_auto_credentials_fields().items():
field_name = info["field_name"]
field_data = input_data.get(field_name)
if field_data and isinstance(field_data, dict):
# Check if _credentials_id key exists in the field data
if "_credentials_id" in field_data:
cred_id = field_data["_credentials_id"]
if cred_id:
# Credential ID provided - acquire credentials
provider = info.get("config", {}).get(
"provider", "external service"
)
file_name = field_data.get("name", "selected file")
try:
credentials, lock = await creds_manager.acquire(
user_id, cred_id
)
creds_locks.append(lock)
extra_exec_kwargs[kwarg_name] = credentials
except ValueError:
# Credential was deleted or doesn't exist
raise ValueError(
f"Authentication expired for '{file_name}' in field '{field_name}'. "
f"The saved {provider.capitalize()} credentials no longer exist. "
f"Please re-select the file to re-authenticate."
)
# else: _credentials_id is explicitly None, skip credentials (for chained data)
else:
# _credentials_id key missing entirely - this is an error
provider = info.get("config", {}).get("provider", "external service")
file_name = field_data.get("name", "selected file")
raise ValueError(
f"Authentication missing for '{file_name}' in field '{field_name}'. "
f"Please re-select the file to authenticate with {provider.capitalize()}."
)
auto_extra_kwargs, auto_locks = await _acquire_auto_credentials(
input_model=input_model,
input_data=input_data,
creds_manager=creds_manager,
user_id=user_id,
)
extra_exec_kwargs.update(auto_extra_kwargs)
creds_locks.extend(auto_locks)
output_size = 0

View File

@@ -0,0 +1,320 @@
"""
Tests for auto_credentials handling in execute_node().
These test the _acquire_auto_credentials() helper function extracted from
execute_node() (manager.py lines 273-308).
"""
import pytest
from pytest_mock import MockerFixture
@pytest.fixture
def google_drive_file_data():
return {
"valid": {
"_credentials_id": "cred-id-123",
"id": "file-123",
"name": "test.xlsx",
"mimeType": "application/vnd.google-apps.spreadsheet",
},
"chained": {
"_credentials_id": None,
"id": "file-456",
"name": "chained.xlsx",
"mimeType": "application/vnd.google-apps.spreadsheet",
},
"missing_key": {
"id": "file-789",
"name": "bad.xlsx",
"mimeType": "application/vnd.google-apps.spreadsheet",
},
}
@pytest.fixture
def mock_input_model(mocker: MockerFixture):
"""Create a mock input model with get_auto_credentials_fields() returning one field."""
input_model = mocker.MagicMock()
input_model.get_auto_credentials_fields.return_value = {
"credentials": {
"field_name": "spreadsheet",
"config": {
"provider": "google",
"type": "oauth2",
"scopes": ["https://www.googleapis.com/auth/drive.readonly"],
},
}
}
return input_model
@pytest.fixture
def mock_creds_manager(mocker: MockerFixture):
manager = mocker.AsyncMock()
mock_lock = mocker.AsyncMock()
mock_creds = mocker.MagicMock()
mock_creds.id = "cred-id-123"
mock_creds.provider = "google"
manager.acquire.return_value = (mock_creds, mock_lock)
return manager, mock_creds, mock_lock
@pytest.mark.asyncio
async def test_auto_credentials_happy_path(
mocker: MockerFixture,
google_drive_file_data,
mock_input_model,
mock_creds_manager,
):
"""When field_data has a valid _credentials_id, credentials should be acquired."""
from backend.executor.manager import _acquire_auto_credentials
manager, mock_creds, mock_lock = mock_creds_manager
input_data = {"spreadsheet": google_drive_file_data["valid"]}
extra_kwargs, locks = await _acquire_auto_credentials(
input_model=mock_input_model,
input_data=input_data,
creds_manager=manager,
user_id="user-1",
)
manager.acquire.assert_called_once_with("user-1", "cred-id-123")
assert extra_kwargs["credentials"] == mock_creds
assert mock_lock in locks
@pytest.mark.asyncio
async def test_auto_credentials_field_none_static_raises(
mocker: MockerFixture,
mock_input_model,
mock_creds_manager,
):
"""
[THE BUG FIX TEST — OPEN-2895]
When field_data is None and the key IS in input_data (user didn't select a file),
should raise ValueError instead of silently skipping.
"""
from backend.executor.manager import _acquire_auto_credentials
manager, _, _ = mock_creds_manager
# Key is present but value is None = user didn't select a file
input_data = {"spreadsheet": None}
with pytest.raises(ValueError, match="No file selected"):
await _acquire_auto_credentials(
input_model=mock_input_model,
input_data=input_data,
creds_manager=manager,
user_id="user-1",
)
@pytest.mark.asyncio
async def test_auto_credentials_field_absent_skips(
mocker: MockerFixture,
mock_input_model,
mock_creds_manager,
):
"""
When the field key is NOT in input_data at all (upstream connection),
should skip without error.
"""
from backend.executor.manager import _acquire_auto_credentials
manager, _, _ = mock_creds_manager
# Key not present = connected from upstream block
input_data = {}
extra_kwargs, locks = await _acquire_auto_credentials(
input_model=mock_input_model,
input_data=input_data,
creds_manager=manager,
user_id="user-1",
)
manager.acquire.assert_not_called()
assert "credentials" not in extra_kwargs
assert locks == []
@pytest.mark.asyncio
async def test_auto_credentials_chained_cred_id_none(
mocker: MockerFixture,
google_drive_file_data,
mock_input_model,
mock_creds_manager,
):
"""
When _credentials_id is explicitly None (chained data from upstream),
should skip credential acquisition.
"""
from backend.executor.manager import _acquire_auto_credentials
manager, _, _ = mock_creds_manager
input_data = {"spreadsheet": google_drive_file_data["chained"]}
extra_kwargs, locks = await _acquire_auto_credentials(
input_model=mock_input_model,
input_data=input_data,
creds_manager=manager,
user_id="user-1",
)
manager.acquire.assert_not_called()
assert "credentials" not in extra_kwargs
@pytest.mark.asyncio
async def test_auto_credentials_missing_cred_id_key_raises(
mocker: MockerFixture,
google_drive_file_data,
mock_input_model,
mock_creds_manager,
):
"""
When _credentials_id key is missing entirely from field_data dict,
should raise ValueError.
"""
from backend.executor.manager import _acquire_auto_credentials
manager, _, _ = mock_creds_manager
input_data = {"spreadsheet": google_drive_file_data["missing_key"]}
with pytest.raises(ValueError, match="Authentication missing"):
await _acquire_auto_credentials(
input_model=mock_input_model,
input_data=input_data,
creds_manager=manager,
user_id="user-1",
)
@pytest.mark.asyncio
async def test_auto_credentials_ownership_mismatch_error(
mocker: MockerFixture,
google_drive_file_data,
mock_input_model,
mock_creds_manager,
):
"""
[SECRT-1772] When acquire() raises ValueError (credential belongs to another user),
the error message should mention 'not available' (not 'expired').
"""
from backend.executor.manager import _acquire_auto_credentials
manager, _, _ = mock_creds_manager
manager.acquire.side_effect = ValueError(
"Credentials #cred-id-123 for user #user-2 not found"
)
input_data = {"spreadsheet": google_drive_file_data["valid"]}
with pytest.raises(ValueError, match="not available in your account"):
await _acquire_auto_credentials(
input_model=mock_input_model,
input_data=input_data,
creds_manager=manager,
user_id="user-2",
)
@pytest.mark.asyncio
async def test_auto_credentials_deleted_credential_error(
mocker: MockerFixture,
google_drive_file_data,
mock_input_model,
mock_creds_manager,
):
"""
[SECRT-1772] When acquire() raises ValueError (credential was deleted),
the error message should mention 'not available' (not 'expired').
"""
from backend.executor.manager import _acquire_auto_credentials
manager, _, _ = mock_creds_manager
manager.acquire.side_effect = ValueError(
"Credentials #cred-id-123 for user #user-1 not found"
)
input_data = {"spreadsheet": google_drive_file_data["valid"]}
with pytest.raises(ValueError, match="not available in your account"):
await _acquire_auto_credentials(
input_model=mock_input_model,
input_data=input_data,
creds_manager=manager,
user_id="user-1",
)
@pytest.mark.asyncio
async def test_auto_credentials_lock_appended(
mocker: MockerFixture,
google_drive_file_data,
mock_input_model,
mock_creds_manager,
):
"""Lock from acquire() should be included in returned locks list."""
from backend.executor.manager import _acquire_auto_credentials
manager, _, mock_lock = mock_creds_manager
input_data = {"spreadsheet": google_drive_file_data["valid"]}
extra_kwargs, locks = await _acquire_auto_credentials(
input_model=mock_input_model,
input_data=input_data,
creds_manager=manager,
user_id="user-1",
)
assert len(locks) == 1
assert locks[0] is mock_lock
@pytest.mark.asyncio
async def test_auto_credentials_multiple_fields(
mocker: MockerFixture,
mock_creds_manager,
):
"""When there are multiple auto_credentials fields, only valid ones should acquire."""
from backend.executor.manager import _acquire_auto_credentials
manager, mock_creds, mock_lock = mock_creds_manager
input_model = mocker.MagicMock()
input_model.get_auto_credentials_fields.return_value = {
"credentials": {
"field_name": "spreadsheet",
"config": {"provider": "google", "type": "oauth2"},
},
"credentials2": {
"field_name": "doc_file",
"config": {"provider": "google", "type": "oauth2"},
},
}
input_data = {
"spreadsheet": {
"_credentials_id": "cred-id-123",
"id": "file-1",
"name": "file1.xlsx",
},
"doc_file": {
"_credentials_id": None,
"id": "file-2",
"name": "chained.doc",
},
}
extra_kwargs, locks = await _acquire_auto_credentials(
input_model=input_model,
input_data=input_data,
creds_manager=manager,
user_id="user-1",
)
# Only the first field should have acquired credentials
manager.acquire.assert_called_once_with("user-1", "cred-id-123")
assert "credentials" in extra_kwargs
assert "credentials2" not in extra_kwargs
assert len(locks) == 1

View File

@@ -4,7 +4,7 @@ import threading
import time
from collections import defaultdict
from concurrent.futures import Future
from typing import Mapping, Optional, cast
from typing import Any, Mapping, Optional, cast
from pydantic import BaseModel, JsonValue, ValidationError
@@ -55,6 +55,87 @@ from backend.util.type import convert
config = Config()
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
# ============ Auto-Credentials Helpers ============ #
class AutoCredentialFieldInfo(BaseModel):
"""Parsed info from an auto-credential field (e.g., GoogleDriveFileField)."""
cred_id: str | None
"""The credential ID to use, or None if not provided."""
provider: str
"""The provider name (e.g., 'google')."""
file_name: str
"""The display name for error messages."""
field_name: str
"""The original field name in the schema."""
error: str | None = None
"""Validation error message, if any."""
def parse_auto_credential_field(
field_name: str,
info: dict,
field_data: Any,
*,
field_present_in_input: bool = True,
) -> AutoCredentialFieldInfo:
"""
Parse auto-credential field data and extract credential info.
This is shared logic used by both credential acquisition (manager.py)
and credential validation (utils.py).
Args:
field_name: The name of the field in the schema
info: The auto_credentials field info from get_auto_credentials_fields()
field_data: The actual field data from input
field_present_in_input: Whether the field key exists in input_data
Returns:
AutoCredentialFieldInfo with parsed data and any validation errors
"""
provider = info.get("config", {}).get("provider", "external service")
file_name = (
field_data.get("name", "selected file")
if isinstance(field_data, dict)
else "selected file"
)
result = AutoCredentialFieldInfo(
cred_id=None,
provider=provider,
file_name=file_name,
field_name=field_name,
)
if field_data and isinstance(field_data, dict):
if "_credentials_id" not in field_data:
# Key removed (e.g., on fork) — needs re-auth
result.error = (
f"Authentication missing for '{file_name}' in field "
f"'{field_name}'. Please re-select the file to authenticate "
f"with {provider.capitalize()}."
)
else:
cred_id = field_data.get("_credentials_id")
if cred_id:
result.cred_id = cred_id
# else: _credentials_id is explicitly None, skip (chained data)
elif field_data is None and not field_present_in_input:
# Field not in input_data at all = connected from upstream block, skip
pass
elif field_present_in_input:
# field_data is None/empty but key IS in input_data = user didn't select
result.error = (
f"No file selected for '{field_name}'. "
f"Please select a file to provide "
f"{provider.capitalize()} authentication."
)
return result
# ============ Resource Helpers ============ #
@@ -259,7 +340,8 @@ async def _validate_node_input_credentials(
# Find any fields of type CredentialsMetaInput
credentials_fields = block.input_schema.get_credentials_fields()
if not credentials_fields:
auto_credentials_fields = block.input_schema.get_auto_credentials_fields()
if not credentials_fields and not auto_credentials_fields:
continue
# Track if any credential field is missing for this node
@@ -339,6 +421,52 @@ async def _validate_node_input_credentials(
] = "Invalid credentials: type/provider mismatch"
continue
# Validate auto-credentials (GoogleDriveFileField-based)
# These have _credentials_id embedded in the file field data
if auto_credentials_fields:
for _kwarg_name, info in auto_credentials_fields.items():
field_name = info["field_name"]
# Check input_default and nodes_input_masks for the field value
field_value = node.input_default.get(field_name)
if nodes_input_masks and node.id in nodes_input_masks:
field_value = nodes_input_masks[node.id].get(
field_name, field_value
)
# Use shared helper to parse the field
parsed = parse_auto_credential_field(
field_name=field_name,
info=info,
field_data=field_value,
field_present_in_input=True, # For validation, assume present
)
if parsed.error:
has_missing_credentials = True
credential_errors[node.id][field_name] = parsed.error
continue
if parsed.cred_id:
# Validate that credentials exist and are accessible
try:
creds_store = get_integration_credentials_store()
creds = await creds_store.get_creds_by_id(
user_id, parsed.cred_id
)
except Exception as e:
has_missing_credentials = True
credential_errors[node.id][
field_name
] = f"Credentials not available: {e}"
continue
if not creds:
has_missing_credentials = True
credential_errors[node.id][field_name] = (
"The saved credentials are not available "
"for your account. Please re-select the file to "
"authenticate with your own account."
)
# If node has optional credentials and any are missing, mark for skipping
# But only if there are no other errors for this node
if (
@@ -370,8 +498,9 @@ def make_node_credentials_input_map(
"""
result: dict[str, dict[str, JsonValue]] = {}
# Get aggregated credentials fields for the graph
graph_cred_inputs = graph.aggregate_credentials_inputs()
# Only map regular credentials (not auto_credentials, which are resolved
# at execution time from _credentials_id in file field data)
graph_cred_inputs = graph.regular_credentials_inputs
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
# Best-effort map: skip missing items

View File

@@ -907,3 +907,335 @@ async def test_stop_graph_execution_cascades_to_child_with_reviews(
# Verify both parent and child status updates
assert mock_execution_db.update_graph_execution_stats.call_count >= 1
# ============================================================================
# Tests for auto_credentials validation in _validate_node_input_credentials
# (Fix 3: SECRT-1772 + Fix 4: Path 4)
# ============================================================================
@pytest.mark.asyncio
async def test_validate_node_input_credentials_auto_creds_valid(
mocker: MockerFixture,
):
"""
[SECRT-1772] When a node has auto_credentials with a valid _credentials_id
that exists in the store, validation should pass without errors.
"""
from backend.executor.utils import _validate_node_input_credentials
mock_node = mocker.MagicMock()
mock_node.id = "node-with-auto-creds"
mock_node.credentials_optional = False
mock_node.input_default = {
"spreadsheet": {
"_credentials_id": "valid-cred-id",
"id": "file-123",
"name": "test.xlsx",
}
}
mock_block = mocker.MagicMock()
# No regular credentials fields
mock_block.input_schema.get_credentials_fields.return_value = {}
# Has auto_credentials fields
mock_block.input_schema.get_auto_credentials_fields.return_value = {
"credentials": {
"field_name": "spreadsheet",
"config": {"provider": "google", "type": "oauth2"},
}
}
mock_node.block = mock_block
mock_graph = mocker.MagicMock()
mock_graph.nodes = [mock_node]
# Mock the credentials store to return valid credentials
mock_store = mocker.MagicMock()
mock_creds = mocker.MagicMock()
mock_creds.id = "valid-cred-id"
mock_store.get_creds_by_id = mocker.AsyncMock(return_value=mock_creds)
mocker.patch(
"backend.executor.utils.get_integration_credentials_store",
return_value=mock_store,
)
errors, nodes_to_skip = await _validate_node_input_credentials(
graph=mock_graph,
user_id="test-user",
nodes_input_masks=None,
)
assert mock_node.id not in errors
assert mock_node.id not in nodes_to_skip
@pytest.mark.asyncio
async def test_validate_node_input_credentials_auto_creds_missing(
mocker: MockerFixture,
):
"""
[SECRT-1772] When a node has auto_credentials with a _credentials_id
that doesn't exist for the current user, validation should report an error.
"""
from backend.executor.utils import _validate_node_input_credentials
mock_node = mocker.MagicMock()
mock_node.id = "node-with-bad-auto-creds"
mock_node.credentials_optional = False
mock_node.input_default = {
"spreadsheet": {
"_credentials_id": "other-users-cred-id",
"id": "file-123",
"name": "test.xlsx",
}
}
mock_block = mocker.MagicMock()
mock_block.input_schema.get_credentials_fields.return_value = {}
mock_block.input_schema.get_auto_credentials_fields.return_value = {
"credentials": {
"field_name": "spreadsheet",
"config": {"provider": "google", "type": "oauth2"},
}
}
mock_node.block = mock_block
mock_graph = mocker.MagicMock()
mock_graph.nodes = [mock_node]
# Mock the credentials store to return None (cred not found for this user)
mock_store = mocker.MagicMock()
mock_store.get_creds_by_id = mocker.AsyncMock(return_value=None)
mocker.patch(
"backend.executor.utils.get_integration_credentials_store",
return_value=mock_store,
)
errors, nodes_to_skip = await _validate_node_input_credentials(
graph=mock_graph,
user_id="different-user",
nodes_input_masks=None,
)
assert mock_node.id in errors
assert "spreadsheet" in errors[mock_node.id]
assert "not available" in errors[mock_node.id]["spreadsheet"].lower()
@pytest.mark.asyncio
async def test_validate_node_input_credentials_both_regular_and_auto(
mocker: MockerFixture,
):
"""
[SECRT-1772] A node that has BOTH regular credentials AND auto_credentials
should have both validated.
"""
from backend.executor.utils import _validate_node_input_credentials
mock_node = mocker.MagicMock()
mock_node.id = "node-with-both-creds"
mock_node.credentials_optional = False
mock_node.input_default = {
"credentials": {
"id": "regular-cred-id",
"provider": "github",
"type": "api_key",
},
"spreadsheet": {
"_credentials_id": "auto-cred-id",
"id": "file-123",
"name": "test.xlsx",
},
}
mock_credentials_field_type = mocker.MagicMock()
mock_credentials_meta = mocker.MagicMock()
mock_credentials_meta.id = "regular-cred-id"
mock_credentials_meta.provider = "github"
mock_credentials_meta.type = "api_key"
mock_credentials_field_type.model_validate.return_value = mock_credentials_meta
mock_block = mocker.MagicMock()
# Regular credentials field
mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type,
}
# Auto-credentials field
mock_block.input_schema.get_auto_credentials_fields.return_value = {
"auto_credentials": {
"field_name": "spreadsheet",
"config": {"provider": "google", "type": "oauth2"},
}
}
mock_node.block = mock_block
mock_graph = mocker.MagicMock()
mock_graph.nodes = [mock_node]
# Mock the credentials store to return valid credentials for both
mock_store = mocker.MagicMock()
mock_regular_creds = mocker.MagicMock()
mock_regular_creds.id = "regular-cred-id"
mock_regular_creds.provider = "github"
mock_regular_creds.type = "api_key"
mock_auto_creds = mocker.MagicMock()
mock_auto_creds.id = "auto-cred-id"
def get_creds_side_effect(user_id, cred_id):
if cred_id == "regular-cred-id":
return mock_regular_creds
elif cred_id == "auto-cred-id":
return mock_auto_creds
return None
mock_store.get_creds_by_id = mocker.AsyncMock(side_effect=get_creds_side_effect)
mocker.patch(
"backend.executor.utils.get_integration_credentials_store",
return_value=mock_store,
)
errors, nodes_to_skip = await _validate_node_input_credentials(
graph=mock_graph,
user_id="test-user",
nodes_input_masks=None,
)
# Both should validate successfully - no errors
assert mock_node.id not in errors
assert mock_node.id not in nodes_to_skip
@pytest.mark.asyncio
async def test_validate_node_input_credentials_auto_creds_skipped_when_none(
mocker: MockerFixture,
):
"""
When a node has auto_credentials but the field value has _credentials_id=None
(e.g., from upstream connection), validation should skip it without error.
"""
from backend.executor.utils import _validate_node_input_credentials
mock_node = mocker.MagicMock()
mock_node.id = "node-with-chained-auto-creds"
mock_node.credentials_optional = False
mock_node.input_default = {
"spreadsheet": {
"_credentials_id": None,
"id": "file-123",
"name": "test.xlsx",
}
}
mock_block = mocker.MagicMock()
mock_block.input_schema.get_credentials_fields.return_value = {}
mock_block.input_schema.get_auto_credentials_fields.return_value = {
"credentials": {
"field_name": "spreadsheet",
"config": {"provider": "google", "type": "oauth2"},
}
}
mock_node.block = mock_block
mock_graph = mocker.MagicMock()
mock_graph.nodes = [mock_node]
errors, nodes_to_skip = await _validate_node_input_credentials(
graph=mock_graph,
user_id="test-user",
nodes_input_masks=None,
)
# No error - chained data with None cred_id is valid
assert mock_node.id not in errors
# ============================================================================
# Tests for CredentialsFieldInfo auto_credential tag (Fix 4: Path 4)
# ============================================================================
def test_credentials_field_info_auto_credential_tag():
"""
[Path 4] CredentialsFieldInfo should support is_auto_credential and
input_field_name fields for distinguishing auto from regular credentials.
"""
from backend.data.model import CredentialsFieldInfo
# Regular credential should have is_auto_credential=False by default
regular = CredentialsFieldInfo.model_validate(
{
"credentials_provider": ["github"],
"credentials_types": ["api_key"],
},
by_alias=True,
)
assert regular.is_auto_credential is False
assert regular.input_field_name is None
# Auto credential should have is_auto_credential=True
auto = CredentialsFieldInfo.model_validate(
{
"credentials_provider": ["google"],
"credentials_types": ["oauth2"],
"is_auto_credential": True,
"input_field_name": "spreadsheet",
},
by_alias=True,
)
assert auto.is_auto_credential is True
assert auto.input_field_name == "spreadsheet"
def test_make_node_credentials_input_map_excludes_auto_creds(
mocker: MockerFixture,
):
"""
[Path 4] make_node_credentials_input_map should only include regular credentials,
not auto_credentials (which are resolved at execution time).
"""
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.executor.utils import make_node_credentials_input_map
from backend.integrations.providers import ProviderName
# Create a mock graph with aggregate_credentials_inputs that returns
# both regular and auto credentials
mock_graph = mocker.MagicMock()
regular_field_info = CredentialsFieldInfo.model_validate(
{
"credentials_provider": ["github"],
"credentials_types": ["api_key"],
"is_auto_credential": False,
},
by_alias=True,
)
# Mock regular_credentials_inputs property (auto_credentials are excluded)
mock_graph.regular_credentials_inputs = {
"github_creds": (regular_field_info, {("node-1", "credentials")}, True),
}
graph_credentials_input = {
"github_creds": CredentialsMetaInput(
id="cred-123",
provider=ProviderName("github"),
type="api_key",
),
}
result = make_node_credentials_input_map(mock_graph, graph_credentials_input)
# Regular credentials should be mapped
assert "node-1" in result
assert "credentials" in result["node-1"]
# Auto credentials should NOT appear in the result
# (they would have been mapped to the kwarg_name "credentials" not "spreadsheet")
for node_id, fields in result.items():
for field_name, value in fields.items():
# Verify no auto-credential phantom entries
if isinstance(value, dict):
assert "_credentials_id" not in value

View File

@@ -342,14 +342,6 @@ async def store_media_file(
if not target_path.is_file():
raise ValueError(f"Local file does not exist: {target_path}")
# Virus scan the local file before any further processing
local_content = target_path.read_bytes()
if len(local_content) > MAX_FILE_SIZE_BYTES:
raise ValueError(
f"File too large: {len(local_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
)
await scan_content_safe(local_content, filename=sanitized_file)
# Return based on requested format
if return_format == "for_local_processing":
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL

View File

@@ -247,100 +247,3 @@ class TestFileCloudIntegration:
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)
@pytest.mark.asyncio
async def test_store_media_file_local_path_scanned(self):
"""Test that local file paths are scanned for viruses."""
graph_exec_id = "test-exec-123"
local_file = "test_video.mp4"
file_content = b"fake video content"
with patch(
"backend.util.file.get_cloud_storage_handler"
) as mock_handler_getter, patch(
"backend.util.file.scan_content_safe"
) as mock_scan, patch(
"backend.util.file.Path"
) as mock_path_class:
# Mock cloud storage handler - not a cloud path
mock_handler = MagicMock()
mock_handler.is_cloud_path.return_value = False
mock_handler_getter.return_value = mock_handler
# Mock virus scanner
mock_scan.return_value = None
# Mock file system operations
mock_base_path = MagicMock()
mock_target_path = MagicMock()
mock_resolved_path = MagicMock()
mock_path_class.return_value = mock_base_path
mock_base_path.mkdir = MagicMock()
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
mock_target_path.resolve.return_value = mock_resolved_path
mock_resolved_path.is_relative_to.return_value = True
mock_resolved_path.is_file.return_value = True
mock_resolved_path.read_bytes.return_value = file_content
mock_resolved_path.relative_to.return_value = Path(local_file)
mock_resolved_path.name = local_file
result = await store_media_file(
file=MediaFileType(local_file),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)
# Verify virus scan was called for local file
mock_scan.assert_called_once_with(file_content, filename=local_file)
# Result should be the relative path
assert str(result) == local_file
@pytest.mark.asyncio
async def test_store_media_file_local_path_virus_detected(self):
"""Test that infected local files raise VirusDetectedError."""
from backend.api.features.store.exceptions import VirusDetectedError
graph_exec_id = "test-exec-123"
local_file = "infected.exe"
file_content = b"malicious content"
with patch(
"backend.util.file.get_cloud_storage_handler"
) as mock_handler_getter, patch(
"backend.util.file.scan_content_safe"
) as mock_scan, patch(
"backend.util.file.Path"
) as mock_path_class:
# Mock cloud storage handler - not a cloud path
mock_handler = MagicMock()
mock_handler.is_cloud_path.return_value = False
mock_handler_getter.return_value = mock_handler
# Mock virus scanner to detect virus
mock_scan.side_effect = VirusDetectedError(
"EICAR-Test-File", "File rejected due to virus detection"
)
# Mock file system operations
mock_base_path = MagicMock()
mock_target_path = MagicMock()
mock_resolved_path = MagicMock()
mock_path_class.return_value = mock_base_path
mock_base_path.mkdir = MagicMock()
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
mock_target_path.resolve.return_value = mock_resolved_path
mock_resolved_path.is_relative_to.return_value = True
mock_resolved_path.is_file.return_value = True
mock_resolved_path.read_bytes.return_value = file_content
with pytest.raises(VirusDetectedError):
await store_media_file(
file=MediaFileType(local_file),
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)

File diff suppressed because it is too large Load Diff

View File

@@ -12,16 +12,16 @@ python = ">=3.10,<3.14"
aio-pika = "^9.5.5"
aiohttp = "^3.10.0"
aiodns = "^3.5.0"
anthropic = "^0.79.0"
anthropic = "^0.59.0"
apscheduler = "^3.11.1"
autogpt-libs = { path = "../autogpt_libs", develop = true }
bleach = { extras = ["css"], version = "^6.2.0" }
click = "^8.2.0"
cryptography = "^46.0"
cryptography = "^45.0"
discord-py = "^2.5.2"
e2b-code-interpreter = "^1.5.2"
elevenlabs = "^1.50.0"
fastapi = "^0.128.5"
fastapi = "^0.116.1"
feedparser = "^6.0.11"
flake8 = "^7.3.0"
google-api-python-client = "^2.177.0"
@@ -35,10 +35,10 @@ jinja2 = "^3.1.6"
jsonref = "^1.1.0"
jsonschema = "^4.25.0"
langfuse = "^3.11.0"
launchdarkly-server-sdk = "^9.14.1"
launchdarkly-server-sdk = "^9.12.0"
mem0ai = "^0.1.115"
moviepy = "^2.1.2"
ollama = "^0.6.1"
ollama = "^0.5.1"
openai = "^1.97.1"
orjson = "^3.10.0"
pika = "^1.3.2"
@@ -48,16 +48,16 @@ postmarker = "^1.0"
praw = "~7.8.1"
prisma = "^0.15.0"
rank-bm25 = "^0.2.2"
prometheus-client = "^0.24.1"
prometheus-client = "^0.22.1"
prometheus-fastapi-instrumentator = "^7.0.0"
psutil = "^7.0.0"
psycopg2-binary = "^2.9.10"
pydantic = { extras = ["email"], version = "^2.12.5" }
pydantic-settings = "^2.12.0"
pydantic = { extras = ["email"], version = "^2.11.7" }
pydantic-settings = "^2.10.1"
pytest = "^8.4.1"
pytest-asyncio = "^1.1.0"
python-dotenv = "^1.1.1"
python-multipart = "^0.0.22"
python-multipart = "^0.0.20"
redis = "^6.2.0"
regex = "^2025.9.18"
replicate = "^1.0.6"
@@ -65,11 +65,11 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal
sqlalchemy = "^2.0.40"
strenum = "^0.4.9"
stripe = "^11.5.0"
supabase = "2.27.3"
tenacity = "^9.1.4"
supabase = "2.17.0"
tenacity = "^9.1.2"
todoist-api-python = "^2.1.7"
tweepy = "^4.16.0"
uvicorn = { extras = ["standard"], version = "^0.40.0" }
uvicorn = { extras = ["standard"], version = "^0.35.0" }
websockets = "^15.0"
youtube-transcript-api = "^1.2.1"
yt-dlp = "2025.12.08"
@@ -77,7 +77,7 @@ zerobouncesdk = "^1.1.2"
# NOTE: please insert new dependencies in their alphabetical location
pytest-snapshot = "^0.9.0"
aiofiles = "^24.1.0"
tiktoken = "^0.12.0"
tiktoken = "^0.9.0"
aioclamd = "^1.0.0"
setuptools = "^80.9.0"
gcloud-aio-storage = "^9.5.0"
@@ -95,13 +95,13 @@ black = "^24.10.0"
faker = "^38.2.0"
httpx = "^0.28.1"
isort = "^5.13.2"
poethepoet = "^0.41.0"
poethepoet = "^0.37.0"
pre-commit = "^4.4.0"
pyright = "^1.1.407"
pytest-mock = "^3.15.1"
pytest-watcher = "^0.6.3"
pytest-watcher = "^0.4.2"
requests = "^2.32.5"
ruff = "^0.15.0"
ruff = "^0.14.5"
# NOTE: please insert new dependencies in their alphabetical location
[build-system]

View File

@@ -25,10 +25,6 @@ RUN if [ -f .env.production ]; then \
cp .env.default .env; \
fi
RUN pnpm run generate:api
# Disable source-map generation in Docker builds to halve webpack memory usage.
# Source maps are only useful when SENTRY_AUTH_TOKEN is set (Vercel deploys);
# the Docker image never uploads them, so generating them just wastes RAM.
ENV NEXT_PUBLIC_SOURCEMAPS="false"
# In CI, we want NEXT_PUBLIC_PW_TEST=true during build so Next.js inlines it
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=4096" pnpm build; else NODE_OPTIONS="--max-old-space-size=4096" pnpm build; fi

View File

@@ -1,12 +1,8 @@
import { withSentryConfig } from "@sentry/nextjs";
// Allow Docker builds to skip source-map generation (halves memory usage).
// Defaults to true so Vercel/local builds are unaffected.
const enableSourceMaps = process.env.NEXT_PUBLIC_SOURCEMAPS !== "false";
/** @type {import('next').NextConfig} */
const nextConfig = {
productionBrowserSourceMaps: enableSourceMaps,
productionBrowserSourceMaps: true,
// Externalize OpenTelemetry packages to fix Turbopack HMR issues
serverExternalPackages: [
"@opentelemetry/instrumentation",
@@ -100,7 +96,7 @@ export default isDevelopmentBuild
// This helps Sentry with sourcemaps... https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/
sourcemaps: {
disable: !enableSourceMaps,
disable: false,
assets: [".next/**/*.js", ".next/**/*.js.map"],
ignore: ["**/node_modules/**"],
deleteSourcemapsAfterUpload: false, // Source is public anyway :)

View File

@@ -30,7 +30,6 @@
"defaults"
],
"dependencies": {
"@ai-sdk/react": "3.0.61",
"@faker-js/faker": "10.0.0",
"@hookform/resolvers": "5.2.2",
"@next/third-parties": "15.4.6",
@@ -61,10 +60,6 @@
"@rjsf/utils": "6.1.2",
"@rjsf/validator-ajv8": "6.1.2",
"@sentry/nextjs": "10.27.0",
"@streamdown/cjk": "1.0.1",
"@streamdown/code": "1.0.1",
"@streamdown/math": "1.0.1",
"@streamdown/mermaid": "1.0.1",
"@supabase/ssr": "0.7.0",
"@supabase/supabase-js": "2.78.0",
"@tanstack/react-query": "5.90.6",
@@ -73,7 +68,6 @@
"@vercel/analytics": "1.5.0",
"@vercel/speed-insights": "1.2.0",
"@xyflow/react": "12.9.2",
"ai": "6.0.59",
"boring-avatars": "1.11.2",
"class-variance-authority": "0.7.1",
"clsx": "2.1.1",
@@ -108,7 +102,7 @@
"react-markdown": "9.0.3",
"react-modal": "3.16.3",
"react-shepherd": "6.1.9",
"react-window": "2.2.0",
"react-window": "1.8.11",
"recharts": "3.3.0",
"rehype-autolink-headings": "7.1.0",
"rehype-highlight": "7.0.2",
@@ -118,11 +112,9 @@
"remark-math": "6.0.0",
"shepherd.js": "14.5.1",
"sonner": "2.0.7",
"streamdown": "2.1.0",
"tailwind-merge": "2.6.0",
"tailwind-scrollbar": "3.1.0",
"tailwindcss-animate": "1.0.7",
"use-stick-to-bottom": "1.1.2",
"uuid": "11.1.0",
"vaul": "1.1.2",
"zod": "3.25.76",
@@ -148,7 +140,7 @@
"@types/react": "18.3.17",
"@types/react-dom": "18.3.5",
"@types/react-modal": "3.16.3",
"@types/react-window": "2.0.0",
"@types/react-window": "1.8.8",
"@vitejs/plugin-react": "5.1.2",
"axe-playwright": "2.2.2",
"chromatic": "13.3.3",

File diff suppressed because it is too large Load Diff

View File

@@ -70,10 +70,10 @@ export const HorizontalScroll: React.FC<HorizontalScrollAreaProps> = ({
{children}
</div>
{canScrollLeft && (
<div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-background via-background/80 to-background/0" />
<div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-white via-white/80 to-white/0" />
)}
{canScrollRight && (
<div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-background via-background/80 to-background/0" />
<div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-white via-white/80 to-white/0" />
)}
{canScrollLeft && (
<button

View File

@@ -1,74 +0,0 @@
"use client";
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
import { UIDataTypes, UIMessage, UITools } from "ai";
import { LayoutGroup, motion } from "framer-motion";
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
import { EmptySession } from "../EmptySession/EmptySession";
export interface ChatContainerProps {
messages: UIMessage<unknown, UIDataTypes, UITools>[];
status: string;
error: Error | undefined;
sessionId: string | null;
isLoadingSession: boolean;
isCreatingSession: boolean;
onCreateSession: () => void | Promise<string>;
onSend: (message: string) => void | Promise<void>;
onStop: () => void;
}
export const ChatContainer = ({
messages,
status,
error,
sessionId,
isLoadingSession,
isCreatingSession,
onCreateSession,
onSend,
onStop,
}: ChatContainerProps) => {
const inputLayoutId = "copilot-2-chat-input";
return (
<CopilotChatActionsProvider onSend={onSend}>
<LayoutGroup id="copilot-2-chat-layout">
<div className="flex h-full min-h-0 w-full flex-col bg-[#f8f8f9] px-2 lg:px-0">
{sessionId ? (
<div className="mx-auto flex h-full min-h-0 w-full max-w-3xl flex-col">
<ChatMessagesContainer
messages={messages}
status={status}
error={error}
isLoading={isLoadingSession}
/>
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.3 }}
className="relative px-3 pb-2 pt-2"
>
<div className="pointer-events-none absolute left-0 right-0 top-[-18px] z-10 h-6 bg-gradient-to-b from-transparent to-[#f8f8f9]" />
<ChatInput
inputId="chat-input-session"
onSend={onSend}
disabled={status === "streaming"}
isStreaming={status === "streaming"}
onStop={onStop}
placeholder="What else can I help with?"
/>
</motion.div>
</div>
) : (
<EmptySession
inputLayoutId={inputLayoutId}
isCreatingSession={isCreatingSession}
onCreateSession={onCreateSession}
onSend={onSend}
/>
)}
</div>
</LayoutGroup>
</CopilotChatActionsProvider>
);
};

View File

@@ -1,274 +0,0 @@
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
import {
Conversation,
ConversationContent,
ConversationScrollButton,
} from "@/components/ai-elements/conversation";
import {
Message,
MessageContent,
MessageResponse,
} from "@/components/ai-elements/message";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { useEffect, useState } from "react";
import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent";
import { EditAgentTool } from "../../tools/EditAgent/EditAgent";
import { FindAgentsTool } from "../../tools/FindAgents/FindAgents";
import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks";
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
// ---------------------------------------------------------------------------
// Workspace media support
// ---------------------------------------------------------------------------
/**
* Resolve workspace:// URLs in markdown text to proxy download URLs.
* Detects MIME type from the hash fragment (e.g. workspace://id#video/mp4)
* and prefixes the alt text with "video:" so the custom img component can
* render a <video> element instead.
*/
function resolveWorkspaceUrls(text: string): string {
return text.replace(
/!\[([^\]]*)\]\(workspace:\/\/([^)#\s]+)(?:#([^)\s]*))?\)/g,
(_match, alt: string, fileId: string, mimeHint?: string) => {
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
const url = `/api/proxy${apiPath}`;
if (mimeHint?.startsWith("video/")) {
return `![video:${alt || "Video"}](${url})`;
}
return `![${alt || "Image"}](${url})`;
},
);
}
/**
* Custom img component for Streamdown that renders <video> elements
* for workspace video files (detected via "video:" alt-text prefix).
* Falls back to <video> when an <img> fails to load for workspace files.
*/
function WorkspaceMediaImage(props: React.JSX.IntrinsicElements["img"]) {
const { src, alt, ...rest } = props;
const [imgFailed, setImgFailed] = useState(false);
const isWorkspace = src?.includes("/workspace/files/") ?? false;
if (!src) return null;
if (alt?.startsWith("video:") || (imgFailed && isWorkspace)) {
return (
<span className="my-2 inline-block">
<video
controls
className="h-auto max-w-full rounded-md border border-zinc-200"
preload="metadata"
>
<source src={src} />
Your browser does not support the video tag.
</video>
</span>
);
}
return (
// eslint-disable-next-line @next/next/no-img-element
<img
src={src}
alt={alt || "Image"}
className="h-auto max-w-full rounded-md border border-zinc-200"
loading="lazy"
onError={() => {
if (isWorkspace) setImgFailed(true);
}}
{...rest}
/>
);
}
/** Stable components override for Streamdown (avoids re-creating on every render). */
const STREAMDOWN_COMPONENTS = { img: WorkspaceMediaImage };
const THINKING_PHRASES = [
"Thinking...",
"Considering this...",
"Working through this...",
"Analyzing your request...",
"Reasoning...",
"Looking into it...",
"Processing your request...",
"Mulling this over...",
"Piecing it together...",
"On it...",
];
function getRandomPhrase() {
return THINKING_PHRASES[Math.floor(Math.random() * THINKING_PHRASES.length)];
}
interface ChatMessagesContainerProps {
messages: UIMessage<unknown, UIDataTypes, UITools>[];
status: string;
error: Error | undefined;
isLoading: boolean;
}
export const ChatMessagesContainer = ({
messages,
status,
error,
isLoading,
}: ChatMessagesContainerProps) => {
const [thinkingPhrase, setThinkingPhrase] = useState(getRandomPhrase);
useEffect(() => {
if (status === "submitted") {
setThinkingPhrase(getRandomPhrase());
}
}, [status]);
const lastMessage = messages[messages.length - 1];
const lastAssistantHasVisibleContent =
lastMessage?.role === "assistant" &&
lastMessage.parts.some(
(p) =>
(p.type === "text" && p.text.trim().length > 0) ||
p.type.startsWith("tool-"),
);
const showThinking =
status === "submitted" ||
(status === "streaming" && !lastAssistantHasVisibleContent);
return (
<Conversation className="min-h-0 flex-1">
<ConversationContent className="gap-6 px-3 py-6">
{isLoading && messages.length === 0 && (
<div className="flex flex-1 items-center justify-center">
<LoadingSpinner size="large" className="text-neutral-400" />
</div>
)}
{messages.map((message, messageIndex) => {
const isLastAssistant =
messageIndex === messages.length - 1 &&
message.role === "assistant";
const messageHasVisibleContent = message.parts.some(
(p) =>
(p.type === "text" && p.text.trim().length > 0) ||
p.type.startsWith("tool-"),
);
return (
<Message from={message.role} key={message.id}>
<MessageContent
className={
"text-[1rem] leading-relaxed " +
"group-[.is-user]:rounded-xl group-[.is-user]:bg-purple-100 group-[.is-user]:px-3 group-[.is-user]:py-2.5 group-[.is-user]:text-slate-900 group-[.is-user]:[border-bottom-right-radius:0] " +
"group-[.is-assistant]:bg-transparent group-[.is-assistant]:text-slate-900"
}
>
{message.parts.map((part, i) => {
switch (part.type) {
case "text":
return (
<MessageResponse
key={`${message.id}-${i}`}
components={STREAMDOWN_COMPONENTS}
>
{resolveWorkspaceUrls(part.text)}
</MessageResponse>
);
case "tool-find_block":
return (
<FindBlocksTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-find_agent":
case "tool-find_library_agent":
return (
<FindAgentsTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-search_docs":
case "tool-get_doc_page":
return (
<SearchDocsTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-run_block":
return (
<RunBlockTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-run_agent":
case "tool-schedule_agent":
return (
<RunAgentTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-create_agent":
return (
<CreateAgentTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-edit_agent":
return (
<EditAgentTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
case "tool-view_agent_output":
return (
<ViewAgentOutputTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
default:
return null;
}
})}
{isLastAssistant &&
!messageHasVisibleContent &&
showThinking && (
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
{thinkingPhrase}
</span>
)}
</MessageContent>
</Message>
);
})}
{showThinking && lastMessage?.role !== "assistant" && (
<Message from="assistant">
<MessageContent className="text-[1rem] leading-relaxed">
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
{thinkingPhrase}
</span>
</MessageContent>
</Message>
)}
{error && (
<div className="rounded-lg bg-red-50 p-3 text-red-600">
Error: {error.message}
</div>
)}
</ConversationContent>
<ConversationScrollButton />
</Conversation>
);
};

View File

@@ -1,188 +0,0 @@
"use client";
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
import { Button } from "@/components/atoms/Button/Button";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { Text } from "@/components/atoms/Text/Text";
import {
Sidebar,
SidebarContent,
SidebarFooter,
SidebarHeader,
SidebarTrigger,
useSidebar,
} from "@/components/ui/sidebar";
import { cn } from "@/lib/utils";
import { PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
import { motion } from "framer-motion";
import { parseAsString, useQueryState } from "nuqs";
export function ChatSidebar() {
const { state } = useSidebar();
const isCollapsed = state === "collapsed";
const [sessionId, setSessionId] = useQueryState("sessionId", parseAsString);
const { data: sessionsResponse, isLoading: isLoadingSessions } =
useGetV2ListSessions({ limit: 50 });
const sessions =
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
function handleNewChat() {
setSessionId(null);
}
function handleSelectSession(id: string) {
setSessionId(id);
}
function formatDate(dateString: string) {
const date = new Date(dateString);
const now = new Date();
const diffMs = now.getTime() - date.getTime();
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
if (diffDays === 0) return "Today";
if (diffDays === 1) return "Yesterday";
if (diffDays < 7) return `${diffDays} days ago`;
const day = date.getDate();
const ordinal =
day % 10 === 1 && day !== 11
? "st"
: day % 10 === 2 && day !== 12
? "nd"
: day % 10 === 3 && day !== 13
? "rd"
: "th";
const month = date.toLocaleDateString("en-US", { month: "short" });
const year = date.getFullYear();
return `${day}${ordinal} ${month} ${year}`;
}
return (
<Sidebar
variant="inset"
collapsible="icon"
className="!top-[50px] !h-[calc(100vh-50px)] border-r border-zinc-100 px-0"
>
{isCollapsed && (
<SidebarHeader
className={cn(
"flex",
isCollapsed
? "flex-row items-center justify-between gap-y-4 md:flex-col md:items-start md:justify-start"
: "flex-row items-center justify-between",
)}
>
<motion.div
key={isCollapsed ? "header-collapsed" : "header-expanded"}
className="flex flex-col items-center gap-3 pt-4"
initial={{ opacity: 0, filter: "blur(3px)" }}
animate={{ opacity: 1, filter: "blur(0px)" }}
transition={{ type: "spring", bounce: 0.2 }}
>
<div className="flex flex-col items-center gap-2">
<SidebarTrigger />
<Button
variant="ghost"
onClick={handleNewChat}
style={{ minWidth: "auto", width: "auto" }}
>
<PlusCircleIcon className="!size-5" />
<span className="sr-only">New Chat</span>
</Button>
</div>
</motion.div>
</SidebarHeader>
)}
<SidebarContent className="gap-4 overflow-y-auto px-4 py-4 [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
{!isCollapsed && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.1 }}
className="flex items-center justify-between px-3"
>
<Text variant="h3" size="body-medium">
Your chats
</Text>
<div className="relative left-6">
<SidebarTrigger />
</div>
</motion.div>
)}
{!isCollapsed && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.15 }}
className="mt-4 flex flex-col gap-1"
>
{isLoadingSessions ? (
<div className="flex items-center justify-center py-4">
<LoadingSpinner size="small" className="text-neutral-400" />
</div>
) : sessions.length === 0 ? (
<p className="py-4 text-center text-sm text-neutral-500">
No conversations yet
</p>
) : (
sessions.map((session) => (
<button
key={session.id}
onClick={() => handleSelectSession(session.id)}
className={cn(
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
session.id === sessionId
? "bg-zinc-100"
: "hover:bg-zinc-50",
)}
>
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
<div className="min-w-0 max-w-full">
<Text
variant="body"
className={cn(
"truncate font-normal",
session.id === sessionId
? "text-zinc-600"
: "text-zinc-800",
)}
>
{session.title || `Untitled chat`}
</Text>
</div>
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}
</Text>
</div>
</button>
))
)}
</motion.div>
)}
</SidebarContent>
{!isCollapsed && sessionId && (
<SidebarFooter className="shrink-0 bg-zinc-50 p-3 pb-1 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.2 }}
>
<Button
variant="primary"
size="small"
onClick={handleNewChat}
className="w-full"
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
>
New Chat
</Button>
</motion.div>
</SidebarFooter>
)}
</Sidebar>
);
}

View File

@@ -1,16 +0,0 @@
"use client";
import { CopilotChatActionsContext } from "./useCopilotChatActions";
interface Props {
onSend: (message: string) => void | Promise<void>;
children: React.ReactNode;
}
export function CopilotChatActionsProvider({ onSend, children }: Props) {
return (
<CopilotChatActionsContext.Provider value={{ onSend }}>
{children}
</CopilotChatActionsContext.Provider>
);
}

View File

@@ -1,23 +0,0 @@
"use client";
import { createContext, useContext } from "react";
interface CopilotChatActions {
onSend: (message: string) => void | Promise<void>;
}
const CopilotChatActionsContext = createContext<CopilotChatActions | null>(
null,
);
export function useCopilotChatActions(): CopilotChatActions {
const ctx = useContext(CopilotChatActionsContext);
if (!ctx) {
throw new Error(
"useCopilotChatActions must be used within CopilotChatActionsProvider",
);
}
return ctx;
}
export { CopilotChatActionsContext };

View File

@@ -0,0 +1,99 @@
"use client";
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
import { Text } from "@/components/atoms/Text/Text";
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
import type { ReactNode } from "react";
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
import { useCopilotShell } from "./useCopilotShell";
interface Props {
children: ReactNode;
}
export function CopilotShell({ children }: Props) {
const {
isMobile,
isDrawerOpen,
isLoading,
isCreatingSession,
isLoggedIn,
hasActiveSession,
sessions,
currentSessionId,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
handleNewChatClick,
handleSessionClick,
hasNextPage,
isFetchingNextPage,
fetchNextPage,
} = useCopilotShell();
if (!isLoggedIn) {
return (
<div className="flex h-full items-center justify-center">
<ChatLoader />
</div>
);
}
return (
<div
className="flex overflow-hidden bg-[#EFEFF0]"
style={{ height: `calc(100vh - ${NAVBAR_HEIGHT_PX}px)` }}
>
{!isMobile && (
<DesktopSidebar
sessions={sessions}
currentSessionId={currentSessionId}
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={handleSessionClick}
onFetchNextPage={fetchNextPage}
onNewChat={handleNewChatClick}
hasActiveSession={Boolean(hasActiveSession)}
/>
)}
<div className="relative flex min-h-0 flex-1 flex-col">
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
<div className="flex min-h-0 flex-1 flex-col">
{isCreatingSession ? (
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
<div className="flex flex-col items-center gap-4">
<ChatLoader />
<Text variant="body" className="text-zinc-500">
Creating your chat...
</Text>
</div>
</div>
) : (
children
)}
</div>
</div>
{isMobile && (
<MobileDrawer
isOpen={isDrawerOpen}
sessions={sessions}
currentSessionId={currentSessionId}
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={handleSessionClick}
onFetchNextPage={fetchNextPage}
onNewChat={handleNewChatClick}
onClose={handleCloseDrawer}
onOpenChange={handleDrawerOpenChange}
hasActiveSession={Boolean(hasActiveSession)}
/>
)}
</div>
);
}

View File

@@ -0,0 +1,70 @@
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { scrollbarStyles } from "@/components/styles/scrollbars";
import { cn } from "@/lib/utils";
import { Plus } from "@phosphor-icons/react";
import { SessionsList } from "../SessionsList/SessionsList";
interface Props {
sessions: SessionSummaryResponse[];
currentSessionId: string | null;
isLoading: boolean;
hasNextPage: boolean;
isFetchingNextPage: boolean;
onSelectSession: (sessionId: string) => void;
onFetchNextPage: () => void;
onNewChat: () => void;
hasActiveSession: boolean;
}
export function DesktopSidebar({
sessions,
currentSessionId,
isLoading,
hasNextPage,
isFetchingNextPage,
onSelectSession,
onFetchNextPage,
onNewChat,
hasActiveSession,
}: Props) {
return (
<aside className="flex h-full w-80 flex-col border-r border-zinc-100 bg-zinc-50">
<div className="shrink-0 px-6 py-4">
<Text variant="h3" size="body-medium">
Your chats
</Text>
</div>
<div
className={cn(
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
scrollbarStyles,
)}
>
<SessionsList
sessions={sessions}
currentSessionId={currentSessionId}
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={onSelectSession}
onFetchNextPage={onFetchNextPage}
/>
</div>
{hasActiveSession && (
<div className="shrink-0 bg-zinc-50 p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<Button
variant="primary"
size="small"
onClick={onNewChat}
className="w-full"
leftIcon={<Plus width="1rem" height="1rem" />}
>
New Chat
</Button>
</div>
)}
</aside>
);
}

View File

@@ -0,0 +1,91 @@
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { Button } from "@/components/atoms/Button/Button";
import { scrollbarStyles } from "@/components/styles/scrollbars";
import { cn } from "@/lib/utils";
import { PlusIcon, X } from "@phosphor-icons/react";
import { Drawer } from "vaul";
import { SessionsList } from "../SessionsList/SessionsList";
interface Props {
isOpen: boolean;
sessions: SessionSummaryResponse[];
currentSessionId: string | null;
isLoading: boolean;
hasNextPage: boolean;
isFetchingNextPage: boolean;
onSelectSession: (sessionId: string) => void;
onFetchNextPage: () => void;
onNewChat: () => void;
onClose: () => void;
onOpenChange: (open: boolean) => void;
hasActiveSession: boolean;
}
export function MobileDrawer({
isOpen,
sessions,
currentSessionId,
isLoading,
hasNextPage,
isFetchingNextPage,
onSelectSession,
onFetchNextPage,
onNewChat,
onClose,
onOpenChange,
hasActiveSession,
}: Props) {
return (
<Drawer.Root open={isOpen} onOpenChange={onOpenChange} direction="left">
<Drawer.Portal>
<Drawer.Overlay className="fixed inset-0 z-[60] bg-black/10 backdrop-blur-sm" />
<Drawer.Content className="fixed left-0 top-0 z-[70] flex h-full w-80 flex-col border-r border-zinc-200 bg-zinc-50">
<div className="shrink-0 border-b border-zinc-200 p-4">
<div className="flex items-center justify-between">
<Drawer.Title className="text-lg font-semibold text-zinc-800">
Your chats
</Drawer.Title>
<Button
variant="icon"
size="icon"
aria-label="Close sessions"
onClick={onClose}
>
<X width="1.25rem" height="1.25rem" />
</Button>
</div>
</div>
<div
className={cn(
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
scrollbarStyles,
)}
>
<SessionsList
sessions={sessions}
currentSessionId={currentSessionId}
isLoading={isLoading}
hasNextPage={hasNextPage}
isFetchingNextPage={isFetchingNextPage}
onSelectSession={onSelectSession}
onFetchNextPage={onFetchNextPage}
/>
</div>
{hasActiveSession && (
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<Button
variant="primary"
size="small"
onClick={onNewChat}
className="w-full"
leftIcon={<PlusIcon width="1rem" height="1rem" />}
>
New Chat
</Button>
</div>
)}
</Drawer.Content>
</Drawer.Portal>
</Drawer.Root>
);
}

View File

@@ -0,0 +1,24 @@
import { useState } from "react";
export function useMobileDrawer() {
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
const handleOpenDrawer = () => {
setIsDrawerOpen(true);
};
const handleCloseDrawer = () => {
setIsDrawerOpen(false);
};
const handleDrawerOpenChange = (open: boolean) => {
setIsDrawerOpen(open);
};
return {
isDrawerOpen,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
};
}

View File

@@ -0,0 +1,80 @@
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
import { Text } from "@/components/atoms/Text/Text";
import { InfiniteList } from "@/components/molecules/InfiniteList/InfiniteList";
import { cn } from "@/lib/utils";
import { getSessionTitle } from "../../helpers";
interface Props {
sessions: SessionSummaryResponse[];
currentSessionId: string | null;
isLoading: boolean;
hasNextPage: boolean;
isFetchingNextPage: boolean;
onSelectSession: (sessionId: string) => void;
onFetchNextPage: () => void;
}
export function SessionsList({
sessions,
currentSessionId,
isLoading,
hasNextPage,
isFetchingNextPage,
onSelectSession,
onFetchNextPage,
}: Props) {
if (isLoading) {
return (
<div className="space-y-1">
{Array.from({ length: 5 }).map((_, i) => (
<div key={i} className="rounded-lg px-3 py-2.5">
<Skeleton className="h-5 w-full" />
</div>
))}
</div>
);
}
if (sessions.length === 0) {
return (
<div className="flex h-full items-center justify-center">
<Text variant="body" className="text-zinc-500">
You don&apos;t have previous chats
</Text>
</div>
);
}
return (
<InfiniteList
items={sessions}
hasMore={hasNextPage}
isFetchingMore={isFetchingNextPage}
onEndReached={onFetchNextPage}
className="space-y-1"
renderItem={(session) => {
const isActive = session.id === currentSessionId;
return (
<button
onClick={() => onSelectSession(session.id)}
className={cn(
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
isActive ? "bg-zinc-100" : "hover:bg-zinc-50",
)}
>
<Text
variant="body"
className={cn(
"font-normal",
isActive ? "text-zinc-600" : "text-zinc-800",
)}
>
{getSessionTitle(session)}
</Text>
</button>
);
}}
/>
);
}

View File

@@ -0,0 +1,91 @@
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { okData } from "@/app/api/helpers";
import { useEffect, useState } from "react";
const PAGE_SIZE = 50;
export interface UseSessionsPaginationArgs {
enabled: boolean;
}
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
const [offset, setOffset] = useState(0);
const [accumulatedSessions, setAccumulatedSessions] = useState<
SessionSummaryResponse[]
>([]);
const [totalCount, setTotalCount] = useState<number | null>(null);
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
{ limit: PAGE_SIZE, offset },
{
query: {
enabled: enabled && offset >= 0,
},
},
);
useEffect(() => {
const responseData = okData(data);
if (responseData) {
const newSessions = responseData.sessions;
const total = responseData.total;
setTotalCount(total);
if (offset === 0) {
setAccumulatedSessions(newSessions);
} else {
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
}
} else if (!enabled) {
setAccumulatedSessions([]);
setTotalCount(null);
}
}, [data, offset, enabled]);
const hasNextPage =
totalCount !== null && accumulatedSessions.length < totalCount;
const areAllSessionsLoaded =
totalCount !== null &&
accumulatedSessions.length >= totalCount &&
!isFetching &&
!isLoading;
useEffect(() => {
if (
hasNextPage &&
!isFetching &&
!isLoading &&
!isError &&
totalCount !== null
) {
setOffset((prev) => prev + PAGE_SIZE);
}
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
const fetchNextPage = () => {
if (hasNextPage && !isFetching) {
setOffset((prev) => prev + PAGE_SIZE);
}
};
const reset = () => {
// Only reset the offset - keep existing sessions visible during refetch
// The effect will replace sessions when new data arrives at offset 0
setOffset(0);
};
return {
sessions: accumulatedSessions,
isLoading,
isFetching,
hasNextPage,
areAllSessionsLoaded,
totalCount,
fetchNextPage,
reset,
};
}

View File

@@ -0,0 +1,106 @@
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { format, formatDistanceToNow, isToday } from "date-fns";
export function convertSessionDetailToSummary(session: SessionDetailResponse) {
return {
id: session.id,
created_at: session.created_at,
updated_at: session.updated_at,
title: undefined,
};
}
export function filterVisibleSessions(sessions: SessionSummaryResponse[]) {
const fiveMinutesAgo = Date.now() - 5 * 60 * 1000;
return sessions.filter((session) => {
const hasBeenUpdated = session.updated_at !== session.created_at;
if (hasBeenUpdated) return true;
const isRecentlyCreated =
new Date(session.created_at).getTime() > fiveMinutesAgo;
return isRecentlyCreated;
});
}
export function getSessionTitle(session: SessionSummaryResponse) {
if (session.title) return session.title;
const isNewSession = session.updated_at === session.created_at;
if (isNewSession) {
const createdDate = new Date(session.created_at);
if (isToday(createdDate)) {
return "Today";
}
return format(createdDate, "MMM d, yyyy");
}
return "Untitled Chat";
}
export function getSessionUpdatedLabel(session: SessionSummaryResponse) {
if (!session.updated_at) return "";
return formatDistanceToNow(new Date(session.updated_at), { addSuffix: true });
}
export function mergeCurrentSessionIntoList(
accumulatedSessions: SessionSummaryResponse[],
currentSessionId: string | null,
currentSessionData: SessionDetailResponse | null | undefined,
recentlyCreatedSessions?: Map<string, SessionSummaryResponse>,
) {
const filteredSessions: SessionSummaryResponse[] = [];
const addedIds = new Set<string>();
if (accumulatedSessions.length > 0) {
const visibleSessions = filterVisibleSessions(accumulatedSessions);
if (currentSessionId) {
const currentInAll = accumulatedSessions.find(
(s) => s.id === currentSessionId,
);
if (currentInAll) {
const isInVisible = visibleSessions.some(
(s) => s.id === currentSessionId,
);
if (!isInVisible) {
filteredSessions.push(currentInAll);
addedIds.add(currentInAll.id);
}
}
}
for (const session of visibleSessions) {
if (!addedIds.has(session.id)) {
filteredSessions.push(session);
addedIds.add(session.id);
}
}
}
if (currentSessionId && currentSessionData) {
if (!addedIds.has(currentSessionId)) {
const summarySession = convertSessionDetailToSummary(currentSessionData);
filteredSessions.unshift(summarySession);
addedIds.add(currentSessionId);
}
}
if (recentlyCreatedSessions) {
for (const [sessionId, sessionData] of recentlyCreatedSessions) {
if (!addedIds.has(sessionId)) {
filteredSessions.unshift(sessionData);
addedIds.add(sessionId);
}
}
}
return filteredSessions;
}
export function getCurrentSessionId(searchParams: URLSearchParams) {
return searchParams.get("sessionId");
}

View File

@@ -0,0 +1,124 @@
"use client";
import {
getGetV2GetSessionQueryKey,
getGetV2ListSessionsQueryKey,
useGetV2GetSession,
} from "@/app/api/__generated__/endpoints/chat/chat";
import { okData } from "@/app/api/helpers";
import { useChatStore } from "@/components/contextual/Chat/chat-store";
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useQueryClient } from "@tanstack/react-query";
import { usePathname, useSearchParams } from "next/navigation";
import { useCopilotStore } from "../../copilot-page-store";
import { useCopilotSessionId } from "../../useCopilotSessionId";
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
import { getCurrentSessionId } from "./helpers";
import { useShellSessionList } from "./useShellSessionList";
export function useCopilotShell() {
const pathname = usePathname();
const searchParams = useSearchParams();
const queryClient = useQueryClient();
const breakpoint = useBreakpoint();
const { isLoggedIn } = useSupabase();
const isMobile =
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
const isOnHomepage = pathname === "/copilot";
const paramSessionId = searchParams.get("sessionId");
const {
isDrawerOpen,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
} = useMobileDrawer();
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
const currentSessionId = getCurrentSessionId(searchParams);
const { data: currentSessionData } = useGetV2GetSession(
currentSessionId || "",
{
query: {
enabled: !!currentSessionId,
select: okData,
},
},
);
const {
sessions,
isLoading,
isSessionsFetching,
hasNextPage,
fetchNextPage,
resetPagination,
recentlyCreatedSessionsRef,
} = useShellSessionList({
paginationEnabled,
currentSessionId,
currentSessionData,
isOnHomepage,
paramSessionId,
});
const stopStream = useChatStore((s) => s.stopStream);
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
function handleSessionClick(sessionId: string) {
if (sessionId === currentSessionId) return;
// Stop current stream - SSE reconnection allows resuming later
if (currentSessionId) {
stopStream(currentSessionId);
}
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
});
}
setUrlSessionId(sessionId, { shallow: false });
if (isMobile) handleCloseDrawer();
}
function handleNewChatClick() {
// Stop current stream - SSE reconnection allows resuming later
if (currentSessionId) {
stopStream(currentSessionId);
}
resetPagination();
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
setUrlSessionId(null, { shallow: false });
if (isMobile) handleCloseDrawer();
}
return {
isMobile,
isDrawerOpen,
isLoggedIn,
hasActiveSession:
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
isLoading: isLoading || isCreatingSession,
isCreatingSession,
sessions,
currentSessionId: urlSessionId,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
handleNewChatClick,
handleSessionClick,
hasNextPage,
isFetchingNextPage: isSessionsFetching,
fetchNextPage,
};
}

View File

@@ -0,0 +1,113 @@
import { getGetV2ListSessionsQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { useChatStore } from "@/components/contextual/Chat/chat-store";
import { useQueryClient } from "@tanstack/react-query";
import { useEffect, useMemo, useRef } from "react";
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
import {
convertSessionDetailToSummary,
filterVisibleSessions,
mergeCurrentSessionIntoList,
} from "./helpers";
interface UseShellSessionListArgs {
paginationEnabled: boolean;
currentSessionId: string | null;
currentSessionData: SessionDetailResponse | null | undefined;
isOnHomepage: boolean;
paramSessionId: string | null;
}
export function useShellSessionList({
paginationEnabled,
currentSessionId,
currentSessionData,
isOnHomepage,
paramSessionId,
}: UseShellSessionListArgs) {
const queryClient = useQueryClient();
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
const {
sessions: accumulatedSessions,
isLoading: isSessionsLoading,
isFetching: isSessionsFetching,
hasNextPage,
fetchNextPage,
reset: resetPagination,
} = useSessionsPagination({
enabled: paginationEnabled,
});
const recentlyCreatedSessionsRef = useRef<
Map<string, SessionSummaryResponse>
>(new Map());
useEffect(() => {
if (isOnHomepage && !paramSessionId) {
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
}
}, [isOnHomepage, paramSessionId, queryClient]);
useEffect(() => {
if (currentSessionId && currentSessionData) {
const isNewSession =
currentSessionData.updated_at === currentSessionData.created_at;
const isNotInAccumulated = !accumulatedSessions.some(
(s) => s.id === currentSessionId,
);
if (isNewSession || isNotInAccumulated) {
const summary = convertSessionDetailToSummary(currentSessionData);
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
}
}
}, [currentSessionId, currentSessionData, accumulatedSessions]);
useEffect(() => {
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
if (accumulatedSessions.some((s) => s.id === sessionId)) {
recentlyCreatedSessionsRef.current.delete(sessionId);
}
}
}, [accumulatedSessions]);
useEffect(() => {
const unsubscribe = onStreamComplete(() => {
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
});
return unsubscribe;
}, [onStreamComplete, queryClient]);
const sessions = useMemo(
() =>
mergeCurrentSessionIntoList(
accumulatedSessions,
currentSessionId,
currentSessionData,
recentlyCreatedSessionsRef.current,
),
[accumulatedSessions, currentSessionId, currentSessionData],
);
const visibleSessions = useMemo(
() => filterVisibleSessions(sessions),
[sessions],
);
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
return {
sessions: visibleSessions,
isLoading,
isSessionsFetching,
hasNextPage,
fetchNextPage,
resetPagination,
recentlyCreatedSessionsRef,
};
}

View File

@@ -1,111 +0,0 @@
"use client";
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { SpinnerGapIcon } from "@phosphor-icons/react";
import { motion } from "framer-motion";
import { useEffect, useState } from "react";
import {
getGreetingName,
getInputPlaceholder,
getQuickActions,
} from "./helpers";
interface Props {
inputLayoutId: string;
isCreatingSession: boolean;
onCreateSession: () => void | Promise<string>;
onSend: (message: string) => void | Promise<void>;
}
export function EmptySession({
inputLayoutId,
isCreatingSession,
onSend,
}: Props) {
const { user } = useSupabase();
const greetingName = getGreetingName(user);
const quickActions = getQuickActions();
const [loadingAction, setLoadingAction] = useState<string | null>(null);
const [inputPlaceholder, setInputPlaceholder] = useState(
getInputPlaceholder(),
);
useEffect(() => {
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
}, [window.innerWidth]);
async function handleQuickActionClick(action: string) {
if (isCreatingSession || loadingAction) return;
setLoadingAction(action);
try {
await onSend(action);
} finally {
setLoadingAction(null);
}
}
return (
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-0 py-5 md:px-6 md:py-10">
<motion.div
className="w-full max-w-3xl text-center"
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.3 }}
>
<div className="mx-auto max-w-3xl">
<Text variant="h3" className="mb-1 !text-[1.375rem] text-zinc-700">
Hey, <span className="text-violet-600">{greetingName}</span>
</Text>
<Text variant="h3" className="mb-8 !font-normal">
Tell me about your work I&apos;ll find what to automate.
</Text>
<div className="mb-6">
<motion.div
layoutId={inputLayoutId}
transition={{ type: "spring", bounce: 0.2, duration: 0.65 }}
className="w-full px-2"
>
<ChatInput
inputId="chat-input-empty"
onSend={onSend}
disabled={isCreatingSession}
placeholder={inputPlaceholder}
className="w-full"
/>
</motion.div>
</div>
</div>
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
{quickActions.map((action) => (
<Button
key={action}
type="button"
variant="outline"
size="small"
onClick={() => void handleQuickActionClick(action)}
disabled={isCreatingSession || loadingAction !== null}
aria-busy={loadingAction === action}
leftIcon={
loadingAction === action ? (
<SpinnerGapIcon
className="h-4 w-4 animate-spin"
weight="bold"
/>
) : null
}
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
>
{action}
</Button>
))}
</div>
</motion.div>
</div>
);
}

View File

@@ -1,140 +0,0 @@
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { scrollbarStyles } from "@/components/styles/scrollbars";
import { cn } from "@/lib/utils";
import { PlusIcon, SpinnerGapIcon, X } from "@phosphor-icons/react";
import { Drawer } from "vaul";
interface Props {
isOpen: boolean;
sessions: SessionSummaryResponse[];
currentSessionId: string | null;
isLoading: boolean;
onSelectSession: (sessionId: string) => void;
onNewChat: () => void;
onClose: () => void;
onOpenChange: (open: boolean) => void;
}
function formatDate(dateString: string) {
const date = new Date(dateString);
const now = new Date();
const diffMs = now.getTime() - date.getTime();
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
if (diffDays === 0) return "Today";
if (diffDays === 1) return "Yesterday";
if (diffDays < 7) return `${diffDays} days ago`;
const day = date.getDate();
const ordinal =
day % 10 === 1 && day !== 11
? "st"
: day % 10 === 2 && day !== 12
? "nd"
: day % 10 === 3 && day !== 13
? "rd"
: "th";
const month = date.toLocaleDateString("en-US", { month: "short" });
const year = date.getFullYear();
return `${day}${ordinal} ${month} ${year}`;
}
export function MobileDrawer({
isOpen,
sessions,
currentSessionId,
isLoading,
onSelectSession,
onNewChat,
onClose,
onOpenChange,
}: Props) {
return (
<Drawer.Root open={isOpen} onOpenChange={onOpenChange} direction="left">
<Drawer.Portal>
<Drawer.Overlay className="fixed inset-0 z-[60] bg-black/10 backdrop-blur-sm" />
<Drawer.Content className="fixed left-0 top-0 z-[70] flex h-full w-80 flex-col border-r border-zinc-200 bg-zinc-50">
<div className="shrink-0 border-b border-zinc-200 px-4 py-2">
<div className="flex items-center justify-between">
<Drawer.Title className="text-lg font-semibold text-zinc-800">
Your chats
</Drawer.Title>
<Button
variant="icon"
size="icon"
aria-label="Close sessions"
onClick={onClose}
>
<X width="1rem" height="1rem" />
</Button>
</div>
</div>
<div
className={cn(
"flex min-h-0 flex-1 flex-col gap-1 overflow-y-auto px-3 py-3",
scrollbarStyles,
)}
>
{isLoading ? (
<div className="flex items-center justify-center py-4">
<SpinnerGapIcon className="h-5 w-5 animate-spin text-neutral-400" />
</div>
) : sessions.length === 0 ? (
<p className="py-4 text-center text-sm text-neutral-500">
No conversations yet
</p>
) : (
sessions.map((session) => (
<button
key={session.id}
onClick={() => onSelectSession(session.id)}
className={cn(
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
session.id === currentSessionId
? "bg-zinc-100"
: "hover:bg-zinc-50",
)}
>
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
<div className="min-w-0 max-w-full">
<Text
variant="body"
className={cn(
"truncate font-normal",
session.id === currentSessionId
? "text-zinc-600"
: "text-zinc-800",
)}
>
{session.title || "Untitled chat"}
</Text>
</div>
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}
</Text>
</div>
</button>
))
)}
</div>
{currentSessionId && (
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<Button
variant="primary"
size="small"
onClick={onNewChat}
className="w-full"
leftIcon={<PlusIcon width="1rem" height="1rem" />}
>
New Chat
</Button>
</div>
)}
</Drawer.Content>
</Drawer.Portal>
</Drawer.Root>
);
}

View File

@@ -1,54 +0,0 @@
import { cn } from "@/lib/utils";
import { AnimatePresence, motion } from "framer-motion";
interface Props {
text: string;
className?: string;
}
export function MorphingTextAnimation({ text, className }: Props) {
const letters = text.split("");
return (
<div className={cn(className)}>
<AnimatePresence mode="popLayout" initial={false}>
<motion.div key={text} className="whitespace-nowrap">
<motion.span className="inline-flex overflow-hidden">
{letters.map((char, index) => (
<motion.span
key={`${text}-${index}`}
initial={{
opacity: 0,
y: 8,
rotateX: "80deg",
filter: "blur(6px)",
}}
animate={{
opacity: 1,
y: 0,
rotateX: "0deg",
filter: "blur(0px)",
}}
exit={{
opacity: 0,
y: -8,
rotateX: "-80deg",
filter: "blur(6px)",
}}
style={{ willChange: "transform" }}
transition={{
delay: 0.015 * index,
type: "spring",
bounce: 0.5,
}}
className="inline-block"
>
{char === " " ? "\u00A0" : char}
</motion.span>
))}
</motion.span>
</motion.div>
</AnimatePresence>
</div>
);
}

View File

@@ -1,69 +0,0 @@
.loader {
position: relative;
animation: rotate 1s infinite;
}
.loader::before,
.loader::after {
border-radius: 50%;
content: "";
display: block;
/* 40% of container size */
height: 40%;
width: 40%;
}
.loader::before {
animation: ball1 1s infinite;
background-color: #a1a1aa; /* zinc-400 */
box-shadow: calc(var(--spacing)) 0 0 #18181b; /* zinc-900 */
margin-bottom: calc(var(--gap));
}
.loader::after {
animation: ball2 1s infinite;
background-color: #18181b; /* zinc-900 */
box-shadow: calc(var(--spacing)) 0 0 #a1a1aa; /* zinc-400 */
}
@keyframes rotate {
0% {
transform: rotate(0deg) scale(0.8);
}
50% {
transform: rotate(360deg) scale(1.2);
}
100% {
transform: rotate(720deg) scale(0.8);
}
}
@keyframes ball1 {
0% {
box-shadow: calc(var(--spacing)) 0 0 #18181b;
}
50% {
box-shadow: 0 0 0 #18181b;
margin-bottom: 0;
transform: translate(calc(var(--spacing) / 2), calc(var(--spacing) / 2));
}
100% {
box-shadow: calc(var(--spacing)) 0 0 #18181b;
margin-bottom: calc(var(--gap));
}
}
@keyframes ball2 {
0% {
box-shadow: calc(var(--spacing)) 0 0 #a1a1aa;
}
50% {
box-shadow: 0 0 0 #a1a1aa;
margin-top: calc(var(--ball-size) * -1);
transform: translate(calc(var(--spacing) / 2), calc(var(--spacing) / 2));
}
100% {
box-shadow: calc(var(--spacing)) 0 0 #a1a1aa;
margin-top: 0;
}
}

View File

@@ -1,28 +0,0 @@
import { cn } from "@/lib/utils";
import styles from "./OrbitLoader.module.css";
interface Props {
size?: number;
className?: string;
}
export function OrbitLoader({ size = 24, className }: Props) {
const ballSize = Math.round(size * 0.4);
const spacing = Math.round(size * 0.6);
const gap = Math.round(size * 0.2);
return (
<div
className={cn(styles.loader, className)}
style={
{
width: size,
height: size,
"--ball-size": `${ballSize}px`,
"--spacing": `${spacing}px`,
"--gap": `${gap}px`,
} as React.CSSProperties
}
/>
);
}

View File

@@ -1,26 +0,0 @@
import { cn } from "@/lib/utils";
interface Props {
value: number;
label?: string;
className?: string;
}
export function ProgressBar({ value, label, className }: Props) {
const clamped = Math.min(100, Math.max(0, value));
return (
<div className={cn("flex flex-col gap-1.5", className)}>
<div className="flex items-center justify-between text-xs text-neutral-500">
<span>{label ?? "Working on it..."}</span>
<span>{Math.round(clamped)}%</span>
</div>
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
<div
className="h-full rounded-full bg-neutral-900 transition-[width] duration-300 ease-out"
style={{ width: `${clamped}%` }}
/>
</div>
</div>
);
}

View File

@@ -1,34 +0,0 @@
.loader {
position: relative;
display: inline-block;
flex-shrink: 0;
}
.loader::before,
.loader::after {
content: "";
box-sizing: border-box;
width: 100%;
height: 100%;
border-radius: 50%;
background: currentColor;
position: absolute;
left: 0;
top: 0;
animation: ripple 2s linear infinite;
}
.loader::after {
animation-delay: 1s;
}
@keyframes ripple {
0% {
transform: scale(0);
opacity: 1;
}
100% {
transform: scale(1);
opacity: 0;
}
}

View File

@@ -1,16 +0,0 @@
import { cn } from "@/lib/utils";
import styles from "./PulseLoader.module.css";
interface Props {
size?: number;
className?: string;
}
export function PulseLoader({ size = 24, className }: Props) {
return (
<div
className={cn(styles.loader, className)}
style={{ width: size, height: size }}
/>
);
}

View File

@@ -1,57 +0,0 @@
.loader {
position: relative;
display: inline-block;
flex-shrink: 0;
transform: rotateZ(45deg);
perspective: 1000px;
border-radius: 50%;
color: currentColor;
}
.loader::before,
.loader::after {
content: "";
display: block;
position: absolute;
top: 0;
left: 0;
width: inherit;
height: inherit;
border-radius: 50%;
transform: rotateX(70deg);
animation: spin 1s linear infinite;
}
.loader::after {
color: var(--spinner-accent, #a855f7);
transform: rotateY(70deg);
animation-delay: 0.4s;
}
@keyframes spin {
0%,
100% {
box-shadow: 0.2em 0 0 0 currentColor;
}
12% {
box-shadow: 0.2em 0.2em 0 0 currentColor;
}
25% {
box-shadow: 0 0.2em 0 0 currentColor;
}
37% {
box-shadow: -0.2em 0.2em 0 0 currentColor;
}
50% {
box-shadow: -0.2em 0 0 0 currentColor;
}
62% {
box-shadow: -0.2em -0.2em 0 0 currentColor;
}
75% {
box-shadow: 0 -0.2em 0 0 currentColor;
}
87% {
box-shadow: 0.2em -0.2em 0 0 currentColor;
}
}

View File

@@ -1,16 +0,0 @@
import { cn } from "@/lib/utils";
import styles from "./SpinnerLoader.module.css";
interface Props {
size?: number;
className?: string;
}
export function SpinnerLoader({ size = 24, className }: Props) {
return (
<div
className={cn(styles.loader, className)}
style={{ width: size, height: size }}
/>
);
}

View File

@@ -1,235 +0,0 @@
import { Link } from "@/components/atoms/Link/Link";
import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils";
/* ------------------------------------------------------------------ */
/* Layout */
/* ------------------------------------------------------------------ */
export function ContentGrid({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return <div className={cn("grid gap-2", className)}>{children}</div>;
}
/* ------------------------------------------------------------------ */
/* Card */
/* ------------------------------------------------------------------ */
export function ContentCard({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<div
className={cn(
"rounded-lg bg-gradient-to-r from-purple-500/30 to-blue-500/30 p-[1px]",
className,
)}
>
<div className="rounded-lg bg-neutral-100 p-3">{children}</div>
</div>
);
}
/** Flex row with a left content area (`children`) and an optional rightside `action`. */
export function ContentCardHeader({
children,
action,
className,
}: {
children: React.ReactNode;
action?: React.ReactNode;
className?: string;
}) {
return (
<div className={cn("flex items-start justify-between gap-2", className)}>
<div className="min-w-0">{children}</div>
{action}
</div>
);
}
export function ContentCardTitle({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<Text
variant="body-medium"
className={cn("truncate text-zinc-800", className)}
>
{children}
</Text>
);
}
export function ContentCardSubtitle({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<Text
variant="small"
className={cn("mt-0.5 truncate font-mono text-zinc-800", className)}
>
{children}
</Text>
);
}
export function ContentCardDescription({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<Text variant="body" className={cn("mt-2 text-zinc-800", className)}>
{children}
</Text>
);
}
/* ------------------------------------------------------------------ */
/* Text */
/* ------------------------------------------------------------------ */
export function ContentMessage({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<Text variant="body" className={cn("text-zinc-800", className)}>
{children}
</Text>
);
}
export function ContentHint({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<Text variant="small" className={cn("text-neutral-500", className)}>
{children}
</Text>
);
}
/* ------------------------------------------------------------------ */
/* Code / data */
/* ------------------------------------------------------------------ */
export function ContentCodeBlock({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<pre
className={cn(
"whitespace-pre-wrap rounded-lg border bg-black p-3 text-xs text-neutral-200",
className,
)}
>
{children}
</pre>
);
}
/* ------------------------------------------------------------------ */
/* Inline elements */
/* ------------------------------------------------------------------ */
export function ContentBadge({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return (
<Text
variant="small"
as="span"
className={cn(
"shrink-0 rounded-full border bg-muted px-2 py-0.5 text-[11px] text-zinc-800",
className,
)}
>
{children}
</Text>
);
}
export function ContentLink({
href,
children,
className,
...rest
}: Omit<React.ComponentProps<typeof Link>, "className"> & {
className?: string;
}) {
return (
<Link
variant="primary"
isExternal
href={href}
className={cn("shrink-0 text-xs text-purple-500", className)}
{...rest}
>
{children}
</Link>
);
}
/* ------------------------------------------------------------------ */
/* Lists */
/* ------------------------------------------------------------------ */
export function ContentSuggestionsList({
items,
max = 5,
className,
}: {
items: string[];
max?: number;
className?: string;
}) {
if (items.length === 0) return null;
return (
<ul
className={cn(
"mt-2 list-disc space-y-1 pl-5 font-sans text-[0.75rem] leading-[1.125rem] text-zinc-800",
className,
)}
>
{items.slice(0, max).map((s) => (
<li key={s}>{s}</li>
))}
</ul>
);
}

View File

@@ -1,102 +0,0 @@
"use client";
import { cn } from "@/lib/utils";
import { CaretDownIcon } from "@phosphor-icons/react";
import { AnimatePresence, motion, useReducedMotion } from "framer-motion";
import { useId } from "react";
import { useToolAccordion } from "./useToolAccordion";
interface Props {
icon: React.ReactNode;
title: React.ReactNode;
titleClassName?: string;
description?: React.ReactNode;
children: React.ReactNode;
className?: string;
defaultExpanded?: boolean;
expanded?: boolean;
onExpandedChange?: (expanded: boolean) => void;
}
export function ToolAccordion({
icon,
title,
titleClassName,
description,
children,
className,
defaultExpanded,
expanded,
onExpandedChange,
}: Props) {
const shouldReduceMotion = useReducedMotion();
const contentId = useId();
const { isExpanded, toggle } = useToolAccordion({
expanded,
defaultExpanded,
onExpandedChange,
});
return (
<div
className={cn(
"mt-2 w-full rounded-lg border border-slate-200 bg-slate-100 px-3 py-2",
className,
)}
>
<button
type="button"
aria-expanded={isExpanded}
aria-controls={contentId}
onClick={toggle}
className="flex w-full items-center justify-between gap-3 py-1 text-left"
>
<div className="flex min-w-0 items-center gap-3">
<span className="flex shrink-0 items-center text-gray-800">
{icon}
</span>
<div className="min-w-0">
<p
className={cn(
"truncate text-sm font-medium text-gray-800",
titleClassName,
)}
>
{title}
</p>
{description && (
<p className="truncate text-xs text-slate-800">{description}</p>
)}
</div>
</div>
<CaretDownIcon
className={cn(
"h-4 w-4 shrink-0 text-slate-500 transition-transform",
isExpanded && "rotate-180",
)}
weight="bold"
/>
</button>
<AnimatePresence initial={false}>
{isExpanded && (
<motion.div
id={contentId}
initial={{ height: 0, opacity: 0, filter: "blur(10px)" }}
animate={{ height: "auto", opacity: 1, filter: "blur(0px)" }}
exit={{ height: 0, opacity: 0, filter: "blur(10px)" }}
transition={
shouldReduceMotion
? { duration: 0 }
: { type: "spring", bounce: 0.35, duration: 0.55 }
}
className="overflow-hidden"
style={{ willChange: "height, opacity, filter" }}
>
<div className="pb-2 pt-3">{children}</div>
</motion.div>
)}
</AnimatePresence>
</div>
);
}

View File

@@ -1,32 +0,0 @@
import { useState } from "react";
interface UseToolAccordionOptions {
expanded?: boolean;
defaultExpanded?: boolean;
onExpandedChange?: (expanded: boolean) => void;
}
interface UseToolAccordionResult {
isExpanded: boolean;
toggle: () => void;
}
export function useToolAccordion({
expanded,
defaultExpanded = false,
onExpandedChange,
}: UseToolAccordionOptions): UseToolAccordionResult {
const [uncontrolledExpanded, setUncontrolledExpanded] =
useState(defaultExpanded);
const isControlled = typeof expanded === "boolean";
const isExpanded = isControlled ? expanded : uncontrolledExpanded;
function toggle() {
const next = !isExpanded;
if (!isControlled) setUncontrolledExpanded(next);
onExpandedChange?.(next);
}
return { isExpanded, toggle };
}

View File

@@ -0,0 +1,56 @@
"use client";
import { create } from "zustand";
interface CopilotStoreState {
isStreaming: boolean;
isSwitchingSession: boolean;
isCreatingSession: boolean;
isInterruptModalOpen: boolean;
pendingAction: (() => void) | null;
}
interface CopilotStoreActions {
setIsStreaming: (isStreaming: boolean) => void;
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
setIsCreatingSession: (isCreating: boolean) => void;
openInterruptModal: (onConfirm: () => void) => void;
confirmInterrupt: () => void;
cancelInterrupt: () => void;
}
type CopilotStore = CopilotStoreState & CopilotStoreActions;
export const useCopilotStore = create<CopilotStore>((set, get) => ({
isStreaming: false,
isSwitchingSession: false,
isCreatingSession: false,
isInterruptModalOpen: false,
pendingAction: null,
setIsStreaming(isStreaming) {
set({ isStreaming });
},
setIsSwitchingSession(isSwitchingSession) {
set({ isSwitchingSession });
},
setIsCreatingSession(isCreatingSession) {
set({ isCreatingSession });
},
openInterruptModal(onConfirm) {
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
},
confirmInterrupt() {
const { pendingAction } = get();
set({ isInterruptModalOpen: false, pendingAction: null });
if (pendingAction) pendingAction();
},
cancelInterrupt() {
set({ isInterruptModalOpen: false, pendingAction: null });
},
}));

View File

@@ -1,26 +1,6 @@
import { User } from "@supabase/supabase-js";
import type { User } from "@supabase/supabase-js";
export function getInputPlaceholder(width?: number) {
if (!width) return "What's your role and what eats up most of your day?";
if (width < 500) {
return "I'm a chef and I hate...";
}
if (width <= 1080) {
return "What's your role and what eats up most of your day?";
}
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
}
export function getQuickActions() {
return [
"I don't know where to start, just ask me stuff",
"I do the same thing every week and it's killing me",
"Help me find where I'm wasting my time",
];
}
export function getGreetingName(user?: User | null) {
export function getGreetingName(user?: User | null): string {
if (!user) return "there";
const metadata = user.user_metadata as Record<string, unknown> | undefined;
const fullName = metadata?.full_name;
@@ -36,3 +16,30 @@ export function getGreetingName(user?: User | null) {
}
return "there";
}
export function buildCopilotChatUrl(prompt: string): string {
const trimmed = prompt.trim();
if (!trimmed) return "/copilot/chat";
const encoded = encodeURIComponent(trimmed);
return `/copilot/chat?prompt=${encoded}`;
}
export function getQuickActions(): string[] {
return [
"I don't know where to start, just ask me stuff",
"I do the same thing every week and it's killing me",
"Help me find where I'm wasting my time",
];
}
export function getInputPlaceholder(width?: number) {
if (!width) return "What's your role and what eats up most of your day?";
if (width < 500) {
return "I'm a chef and I hate...";
}
if (width <= 1080) {
return "What's your role and what eats up most of your day?";
}
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
}

View File

@@ -1,128 +0,0 @@
import type { UIMessage, UIDataTypes, UITools } from "ai";
interface SessionChatMessage {
role: string;
content: string | null;
tool_call_id: string | null;
tool_calls: unknown[] | null;
}
function coerceSessionChatMessages(
rawMessages: unknown[],
): SessionChatMessage[] {
return rawMessages
.map((m) => {
if (!m || typeof m !== "object") return null;
const msg = m as Record<string, unknown>;
const role = typeof msg.role === "string" ? msg.role : null;
if (!role) return null;
return {
role,
content:
typeof msg.content === "string"
? msg.content
: msg.content == null
? null
: String(msg.content),
tool_call_id:
typeof msg.tool_call_id === "string"
? msg.tool_call_id
: msg.tool_call_id == null
? null
: String(msg.tool_call_id),
tool_calls: Array.isArray(msg.tool_calls) ? msg.tool_calls : null,
};
})
.filter((m): m is SessionChatMessage => m !== null);
}
function safeJsonParse(value: string): unknown {
try {
return JSON.parse(value) as unknown;
} catch {
return value;
}
}
function toToolInput(rawArguments: unknown): unknown {
if (typeof rawArguments === "string") {
const trimmed = rawArguments.trim();
return trimmed ? safeJsonParse(trimmed) : {};
}
if (rawArguments && typeof rawArguments === "object") return rawArguments;
return {};
}
export function convertChatSessionMessagesToUiMessages(
sessionId: string,
rawMessages: unknown[],
): UIMessage<unknown, UIDataTypes, UITools>[] {
const messages = coerceSessionChatMessages(rawMessages);
const toolOutputsByCallId = new Map<string, unknown>();
for (const msg of messages) {
if (msg.role !== "tool") continue;
if (!msg.tool_call_id) continue;
if (msg.content == null) continue;
toolOutputsByCallId.set(msg.tool_call_id, msg.content);
}
const uiMessages: UIMessage<unknown, UIDataTypes, UITools>[] = [];
messages.forEach((msg, index) => {
if (msg.role === "tool") return;
if (msg.role !== "user" && msg.role !== "assistant") return;
const parts: UIMessage<unknown, UIDataTypes, UITools>["parts"] = [];
if (typeof msg.content === "string" && msg.content.trim()) {
parts.push({ type: "text", text: msg.content, state: "done" });
}
if (msg.role === "assistant" && Array.isArray(msg.tool_calls)) {
for (const rawToolCall of msg.tool_calls) {
if (!rawToolCall || typeof rawToolCall !== "object") continue;
const toolCall = rawToolCall as {
id?: unknown;
function?: { name?: unknown; arguments?: unknown };
};
const toolCallId = String(toolCall.id ?? "").trim();
const toolName = String(toolCall.function?.name ?? "").trim();
if (!toolCallId || !toolName) continue;
const input = toToolInput(toolCall.function?.arguments);
const output = toolOutputsByCallId.get(toolCallId);
if (output !== undefined) {
parts.push({
type: `tool-${toolName}`,
toolCallId,
state: "output-available",
input,
output: typeof output === "string" ? safeJsonParse(output) : output,
});
} else {
parts.push({
type: `tool-${toolName}`,
toolCallId,
state: "input-available",
input,
});
}
}
}
if (parts.length === 0) return;
uiMessages.push({
id: `${sessionId}-${index}`,
role: msg.role,
parts,
});
});
return uiMessages;
}

View File

@@ -0,0 +1,13 @@
"use client";
import { FeatureFlagPage } from "@/services/feature-flags/FeatureFlagPage";
import { Flag } from "@/services/feature-flags/use-get-flag";
import { type ReactNode } from "react";
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
export default function CopilotLayout({ children }: { children: ReactNode }) {
return (
<FeatureFlagPage flag={Flag.CHAT} whenDisabled="/library">
<CopilotShell>{children}</CopilotShell>
</FeatureFlagPage>
);
}

View File

@@ -1,69 +1,149 @@
"use client";
import { SidebarProvider } from "@/components/ui/sidebar";
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
import { Button } from "@/components/atoms/Button/Button";
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
import { Text } from "@/components/atoms/Text/Text";
import { Chat } from "@/components/contextual/Chat/Chat";
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useEffect, useState } from "react";
import { useCopilotStore } from "./copilot-page-store";
import { getInputPlaceholder } from "./helpers";
import { useCopilotPage } from "./useCopilotPage";
export default function Page() {
export default function CopilotPage() {
const { state, handlers } = useCopilotPage();
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
const [inputPlaceholder, setInputPlaceholder] = useState(
getInputPlaceholder(),
);
useEffect(() => {
const handleResize = () => {
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
};
handleResize();
window.addEventListener("resize", handleResize);
return () => window.removeEventListener("resize", handleResize);
}, []);
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } =
state;
const {
sessionId,
messages,
status,
error,
stop,
isLoadingSession,
isCreatingSession,
createSession,
onSend,
// Mobile drawer
isMobile,
isDrawerOpen,
sessions,
isLoadingSessions,
handleOpenDrawer,
handleCloseDrawer,
handleDrawerOpenChange,
handleSelectSession,
handleNewChat,
} = useCopilotPage();
handleQuickAction,
startChatWithPrompt,
handleSessionNotFound,
handleStreamingChange,
} = handlers;
if (hasSession) {
return (
<div className="flex h-full flex-col">
<Chat
className="flex-1"
initialPrompt={initialPrompt}
onSessionNotFound={handleSessionNotFound}
onStreamingChange={handleStreamingChange}
/>
<Dialog
title="Interrupt current chat?"
styling={{ maxWidth: 300, width: "100%" }}
controlled={{
isOpen: isInterruptModalOpen,
set: (open) => {
if (!open) cancelInterrupt();
},
}}
onClose={cancelInterrupt}
>
<Dialog.Content>
<div className="flex flex-col gap-4">
<Text variant="body">
The current chat response will be interrupted. Are you sure you
want to continue?
</Text>
<Dialog.Footer>
<Button
type="button"
variant="outline"
onClick={cancelInterrupt}
>
Cancel
</Button>
<Button
type="button"
variant="primary"
onClick={confirmInterrupt}
>
Continue
</Button>
</Dialog.Footer>
</div>
</Dialog.Content>
</Dialog>
</div>
);
}
return (
<SidebarProvider
defaultOpen={true}
className="h-[calc(100vh-72px)] min-h-0"
>
{!isMobile && <ChatSidebar />}
<div className="relative flex h-full w-full flex-col overflow-hidden bg-[#f8f8f9] px-0">
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
<div className="flex-1 overflow-hidden">
<ChatContainer
messages={messages}
status={status}
error={error}
sessionId={sessionId}
isLoadingSession={isLoadingSession}
isCreatingSession={isCreatingSession}
onCreateSession={createSession}
onSend={onSend}
onStop={stop}
/>
</div>
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-3 py-5 md:px-6 md:py-10">
<div className="w-full text-center">
{isLoading ? (
<div className="mx-auto max-w-2xl">
<Skeleton className="mx-auto mb-3 h-8 w-64" />
<Skeleton className="mx-auto mb-8 h-6 w-80" />
<div className="mb-8">
<Skeleton className="mx-auto h-14 w-full rounded-lg" />
</div>
<div className="flex flex-wrap items-center justify-center gap-3">
{Array.from({ length: 4 }).map((_, i) => (
<Skeleton key={i} className="h-9 w-48 rounded-md" />
))}
</div>
</div>
) : (
<>
<div className="mx-auto max-w-3xl">
<Text
variant="h3"
className="mb-1 !text-[1.375rem] text-zinc-700"
>
Hey, <span className="text-violet-600">{greetingName}</span>
</Text>
<Text variant="h3" className="mb-8 !font-normal">
Tell me about your work I&apos;ll find what to automate.
</Text>
<div className="mb-6">
<ChatInput
onSend={startChatWithPrompt}
placeholder={inputPlaceholder}
/>
</div>
</div>
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
{quickActions.map((action) => (
<Button
key={action}
type="button"
variant="outline"
size="small"
onClick={() => handleQuickAction(action)}
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
>
{action}
</Button>
))}
</div>
</>
)}
</div>
{isMobile && (
<MobileDrawer
isOpen={isDrawerOpen}
sessions={sessions}
currentSessionId={sessionId}
isLoading={isLoadingSessions}
onSelectSession={handleSelectSession}
onNewChat={handleNewChat}
onClose={handleCloseDrawer}
onOpenChange={handleDrawerOpenChange}
/>
)}
</SidebarProvider>
</div>
);
}

View File

@@ -1,237 +0,0 @@
"use client";
import { WarningDiamondIcon } from "@phosphor-icons/react";
import type { ToolUIPart } from "ai";
import { useCopilotChatActions } from "../../components/CopilotChatActionsProvider/useCopilotChatActions";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
import { ProgressBar } from "../../components/ProgressBar/ProgressBar";
import {
ContentCardDescription,
ContentCodeBlock,
ContentGrid,
ContentHint,
ContentLink,
ContentMessage,
} from "../../components/ToolAccordion/AccordionContent";
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
import { useAsymptoticProgress } from "../../hooks/useAsymptoticProgress";
import {
ClarificationQuestionsCard,
ClarifyingQuestion,
} from "./components/ClarificationQuestionsCard";
import {
AccordionIcon,
formatMaybeJson,
getAnimationText,
getCreateAgentToolOutput,
isAgentPreviewOutput,
isAgentSavedOutput,
isClarificationNeededOutput,
isErrorOutput,
isOperationInProgressOutput,
isOperationPendingOutput,
isOperationStartedOutput,
ToolIcon,
truncateText,
type CreateAgentToolOutput,
} from "./helpers";
export interface CreateAgentToolPart {
type: string;
toolCallId: string;
state: ToolUIPart["state"];
input?: unknown;
output?: unknown;
}
interface Props {
part: CreateAgentToolPart;
}
function getAccordionMeta(output: CreateAgentToolOutput): {
icon: React.ReactNode;
title: React.ReactNode;
titleClassName?: string;
description?: string;
} {
const icon = <AccordionIcon />;
if (isAgentSavedOutput(output)) {
return { icon, title: output.agent_name };
}
if (isAgentPreviewOutput(output)) {
return {
icon,
title: output.agent_name,
description: `${output.node_count} block${output.node_count === 1 ? "" : "s"}`,
};
}
if (isClarificationNeededOutput(output)) {
const questions = output.questions ?? [];
return {
icon,
title: "Needs clarification",
description: `${questions.length} question${questions.length === 1 ? "" : "s"}`,
};
}
if (
isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output)
) {
return {
icon: <OrbitLoader size={32} />,
title: "Creating agent, this may take a few minutes. Sit back and relax.",
};
}
return {
icon: (
<WarningDiamondIcon size={32} weight="light" className="text-red-500" />
),
title: "Error",
titleClassName: "text-red-500",
};
}
export function CreateAgentTool({ part }: Props) {
const text = getAnimationText(part);
const { onSend } = useCopilotChatActions();
const isStreaming =
part.state === "input-streaming" || part.state === "input-available";
const output = getCreateAgentToolOutput(part);
const isError =
part.state === "output-error" || (!!output && isErrorOutput(output));
const isOperating =
!!output &&
(isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output));
const progress = useAsymptoticProgress(isOperating);
const hasExpandableContent =
part.state === "output-available" &&
!!output &&
(isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output) ||
isAgentPreviewOutput(output) ||
isAgentSavedOutput(output) ||
isClarificationNeededOutput(output) ||
isErrorOutput(output));
function handleClarificationAnswers(answers: Record<string, string>) {
const questions =
output && isClarificationNeededOutput(output)
? (output.questions ?? [])
: [];
const contextMessage = questions
.map((q) => {
const answer = answers[q.keyword] || "";
return `> ${q.question}\n\n${answer}`;
})
.join("\n\n");
onSend(
`**Here are my answers:**\n\n${contextMessage}\n\nPlease proceed with creating the agent.`,
);
}
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
{hasExpandableContent && output && (
<ToolAccordion
{...getAccordionMeta(output)}
defaultExpanded={isOperating || isClarificationNeededOutput(output)}
>
{isOperating && (
<ContentGrid>
<ProgressBar value={progress} className="max-w-[280px]" />
<ContentHint>
This could take a few minutes, grab a coffee
</ContentHint>
</ContentGrid>
)}
{isAgentSavedOutput(output) && (
<ContentGrid>
<ContentMessage>{output.message}</ContentMessage>
<div className="flex flex-wrap gap-2">
<ContentLink href={output.library_agent_link}>
Open in library
</ContentLink>
<ContentLink href={output.agent_page_link}>
Open in builder
</ContentLink>
</div>
<ContentCodeBlock>
{truncateText(
formatMaybeJson({ agent_id: output.agent_id }),
800,
)}
</ContentCodeBlock>
</ContentGrid>
)}
{isAgentPreviewOutput(output) && (
<ContentGrid>
<ContentMessage>{output.message}</ContentMessage>
{output.description?.trim() && (
<ContentCardDescription>
{output.description}
</ContentCardDescription>
)}
<ContentCodeBlock>
{truncateText(formatMaybeJson(output.agent_json), 1600)}
</ContentCodeBlock>
</ContentGrid>
)}
{isClarificationNeededOutput(output) && (
<ClarificationQuestionsCard
questions={(output.questions ?? []).map((q) => {
const item: ClarifyingQuestion = {
question: q.question,
keyword: q.keyword,
};
const example =
typeof q.example === "string" && q.example.trim()
? q.example.trim()
: null;
if (example) item.example = example;
return item;
})}
message={output.message}
onSubmitAnswers={handleClarificationAnswers}
/>
)}
{isErrorOutput(output) && (
<ContentGrid>
<ContentMessage>{output.message}</ContentMessage>
{output.error && (
<ContentCodeBlock>
{formatMaybeJson(output.error)}
</ContentCodeBlock>
)}
{output.details && (
<ContentCodeBlock>
{formatMaybeJson(output.details)}
</ContentCodeBlock>
)}
</ContentGrid>
)}
</ToolAccordion>
)}
</div>
);
}

View File

@@ -1,186 +0,0 @@
import type { AgentPreviewResponse } from "@/app/api/__generated__/models/agentPreviewResponse";
import type { AgentSavedResponse } from "@/app/api/__generated__/models/agentSavedResponse";
import type { ClarificationNeededResponse } from "@/app/api/__generated__/models/clarificationNeededResponse";
import type { ErrorResponse } from "@/app/api/__generated__/models/errorResponse";
import type { OperationInProgressResponse } from "@/app/api/__generated__/models/operationInProgressResponse";
import type { OperationPendingResponse } from "@/app/api/__generated__/models/operationPendingResponse";
import type { OperationStartedResponse } from "@/app/api/__generated__/models/operationStartedResponse";
import { ResponseType } from "@/app/api/__generated__/models/responseType";
import {
PlusCircleIcon,
PlusIcon,
WarningDiamondIcon,
} from "@phosphor-icons/react";
import type { ToolUIPart } from "ai";
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
export type CreateAgentToolOutput =
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
| AgentPreviewResponse
| AgentSavedResponse
| ClarificationNeededResponse
| ErrorResponse;
function parseOutput(output: unknown): CreateAgentToolOutput | null {
if (!output) return null;
if (typeof output === "string") {
const trimmed = output.trim();
if (!trimmed) return null;
try {
return parseOutput(JSON.parse(trimmed) as unknown);
} catch {
return null;
}
}
if (typeof output === "object") {
const type = (output as { type?: unknown }).type;
if (
type === ResponseType.operation_started ||
type === ResponseType.operation_pending ||
type === ResponseType.operation_in_progress ||
type === ResponseType.agent_preview ||
type === ResponseType.agent_saved ||
type === ResponseType.clarification_needed ||
type === ResponseType.error
) {
return output as CreateAgentToolOutput;
}
if ("operation_id" in output && "tool_name" in output)
return output as OperationStartedResponse | OperationPendingResponse;
if ("tool_call_id" in output) return output as OperationInProgressResponse;
if ("agent_json" in output && "agent_name" in output)
return output as AgentPreviewResponse;
if ("agent_id" in output && "library_agent_id" in output)
return output as AgentSavedResponse;
if ("questions" in output) return output as ClarificationNeededResponse;
if ("error" in output || "details" in output)
return output as ErrorResponse;
}
return null;
}
export function getCreateAgentToolOutput(
part: unknown,
): CreateAgentToolOutput | null {
if (!part || typeof part !== "object") return null;
return parseOutput((part as { output?: unknown }).output);
}
export function isOperationStartedOutput(
output: CreateAgentToolOutput,
): output is OperationStartedResponse {
return (
output.type === ResponseType.operation_started ||
("operation_id" in output && "tool_name" in output)
);
}
export function isOperationPendingOutput(
output: CreateAgentToolOutput,
): output is OperationPendingResponse {
return output.type === ResponseType.operation_pending;
}
export function isOperationInProgressOutput(
output: CreateAgentToolOutput,
): output is OperationInProgressResponse {
return (
output.type === ResponseType.operation_in_progress ||
"tool_call_id" in output
);
}
export function isAgentPreviewOutput(
output: CreateAgentToolOutput,
): output is AgentPreviewResponse {
return output.type === ResponseType.agent_preview || "agent_json" in output;
}
export function isAgentSavedOutput(
output: CreateAgentToolOutput,
): output is AgentSavedResponse {
return (
output.type === ResponseType.agent_saved || "agent_page_link" in output
);
}
export function isClarificationNeededOutput(
output: CreateAgentToolOutput,
): output is ClarificationNeededResponse {
return (
output.type === ResponseType.clarification_needed || "questions" in output
);
}
export function isErrorOutput(
output: CreateAgentToolOutput,
): output is ErrorResponse {
return output.type === ResponseType.error || "error" in output;
}
export function getAnimationText(part: {
state: ToolUIPart["state"];
input?: unknown;
output?: unknown;
}): string {
switch (part.state) {
case "input-streaming":
case "input-available":
return "Creating a new agent";
case "output-available": {
const output = parseOutput(part.output);
if (!output) return "Creating a new agent";
if (isOperationStartedOutput(output)) return "Agent creation started";
if (isOperationPendingOutput(output)) return "Agent creation in progress";
if (isOperationInProgressOutput(output))
return "Agent creation already in progress";
if (isAgentSavedOutput(output)) return `Saved "${output.agent_name}"`;
if (isAgentPreviewOutput(output)) return `Preview "${output.agent_name}"`;
if (isClarificationNeededOutput(output)) return "Needs clarification";
return "Error creating agent";
}
case "output-error":
return "Error creating agent";
default:
return "Creating a new agent";
}
}
export function ToolIcon({
isStreaming,
isError,
}: {
isStreaming?: boolean;
isError?: boolean;
}) {
if (isError) {
return (
<WarningDiamondIcon size={14} weight="regular" className="text-red-500" />
);
}
if (isStreaming) {
return <OrbitLoader size={24} />;
}
return <PlusIcon size={14} weight="regular" className="text-neutral-400" />;
}
export function AccordionIcon() {
return <PlusCircleIcon size={32} weight="light" />;
}
export function formatMaybeJson(value: unknown): string {
if (typeof value === "string") return value;
try {
return JSON.stringify(value, null, 2);
} catch {
return String(value);
}
}
export function truncateText(text: string, maxChars: number): string {
const trimmed = text.trim();
if (trimmed.length <= maxChars) return trimmed;
return `${trimmed.slice(0, maxChars).trimEnd()}`;
}

View File

@@ -1,234 +0,0 @@
"use client";
import { WarningDiamondIcon } from "@phosphor-icons/react";
import type { ToolUIPart } from "ai";
import { useCopilotChatActions } from "../../components/CopilotChatActionsProvider/useCopilotChatActions";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
import { ProgressBar } from "../../components/ProgressBar/ProgressBar";
import {
ContentCardDescription,
ContentCodeBlock,
ContentGrid,
ContentHint,
ContentLink,
ContentMessage,
} from "../../components/ToolAccordion/AccordionContent";
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
import { useAsymptoticProgress } from "../../hooks/useAsymptoticProgress";
import {
ClarificationQuestionsCard,
ClarifyingQuestion,
} from "../CreateAgent/components/ClarificationQuestionsCard";
import {
AccordionIcon,
formatMaybeJson,
getAnimationText,
getEditAgentToolOutput,
isAgentPreviewOutput,
isAgentSavedOutput,
isClarificationNeededOutput,
isErrorOutput,
isOperationInProgressOutput,
isOperationPendingOutput,
isOperationStartedOutput,
ToolIcon,
truncateText,
type EditAgentToolOutput,
} from "./helpers";
export interface EditAgentToolPart {
type: string;
toolCallId: string;
state: ToolUIPart["state"];
input?: unknown;
output?: unknown;
}
interface Props {
part: EditAgentToolPart;
}
function getAccordionMeta(output: EditAgentToolOutput): {
icon: React.ReactNode;
title: string;
titleClassName?: string;
description?: string;
} {
const icon = <AccordionIcon />;
if (isAgentSavedOutput(output)) {
return { icon, title: output.agent_name };
}
if (isAgentPreviewOutput(output)) {
return {
icon,
title: output.agent_name,
description: `${output.node_count} block${output.node_count === 1 ? "" : "s"}`,
};
}
if (isClarificationNeededOutput(output)) {
const questions = output.questions ?? [];
return {
icon,
title: "Needs clarification",
description: `${questions.length} question${questions.length === 1 ? "" : "s"}`,
};
}
if (
isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output)
) {
return { icon: <OrbitLoader size={32} />, title: "Editing agent" };
}
return {
icon: (
<WarningDiamondIcon size={32} weight="light" className="text-red-500" />
),
title: "Error",
titleClassName: "text-red-500",
};
}
export function EditAgentTool({ part }: Props) {
const text = getAnimationText(part);
const { onSend } = useCopilotChatActions();
const isStreaming =
part.state === "input-streaming" || part.state === "input-available";
const output = getEditAgentToolOutput(part);
const isError =
part.state === "output-error" || (!!output && isErrorOutput(output));
const isOperating =
!!output &&
(isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output));
const progress = useAsymptoticProgress(isOperating);
const hasExpandableContent =
part.state === "output-available" &&
!!output &&
(isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output) ||
isAgentPreviewOutput(output) ||
isAgentSavedOutput(output) ||
isClarificationNeededOutput(output) ||
isErrorOutput(output));
function handleClarificationAnswers(answers: Record<string, string>) {
const questions =
output && isClarificationNeededOutput(output)
? (output.questions ?? [])
: [];
const contextMessage = questions
.map((q) => {
const answer = answers[q.keyword] || "";
return `> ${q.question}\n\n${answer}`;
})
.join("\n\n");
onSend(
`**Here are my answers:**\n\n${contextMessage}\n\nPlease proceed with editing the agent.`,
);
}
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
{hasExpandableContent && output && (
<ToolAccordion
{...getAccordionMeta(output)}
defaultExpanded={isOperating || isClarificationNeededOutput(output)}
>
{isOperating && (
<ContentGrid>
<ProgressBar value={progress} className="max-w-[280px]" />
<ContentHint>
This could take a few minutes, grab a coffee
</ContentHint>
</ContentGrid>
)}
{isAgentSavedOutput(output) && (
<ContentGrid>
<ContentMessage>{output.message}</ContentMessage>
<div className="flex flex-wrap gap-2">
<ContentLink href={output.library_agent_link}>
Open in library
</ContentLink>
<ContentLink href={output.agent_page_link}>
Open in builder
</ContentLink>
</div>
<ContentCodeBlock>
{truncateText(
formatMaybeJson({ agent_id: output.agent_id }),
800,
)}
</ContentCodeBlock>
</ContentGrid>
)}
{isAgentPreviewOutput(output) && (
<ContentGrid>
<ContentMessage>{output.message}</ContentMessage>
{output.description?.trim() && (
<ContentCardDescription>
{output.description}
</ContentCardDescription>
)}
<ContentCodeBlock>
{truncateText(formatMaybeJson(output.agent_json), 1600)}
</ContentCodeBlock>
</ContentGrid>
)}
{isClarificationNeededOutput(output) && (
<ClarificationQuestionsCard
questions={(output.questions ?? []).map((q) => {
const item: ClarifyingQuestion = {
question: q.question,
keyword: q.keyword,
};
const example =
typeof q.example === "string" && q.example.trim()
? q.example.trim()
: null;
if (example) item.example = example;
return item;
})}
message={output.message}
onSubmitAnswers={handleClarificationAnswers}
/>
)}
{isErrorOutput(output) && (
<ContentGrid>
<ContentMessage>{output.message}</ContentMessage>
{output.error && (
<ContentCodeBlock>
{formatMaybeJson(output.error)}
</ContentCodeBlock>
)}
{output.details && (
<ContentCodeBlock>
{formatMaybeJson(output.details)}
</ContentCodeBlock>
)}
</ContentGrid>
)}
</ToolAccordion>
)}
</div>
);
}

View File

@@ -1,188 +0,0 @@
import type { AgentPreviewResponse } from "@/app/api/__generated__/models/agentPreviewResponse";
import type { AgentSavedResponse } from "@/app/api/__generated__/models/agentSavedResponse";
import type { ClarificationNeededResponse } from "@/app/api/__generated__/models/clarificationNeededResponse";
import type { ErrorResponse } from "@/app/api/__generated__/models/errorResponse";
import type { OperationInProgressResponse } from "@/app/api/__generated__/models/operationInProgressResponse";
import type { OperationPendingResponse } from "@/app/api/__generated__/models/operationPendingResponse";
import type { OperationStartedResponse } from "@/app/api/__generated__/models/operationStartedResponse";
import { ResponseType } from "@/app/api/__generated__/models/responseType";
import {
NotePencilIcon,
PencilLineIcon,
WarningDiamondIcon,
} from "@phosphor-icons/react";
import type { ToolUIPart } from "ai";
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
export type EditAgentToolOutput =
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
| AgentPreviewResponse
| AgentSavedResponse
| ClarificationNeededResponse
| ErrorResponse;
function parseOutput(output: unknown): EditAgentToolOutput | null {
if (!output) return null;
if (typeof output === "string") {
const trimmed = output.trim();
if (!trimmed) return null;
try {
return parseOutput(JSON.parse(trimmed) as unknown);
} catch {
return null;
}
}
if (typeof output === "object") {
const type = (output as { type?: unknown }).type;
if (
type === ResponseType.operation_started ||
type === ResponseType.operation_pending ||
type === ResponseType.operation_in_progress ||
type === ResponseType.agent_preview ||
type === ResponseType.agent_saved ||
type === ResponseType.clarification_needed ||
type === ResponseType.error
) {
return output as EditAgentToolOutput;
}
if ("operation_id" in output && "tool_name" in output)
return output as OperationStartedResponse | OperationPendingResponse;
if ("tool_call_id" in output) return output as OperationInProgressResponse;
if ("agent_json" in output && "agent_name" in output)
return output as AgentPreviewResponse;
if ("agent_id" in output && "library_agent_id" in output)
return output as AgentSavedResponse;
if ("questions" in output) return output as ClarificationNeededResponse;
if ("error" in output || "details" in output)
return output as ErrorResponse;
}
return null;
}
export function getEditAgentToolOutput(
part: unknown,
): EditAgentToolOutput | null {
if (!part || typeof part !== "object") return null;
return parseOutput((part as { output?: unknown }).output);
}
export function isOperationStartedOutput(
output: EditAgentToolOutput,
): output is OperationStartedResponse {
return (
output.type === ResponseType.operation_started ||
("operation_id" in output && "tool_name" in output)
);
}
export function isOperationPendingOutput(
output: EditAgentToolOutput,
): output is OperationPendingResponse {
return output.type === ResponseType.operation_pending;
}
export function isOperationInProgressOutput(
output: EditAgentToolOutput,
): output is OperationInProgressResponse {
return (
output.type === ResponseType.operation_in_progress ||
"tool_call_id" in output
);
}
export function isAgentPreviewOutput(
output: EditAgentToolOutput,
): output is AgentPreviewResponse {
return output.type === ResponseType.agent_preview || "agent_json" in output;
}
export function isAgentSavedOutput(
output: EditAgentToolOutput,
): output is AgentSavedResponse {
return (
output.type === ResponseType.agent_saved || "agent_page_link" in output
);
}
export function isClarificationNeededOutput(
output: EditAgentToolOutput,
): output is ClarificationNeededResponse {
return (
output.type === ResponseType.clarification_needed || "questions" in output
);
}
export function isErrorOutput(
output: EditAgentToolOutput,
): output is ErrorResponse {
return output.type === ResponseType.error || "error" in output;
}
export function getAnimationText(part: {
state: ToolUIPart["state"];
input?: unknown;
output?: unknown;
}): string {
switch (part.state) {
case "input-streaming":
case "input-available":
return "Editing the agent";
case "output-available": {
const output = parseOutput(part.output);
if (!output) return "Editing the agent";
if (isOperationStartedOutput(output)) return "Agent update started";
if (isOperationPendingOutput(output)) return "Agent update in progress";
if (isOperationInProgressOutput(output))
return "Agent update already in progress";
if (isAgentSavedOutput(output)) return `Saved "${output.agent_name}"`;
if (isAgentPreviewOutput(output)) return `Preview "${output.agent_name}"`;
if (isClarificationNeededOutput(output)) return "Needs clarification";
return "Error editing agent";
}
case "output-error":
return "Error editing agent";
default:
return "Editing the agent";
}
}
export function ToolIcon({
isStreaming,
isError,
}: {
isStreaming?: boolean;
isError?: boolean;
}) {
if (isError) {
return (
<WarningDiamondIcon size={14} weight="regular" className="text-red-500" />
);
}
if (isStreaming) {
return <OrbitLoader size={24} />;
}
return (
<PencilLineIcon size={14} weight="regular" className="text-neutral-400" />
);
}
export function AccordionIcon() {
return <NotePencilIcon size={32} weight="light" />;
}
export function formatMaybeJson(value: unknown): string {
if (typeof value === "string") return value;
try {
return JSON.stringify(value, null, 2);
} catch {
return String(value);
}
}
export function truncateText(text: string, maxChars: number): string {
const trimmed = text.trim();
if (trimmed.length <= maxChars) return trimmed;
return `${trimmed.slice(0, maxChars).trimEnd()}`;
}

View File

@@ -1,127 +0,0 @@
"use client";
import { ToolUIPart } from "ai";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
import {
ContentBadge,
ContentCard,
ContentCardDescription,
ContentCardHeader,
ContentCardTitle,
ContentGrid,
ContentLink,
} from "../../components/ToolAccordion/AccordionContent";
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
import {
AccordionIcon,
getAgentHref,
getAnimationText,
getFindAgentsOutput,
getSourceLabelFromToolType,
isAgentsFoundOutput,
isErrorOutput,
ToolIcon,
} from "./helpers";
export interface FindAgentsToolPart {
type: string;
toolCallId: string;
state: ToolUIPart["state"];
input?: unknown;
output?: unknown;
}
interface Props {
part: FindAgentsToolPart;
}
export function FindAgentsTool({ part }: Props) {
const text = getAnimationText(part);
const output = getFindAgentsOutput(part);
const isStreaming =
part.state === "input-streaming" || part.state === "input-available";
const isError =
part.state === "output-error" || (!!output && isErrorOutput(output));
const query =
typeof part.input === "object" && part.input !== null
? String((part.input as { query?: unknown }).query ?? "").trim()
: "";
const agentsFoundOutput =
part.state === "output-available" && output && isAgentsFoundOutput(output)
? output
: null;
const hasAgents =
!!agentsFoundOutput &&
agentsFoundOutput.agents.length > 0 &&
(typeof agentsFoundOutput.count !== "number" ||
agentsFoundOutput.count > 0);
const totalCount = agentsFoundOutput ? agentsFoundOutput.count : 0;
const { source } = getSourceLabelFromToolType(part.type);
const scopeText =
source === "library"
? "in your library"
: source === "marketplace"
? "in marketplace"
: "";
const accordionDescription = `Found ${totalCount}${scopeText ? ` ${scopeText}` : ""}${
query ? ` for "${query}"` : ""
}`;
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon
toolType={part.type}
isStreaming={isStreaming}
isError={isError}
/>
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
{hasAgents && agentsFoundOutput && (
<ToolAccordion
icon={<AccordionIcon toolType={part.type} />}
title="Agent results"
description={accordionDescription}
>
<ContentGrid className="sm:grid-cols-2">
{agentsFoundOutput.agents.map((agent) => {
const href = getAgentHref(agent);
const agentSource =
agent.source === "library"
? "Library"
: agent.source === "marketplace"
? "Marketplace"
: null;
return (
<ContentCard key={agent.id}>
<ContentCardHeader
action={
href ? <ContentLink href={href}>Open</ContentLink> : null
}
>
<div className="flex items-center gap-2">
<ContentCardTitle>{agent.name}</ContentCardTitle>
{agentSource && (
<ContentBadge>{agentSource}</ContentBadge>
)}
</div>
<ContentCardDescription className="mt-1 line-clamp-2">
{agent.description}
</ContentCardDescription>
</ContentCardHeader>
</ContentCard>
);
})}
</ContentGrid>
</ToolAccordion>
)}
</div>
);
}

View File

@@ -1,187 +0,0 @@
import type { AgentInfo } from "@/app/api/__generated__/models/agentInfo";
import type { AgentsFoundResponse } from "@/app/api/__generated__/models/agentsFoundResponse";
import type { ErrorResponse } from "@/app/api/__generated__/models/errorResponse";
import type { NoResultsResponse } from "@/app/api/__generated__/models/noResultsResponse";
import { ResponseType } from "@/app/api/__generated__/models/responseType";
import {
FolderOpenIcon,
MagnifyingGlassIcon,
SquaresFourIcon,
StorefrontIcon,
} from "@phosphor-icons/react";
import { ToolUIPart } from "ai";
export interface FindAgentInput {
query: string;
}
export type FindAgentsOutput =
| AgentsFoundResponse
| NoResultsResponse
| ErrorResponse;
export type FindAgentsToolType =
| "tool-find_agent"
| "tool-find_library_agent"
| (string & {});
function parseOutput(output: unknown): FindAgentsOutput | null {
if (!output) return null;
if (typeof output === "string") {
const trimmed = output.trim();
if (!trimmed) return null;
try {
return parseOutput(JSON.parse(trimmed) as unknown);
} catch {
return null;
}
}
if (typeof output === "object") {
const type = (output as { type?: unknown }).type;
if (
type === ResponseType.agents_found ||
type === ResponseType.no_results ||
type === ResponseType.error
) {
return output as FindAgentsOutput;
}
if ("agents" in output && "count" in output)
return output as AgentsFoundResponse;
if ("suggestions" in output && !("error" in output))
return output as NoResultsResponse;
if ("error" in output || "details" in output)
return output as ErrorResponse;
}
return null;
}
export function getFindAgentsOutput(part: unknown): FindAgentsOutput | null {
if (!part || typeof part !== "object") return null;
return parseOutput((part as { output?: unknown }).output);
}
export function isAgentsFoundOutput(
output: FindAgentsOutput,
): output is AgentsFoundResponse {
return output.type === ResponseType.agents_found || "agents" in output;
}
export function isNoResultsOutput(
output: FindAgentsOutput,
): output is NoResultsResponse {
return (
output.type === ResponseType.no_results ||
("suggestions" in output && !("error" in output))
);
}
export function isErrorOutput(
output: FindAgentsOutput,
): output is ErrorResponse {
return output.type === ResponseType.error || "error" in output;
}
export function getSourceLabelFromToolType(toolType?: FindAgentsToolType): {
source: "marketplace" | "library" | "unknown";
label: string;
} {
if (toolType === "tool-find_library_agent") {
return { source: "library", label: "Library" };
}
if (toolType === "tool-find_agent") {
return { source: "marketplace", label: "Marketplace" };
}
return { source: "unknown", label: "Agents" };
}
export function getAnimationText(part: {
type?: FindAgentsToolType;
state: ToolUIPart["state"];
input?: unknown;
output?: unknown;
}): string {
const { source } = getSourceLabelFromToolType(part.type);
const query = (part.input as FindAgentInput | undefined)?.query?.trim();
// Action phrase matching legacy ToolCallMessage
const actionPhrase =
source === "library"
? "Looking for library agents"
: "Looking for agents in the marketplace";
const queryText = query ? ` matching "${query}"` : "";
switch (part.state) {
case "input-streaming":
case "input-available":
return `${actionPhrase}${queryText}`;
case "output-available": {
const output = parseOutput(part.output);
if (!output) {
return `${actionPhrase}${queryText}`;
}
if (isNoResultsOutput(output)) {
return `No agents found${queryText}`;
}
if (isAgentsFoundOutput(output)) {
const count = output.count ?? output.agents?.length ?? 0;
return `Found ${count} agent${count === 1 ? "" : "s"}${queryText}`;
}
if (isErrorOutput(output)) {
return `Error finding agents${queryText}`;
}
return `${actionPhrase}${queryText}`;
}
case "output-error":
return `Error finding agents${queryText}`;
default:
return actionPhrase;
}
}
export function getAgentHref(agent: AgentInfo): string | null {
if (agent.source === "library") {
return `/library/agents/${encodeURIComponent(agent.id)}`;
}
const [creator, slug, ...rest] = agent.id.split("/");
if (!creator || !slug || rest.length > 0) return null;
return `/marketplace/agent/${encodeURIComponent(creator)}/${encodeURIComponent(slug)}`;
}
export function ToolIcon({
toolType,
isStreaming,
isError,
}: {
toolType?: FindAgentsToolType;
isStreaming?: boolean;
isError?: boolean;
}) {
const { source } = getSourceLabelFromToolType(toolType);
const IconComponent =
source === "library" ? MagnifyingGlassIcon : SquaresFourIcon;
return (
<IconComponent
size={14}
weight="regular"
className={
isError
? "text-red-500"
: isStreaming
? "text-neutral-500"
: "text-neutral-400"
}
/>
);
}
export function AccordionIcon({ toolType }: { toolType?: FindAgentsToolType }) {
const { source } = getSourceLabelFromToolType(toolType);
const IconComponent = source === "library" ? FolderOpenIcon : StorefrontIcon;
return <IconComponent size={32} weight="light" />;
}

View File

@@ -1,92 +0,0 @@
"use client";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
import {
ContentCard,
ContentCardDescription,
ContentCardTitle,
} from "../../components/ToolAccordion/AccordionContent";
import type { BlockListResponse } from "@/app/api/__generated__/models/blockListResponse";
import type { BlockInfoSummary } from "@/app/api/__generated__/models/blockInfoSummary";
import { ToolUIPart } from "ai";
import { HorizontalScroll } from "@/app/(platform)/build/components/NewControlPanel/NewBlockMenu/HorizontalScroll";
import {
AccordionIcon,
getAnimationText,
parseOutput,
ToolIcon,
} from "./helpers";
export interface FindBlockInput {
query: string;
}
export type FindBlockOutput = BlockListResponse;
export interface FindBlockToolPart {
type: string;
toolName?: string;
toolCallId: string;
state: ToolUIPart["state"];
input?: FindBlockInput | unknown;
output?: string | FindBlockOutput | unknown;
title?: string;
}
interface Props {
part: FindBlockToolPart;
}
function BlockCard({ block }: { block: BlockInfoSummary }) {
return (
<ContentCard className="w-48 shrink-0">
<ContentCardTitle>{block.name}</ContentCardTitle>
<ContentCardDescription className="mt-1 line-clamp-2">
{block.description}
</ContentCardDescription>
</ContentCard>
);
}
export function FindBlocksTool({ part }: Props) {
const text = getAnimationText(part);
const isStreaming =
part.state === "input-streaming" || part.state === "input-available";
const isError = part.state === "output-error";
const parsed =
part.state === "output-available" ? parseOutput(part.output) : null;
const hasBlocks = !!parsed && parsed.blocks.length > 0;
const query = (part.input as FindBlockInput | undefined)?.query?.trim();
const accordionDescription = parsed
? `Found ${parsed.count} block${parsed.count === 1 ? "" : "s"}${query ? ` for "${query}"` : ""}`
: undefined;
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
{hasBlocks && parsed && (
<ToolAccordion
icon={<AccordionIcon />}
title="Block results"
description={accordionDescription}
>
<HorizontalScroll dependencyList={[parsed.blocks.length]}>
{parsed.blocks.map((block) => (
<BlockCard key={block.id} block={block} />
))}
</HorizontalScroll>
</ToolAccordion>
)}
</div>
);
}

View File

@@ -1,75 +0,0 @@
import type { BlockListResponse } from "@/app/api/__generated__/models/blockListResponse";
import { ResponseType } from "@/app/api/__generated__/models/responseType";
import { CubeIcon, PackageIcon } from "@phosphor-icons/react";
import { FindBlockInput, FindBlockToolPart } from "./FindBlocks";
export function parseOutput(output: unknown): BlockListResponse | null {
if (!output) return null;
if (typeof output === "string") {
const trimmed = output.trim();
if (!trimmed) return null;
try {
return parseOutput(JSON.parse(trimmed) as unknown);
} catch {
return null;
}
}
if (typeof output === "object") {
const type = (output as { type?: unknown }).type;
if (type === ResponseType.block_list || "blocks" in output) {
return output as BlockListResponse;
}
}
return null;
}
export function getAnimationText(part: FindBlockToolPart): string {
const query = (part.input as FindBlockInput | undefined)?.query?.trim();
const queryText = query ? ` matching "${query}"` : "";
switch (part.state) {
case "input-streaming":
case "input-available":
return `Searching for blocks${queryText}`;
case "output-available": {
const parsed = parseOutput(part.output);
if (parsed) {
return `Found ${parsed.count} block${parsed.count === 1 ? "" : "s"}${queryText}`;
}
return `Searching for blocks${queryText}`;
}
case "output-error":
return `Error finding blocks${queryText}`;
default:
return "Searching for blocks";
}
}
export function ToolIcon({
isStreaming,
isError,
}: {
isStreaming?: boolean;
isError?: boolean;
}) {
return (
<PackageIcon
size={14}
weight="regular"
className={
isError
? "text-red-500"
: isStreaming
? "text-neutral-500"
: "text-neutral-400"
}
/>
);
}
export function AccordionIcon() {
return <CubeIcon size={32} weight="light" />;
}

View File

@@ -1,93 +0,0 @@
"use client";
import type { ToolUIPart } from "ai";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
import { ContentMessage } from "../../components/ToolAccordion/AccordionContent";
import {
getAccordionMeta,
getAnimationText,
getRunAgentToolOutput,
isRunAgentAgentDetailsOutput,
isRunAgentErrorOutput,
isRunAgentExecutionStartedOutput,
isRunAgentNeedLoginOutput,
isRunAgentSetupRequirementsOutput,
ToolIcon,
} from "./helpers";
import { ExecutionStartedCard } from "./components/ExecutionStartedCard/ExecutionStartedCard";
import { AgentDetailsCard } from "./components/AgentDetailsCard/AgentDetailsCard";
import { SetupRequirementsCard } from "./components/SetupRequirementsCard/SetupRequirementsCard";
import { ErrorCard } from "./components/ErrorCard/ErrorCard";
export interface RunAgentToolPart {
type: string;
toolCallId: string;
state: ToolUIPart["state"];
input?: unknown;
output?: unknown;
}
interface Props {
part: RunAgentToolPart;
}
export function RunAgentTool({ part }: Props) {
const text = getAnimationText(part);
const isStreaming =
part.state === "input-streaming" || part.state === "input-available";
const output = getRunAgentToolOutput(part);
const isError =
part.state === "output-error" ||
(!!output && isRunAgentErrorOutput(output));
const hasExpandableContent =
part.state === "output-available" &&
!!output &&
(isRunAgentExecutionStartedOutput(output) ||
isRunAgentAgentDetailsOutput(output) ||
isRunAgentSetupRequirementsOutput(output) ||
isRunAgentNeedLoginOutput(output) ||
isRunAgentErrorOutput(output));
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<ToolIcon isStreaming={isStreaming} isError={isError} />
<MorphingTextAnimation
text={text}
className={isError ? "text-red-500" : undefined}
/>
</div>
{hasExpandableContent && output && (
<ToolAccordion
{...getAccordionMeta(output)}
defaultExpanded={
isRunAgentExecutionStartedOutput(output) ||
isRunAgentSetupRequirementsOutput(output) ||
isRunAgentAgentDetailsOutput(output)
}
>
{isRunAgentExecutionStartedOutput(output) && (
<ExecutionStartedCard output={output} />
)}
{isRunAgentAgentDetailsOutput(output) && (
<AgentDetailsCard output={output} />
)}
{isRunAgentSetupRequirementsOutput(output) && (
<SetupRequirementsCard output={output} />
)}
{isRunAgentNeedLoginOutput(output) && (
<ContentMessage>{output.message}</ContentMessage>
)}
{isRunAgentErrorOutput(output) && <ErrorCard output={output} />}
</ToolAccordion>
)}
</div>
);
}

View File

@@ -1,116 +0,0 @@
"use client";
import type { AgentDetailsResponse } from "@/app/api/__generated__/models/agentDetailsResponse";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { FormRenderer } from "@/components/renderers/InputRenderer/FormRenderer";
import { AnimatePresence, motion } from "framer-motion";
import { useState } from "react";
import { useCopilotChatActions } from "../../../../components/CopilotChatActionsProvider/useCopilotChatActions";
import { ContentMessage } from "../../../../components/ToolAccordion/AccordionContent";
import { buildInputSchema } from "./helpers";
interface Props {
output: AgentDetailsResponse;
}
export function AgentDetailsCard({ output }: Props) {
const { onSend } = useCopilotChatActions();
const [showInputForm, setShowInputForm] = useState(false);
const [inputValues, setInputValues] = useState<Record<string, unknown>>({});
function handleRunWithExamples() {
onSend(
`Run the agent "${output.agent.name}" with placeholder/example values so I can test it.`,
);
}
function handleRunWithInputs() {
const nonEmpty = Object.fromEntries(
Object.entries(inputValues).filter(
([, v]) => v !== undefined && v !== null && v !== "",
),
);
onSend(
`Run the agent "${output.agent.name}" with these inputs: ${JSON.stringify(nonEmpty, null, 2)}`,
);
setShowInputForm(false);
setInputValues({});
}
return (
<div className="grid gap-2">
<ContentMessage>
Run this agent with example values or your own inputs.
</ContentMessage>
<div className="flex gap-2 pt-4">
<Button size="small" className="w-fit" onClick={handleRunWithExamples}>
Run with example values
</Button>
<Button
variant="outline"
size="small"
className="w-fit"
onClick={() => setShowInputForm((prev) => !prev)}
>
Run with my inputs
</Button>
</div>
<AnimatePresence initial={false}>
{showInputForm && buildInputSchema(output.agent.inputs) && (
<motion.div
initial={{ height: 0, opacity: 0, filter: "blur(6px)" }}
animate={{ height: "auto", opacity: 1, filter: "blur(0px)" }}
exit={{ height: 0, opacity: 0, filter: "blur(6px)" }}
transition={{
height: { type: "spring", bounce: 0.15, duration: 0.5 },
opacity: { duration: 0.25 },
filter: { duration: 0.2 },
}}
className="overflow-hidden"
style={{ willChange: "height, opacity, filter" }}
>
<div className="mt-4 rounded-2xl border bg-background p-3 pt-4">
<Text variant="body-medium">Enter your inputs</Text>
<FormRenderer
jsonSchema={buildInputSchema(output.agent.inputs)!}
handleChange={(v) => setInputValues(v.formData ?? {})}
uiSchema={{
"ui:submitButtonOptions": { norender: true },
}}
initialValues={inputValues}
formContext={{
showHandles: false,
size: "small",
}}
/>
<div className="-mt-8 flex gap-2">
<Button
variant="primary"
size="small"
className="w-fit"
onClick={handleRunWithInputs}
>
Run
</Button>
<Button
variant="secondary"
size="small"
className="w-fit"
onClick={() => {
setShowInputForm(false);
setInputValues({});
}}
>
Cancel
</Button>
</div>
</div>
</motion.div>
)}
</AnimatePresence>
</div>
);
}

Some files were not shown because too many files have changed in this diff Show More