mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-07 05:15:09 -05:00
Compare commits
6 Commits
dev
...
feat/copit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5efb80d47b | ||
|
|
b49d8e2cba | ||
|
|
452544530d | ||
|
|
32ee7e6cf8 | ||
|
|
670663c406 | ||
|
|
0dbe4cf51e |
16
.github/workflows/platform-frontend-ci.yml
vendored
16
.github/workflows/platform-frontend-ci.yml
vendored
@@ -27,20 +27,11 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||||
components-changed: ${{ steps.filter.outputs.components }}
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Check for component changes
|
|
||||||
uses: dorny/paths-filter@v3
|
|
||||||
id: filter
|
|
||||||
with:
|
|
||||||
filters: |
|
|
||||||
components:
|
|
||||||
- 'autogpt_platform/frontend/src/components/**'
|
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
@@ -99,11 +90,8 @@ jobs:
|
|||||||
chromatic:
|
chromatic:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: setup
|
needs: setup
|
||||||
# Disabled: to re-enable, remove 'false &&' from the condition below
|
# Only run on dev branch pushes or PRs targeting dev
|
||||||
if: >-
|
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
|
||||||
false
|
|
||||||
&& (github.ref == 'refs/heads/dev' || github.base_ref == 'dev')
|
|
||||||
&& needs.setup.outputs.components-changed == 'true'
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
|
|||||||
1320
autogpt_platform/autogpt_libs/poetry.lock
generated
1320
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -11,15 +11,15 @@ python = ">=3.10,<4.0"
|
|||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
cryptography = "^45.0"
|
cryptography = "^45.0"
|
||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.128.0"
|
fastapi = "^0.116.1"
|
||||||
google-cloud-logging = "^3.13.0"
|
google-cloud-logging = "^3.12.1"
|
||||||
launchdarkly-server-sdk = "^9.14.1"
|
launchdarkly-server-sdk = "^9.12.0"
|
||||||
pydantic = "^2.12.5"
|
pydantic = "^2.11.7"
|
||||||
pydantic-settings = "^2.12.0"
|
pydantic-settings = "^2.10.1"
|
||||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.27.2"
|
supabase = "^2.16.0"
|
||||||
uvicorn = "^0.40.0"
|
uvicorn = "^0.35.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pyright = "^1.1.404"
|
pyright = "^1.1.404"
|
||||||
|
|||||||
@@ -27,12 +27,20 @@ class ChatConfig(BaseSettings):
|
|||||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||||
|
|
||||||
# Streaming Configuration
|
# Streaming Configuration
|
||||||
|
# Note: When using Claude Agent SDK, context management is handled automatically
|
||||||
|
# via the SDK's built-in compaction. This is mainly used for the fallback path.
|
||||||
max_context_messages: int = Field(
|
max_context_messages: int = Field(
|
||||||
default=50, ge=1, le=200, description="Maximum context messages"
|
default=100,
|
||||||
|
ge=1,
|
||||||
|
le=500,
|
||||||
|
description="Max context messages (SDK handles compaction automatically)",
|
||||||
)
|
)
|
||||||
|
|
||||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
max_retries: int = Field(
|
||||||
|
default=3,
|
||||||
|
description="Max retries for fallback path (SDK handles retries internally)",
|
||||||
|
)
|
||||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||||
max_agent_schedules: int = Field(
|
max_agent_schedules: int = Field(
|
||||||
default=30, description="Maximum number of agent schedules"
|
default=30, description="Maximum number of agent schedules"
|
||||||
@@ -93,6 +101,12 @@ class ChatConfig(BaseSettings):
|
|||||||
description="Name of the prompt in Langfuse to fetch",
|
description="Name of the prompt in Langfuse to fetch",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Claude Agent SDK Configuration
|
||||||
|
use_claude_agent_sdk: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Use Claude Agent SDK for chat completions",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("api_key", mode="before")
|
@field_validator("api_key", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_api_key(cls, v):
|
def get_api_key(cls, v):
|
||||||
@@ -132,6 +146,17 @@ class ChatConfig(BaseSettings):
|
|||||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@field_validator("use_claude_agent_sdk", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def get_use_claude_agent_sdk(cls, v):
|
||||||
|
"""Get use_claude_agent_sdk from environment if not provided."""
|
||||||
|
# Check environment variable - default to True if not set
|
||||||
|
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
||||||
|
if env_val:
|
||||||
|
return env_val in ("true", "1", "yes", "on")
|
||||||
|
# Default to True (SDK enabled by default)
|
||||||
|
return True if v is None else v
|
||||||
|
|
||||||
# Prompt paths for different contexts
|
# Prompt paths for different contexts
|
||||||
PROMPT_PATHS: dict[str, str] = {
|
PROMPT_PATHS: dict[str, str] = {
|
||||||
"default": "prompts/chat_system.md",
|
"default": "prompts/chat_system.md",
|
||||||
|
|||||||
@@ -273,9 +273,8 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
|||||||
try:
|
try:
|
||||||
session = ChatSession.model_validate_json(raw_session)
|
session = ChatSession.model_validate_json(raw_session)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loading session {session_id} from cache: "
|
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
|
||||||
f"message_count={len(session.messages)}, "
|
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
|
||||||
f"roles={[m.role for m in session.messages]}"
|
|
||||||
)
|
)
|
||||||
return session
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -317,11 +316,9 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
messages = prisma_session.Messages
|
messages = prisma_session.Messages
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"Loading session {session_id} from DB: "
|
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
|
||||||
f"has_messages={messages is not None}, "
|
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
|
||||||
f"message_count={len(messages) if messages else 0}, "
|
|
||||||
f"roles={[m.role for m in messages] if messages else []}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatSession.from_db(prisma_session, messages)
|
return ChatSession.from_db(prisma_session, messages)
|
||||||
@@ -372,10 +369,9 @@ async def _save_session_to_db(
|
|||||||
"function_call": msg.function_call,
|
"function_call": msg.function_call,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
|
||||||
f"roles={[m['role'] for m in messages_data]}, "
|
f"roles={[m['role'] for m in messages_data]}"
|
||||||
f"start_sequence={existing_message_count}"
|
|
||||||
)
|
)
|
||||||
await chat_db.add_chat_messages_batch(
|
await chat_db.add_chat_messages_batch(
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
@@ -415,7 +411,7 @@ async def get_chat_session(
|
|||||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||||
|
|
||||||
# Fall back to database
|
# Fall back to database
|
||||||
logger.info(f"Session {session_id} not in cache, checking database")
|
logger.debug(f"Session {session_id} not in cache, checking database")
|
||||||
session = await _get_session_from_db(session_id)
|
session = await _get_session_from_db(session_id)
|
||||||
|
|
||||||
if session is None:
|
if session is None:
|
||||||
@@ -432,7 +428,6 @@ async def get_chat_session(
|
|||||||
# Cache the session from DB
|
# Cache the session from DB
|
||||||
try:
|
try:
|
||||||
await _cache_session(session)
|
await _cache_session(session)
|
||||||
logger.info(f"Cached session {session_id} from database")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||||
|
|
||||||
@@ -603,13 +598,19 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
|||||||
logger.warning(f"Session {session_id} not found for title update")
|
logger.warning(f"Session {session_id} not found for title update")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Invalidate cache so next fetch gets updated title
|
# Update title in cache if it exists (instead of invalidating).
|
||||||
|
# This prevents race conditions where cache invalidation causes
|
||||||
|
# the frontend to see stale DB data while streaming is still in progress.
|
||||||
try:
|
try:
|
||||||
redis_key = _get_session_cache_key(session_id)
|
cached = await _get_session_from_cache(session_id)
|
||||||
async_redis = await get_redis_async()
|
if cached:
|
||||||
await async_redis.delete(redis_key)
|
cached.title = title
|
||||||
|
await _cache_session(cached)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
# Not critical - title will be correct on next full cache refresh
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to update title in cache for session {session_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Chat API routes for chat session management and streaming via SSE."""
|
"""Chat API routes for chat session management and streaming via SSE."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid as uuid_module
|
import uuid as uuid_module
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
@@ -16,8 +17,17 @@ from . import service as chat_service
|
|||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
from .model import (
|
||||||
|
ChatMessage,
|
||||||
|
ChatSession,
|
||||||
|
create_chat_session,
|
||||||
|
get_chat_session,
|
||||||
|
get_user_sessions,
|
||||||
|
upsert_chat_session,
|
||||||
|
)
|
||||||
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
||||||
|
from .sdk import service as sdk_service
|
||||||
|
from .tracking import track_user_message
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -209,6 +219,10 @@ async def get_session(
|
|||||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||||
session_id, user_id
|
session_id, user_id
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
|
||||||
|
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||||
|
)
|
||||||
if active_task:
|
if active_task:
|
||||||
# Filter out the in-progress assistant message from the session response.
|
# Filter out the in-progress assistant message from the session response.
|
||||||
# The client will receive the complete assistant response through the SSE
|
# The client will receive the complete assistant response through the SSE
|
||||||
@@ -265,10 +279,30 @@ async def stream_chat_post(
|
|||||||
containing the task_id for reconnection.
|
containing the task_id for reconnection.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
|
||||||
|
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
|
# Add user message to session BEFORE creating task to avoid race condition
|
||||||
|
# where GET_SESSION sees the task as "running" but the message isn't saved yet
|
||||||
|
if request.message:
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="user" if request.is_user_message else "assistant",
|
||||||
|
content=request.message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if request.is_user_message:
|
||||||
|
track_user_message(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
message_length=len(request.message),
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[STREAM] Saving user message to session {session_id}, "
|
||||||
|
f"msg_count={len(session.messages)}"
|
||||||
|
)
|
||||||
|
session = await upsert_chat_session(session)
|
||||||
|
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support
|
||||||
task_id = str(uuid_module.uuid4())
|
task_id = str(uuid_module.uuid4())
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
@@ -283,24 +317,38 @@ async def stream_chat_post(
|
|||||||
|
|
||||||
# Background task that runs the AI generation independently of SSE connection
|
# Background task that runs the AI generation independently of SSE connection
|
||||||
async def run_ai_generation():
|
async def run_ai_generation():
|
||||||
|
chunk_count = 0
|
||||||
try:
|
try:
|
||||||
# Emit a start event with task_id for reconnection
|
# Emit a start event with task_id for reconnection
|
||||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||||
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
# Choose service based on configuration
|
||||||
|
use_sdk = config.use_claude_agent_sdk
|
||||||
|
stream_fn = (
|
||||||
|
sdk_service.stream_chat_completion_sdk
|
||||||
|
if use_sdk
|
||||||
|
else chat_service.stream_chat_completion
|
||||||
|
)
|
||||||
|
# Pass message=None since we already added it to the session above
|
||||||
|
async for chunk in stream_fn(
|
||||||
session_id,
|
session_id,
|
||||||
request.message,
|
None, # Message already in session
|
||||||
is_user_message=request.is_user_message,
|
is_user_message=request.is_user_message,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
session=session, # Pass session with message already added
|
||||||
context=request.context,
|
context=request.context,
|
||||||
):
|
):
|
||||||
|
chunk_count += 1
|
||||||
# Write to Redis (subscribers will receive via XREAD)
|
# Write to Redis (subscribers will receive via XREAD)
|
||||||
await stream_registry.publish_chunk(task_id, chunk)
|
await stream_registry.publish_chunk(task_id, chunk)
|
||||||
|
|
||||||
# Mark task as completed
|
logger.info(
|
||||||
await stream_registry.mark_task_completed(task_id, "completed")
|
f"[BG_TASK] AI generation completed for session {session_id}: {chunk_count} chunks, marking task {task_id} as completed"
|
||||||
|
)
|
||||||
|
# Mark task as completed (also publishes StreamFinish)
|
||||||
|
completed = await stream_registry.mark_task_completed(task_id, "completed")
|
||||||
|
logger.info(f"[BG_TASK] mark_task_completed returned: {completed}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error in background AI generation for session {session_id}: {e}"
|
f"Error in background AI generation for session {session_id}: {e}"
|
||||||
@@ -315,7 +363,7 @@ async def stream_chat_post(
|
|||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
subscriber_queue = None
|
subscriber_queue = None
|
||||||
try:
|
try:
|
||||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
# Subscribe to the task stream (replays + live updates)
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -323,6 +371,7 @@ async def stream_chat_post(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if subscriber_queue is None:
|
if subscriber_queue is None:
|
||||||
|
logger.warning(f"Failed to subscribe to task {task_id}")
|
||||||
yield StreamFinish().to_sse()
|
yield StreamFinish().to_sse()
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
return
|
return
|
||||||
@@ -341,11 +390,11 @@ async def stream_chat_post(
|
|||||||
yield StreamHeartbeat().to_sse()
|
yield StreamHeartbeat().to_sse()
|
||||||
|
|
||||||
except GeneratorExit:
|
except GeneratorExit:
|
||||||
pass # Client disconnected - background task continues
|
pass # Client disconnected - normal behavior
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
||||||
finally:
|
finally:
|
||||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
# Unsubscribe when client disconnects or stream ends
|
||||||
if subscriber_queue is not None:
|
if subscriber_queue is not None:
|
||||||
try:
|
try:
|
||||||
await stream_registry.unsubscribe_from_task(
|
await stream_registry.unsubscribe_from_task(
|
||||||
@@ -400,35 +449,21 @@ async def stream_chat_get(
|
|||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
chunk_count = 0
|
# Choose service based on configuration
|
||||||
first_chunk_type: str | None = None
|
use_sdk = config.use_claude_agent_sdk
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
stream_fn = (
|
||||||
|
sdk_service.stream_chat_completion_sdk
|
||||||
|
if use_sdk
|
||||||
|
else chat_service.stream_chat_completion
|
||||||
|
)
|
||||||
|
async for chunk in stream_fn(
|
||||||
session_id,
|
session_id,
|
||||||
message,
|
message,
|
||||||
is_user_message=is_user_message,
|
is_user_message=is_user_message,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
):
|
):
|
||||||
if chunk_count < 3:
|
|
||||||
logger.info(
|
|
||||||
"Chat stream chunk",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_type": str(chunk.type),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if not first_chunk_type:
|
|
||||||
first_chunk_type = str(chunk.type)
|
|
||||||
chunk_count += 1
|
|
||||||
yield chunk.to_sse()
|
yield chunk.to_sse()
|
||||||
logger.info(
|
|
||||||
"Chat stream completed",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"first_chunk_type": first_chunk_type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination
|
# AI SDK protocol termination
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
@@ -550,8 +585,6 @@ async def stream_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
import asyncio
|
|
||||||
|
|
||||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
"""Claude Agent SDK integration for CoPilot.
|
||||||
|
|
||||||
|
This module provides the integration layer between the Claude Agent SDK
|
||||||
|
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||||
|
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .service import stream_chat_completion_sdk
|
||||||
|
from .tool_adapter import create_copilot_mcp_server
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"stream_chat_completion_sdk",
|
||||||
|
"create_copilot_mcp_server",
|
||||||
|
]
|
||||||
@@ -0,0 +1,348 @@
|
|||||||
|
"""Anthropic SDK fallback implementation.
|
||||||
|
|
||||||
|
This module provides the fallback streaming implementation using the Anthropic SDK
|
||||||
|
directly when the Claude Agent SDK is not available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from ..model import ChatMessage, ChatSession
|
||||||
|
from ..response_model import (
|
||||||
|
StreamBaseResponse,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamTextEnd,
|
||||||
|
StreamTextStart,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolInputStart,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
StreamUsage,
|
||||||
|
)
|
||||||
|
from .tool_adapter import get_tool_definitions, get_tool_handlers
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_with_anthropic(
|
||||||
|
session: ChatSession,
|
||||||
|
system_prompt: str,
|
||||||
|
text_block_id: str,
|
||||||
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
|
"""Stream using Anthropic SDK directly with tool calling support.
|
||||||
|
|
||||||
|
This function accumulates messages into the session for persistence.
|
||||||
|
The caller should NOT yield an additional StreamFinish - this function handles it.
|
||||||
|
"""
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
# Only use ANTHROPIC_API_KEY - don't fall back to OpenRouter keys
|
||||||
|
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
yield StreamError(
|
||||||
|
errorText="ANTHROPIC_API_KEY not configured for fallback",
|
||||||
|
code="config_error",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
return
|
||||||
|
|
||||||
|
client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||||
|
tool_definitions = get_tool_definitions()
|
||||||
|
tool_handlers = get_tool_handlers()
|
||||||
|
|
||||||
|
anthropic_tools = [
|
||||||
|
{
|
||||||
|
"name": t["name"],
|
||||||
|
"description": t["description"],
|
||||||
|
"input_schema": t["inputSchema"],
|
||||||
|
}
|
||||||
|
for t in tool_definitions
|
||||||
|
]
|
||||||
|
|
||||||
|
anthropic_messages = _convert_session_to_anthropic(session)
|
||||||
|
|
||||||
|
if not anthropic_messages or anthropic_messages[-1]["role"] != "user":
|
||||||
|
anthropic_messages.append(
|
||||||
|
{"role": "user", "content": "Continue with the task."}
|
||||||
|
)
|
||||||
|
|
||||||
|
has_started_text = False
|
||||||
|
max_iterations = 10
|
||||||
|
accumulated_text = ""
|
||||||
|
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for _ in range(max_iterations):
|
||||||
|
try:
|
||||||
|
async with client.messages.stream(
|
||||||
|
model="claude-sonnet-4-20250514",
|
||||||
|
max_tokens=4096,
|
||||||
|
system=system_prompt,
|
||||||
|
messages=cast(Any, anthropic_messages),
|
||||||
|
tools=cast(Any, anthropic_tools) if anthropic_tools else [],
|
||||||
|
) as stream:
|
||||||
|
async for event in stream:
|
||||||
|
if event.type == "content_block_start":
|
||||||
|
block = event.content_block
|
||||||
|
if hasattr(block, "type"):
|
||||||
|
if block.type == "text" and not has_started_text:
|
||||||
|
yield StreamTextStart(id=text_block_id)
|
||||||
|
has_started_text = True
|
||||||
|
elif block.type == "tool_use":
|
||||||
|
yield StreamToolInputStart(
|
||||||
|
toolCallId=block.id, toolName=block.name
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event.type == "content_block_delta":
|
||||||
|
delta = event.delta
|
||||||
|
if hasattr(delta, "type") and delta.type == "text_delta":
|
||||||
|
accumulated_text += delta.text
|
||||||
|
yield StreamTextDelta(id=text_block_id, delta=delta.text)
|
||||||
|
|
||||||
|
final_message = await stream.get_final_message()
|
||||||
|
|
||||||
|
if final_message.stop_reason == "tool_use":
|
||||||
|
if has_started_text:
|
||||||
|
yield StreamTextEnd(id=text_block_id)
|
||||||
|
has_started_text = False
|
||||||
|
text_block_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
tool_results = []
|
||||||
|
assistant_content: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for block in final_message.content:
|
||||||
|
if block.type == "text":
|
||||||
|
assistant_content.append(
|
||||||
|
{"type": "text", "text": block.text}
|
||||||
|
)
|
||||||
|
elif block.type == "tool_use":
|
||||||
|
assistant_content.append(
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": block.id,
|
||||||
|
"name": block.name,
|
||||||
|
"input": block.input,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track tool call for session persistence
|
||||||
|
accumulated_tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": block.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": block.name,
|
||||||
|
"arguments": json.dumps(
|
||||||
|
block.input
|
||||||
|
if isinstance(block.input, dict)
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
yield StreamToolInputAvailable(
|
||||||
|
toolCallId=block.id,
|
||||||
|
toolName=block.name,
|
||||||
|
input=(
|
||||||
|
block.input if isinstance(block.input, dict) else {}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
output, is_error = await _execute_tool(
|
||||||
|
block.name, block.input, tool_handlers
|
||||||
|
)
|
||||||
|
|
||||||
|
yield StreamToolOutputAvailable(
|
||||||
|
toolCallId=block.id,
|
||||||
|
toolName=block.name,
|
||||||
|
output=output,
|
||||||
|
success=not is_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save tool result to session
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="tool",
|
||||||
|
content=output,
|
||||||
|
tool_call_id=block.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_results.append(
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": block.id,
|
||||||
|
"content": output,
|
||||||
|
"is_error": is_error,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save assistant message with tool calls to session
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=accumulated_text or None,
|
||||||
|
tool_calls=(
|
||||||
|
accumulated_tool_calls
|
||||||
|
if accumulated_tool_calls
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Reset for next iteration
|
||||||
|
accumulated_text = ""
|
||||||
|
accumulated_tool_calls = []
|
||||||
|
|
||||||
|
anthropic_messages.append(
|
||||||
|
{"role": "assistant", "content": assistant_content}
|
||||||
|
)
|
||||||
|
anthropic_messages.append({"role": "user", "content": tool_results})
|
||||||
|
continue
|
||||||
|
|
||||||
|
else:
|
||||||
|
if has_started_text:
|
||||||
|
yield StreamTextEnd(id=text_block_id)
|
||||||
|
|
||||||
|
# Save final assistant response to session
|
||||||
|
if accumulated_text:
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(role="assistant", content=accumulated_text)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield StreamUsage(
|
||||||
|
promptTokens=final_message.usage.input_tokens,
|
||||||
|
completionTokens=final_message.usage.output_tokens,
|
||||||
|
totalTokens=final_message.usage.input_tokens
|
||||||
|
+ final_message.usage.output_tokens,
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Anthropic Fallback] Error: {e}", exc_info=True)
|
||||||
|
yield StreamError(
|
||||||
|
errorText="An error occurred. Please try again.",
|
||||||
|
code="anthropic_error",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
return
|
||||||
|
|
||||||
|
yield StreamError(errorText="Max tool iterations reached", code="max_iterations")
|
||||||
|
yield StreamFinish()
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]:
|
||||||
|
"""Convert session messages to Anthropic format.
|
||||||
|
|
||||||
|
Handles merging consecutive same-role messages (Anthropic requires alternating roles).
|
||||||
|
"""
|
||||||
|
messages: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for msg in session.messages:
|
||||||
|
if msg.role == "user":
|
||||||
|
new_msg = {"role": "user", "content": msg.content or ""}
|
||||||
|
elif msg.role == "assistant":
|
||||||
|
content: list[dict[str, Any]] = []
|
||||||
|
if msg.content:
|
||||||
|
content.append({"type": "text", "text": msg.content})
|
||||||
|
if msg.tool_calls:
|
||||||
|
for tc in msg.tool_calls:
|
||||||
|
func = tc.get("function", {})
|
||||||
|
args = func.get("arguments", {})
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
args = json.loads(args)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
args = {}
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": tc.get("id", str(uuid.uuid4())),
|
||||||
|
"name": func.get("name", ""),
|
||||||
|
"input": args,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if content:
|
||||||
|
new_msg = {"role": "assistant", "content": content}
|
||||||
|
else:
|
||||||
|
continue # Skip empty assistant messages
|
||||||
|
elif msg.role == "tool":
|
||||||
|
new_msg = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": msg.tool_call_id or "",
|
||||||
|
"content": msg.content or "",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
messages.append(new_msg)
|
||||||
|
|
||||||
|
# Merge consecutive same-role messages (Anthropic requires alternating roles)
|
||||||
|
return _merge_consecutive_roles(messages)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_consecutive_roles(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Merge consecutive messages with the same role.
|
||||||
|
|
||||||
|
Anthropic API requires alternating user/assistant roles.
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
merged: list[dict[str, Any]] = []
|
||||||
|
for msg in messages:
|
||||||
|
if merged and merged[-1]["role"] == msg["role"]:
|
||||||
|
# Merge with previous message
|
||||||
|
prev_content = merged[-1]["content"]
|
||||||
|
new_content = msg["content"]
|
||||||
|
|
||||||
|
# Normalize both to list-of-blocks form
|
||||||
|
if isinstance(prev_content, str):
|
||||||
|
prev_content = [{"type": "text", "text": prev_content}]
|
||||||
|
if isinstance(new_content, str):
|
||||||
|
new_content = [{"type": "text", "text": new_content}]
|
||||||
|
|
||||||
|
# Ensure both are lists
|
||||||
|
if not isinstance(prev_content, list):
|
||||||
|
prev_content = [prev_content]
|
||||||
|
if not isinstance(new_content, list):
|
||||||
|
new_content = [new_content]
|
||||||
|
|
||||||
|
merged[-1]["content"] = prev_content + new_content
|
||||||
|
else:
|
||||||
|
merged.append(msg)
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_tool(
|
||||||
|
tool_name: str, tool_input: Any, handlers: dict[str, Any]
|
||||||
|
) -> tuple[str, bool]:
|
||||||
|
"""Execute a tool and return (output, is_error)."""
|
||||||
|
handler = handlers.get(tool_name)
|
||||||
|
if not handler:
|
||||||
|
return f"Unknown tool: {tool_name}", True
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await handler(tool_input)
|
||||||
|
# Safely extract output - handle empty or missing content
|
||||||
|
content = result.get("content") or []
|
||||||
|
if content and isinstance(content, list) and len(content) > 0:
|
||||||
|
first_item = content[0]
|
||||||
|
output = first_item.get("text", "") if isinstance(first_item, dict) else ""
|
||||||
|
else:
|
||||||
|
output = ""
|
||||||
|
is_error = result.get("isError", False)
|
||||||
|
return output, is_error
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error: {str(e)}", True
|
||||||
@@ -0,0 +1,300 @@
|
|||||||
|
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||||
|
|
||||||
|
This module provides the adapter layer that converts streaming messages from
|
||||||
|
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
|
||||||
|
the frontend expects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
|
from backend.api.features.chat.response_model import (
|
||||||
|
StreamBaseResponse,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamHeartbeat,
|
||||||
|
StreamStart,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamTextEnd,
|
||||||
|
StreamTextStart,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolInputStart,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
StreamUsage,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SDKResponseAdapter:
|
||||||
|
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||||
|
|
||||||
|
This class maintains state during a streaming session to properly track
|
||||||
|
text blocks, tool calls, and message lifecycle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message_id: str | None = None):
|
||||||
|
"""Initialize the adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_id: Optional message ID. If not provided, one will be generated.
|
||||||
|
"""
|
||||||
|
self.message_id = message_id or str(uuid.uuid4())
|
||||||
|
self.text_block_id = str(uuid.uuid4())
|
||||||
|
self.has_started_text = False
|
||||||
|
self.has_ended_text = False
|
||||||
|
self.current_tool_calls: dict[str, dict[str, Any]] = {}
|
||||||
|
self.task_id: str | None = None
|
||||||
|
|
||||||
|
def set_task_id(self, task_id: str) -> None:
|
||||||
|
"""Set the task ID for reconnection support."""
|
||||||
|
self.task_id = task_id
|
||||||
|
|
||||||
|
def convert_message(self, sdk_message: Any) -> list[StreamBaseResponse]:
|
||||||
|
"""Convert a single SDK message to Vercel AI SDK format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sdk_message: A message from the Claude Agent SDK.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of StreamBaseResponse objects (may be empty or multiple).
|
||||||
|
"""
|
||||||
|
responses: list[StreamBaseResponse] = []
|
||||||
|
|
||||||
|
# Handle different SDK message types - use class name since SDK uses dataclasses
|
||||||
|
class_name = type(sdk_message).__name__
|
||||||
|
msg_subtype = getattr(sdk_message, "subtype", None)
|
||||||
|
|
||||||
|
if class_name == "SystemMessage":
|
||||||
|
if msg_subtype == "init":
|
||||||
|
# Session initialization - emit start
|
||||||
|
responses.append(
|
||||||
|
StreamStart(
|
||||||
|
messageId=self.message_id,
|
||||||
|
taskId=self.task_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif class_name == "AssistantMessage":
|
||||||
|
# Assistant message with content blocks
|
||||||
|
content = getattr(sdk_message, "content", [])
|
||||||
|
for block in content:
|
||||||
|
# Check block type by class name (SDK uses dataclasses) or dict type
|
||||||
|
block_class = type(block).__name__
|
||||||
|
block_type = block.get("type") if isinstance(block, dict) else None
|
||||||
|
|
||||||
|
if block_class == "TextBlock" or block_type == "text":
|
||||||
|
# Text content
|
||||||
|
text = getattr(block, "text", None) or (
|
||||||
|
block.get("text") if isinstance(block, dict) else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
if text:
|
||||||
|
# Start text block if needed (or restart after tool calls)
|
||||||
|
if not self.has_started_text or self.has_ended_text:
|
||||||
|
# Generate new text block ID for text after tools
|
||||||
|
if self.has_ended_text:
|
||||||
|
self.text_block_id = str(uuid.uuid4())
|
||||||
|
self.has_ended_text = False
|
||||||
|
responses.append(StreamTextStart(id=self.text_block_id))
|
||||||
|
self.has_started_text = True
|
||||||
|
|
||||||
|
# Emit text delta
|
||||||
|
responses.append(
|
||||||
|
StreamTextDelta(
|
||||||
|
id=self.text_block_id,
|
||||||
|
delta=text,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif block_class == "ToolUseBlock" or block_type == "tool_use":
|
||||||
|
# Tool call
|
||||||
|
tool_id_raw = getattr(block, "id", None) or (
|
||||||
|
block.get("id") if isinstance(block, dict) else None
|
||||||
|
)
|
||||||
|
tool_id: str = (
|
||||||
|
str(tool_id_raw) if tool_id_raw else str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_name_raw = getattr(block, "name", None) or (
|
||||||
|
block.get("name") if isinstance(block, dict) else None
|
||||||
|
)
|
||||||
|
tool_name: str = str(tool_name_raw) if tool_name_raw else "unknown"
|
||||||
|
|
||||||
|
tool_input = getattr(block, "input", None) or (
|
||||||
|
block.get("input") if isinstance(block, dict) else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# End text block if we were streaming text
|
||||||
|
if self.has_started_text and not self.has_ended_text:
|
||||||
|
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||||
|
self.has_ended_text = True
|
||||||
|
|
||||||
|
# Emit tool input start
|
||||||
|
responses.append(
|
||||||
|
StreamToolInputStart(
|
||||||
|
toolCallId=tool_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit tool input available with full input
|
||||||
|
responses.append(
|
||||||
|
StreamToolInputAvailable(
|
||||||
|
toolCallId=tool_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
input=tool_input if isinstance(tool_input, dict) else {},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track the tool call
|
||||||
|
self.current_tool_calls[tool_id] = {
|
||||||
|
"name": tool_name,
|
||||||
|
"input": tool_input,
|
||||||
|
}
|
||||||
|
|
||||||
|
elif class_name in ("ToolResultMessage", "UserMessage"):
|
||||||
|
# Tool result - check for tool_result content
|
||||||
|
content = getattr(sdk_message, "content", [])
|
||||||
|
|
||||||
|
for block in content:
|
||||||
|
block_class = type(block).__name__
|
||||||
|
block_type = block.get("type") if isinstance(block, dict) else None
|
||||||
|
|
||||||
|
if block_class == "ToolResultBlock" or block_type == "tool_result":
|
||||||
|
tool_use_id = getattr(block, "tool_use_id", None) or (
|
||||||
|
block.get("tool_use_id") if isinstance(block, dict) else None
|
||||||
|
)
|
||||||
|
result_content = getattr(block, "content", None) or (
|
||||||
|
block.get("content") if isinstance(block, dict) else ""
|
||||||
|
)
|
||||||
|
is_error = getattr(block, "is_error", False) or (
|
||||||
|
block.get("is_error", False)
|
||||||
|
if isinstance(block, dict)
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool_use_id:
|
||||||
|
tool_info = self.current_tool_calls.get(tool_use_id, {})
|
||||||
|
tool_name = tool_info.get("name", "unknown")
|
||||||
|
|
||||||
|
# Format the output
|
||||||
|
if isinstance(result_content, list):
|
||||||
|
# Extract text from content blocks
|
||||||
|
output_text = ""
|
||||||
|
for item in result_content:
|
||||||
|
if (
|
||||||
|
isinstance(item, dict)
|
||||||
|
and item.get("type") == "text"
|
||||||
|
):
|
||||||
|
output_text += item.get("text", "")
|
||||||
|
elif hasattr(item, "text"):
|
||||||
|
output_text += getattr(item, "text", "")
|
||||||
|
output = output_text
|
||||||
|
elif isinstance(result_content, str):
|
||||||
|
output = result_content
|
||||||
|
else:
|
||||||
|
output = json.dumps(result_content)
|
||||||
|
|
||||||
|
responses.append(
|
||||||
|
StreamToolOutputAvailable(
|
||||||
|
toolCallId=tool_use_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
output=output,
|
||||||
|
success=not is_error,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif class_name == "ResultMessage":
|
||||||
|
# Final result
|
||||||
|
if msg_subtype == "success":
|
||||||
|
# End text block if still open
|
||||||
|
if self.has_started_text and not self.has_ended_text:
|
||||||
|
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||||
|
self.has_ended_text = True
|
||||||
|
|
||||||
|
# Emit finish
|
||||||
|
responses.append(StreamFinish())
|
||||||
|
|
||||||
|
elif msg_subtype in ("error", "error_during_execution"):
|
||||||
|
error_msg = getattr(sdk_message, "error", "Unknown error")
|
||||||
|
responses.append(
|
||||||
|
StreamError(
|
||||||
|
errorText=str(error_msg),
|
||||||
|
code="sdk_error",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
responses.append(StreamFinish())
|
||||||
|
|
||||||
|
elif class_name == "ErrorMessage":
|
||||||
|
# Error message
|
||||||
|
error_msg = getattr(sdk_message, "message", None) or getattr(
|
||||||
|
sdk_message, "error", "Unknown error"
|
||||||
|
)
|
||||||
|
responses.append(
|
||||||
|
StreamError(
|
||||||
|
errorText=str(error_msg),
|
||||||
|
code="sdk_error",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
responses.append(StreamFinish())
|
||||||
|
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def create_heartbeat(self, tool_call_id: str | None = None) -> StreamHeartbeat:
|
||||||
|
"""Create a heartbeat response."""
|
||||||
|
return StreamHeartbeat(toolCallId=tool_call_id)
|
||||||
|
|
||||||
|
def create_usage(
|
||||||
|
self,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
) -> StreamUsage:
|
||||||
|
"""Create a usage statistics response."""
|
||||||
|
return StreamUsage(
|
||||||
|
promptTokens=prompt_tokens,
|
||||||
|
completionTokens=completion_tokens,
|
||||||
|
totalTokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def adapt_sdk_stream(
|
||||||
|
sdk_stream: AsyncGenerator[Any, None],
|
||||||
|
message_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
|
"""Adapt a Claude Agent SDK stream to Vercel AI SDK format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sdk_stream: The async generator from the Claude Agent SDK.
|
||||||
|
message_id: Optional message ID for the response.
|
||||||
|
task_id: Optional task ID for reconnection support.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamBaseResponse objects in Vercel AI SDK format.
|
||||||
|
"""
|
||||||
|
adapter = SDKResponseAdapter(message_id=message_id)
|
||||||
|
if task_id:
|
||||||
|
adapter.set_task_id(task_id)
|
||||||
|
|
||||||
|
# Emit start immediately
|
||||||
|
yield StreamStart(messageId=adapter.message_id, taskId=task_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for sdk_message in sdk_stream:
|
||||||
|
responses = adapter.convert_message(sdk_message)
|
||||||
|
for response in responses:
|
||||||
|
# Skip duplicate start messages
|
||||||
|
if isinstance(response, StreamStart):
|
||||||
|
continue
|
||||||
|
yield response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in SDK stream: {e}", exc_info=True)
|
||||||
|
yield StreamError(
|
||||||
|
errorText=f"Stream error: {str(e)}",
|
||||||
|
code="stream_error",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
@@ -0,0 +1,278 @@
|
|||||||
|
"""Security hooks for Claude Agent SDK integration.
|
||||||
|
|
||||||
|
This module provides security hooks that validate tool calls before execution,
|
||||||
|
ensuring multi-user isolation and preventing unauthorized operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Tools that are blocked entirely (CLI/system access)
|
||||||
|
BLOCKED_TOOLS = {
|
||||||
|
"Bash",
|
||||||
|
"bash",
|
||||||
|
"shell",
|
||||||
|
"exec",
|
||||||
|
"terminal",
|
||||||
|
"command",
|
||||||
|
"Read", # Block raw file read - use workspace tools instead
|
||||||
|
"Write", # Block raw file write - use workspace tools instead
|
||||||
|
"Edit", # Block raw file edit - use workspace tools instead
|
||||||
|
"Glob", # Block raw file glob - use workspace tools instead
|
||||||
|
"Grep", # Block raw file grep - use workspace tools instead
|
||||||
|
}
|
||||||
|
|
||||||
|
# Dangerous patterns in tool inputs
|
||||||
|
DANGEROUS_PATTERNS = [
|
||||||
|
r"sudo",
|
||||||
|
r"rm\s+-rf",
|
||||||
|
r"dd\s+if=",
|
||||||
|
r"/etc/passwd",
|
||||||
|
r"/etc/shadow",
|
||||||
|
r"chmod\s+777",
|
||||||
|
r"curl\s+.*\|.*sh",
|
||||||
|
r"wget\s+.*\|.*sh",
|
||||||
|
r"eval\s*\(",
|
||||||
|
r"exec\s*\(",
|
||||||
|
r"__import__",
|
||||||
|
r"os\.system",
|
||||||
|
r"subprocess",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_tool_access(tool_name: str, tool_input: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Validate that a tool call is allowed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Empty dict to allow, or dict with hookSpecificOutput to deny
|
||||||
|
"""
|
||||||
|
# Block forbidden tools
|
||||||
|
if tool_name in BLOCKED_TOOLS:
|
||||||
|
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
||||||
|
return {
|
||||||
|
"hookSpecificOutput": {
|
||||||
|
"hookEventName": "PreToolUse",
|
||||||
|
"permissionDecision": "deny",
|
||||||
|
"permissionDecisionReason": (
|
||||||
|
f"Tool '{tool_name}' is not available. "
|
||||||
|
"Use the CoPilot-specific tools instead."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check for dangerous patterns in tool input
|
||||||
|
input_str = str(tool_input)
|
||||||
|
|
||||||
|
for pattern in DANGEROUS_PATTERNS:
|
||||||
|
if re.search(pattern, input_str, re.IGNORECASE):
|
||||||
|
logger.warning(
|
||||||
|
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"hookSpecificOutput": {
|
||||||
|
"hookEventName": "PreToolUse",
|
||||||
|
"permissionDecision": "deny",
|
||||||
|
"permissionDecisionReason": "Input contains blocked pattern",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_user_isolation(
|
||||||
|
tool_name: str, tool_input: dict[str, Any], user_id: str | None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Validate that tool calls respect user isolation."""
|
||||||
|
# For workspace file tools, ensure path doesn't escape
|
||||||
|
if "workspace" in tool_name.lower():
|
||||||
|
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||||
|
if path:
|
||||||
|
# Check for path traversal
|
||||||
|
if ".." in path or path.startswith("/"):
|
||||||
|
logger.warning(
|
||||||
|
f"Blocked path traversal attempt: {path} by user {user_id}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"hookSpecificOutput": {
|
||||||
|
"hookEventName": "PreToolUse",
|
||||||
|
"permissionDecision": "deny",
|
||||||
|
"permissionDecisionReason": "Path traversal not allowed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def create_security_hooks(user_id: str | None) -> dict[str, Any]:
|
||||||
|
"""Create the security hooks configuration for Claude Agent SDK.
|
||||||
|
|
||||||
|
Includes security validation and observability hooks:
|
||||||
|
- PreToolUse: Security validation before tool execution
|
||||||
|
- PostToolUse: Log successful tool executions
|
||||||
|
- PostToolUseFailure: Log and handle failed tool executions
|
||||||
|
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Current user ID for isolation validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hooks configuration dict for ClaudeAgentOptions
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import HookMatcher
|
||||||
|
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||||
|
|
||||||
|
async def pre_tool_use_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Combined pre-tool-use validation hook."""
|
||||||
|
_ = context # unused but required by signature
|
||||||
|
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||||
|
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||||
|
|
||||||
|
# Validate basic tool access
|
||||||
|
result = _validate_tool_access(tool_name, tool_input)
|
||||||
|
if result:
|
||||||
|
return cast(SyncHookJSONOutput, result)
|
||||||
|
|
||||||
|
# Validate user isolation
|
||||||
|
result = _validate_user_isolation(tool_name, tool_input, user_id)
|
||||||
|
if result:
|
||||||
|
return cast(SyncHookJSONOutput, result)
|
||||||
|
|
||||||
|
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
async def post_tool_use_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Log successful tool executions for observability."""
|
||||||
|
_ = context
|
||||||
|
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||||
|
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
async def post_tool_failure_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Log failed tool executions for debugging."""
|
||||||
|
_ = context
|
||||||
|
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||||
|
error = input_data.get("error", "Unknown error")
|
||||||
|
logger.warning(
|
||||||
|
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||||
|
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||||
|
)
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
async def pre_compact_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Log when SDK triggers context compaction.
|
||||||
|
|
||||||
|
The SDK automatically compacts conversation history when it grows too large.
|
||||||
|
This hook provides visibility into when compaction happens.
|
||||||
|
"""
|
||||||
|
_ = context, tool_use_id
|
||||||
|
trigger = input_data.get("trigger", "auto")
|
||||||
|
logger.info(
|
||||||
|
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
||||||
|
)
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
||||||
|
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
||||||
|
"PostToolUseFailure": [
|
||||||
|
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
||||||
|
],
|
||||||
|
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||||
|
}
|
||||||
|
except ImportError:
|
||||||
|
# Fallback for when SDK isn't available - return empty hooks
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def create_strict_security_hooks(
|
||||||
|
user_id: str | None,
|
||||||
|
allowed_tools: list[str] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create strict security hooks that only allow specific tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Current user ID
|
||||||
|
allowed_tools: List of allowed tool names (defaults to CoPilot tools)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hooks configuration dict
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import HookMatcher
|
||||||
|
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||||
|
|
||||||
|
from .tool_adapter import RAW_TOOL_NAMES
|
||||||
|
|
||||||
|
tools_list = allowed_tools if allowed_tools is not None else RAW_TOOL_NAMES
|
||||||
|
allowed_set = set(tools_list)
|
||||||
|
|
||||||
|
async def strict_pre_tool_use(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Strict validation that only allows whitelisted tools."""
|
||||||
|
_ = context # unused but required by signature
|
||||||
|
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||||
|
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||||
|
|
||||||
|
# Remove MCP prefix if present
|
||||||
|
clean_name = tool_name.removeprefix("mcp__copilot__")
|
||||||
|
|
||||||
|
if clean_name not in allowed_set:
|
||||||
|
logger.warning(f"Blocked non-whitelisted tool: {tool_name}")
|
||||||
|
return cast(
|
||||||
|
SyncHookJSONOutput,
|
||||||
|
{
|
||||||
|
"hookSpecificOutput": {
|
||||||
|
"hookEventName": "PreToolUse",
|
||||||
|
"permissionDecision": "deny",
|
||||||
|
"permissionDecisionReason": (
|
||||||
|
f"Tool '{tool_name}' is not in the allowed list"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run standard validations
|
||||||
|
result = _validate_tool_access(tool_name, tool_input)
|
||||||
|
if result:
|
||||||
|
return cast(SyncHookJSONOutput, result)
|
||||||
|
|
||||||
|
result = _validate_user_isolation(tool_name, tool_input, user_id)
|
||||||
|
if result:
|
||||||
|
return cast(SyncHookJSONOutput, result)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[SDK Audit] Tool call: tool={tool_name}, "
|
||||||
|
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||||
|
)
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"PreToolUse": [
|
||||||
|
HookMatcher(matcher="*", hooks=[strict_pre_tool_use]),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
except ImportError:
|
||||||
|
return {}
|
||||||
@@ -0,0 +1,471 @@
|
|||||||
|
"""Claude Agent SDK service layer for CoPilot chat completions."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from backend.data.understanding import (
|
||||||
|
format_understanding_for_prompt,
|
||||||
|
get_business_understanding,
|
||||||
|
)
|
||||||
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
|
from ..config import ChatConfig
|
||||||
|
from ..model import (
|
||||||
|
ChatMessage,
|
||||||
|
ChatSession,
|
||||||
|
get_chat_session,
|
||||||
|
update_session_title,
|
||||||
|
upsert_chat_session,
|
||||||
|
)
|
||||||
|
from ..response_model import (
|
||||||
|
StreamBaseResponse,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamStart,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
)
|
||||||
|
from ..tracking import track_user_message
|
||||||
|
from .anthropic_fallback import stream_with_anthropic
|
||||||
|
from .response_adapter import SDKResponseAdapter
|
||||||
|
from .security_hooks import create_security_hooks
|
||||||
|
from .tool_adapter import (
|
||||||
|
COPILOT_TOOL_NAMES,
|
||||||
|
create_copilot_mcp_server,
|
||||||
|
set_execution_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
config = ChatConfig()
|
||||||
|
|
||||||
|
# Set to hold background tasks to prevent garbage collection
|
||||||
|
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||||
|
|
||||||
|
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
|
||||||
|
|
||||||
|
Here is everything you know about the current user from previous interactions:
|
||||||
|
|
||||||
|
<users_information>
|
||||||
|
{users_information}
|
||||||
|
</users_information>
|
||||||
|
|
||||||
|
## YOUR CORE MANDATE
|
||||||
|
|
||||||
|
You are action-oriented. Your success is measured by:
|
||||||
|
- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"?
|
||||||
|
- **Demonstrable Proof**: Show working automations, not descriptions of what's possible
|
||||||
|
- **Time Saved**: Focus on tangible efficiency gains
|
||||||
|
- **Quality Output**: Deliver results that meet or exceed expectations
|
||||||
|
|
||||||
|
## YOUR WORKFLOW
|
||||||
|
|
||||||
|
Adapt flexibly to the conversation context. Not every interaction requires all stages:
|
||||||
|
|
||||||
|
1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations.
|
||||||
|
|
||||||
|
2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task.
|
||||||
|
|
||||||
|
3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.).
|
||||||
|
|
||||||
|
4. **Discover or Create Agents**:
|
||||||
|
- **Always check the user's library first** with `find_library_agent` (these may be customized to their needs)
|
||||||
|
- Search the marketplace with `find_agent` for pre-built automations
|
||||||
|
- Find reusable components with `find_block`
|
||||||
|
- Create custom solutions with `create_agent` if nothing suitable exists
|
||||||
|
- Modify existing library agents with `edit_agent`
|
||||||
|
|
||||||
|
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
|
||||||
|
|
||||||
|
6. **Show Results**: Display outputs using `agent_output`.
|
||||||
|
|
||||||
|
## BEHAVIORAL GUIDELINES
|
||||||
|
|
||||||
|
**Be Concise:**
|
||||||
|
- Target 2-5 short lines maximum
|
||||||
|
- Make every word count—no repetition or filler
|
||||||
|
- Use lightweight structure for scannability (bullets, numbered lists, short prompts)
|
||||||
|
- Avoid jargon (blocks, slugs, cron) unless the user asks
|
||||||
|
|
||||||
|
**Be Proactive:**
|
||||||
|
- Suggest next steps before being asked
|
||||||
|
- Anticipate needs based on conversation context and user information
|
||||||
|
- Look for opportunities to expand scope when relevant
|
||||||
|
- Reveal capabilities through action, not explanation
|
||||||
|
|
||||||
|
**Use Tools Effectively:**
|
||||||
|
- Select the right tool for each task
|
||||||
|
- **Always check `find_library_agent` before searching the marketplace**
|
||||||
|
- Use `add_understanding` to capture valuable business context
|
||||||
|
- When tool calls fail, try alternative approaches
|
||||||
|
|
||||||
|
## CRITICAL REMINDER
|
||||||
|
|
||||||
|
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_system_prompt(
|
||||||
|
user_id: str | None, has_conversation_history: bool = False
|
||||||
|
) -> tuple[str, Any]:
|
||||||
|
"""Build the system prompt with user's business understanding context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID to fetch understanding for.
|
||||||
|
has_conversation_history: Whether there's existing conversation history.
|
||||||
|
If True, we don't tell the model to greet/introduce (since they're
|
||||||
|
already in a conversation).
|
||||||
|
"""
|
||||||
|
understanding = None
|
||||||
|
if user_id:
|
||||||
|
try:
|
||||||
|
understanding = await get_business_understanding(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch business understanding: {e}")
|
||||||
|
|
||||||
|
if understanding:
|
||||||
|
context = format_understanding_for_prompt(understanding)
|
||||||
|
elif has_conversation_history:
|
||||||
|
# Don't tell model to greet if there's conversation history
|
||||||
|
context = "No prior understanding saved yet. Continue the existing conversation naturally."
|
||||||
|
else:
|
||||||
|
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||||
|
|
||||||
|
return DEFAULT_SYSTEM_PROMPT.format(users_information=context), understanding
|
||||||
|
|
||||||
|
|
||||||
|
def _format_conversation_history(session: ChatSession) -> str:
|
||||||
|
"""Format conversation history as a prompt context.
|
||||||
|
|
||||||
|
The SDK handles context compaction automatically, but we apply
|
||||||
|
max_context_messages as a safety guard to limit initial prompt size.
|
||||||
|
"""
|
||||||
|
if not session.messages:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Get all messages except the last user message (which will be the prompt)
|
||||||
|
messages = session.messages[:-1] if session.messages else []
|
||||||
|
if not messages:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Apply max_context_messages limit as a safety guard
|
||||||
|
# (SDK handles compaction, but this prevents excessively large initial prompts)
|
||||||
|
max_messages = config.max_context_messages
|
||||||
|
if len(messages) > max_messages:
|
||||||
|
messages = messages[-max_messages:]
|
||||||
|
|
||||||
|
history_parts = ["<conversation_history>"]
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
if msg.role == "user":
|
||||||
|
history_parts.append(f"User: {msg.content or ''}")
|
||||||
|
elif msg.role == "assistant":
|
||||||
|
# Pass full content - SDK handles compaction automatically
|
||||||
|
history_parts.append(f"Assistant: {msg.content or ''}")
|
||||||
|
if msg.tool_calls:
|
||||||
|
for tc in msg.tool_calls:
|
||||||
|
func = tc.get("function", {})
|
||||||
|
history_parts.append(
|
||||||
|
f" [Called tool: {func.get('name', 'unknown')}]"
|
||||||
|
)
|
||||||
|
elif msg.role == "tool":
|
||||||
|
# Pass full tool results - SDK handles compaction
|
||||||
|
history_parts.append(f" [Tool result: {msg.content or ''}]")
|
||||||
|
|
||||||
|
history_parts.append("</conversation_history>")
|
||||||
|
history_parts.append("")
|
||||||
|
history_parts.append(
|
||||||
|
"Continue this conversation. Respond to the user's latest message:"
|
||||||
|
)
|
||||||
|
history_parts.append("")
|
||||||
|
|
||||||
|
return "\n".join(history_parts)
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_session_title(
|
||||||
|
message: str,
|
||||||
|
user_id: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
"""Generate a concise title for a chat session."""
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
try:
|
||||||
|
# Build extra_body for OpenRouter tracing
|
||||||
|
extra_body: dict[str, Any] = {
|
||||||
|
"posthogProperties": {"environment": settings.config.app_env.value},
|
||||||
|
}
|
||||||
|
if user_id:
|
||||||
|
extra_body["user"] = user_id[:128]
|
||||||
|
extra_body["posthogDistinctId"] = user_id
|
||||||
|
if session_id:
|
||||||
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
|
client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=config.title_model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Generate a very short title (3-6 words) for a chat conversation based on the user's first message. Return ONLY the title, no quotes or punctuation.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": message[:500]},
|
||||||
|
],
|
||||||
|
max_tokens=20,
|
||||||
|
extra_body=extra_body,
|
||||||
|
)
|
||||||
|
title = response.choices[0].message.content
|
||||||
|
if title:
|
||||||
|
title = title.strip().strip("\"'")
|
||||||
|
return title[:47] + "..." if len(title) > 50 else title
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to generate session title: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_chat_completion_sdk(
|
||||||
|
session_id: str,
|
||||||
|
message: str | None = None,
|
||||||
|
tool_call_response: str | None = None, # noqa: ARG001
|
||||||
|
is_user_message: bool = True,
|
||||||
|
user_id: str | None = None,
|
||||||
|
retry_count: int = 0, # noqa: ARG001
|
||||||
|
session: ChatSession | None = None,
|
||||||
|
context: dict[str, str] | None = None, # noqa: ARG001
|
||||||
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
|
"""Stream chat completion using Claude Agent SDK.
|
||||||
|
|
||||||
|
Drop-in replacement for stream_chat_completion with improved reliability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
raise NotFoundError(
|
||||||
|
f"Session {session_id} not found. Please create a new session first."
|
||||||
|
)
|
||||||
|
|
||||||
|
if message:
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="user" if is_user_message else "assistant", content=message
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if is_user_message:
|
||||||
|
track_user_message(
|
||||||
|
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||||
|
)
|
||||||
|
|
||||||
|
session = await upsert_chat_session(session)
|
||||||
|
|
||||||
|
# Generate title for new sessions (first user message)
|
||||||
|
if is_user_message and not session.title:
|
||||||
|
user_messages = [m for m in session.messages if m.role == "user"]
|
||||||
|
if len(user_messages) == 1:
|
||||||
|
first_message = user_messages[0].content or message or ""
|
||||||
|
if first_message:
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_update_title_async(session_id, first_message, user_id)
|
||||||
|
)
|
||||||
|
# Store reference to prevent garbage collection
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
|
# Check if there's conversation history (more than just the current message)
|
||||||
|
has_history = len(session.messages) > 1
|
||||||
|
system_prompt, _ = await _build_system_prompt(
|
||||||
|
user_id, has_conversation_history=has_history
|
||||||
|
)
|
||||||
|
set_execution_context(user_id, session, None)
|
||||||
|
|
||||||
|
message_id = str(uuid.uuid4())
|
||||||
|
text_block_id = str(uuid.uuid4())
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
yield StreamStart(messageId=message_id, taskId=task_id)
|
||||||
|
|
||||||
|
# Track whether the stream completed normally via ResultMessage
|
||||||
|
stream_completed = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||||
|
|
||||||
|
# Create MCP server with CoPilot tools
|
||||||
|
mcp_server = create_copilot_mcp_server()
|
||||||
|
|
||||||
|
options = ClaudeAgentOptions(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type]
|
||||||
|
allowed_tools=COPILOT_TOOL_NAMES,
|
||||||
|
hooks=create_security_hooks(user_id), # type: ignore[arg-type]
|
||||||
|
continue_conversation=True, # Enable conversation continuation
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter = SDKResponseAdapter(message_id=message_id)
|
||||||
|
adapter.set_task_id(task_id)
|
||||||
|
|
||||||
|
async with ClaudeSDKClient(options=options) as client:
|
||||||
|
# Build prompt with conversation history for context
|
||||||
|
# The SDK doesn't support replaying full conversation history,
|
||||||
|
# so we include it as context in the prompt
|
||||||
|
current_message = message or ""
|
||||||
|
if not current_message and session.messages:
|
||||||
|
last_user = [m for m in session.messages if m.role == "user"]
|
||||||
|
if last_user:
|
||||||
|
current_message = last_user[-1].content or ""
|
||||||
|
|
||||||
|
# Include conversation history if there are prior messages
|
||||||
|
if len(session.messages) > 1:
|
||||||
|
history_context = _format_conversation_history(session)
|
||||||
|
prompt = f"{history_context}{current_message}"
|
||||||
|
else:
|
||||||
|
prompt = current_message
|
||||||
|
|
||||||
|
# Guard against empty prompts
|
||||||
|
if not prompt.strip():
|
||||||
|
yield StreamError(
|
||||||
|
errorText="Message cannot be empty.",
|
||||||
|
code="empty_prompt",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
return
|
||||||
|
|
||||||
|
await client.query(prompt, session_id=session_id)
|
||||||
|
|
||||||
|
# Track assistant response to save to session
|
||||||
|
# We may need multiple assistant messages if text comes after tool results
|
||||||
|
assistant_response = ChatMessage(role="assistant", content="")
|
||||||
|
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||||
|
has_appended_assistant = False
|
||||||
|
has_tool_results = False # Track if we've received tool results
|
||||||
|
|
||||||
|
# Receive messages from the SDK
|
||||||
|
async for sdk_msg in client.receive_messages():
|
||||||
|
|
||||||
|
for response in adapter.convert_message(sdk_msg):
|
||||||
|
if isinstance(response, StreamStart):
|
||||||
|
continue
|
||||||
|
yield response
|
||||||
|
|
||||||
|
# Accumulate text deltas into assistant response
|
||||||
|
if isinstance(response, StreamTextDelta):
|
||||||
|
delta = response.delta or ""
|
||||||
|
# After tool results, create new assistant message for post-tool text
|
||||||
|
if has_tool_results and has_appended_assistant:
|
||||||
|
assistant_response = ChatMessage(
|
||||||
|
role="assistant", content=delta
|
||||||
|
)
|
||||||
|
accumulated_tool_calls = [] # Reset for new message
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
has_tool_results = False
|
||||||
|
else:
|
||||||
|
assistant_response.content = (
|
||||||
|
assistant_response.content or ""
|
||||||
|
) + delta
|
||||||
|
if not has_appended_assistant:
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
has_appended_assistant = True
|
||||||
|
|
||||||
|
# Track tool calls on the assistant message
|
||||||
|
elif isinstance(response, StreamToolInputAvailable):
|
||||||
|
accumulated_tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": response.toolCallId,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": response.toolName,
|
||||||
|
"arguments": json.dumps(response.input or {}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Update assistant message with tool calls
|
||||||
|
assistant_response.tool_calls = accumulated_tool_calls
|
||||||
|
# Append assistant message if not already (tool-only response)
|
||||||
|
if not has_appended_assistant:
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
has_appended_assistant = True
|
||||||
|
|
||||||
|
elif isinstance(response, StreamToolOutputAvailable):
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="tool",
|
||||||
|
content=(
|
||||||
|
response.output
|
||||||
|
if isinstance(response.output, str)
|
||||||
|
else str(response.output)
|
||||||
|
),
|
||||||
|
tool_call_id=response.toolCallId,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
has_tool_results = True
|
||||||
|
|
||||||
|
elif isinstance(response, StreamFinish):
|
||||||
|
stream_completed = True
|
||||||
|
|
||||||
|
# Break out of the message loop if we received finish signal
|
||||||
|
if stream_completed:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Ensure assistant response is saved even if no text deltas
|
||||||
|
# (e.g., only tool calls were made)
|
||||||
|
if (
|
||||||
|
assistant_response.content or assistant_response.tool_calls
|
||||||
|
) and not has_appended_assistant:
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"[SDK] claude-agent-sdk not available, using Anthropic fallback"
|
||||||
|
)
|
||||||
|
async for response in stream_with_anthropic(
|
||||||
|
session, system_prompt, text_block_id
|
||||||
|
):
|
||||||
|
yield response
|
||||||
|
|
||||||
|
# Save the session with accumulated messages
|
||||||
|
await upsert_chat_session(session)
|
||||||
|
logger.debug(
|
||||||
|
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
||||||
|
)
|
||||||
|
# Always yield StreamFinish to signal completion to the caller
|
||||||
|
# The adapter yields StreamFinish for the SSE stream, but we need to
|
||||||
|
# yield it here so the background task in routes.py knows to call mark_task_completed
|
||||||
|
yield StreamFinish()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
||||||
|
# Save session even on error to preserve any partial response
|
||||||
|
try:
|
||||||
|
await upsert_chat_session(session)
|
||||||
|
except Exception as save_err:
|
||||||
|
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
||||||
|
# Sanitize error message to avoid exposing internal details
|
||||||
|
yield StreamError(
|
||||||
|
errorText="An error occurred. Please try again.",
|
||||||
|
code="sdk_error",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_title_async(
|
||||||
|
session_id: str, message: str, user_id: str | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Background task to update session title."""
|
||||||
|
try:
|
||||||
|
title = await _generate_session_title(
|
||||||
|
message, user_id=user_id, session_id=session_id
|
||||||
|
)
|
||||||
|
if title:
|
||||||
|
await update_session_title(session_id, title)
|
||||||
|
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[SDK] Failed to update session title: {e}")
|
||||||
@@ -0,0 +1,213 @@
|
|||||||
|
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
|
||||||
|
|
||||||
|
This module provides the adapter layer that converts existing BaseTool implementations
|
||||||
|
into in-process MCP tools that can be used with the Claude Agent SDK.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tools import TOOL_REGISTRY
|
||||||
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Context variables to pass user/session info to tool execution
|
||||||
|
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||||
|
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||||
|
"current_session", default=None
|
||||||
|
)
|
||||||
|
_current_tool_call_id: ContextVar[str | None] = ContextVar(
|
||||||
|
"current_tool_call_id", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_execution_context(
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
tool_call_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set the execution context for tool calls.
|
||||||
|
|
||||||
|
This must be called before streaming begins to ensure tools have access
|
||||||
|
to user_id and session information.
|
||||||
|
"""
|
||||||
|
_current_user_id.set(user_id)
|
||||||
|
_current_session.set(session)
|
||||||
|
_current_tool_call_id.set(tool_call_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_execution_context() -> tuple[str | None, ChatSession | None, str | None]:
|
||||||
|
"""Get the current execution context."""
|
||||||
|
return (
|
||||||
|
_current_user_id.get(),
|
||||||
|
_current_session.get(),
|
||||||
|
_current_tool_call_id.get(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_tool_handler(base_tool: BaseTool):
|
||||||
|
"""Create an async handler function for a BaseTool.
|
||||||
|
|
||||||
|
This wraps the existing BaseTool._execute method to be compatible
|
||||||
|
with the Claude Agent SDK MCP tool format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Execute the wrapped tool and return MCP-formatted response."""
|
||||||
|
user_id, session, tool_call_id = get_execution_context()
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(
|
||||||
|
{
|
||||||
|
"error": "No session context available",
|
||||||
|
"type": "error",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call the existing tool's execute method
|
||||||
|
result = await base_tool.execute(
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
tool_call_id=tool_call_id or "sdk-call",
|
||||||
|
**args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The result is a StreamToolOutputAvailable, extract the output
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": (
|
||||||
|
result.output
|
||||||
|
if isinstance(result.output, str)
|
||||||
|
else json.dumps(result.output)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"isError": not result.success,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(
|
||||||
|
{
|
||||||
|
"error": str(e),
|
||||||
|
"type": "error",
|
||||||
|
"message": f"Failed to execute {base_tool.name}",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
return tool_handler
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_definitions() -> list[dict[str, Any]]:
|
||||||
|
"""Get all tool definitions in MCP format.
|
||||||
|
|
||||||
|
Returns a list of tool definitions that can be used with
|
||||||
|
create_sdk_mcp_server or as raw tool definitions.
|
||||||
|
"""
|
||||||
|
tool_definitions = []
|
||||||
|
|
||||||
|
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||||
|
tool_def = {
|
||||||
|
"name": tool_name,
|
||||||
|
"description": base_tool.description,
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": base_tool.parameters.get("properties", {}),
|
||||||
|
"required": base_tool.parameters.get("required", []),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tool_definitions.append(tool_def)
|
||||||
|
|
||||||
|
return tool_definitions
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_handlers() -> dict[str, Any]:
|
||||||
|
"""Get all tool handlers mapped by name.
|
||||||
|
|
||||||
|
Returns a dictionary mapping tool names to their handler functions.
|
||||||
|
"""
|
||||||
|
handlers = {}
|
||||||
|
|
||||||
|
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||||
|
handlers[tool_name] = create_tool_handler(base_tool)
|
||||||
|
|
||||||
|
return handlers
|
||||||
|
|
||||||
|
|
||||||
|
# Create the MCP server configuration
|
||||||
|
def create_copilot_mcp_server():
|
||||||
|
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||||
|
|
||||||
|
This can be passed to ClaudeAgentOptions.mcp_servers.
|
||||||
|
|
||||||
|
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
|
||||||
|
package being available. This function returns the configuration that
|
||||||
|
can be used with the SDK.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||||
|
|
||||||
|
# Create decorated tool functions
|
||||||
|
sdk_tools = []
|
||||||
|
|
||||||
|
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||||
|
# Get the handler
|
||||||
|
handler = create_tool_handler(base_tool)
|
||||||
|
|
||||||
|
# Create the decorated tool
|
||||||
|
# The @tool decorator expects (name, description, schema)
|
||||||
|
decorated = tool(
|
||||||
|
tool_name,
|
||||||
|
base_tool.description,
|
||||||
|
base_tool.parameters.get("properties", {}),
|
||||||
|
)(handler)
|
||||||
|
|
||||||
|
sdk_tools.append(decorated)
|
||||||
|
|
||||||
|
# Create the MCP server
|
||||||
|
server = create_sdk_mcp_server(
|
||||||
|
name="copilot",
|
||||||
|
version="1.0.0",
|
||||||
|
tools=sdk_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
return server
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"claude-agent-sdk not available, returning tool definitions only"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"tools": get_tool_definitions(),
|
||||||
|
"handlers": get_tool_handlers(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# List of tool names for allowed_tools configuration
|
||||||
|
COPILOT_TOOL_NAMES = [f"mcp__copilot__{name}" for name in TOOL_REGISTRY.keys()]
|
||||||
|
|
||||||
|
# Also export the raw tool names for flexibility
|
||||||
|
RAW_TOOL_NAMES = list(TOOL_REGISTRY.keys())
|
||||||
@@ -555,6 +555,10 @@ async def get_active_task_for_session(
|
|||||||
if task_user_id and user_id != task_user_id:
|
if task_user_id and user_id != task_user_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
||||||
|
)
|
||||||
|
|
||||||
# Get the last message ID from Redis Stream
|
# Get the last message ID from Redis Stream
|
||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
last_id = "0-0"
|
last_id = "0-0"
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any
|
|||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.library import model as library_model
|
from backend.api.features.library import model as library_model
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
|
from backend.data import graph as graph_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
CredentialsFieldInfo,
|
CredentialsFieldInfo,
|
||||||
@@ -43,8 +44,14 @@ async def fetch_graph_from_store_slug(
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Get the graph from store listing version
|
# Get the graph from store listing version
|
||||||
graph = await store_db.get_available_graph(
|
graph_meta = await store_db.get_available_graph(
|
||||||
store_agent.store_listing_version_id, hide_nodes=False
|
store_agent.store_listing_version_id
|
||||||
|
)
|
||||||
|
graph = await graph_db.get_graph(
|
||||||
|
graph_id=graph_meta.id,
|
||||||
|
version=graph_meta.version,
|
||||||
|
user_id=None, # Public access
|
||||||
|
include_subgraphs=True,
|
||||||
)
|
)
|
||||||
return graph, store_agent
|
return graph, store_agent
|
||||||
|
|
||||||
@@ -121,7 +128,7 @@ def build_missing_credentials_from_graph(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
field_key: _serialize_missing_credential(field_key, field_info)
|
field_key: _serialize_missing_credential(field_key, field_info)
|
||||||
for field_key, (field_info, _, _) in aggregated_fields.items()
|
for field_key, (field_info, _node_fields) in aggregated_fields.items()
|
||||||
if field_key not in matched_keys
|
if field_key not in matched_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -262,8 +269,7 @@ async def match_user_credentials_to_graph(
|
|||||||
# provider is in the set of acceptable providers.
|
# provider is in the set of acceptable providers.
|
||||||
for credential_field_name, (
|
for credential_field_name, (
|
||||||
credential_requirements,
|
credential_requirements,
|
||||||
_,
|
_node_fields,
|
||||||
_,
|
|
||||||
) in aggregated_creds.items():
|
) in aggregated_creds.items():
|
||||||
# Find first matching credential by provider, type, and scopes
|
# Find first matching credential by provider, type, and scopes
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
|
|||||||
@@ -374,7 +374,7 @@ async def get_library_agent_by_graph_id(
|
|||||||
|
|
||||||
|
|
||||||
async def add_generated_agent_image(
|
async def add_generated_agent_image(
|
||||||
graph: graph_db.GraphBaseMeta,
|
graph: graph_db.BaseGraph,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
) -> Optional[prisma.models.LibraryAgent]:
|
) -> Optional[prisma.models.LibraryAgent]:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Literal, overload
|
from typing import Any, Literal
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
@@ -11,8 +11,8 @@ import prisma.types
|
|||||||
|
|
||||||
from backend.data.db import transaction
|
from backend.data.db import transaction
|
||||||
from backend.data.graph import (
|
from backend.data.graph import (
|
||||||
|
GraphMeta,
|
||||||
GraphModel,
|
GraphModel,
|
||||||
GraphModelWithoutNodes,
|
|
||||||
get_graph,
|
get_graph,
|
||||||
get_graph_as_admin,
|
get_graph_as_admin,
|
||||||
get_sub_graphs,
|
get_sub_graphs,
|
||||||
@@ -334,22 +334,7 @@ async def get_store_agent_details(
|
|||||||
raise DatabaseError("Failed to fetch agent details") from e
|
raise DatabaseError("Failed to fetch agent details") from e
|
||||||
|
|
||||||
|
|
||||||
@overload
|
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
||||||
async def get_available_graph(
|
|
||||||
store_listing_version_id: str, hide_nodes: Literal[False]
|
|
||||||
) -> GraphModel: ...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
async def get_available_graph(
|
|
||||||
store_listing_version_id: str, hide_nodes: Literal[True] = True
|
|
||||||
) -> GraphModelWithoutNodes: ...
|
|
||||||
|
|
||||||
|
|
||||||
async def get_available_graph(
|
|
||||||
store_listing_version_id: str,
|
|
||||||
hide_nodes: bool = True,
|
|
||||||
) -> GraphModelWithoutNodes | GraphModel:
|
|
||||||
try:
|
try:
|
||||||
# Get avaialble, non-deleted store listing version
|
# Get avaialble, non-deleted store listing version
|
||||||
store_listing_version = (
|
store_listing_version = (
|
||||||
@@ -359,7 +344,7 @@ async def get_available_graph(
|
|||||||
"isAvailable": True,
|
"isAvailable": True,
|
||||||
"isDeleted": False,
|
"isDeleted": False,
|
||||||
},
|
},
|
||||||
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
include={"AgentGraph": {"include": {"Nodes": True}}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -369,9 +354,7 @@ async def get_available_graph(
|
|||||||
detail=f"Store listing version {store_listing_version_id} not found",
|
detail=f"Store listing version {store_listing_version_id} not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
return (GraphModelWithoutNodes if hide_nodes else GraphModel).from_db(
|
return GraphModel.from_db(store_listing_version.AgentGraph).meta()
|
||||||
store_listing_version.AgentGraph
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting agent: {e}")
|
logger.error(f"Error getting agent: {e}")
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from backend.blocks.ideogram import (
|
|||||||
StyleType,
|
StyleType,
|
||||||
UpscaleOption,
|
UpscaleOption,
|
||||||
)
|
)
|
||||||
from backend.data.graph import GraphBaseMeta
|
from backend.data.graph import BaseGraph
|
||||||
from backend.data.model import CredentialsMetaInput, ProviderName
|
from backend.data.model import CredentialsMetaInput, ProviderName
|
||||||
from backend.integrations.credentials_store import ideogram_credentials
|
from backend.integrations.credentials_store import ideogram_credentials
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
@@ -34,14 +34,14 @@ class ImageStyle(str, Enum):
|
|||||||
DIGITAL_ART = "digital art"
|
DIGITAL_ART = "digital art"
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||||
if settings.config.use_agent_image_generation_v2:
|
if settings.config.use_agent_image_generation_v2:
|
||||||
return await generate_agent_image_v2(graph=agent)
|
return await generate_agent_image_v2(graph=agent)
|
||||||
else:
|
else:
|
||||||
return await generate_agent_image_v1(agent=agent)
|
return await generate_agent_image_v1(agent=agent)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Ideogram model.
|
Generate an image for an agent using Ideogram model.
|
||||||
Returns:
|
Returns:
|
||||||
@@ -54,17 +54,14 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
|
|||||||
description = f"{name} ({graph.description})" if graph.description else name
|
description = f"{name} ({graph.description})" if graph.description else name
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
"Create a visually striking retro-futuristic vector pop art illustration "
|
f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring "
|
||||||
f'prominently featuring "{name}" in bold typography. The image clearly and '
|
f'"{name}" in bold typography. The image clearly and literally depicts a {description}, '
|
||||||
f"literally depicts a {description}, along with recognizable objects directly "
|
f"along with recognizable objects directly associated with the primary function of a {name}. "
|
||||||
f"associated with the primary function of a {name}. "
|
f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the "
|
||||||
f"Ensure the imagery is concrete, intuitive, and immediately understandable, "
|
f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric "
|
||||||
f"clearly conveying the purpose of a {name}. "
|
f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a "
|
||||||
"Maintain vibrant, limited-palette colors, sharp vector lines, "
|
f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, "
|
||||||
"geometric shapes, flat illustration techniques, and solid colors "
|
f"prioritizing clear visual storytelling and thematic clarity above all else."
|
||||||
"without gradients or shading. Preserve a retro-futuristic aesthetic "
|
|
||||||
"influenced by mid-century futurism and 1960s psychedelia, "
|
|
||||||
"prioritizing clear visual storytelling and thematic clarity above all else."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
custom_colors = [
|
custom_colors = [
|
||||||
@@ -102,12 +99,12 @@ async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.Bytes
|
|||||||
return io.BytesIO(response.content)
|
return io.BytesIO(response.content)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Flux model via Replicate API.
|
Generate an image for an agent using Flux model via Replicate API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent (GraphBaseMeta | AgentGraph): The agent to generate an image for
|
agent (Graph): The agent to generate an image for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
io.BytesIO: The generated image as bytes
|
io.BytesIO: The generated image as bytes
|
||||||
@@ -117,13 +114,7 @@ async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.Bytes
|
|||||||
raise ValueError("Missing Replicate API key in settings")
|
raise ValueError("Missing Replicate API key in settings")
|
||||||
|
|
||||||
# Construct prompt from agent details
|
# Construct prompt from agent details
|
||||||
prompt = (
|
prompt = f"Create a visually engaging app store thumbnail for the AI agent that highlights what it does in a clear and captivating way:\n- **Name**: {agent.name}\n- **Description**: {agent.description}\nFocus on showcasing its core functionality with an appealing design."
|
||||||
"Create a visually engaging app store thumbnail for the AI agent "
|
|
||||||
"that highlights what it does in a clear and captivating way:\n"
|
|
||||||
f"- **Name**: {agent.name}\n"
|
|
||||||
f"- **Description**: {agent.description}\n"
|
|
||||||
f"Focus on showcasing its core functionality with an appealing design."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set up Replicate client
|
# Set up Replicate client
|
||||||
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ async def get_agent(
|
|||||||
)
|
)
|
||||||
async def get_graph_meta_by_store_listing_version_id(
|
async def get_graph_meta_by_store_listing_version_id(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
) -> backend.data.graph.GraphModelWithoutNodes:
|
) -> backend.data.graph.GraphMeta:
|
||||||
"""
|
"""
|
||||||
Get Agent Graph from Store Listing Version ID.
|
Get Agent Graph from Store Listing Version ID.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
webset = await aexa.websets.get(id=input_data.external_id)
|
webset = aexa.websets.get(id=input_data.external_id)
|
||||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||||
|
|
||||||
yield "webset", webset_result
|
yield "webset", webset_result
|
||||||
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
count=input_data.search_count,
|
count=input_data.search_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
webset = await aexa.websets.create(
|
webset = aexa.websets.create(
|
||||||
params=CreateWebsetParameters(
|
params=CreateWebsetParameters(
|
||||||
search=search_params,
|
search=search_params,
|
||||||
external_id=input_data.external_id,
|
external_id=input_data.external_id,
|
||||||
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = await aexa.websets.list(
|
response = aexa.websets.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
sdk_webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
deleted_webset = aexa.websets.delete(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
deleted_webset.status.value
|
deleted_webset.status.value
|
||||||
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
canceled_webset.status.value
|
canceled_webset.status.value
|
||||||
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
|
|||||||
entity["description"] = input_data.entity_description
|
entity["description"] = input_data.entity_description
|
||||||
payload["entity"] = entity
|
payload["entity"] = entity
|
||||||
|
|
||||||
sdk_preview = await aexa.websets.preview(params=payload)
|
sdk_preview = aexa.websets.preview(params=payload)
|
||||||
|
|
||||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||||
|
|
||||||
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Extract basic info
|
# Extract basic info
|
||||||
webset_id = webset.id
|
webset_id = webset.id
|
||||||
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
total_items = 0
|
total_items = 0
|
||||||
|
|
||||||
if input_data.include_sample_items and input_data.sample_size > 0:
|
if input_data.include_sample_items and input_data.sample_size > 0:
|
||||||
items_response = await aexa.websets.items.list(
|
items_response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
sample_items_data = [
|
sample_items_data = [
|
||||||
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset details
|
# Get webset details
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = await aexa.websets.enrichments.create(
|
sdk_enrichment = aexa.websets.enrichments.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_enrich = await aexa.websets.enrichments.get(
|
current_enrich = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=enrichment_id
|
webset_id=input_data.webset_id, id=enrichment_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
|
|
||||||
if current_status in ["completed", "failed", "cancelled"]:
|
if current_status in ["completed", "failed", "cancelled"]:
|
||||||
# Estimate items from webset searches
|
# Estimate items from webset searches
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
for search in webset.searches:
|
for search in webset.searches:
|
||||||
if search.progress:
|
if search.progress:
|
||||||
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = await aexa.websets.enrichments.get(
|
sdk_enrichment = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_enrichment = await aexa.websets.enrichments.delete(
|
deleted_enrichment = aexa.websets.enrichments.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
canceled_enrichment = aexa.websets.enrichments.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to estimate how many items were enriched before cancellation
|
# Try to estimate how many items were enriched before cancellation
|
||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
items_response = await aexa.websets.items.list(
|
items_response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=100
|
webset_id=input_data.webset_id, limit=100
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
# Create mock SDK import object
|
# Create mock SDK import object
|
||||||
mock_import = MagicMock()
|
mock_import = MagicMock()
|
||||||
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
imports=MagicMock(create=AsyncMock(return_value=mock_import))
|
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_import = await aexa.websets.imports.create(
|
sdk_import = aexa.websets.imports.create(
|
||||||
params=payload, csv_data=input_data.csv_data
|
params=payload, csv_data=input_data.csv_data
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
||||||
|
|
||||||
import_obj = ImportModel.from_sdk(sdk_import)
|
import_obj = ImportModel.from_sdk(sdk_import)
|
||||||
|
|
||||||
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = await aexa.websets.imports.list(
|
response = aexa.websets.imports.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -474,9 +474,7 @@ class ExaDeleteImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_import = await aexa.websets.imports.delete(
|
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
||||||
import_id=input_data.import_id
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "import_id", deleted_import.id
|
yield "import_id", deleted_import.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -575,14 +573,14 @@ class ExaExportWebsetBlock(Block):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create async iterator for list_all
|
# Create mock iterator
|
||||||
async def async_item_iterator(*args, **kwargs):
|
mock_items = [mock_item1, mock_item2]
|
||||||
for item in [mock_item1, mock_item2]:
|
|
||||||
yield item
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(items=MagicMock(list_all=async_item_iterator))
|
websets=MagicMock(
|
||||||
|
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
||||||
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -604,7 +602,7 @@ class ExaExportWebsetBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
async for sdk_item in item_iterator:
|
for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_item = await aexa.websets.items.get(
|
sdk_item = aexa.websets.items.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
response = None
|
response = None
|
||||||
|
|
||||||
while time.time() - start_time < input_data.wait_timeout:
|
while time.time() - start_time < input_data.wait_timeout:
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
interval = min(interval * 1.2, 10)
|
interval = min(interval * 1.2, 10)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_item = await aexa.websets.items.delete(
|
deleted_item = aexa.websets.items.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
async for sdk_item in item_iterator:
|
for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
entity_type = "unknown"
|
entity_type = "unknown"
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Get sample items if requested
|
# Get sample items if requested
|
||||||
sample_items: List[WebsetItemModel] = []
|
sample_items: List[WebsetItemModel] = []
|
||||||
if input_data.sample_size > 0:
|
if input_data.sample_size > 0:
|
||||||
items_response = await aexa.websets.items.list(
|
items_response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
# Convert to our stable models
|
# Convert to our stable models
|
||||||
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get items starting from cursor
|
# Get items starting from cursor
|
||||||
response = await aexa.websets.items.list(
|
response = aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.since_cursor,
|
cursor=input_data.since_cursor,
|
||||||
limit=input_data.max_items,
|
limit=input_data.max_items,
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
# Create mock SDK monitor object
|
# Create mock SDK monitor object
|
||||||
mock_monitor = MagicMock()
|
mock_monitor = MagicMock()
|
||||||
@@ -263,7 +263,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
monitors=MagicMock(create=AsyncMock(return_value=mock_monitor))
|
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -320,7 +320,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = await aexa.websets.monitors.create(params=payload)
|
sdk_monitor = aexa.websets.monitors.create(params=payload)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -384,7 +384,7 @@ class ExaGetMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -476,7 +476,7 @@ class ExaUpdateMonitorBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = await aexa.websets.monitors.update(
|
sdk_monitor = aexa.websets.monitors.update(
|
||||||
monitor_id=input_data.monitor_id, params=payload
|
monitor_id=input_data.monitor_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -522,9 +522,7 @@ class ExaDeleteMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_monitor = await aexa.websets.monitors.delete(
|
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
|
||||||
monitor_id=input_data.monitor_id
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "monitor_id", deleted_monitor.id
|
yield "monitor_id", deleted_monitor.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -581,7 +579,7 @@ class ExaListMonitorsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = await aexa.websets.monitors.list(
|
response = aexa.websets.monitors.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
WebsetTargetStatus.IDLE,
|
WebsetTargetStatus.IDLE,
|
||||||
WebsetTargetStatus.ANY_COMPLETE,
|
WebsetTargetStatus.ANY_COMPLETE,
|
||||||
]:
|
]:
|
||||||
final_webset = await aexa.websets.wait_until_idle(
|
final_webset = aexa.websets.wait_until_idle(
|
||||||
id=input_data.webset_id,
|
id=input_data.webset_id,
|
||||||
timeout=input_data.timeout,
|
timeout=input_data.timeout,
|
||||||
poll_interval=input_data.check_interval,
|
poll_interval=input_data.check_interval,
|
||||||
@@ -164,7 +164,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
interval = input_data.check_interval
|
interval = input_data.check_interval
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current webset status
|
# Get current webset status
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
current_status = (
|
current_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -209,7 +209,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
|
|
||||||
# Timeout reached
|
# Timeout reached
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
final_status = (
|
final_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -345,7 +345,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current search status using SDK
|
# Get current search status using SDK
|
||||||
search = await aexa.websets.searches.get(
|
search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -401,7 +401,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
search = await aexa.websets.searches.get(
|
search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -503,7 +503,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current enrichment status using SDK
|
# Get current enrichment status using SDK
|
||||||
enrichment = await aexa.websets.enrichments.get(
|
enrichment = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -548,7 +548,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
enrichment = await aexa.websets.enrichments.get(
|
enrichment = aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -575,7 +575,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||||
"""Get sample enriched data and count."""
|
"""Get sample enriched data and count."""
|
||||||
# Get a few items to see enrichment results using SDK
|
# Get a few items to see enrichment results using SDK
|
||||||
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
|
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||||
|
|
||||||
sample_data: list[SampleEnrichmentModel] = []
|
sample_data: list[SampleEnrichmentModel] = []
|
||||||
enriched_count = 0
|
enriched_count = 0
|
||||||
|
|||||||
@@ -317,7 +317,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
|
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = await aexa.websets.searches.create(
|
sdk_search = aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -350,7 +350,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
poll_start = time.time()
|
poll_start = time.time()
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_search = await aexa.websets.searches.get(
|
current_search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=search_id
|
webset_id=input_data.webset_id, id=search_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -442,7 +442,7 @@ class ExaGetWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = await aexa.websets.searches.get(
|
sdk_search = aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -523,7 +523,7 @@ class ExaCancelWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_search = await aexa.websets.searches.cancel(
|
canceled_search = aexa.websets.searches.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -604,7 +604,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset to check existing searches
|
# Get webset to check existing searches
|
||||||
webset = await aexa.websets.get(id=input_data.webset_id)
|
webset = aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Look for existing search with same query
|
# Look for existing search with same query
|
||||||
existing_search = None
|
existing_search = None
|
||||||
@@ -636,7 +636,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
if input_data.entity_type != SearchEntityType.AUTO:
|
if input_data.entity_type != SearchEntityType.AUTO:
|
||||||
payload["entity"] = {"type": input_data.entity_type.value}
|
payload["entity"] = {"type": input_data.entity_type.value}
|
||||||
|
|
||||||
sdk_search = await aexa.websets.searches.create(
|
sdk_search = aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -596,10 +596,10 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
|||||||
|
|
||||||
def get_parallel_tool_calls_param(
|
def get_parallel_tool_calls_param(
|
||||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||||
) -> bool | openai.Omit:
|
):
|
||||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||||
return openai.omit
|
return openai.NOT_GIVEN
|
||||||
return parallel_tool_calls
|
return parallel_tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -246,9 +246,7 @@ class BlockSchema(BaseModel):
|
|||||||
f"is not of type {CredentialsMetaInput.__name__}"
|
f"is not of type {CredentialsMetaInput.__name__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
CredentialsMetaInput.validate_credentials_field_schema(
|
credentials_fields[field_name].validate_credentials_field_schema(cls)
|
||||||
cls.get_field_schema(field_name), field_name
|
|
||||||
)
|
|
||||||
|
|
||||||
elif field_name in credentials_fields:
|
elif field_name in credentials_fields:
|
||||||
raise KeyError(
|
raise KeyError(
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Self, cast
|
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, cast
|
||||||
|
|
||||||
from prisma.enums import SubmissionStatus
|
from prisma.enums import SubmissionStatus
|
||||||
from prisma.models import (
|
from prisma.models import (
|
||||||
@@ -20,7 +20,7 @@ from prisma.types import (
|
|||||||
AgentNodeLinkCreateInput,
|
AgentNodeLinkCreateInput,
|
||||||
StoreListingVersionWhereInput,
|
StoreListingVersionWhereInput,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, BeforeValidator, Field
|
from pydantic import BaseModel, BeforeValidator, Field, create_model
|
||||||
from pydantic.fields import computed_field
|
from pydantic.fields import computed_field
|
||||||
|
|
||||||
from backend.blocks.agent import AgentExecutorBlock
|
from backend.blocks.agent import AgentExecutorBlock
|
||||||
@@ -30,6 +30,7 @@ from backend.data.db import prisma as db
|
|||||||
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
||||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
|
CredentialsField,
|
||||||
CredentialsFieldInfo,
|
CredentialsFieldInfo,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
is_credentials_field_name,
|
is_credentials_field_name,
|
||||||
@@ -44,6 +45,7 @@ from .block import (
|
|||||||
AnyBlockSchema,
|
AnyBlockSchema,
|
||||||
Block,
|
Block,
|
||||||
BlockInput,
|
BlockInput,
|
||||||
|
BlockSchema,
|
||||||
BlockType,
|
BlockType,
|
||||||
EmptySchema,
|
EmptySchema,
|
||||||
get_block,
|
get_block,
|
||||||
@@ -111,12 +113,10 @@ class Link(BaseDbModel):
|
|||||||
|
|
||||||
class Node(BaseDbModel):
|
class Node(BaseDbModel):
|
||||||
block_id: str
|
block_id: str
|
||||||
input_default: BlockInput = Field( # dict[input_name, default_value]
|
input_default: BlockInput = {} # dict[input_name, default_value]
|
||||||
default_factory=dict
|
metadata: dict[str, Any] = {}
|
||||||
)
|
input_links: list[Link] = []
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
output_links: list[Link] = []
|
||||||
input_links: list[Link] = Field(default_factory=list)
|
|
||||||
output_links: list[Link] = Field(default_factory=list)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def credentials_optional(self) -> bool:
|
def credentials_optional(self) -> bool:
|
||||||
@@ -221,33 +221,18 @@ class NodeModel(Node):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class GraphBaseMeta(BaseDbModel):
|
class BaseGraph(BaseDbModel):
|
||||||
"""
|
|
||||||
Shared base for `GraphMeta` and `BaseGraph`, with core graph metadata fields.
|
|
||||||
"""
|
|
||||||
|
|
||||||
version: int = 1
|
version: int = 1
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
|
nodes: list[Node] = []
|
||||||
|
links: list[Link] = []
|
||||||
forked_from_id: str | None = None
|
forked_from_id: str | None = None
|
||||||
forked_from_version: int | None = None
|
forked_from_version: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class BaseGraph(GraphBaseMeta):
|
|
||||||
"""
|
|
||||||
Graph with nodes, links, and computed I/O schema fields.
|
|
||||||
|
|
||||||
Used to represent sub-graphs within a `Graph`. Contains the full graph
|
|
||||||
structure including nodes and links, plus computed fields for schemas
|
|
||||||
and trigger info. Does NOT include user_id or created_at (see GraphModel).
|
|
||||||
"""
|
|
||||||
|
|
||||||
nodes: list[Node] = Field(default_factory=list)
|
|
||||||
links: list[Link] = Field(default_factory=list)
|
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def input_schema(self) -> dict[str, Any]:
|
def input_schema(self) -> dict[str, Any]:
|
||||||
@@ -376,79 +361,44 @@ class GraphTriggerInfo(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Graph(BaseGraph):
|
class Graph(BaseGraph):
|
||||||
"""Creatable graph model used in API create/update endpoints."""
|
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
|
||||||
|
|
||||||
sub_graphs: list[BaseGraph] = Field(default_factory=list) # Flattened sub-graphs
|
|
||||||
|
|
||||||
|
|
||||||
class GraphMeta(GraphBaseMeta):
|
|
||||||
"""
|
|
||||||
Lightweight graph metadata model representing an existing graph from the database,
|
|
||||||
for use in listings and summaries.
|
|
||||||
|
|
||||||
Lacks `GraphModel`'s nodes, links, and expensive computed fields.
|
|
||||||
Use for list endpoints where full graph data is not needed and performance matters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: str # type: ignore
|
|
||||||
version: int # type: ignore
|
|
||||||
user_id: str
|
|
||||||
created_at: datetime
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_db(cls, graph: "AgentGraph") -> Self:
|
|
||||||
return cls(
|
|
||||||
id=graph.id,
|
|
||||||
version=graph.version,
|
|
||||||
is_active=graph.isActive,
|
|
||||||
name=graph.name or "",
|
|
||||||
description=graph.description or "",
|
|
||||||
instructions=graph.instructions,
|
|
||||||
recommended_schedule_cron=graph.recommendedScheduleCron,
|
|
||||||
forked_from_id=graph.forkedFromId,
|
|
||||||
forked_from_version=graph.forkedFromVersion,
|
|
||||||
user_id=graph.userId,
|
|
||||||
created_at=graph.createdAt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphModel(Graph, GraphMeta):
|
|
||||||
"""
|
|
||||||
Full graph model representing an existing graph from the database.
|
|
||||||
|
|
||||||
This is the primary model for working with persisted graphs. Includes all
|
|
||||||
graph data (nodes, links, sub_graphs) plus user ownership and timestamps.
|
|
||||||
Provides computed fields (input_schema, output_schema, etc.) used during
|
|
||||||
set-up (frontend) and execution (backend).
|
|
||||||
|
|
||||||
Inherits from:
|
|
||||||
- `Graph`: provides structure (nodes, links, sub_graphs) and computed schemas
|
|
||||||
- `GraphMeta`: provides user_id, created_at for database records
|
|
||||||
"""
|
|
||||||
|
|
||||||
nodes: list[NodeModel] = Field(default_factory=list) # type: ignore
|
|
||||||
|
|
||||||
@property
|
|
||||||
def starting_nodes(self) -> list[NodeModel]:
|
|
||||||
outbound_nodes = {link.sink_id for link in self.links}
|
|
||||||
input_nodes = {
|
|
||||||
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
|
||||||
}
|
|
||||||
return [
|
|
||||||
node
|
|
||||||
for node in self.nodes
|
|
||||||
if node.id not in outbound_nodes or node.id in input_nodes
|
|
||||||
]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
|
||||||
return cast(NodeModel, super().webhook_input_node)
|
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def credentials_input_schema(self) -> dict[str, Any]:
|
def credentials_input_schema(self) -> dict[str, Any]:
|
||||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
schema = self._credentials_input_schema.jsonschema()
|
||||||
|
|
||||||
|
# Determine which credential fields are required based on credentials_optional metadata
|
||||||
|
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||||
|
required_fields = []
|
||||||
|
|
||||||
|
# Build a map of node_id -> node for quick lookup
|
||||||
|
all_nodes = {node.id: node for node in self.nodes}
|
||||||
|
for sub_graph in self.sub_graphs:
|
||||||
|
for node in sub_graph.nodes:
|
||||||
|
all_nodes[node.id] = node
|
||||||
|
|
||||||
|
for field_key, (
|
||||||
|
_field_info,
|
||||||
|
node_field_pairs,
|
||||||
|
) in graph_credentials_inputs.items():
|
||||||
|
# A field is required if ANY node using it has credentials_optional=False
|
||||||
|
is_required = False
|
||||||
|
for node_id, _field_name in node_field_pairs:
|
||||||
|
node = all_nodes.get(node_id)
|
||||||
|
if node and not node.credentials_optional:
|
||||||
|
is_required = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if is_required:
|
||||||
|
required_fields.append(field_key)
|
||||||
|
|
||||||
|
schema["required"] = required_fields
|
||||||
|
return schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _credentials_input_schema(self) -> type[BlockSchema]:
|
||||||
|
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||||
f"{graph_credentials_inputs}"
|
f"{graph_credentials_inputs}"
|
||||||
@@ -456,8 +406,8 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
|
|
||||||
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
||||||
graph_cred_fields = list(graph_credentials_inputs.values())
|
graph_cred_fields = list(graph_credentials_inputs.values())
|
||||||
for i, (field, keys, _) in enumerate(graph_cred_fields):
|
for i, (field, keys) in enumerate(graph_cred_fields):
|
||||||
for other_field, other_keys, _ in list(graph_cred_fields)[i + 1 :]:
|
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
||||||
if field.provider != other_field.provider:
|
if field.provider != other_field.provider:
|
||||||
continue
|
continue
|
||||||
if ProviderName.HTTP in field.provider:
|
if ProviderName.HTTP in field.provider:
|
||||||
@@ -473,78 +423,31 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
f"keys: {keys} <> {other_keys}."
|
f"keys: {keys} <> {other_keys}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build JSON schema directly to avoid expensive create_model + validation overhead
|
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
|
||||||
properties = {}
|
agg_field_key: (
|
||||||
required_fields = []
|
CredentialsMetaInput[
|
||||||
|
Literal[tuple(field_info.provider)], # type: ignore
|
||||||
for agg_field_key, (
|
Literal[tuple(field_info.supported_types)], # type: ignore
|
||||||
field_info,
|
],
|
||||||
_,
|
CredentialsField(
|
||||||
is_required,
|
required_scopes=set(field_info.required_scopes or []),
|
||||||
) in graph_credentials_inputs.items():
|
discriminator=field_info.discriminator,
|
||||||
providers = list(field_info.provider)
|
discriminator_mapping=field_info.discriminator_mapping,
|
||||||
cred_types = list(field_info.supported_types)
|
discriminator_values=field_info.discriminator_values,
|
||||||
|
),
|
||||||
field_schema: dict[str, Any] = {
|
|
||||||
"credentials_provider": providers,
|
|
||||||
"credentials_types": cred_types,
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"id": {"title": "Id", "type": "string"},
|
|
||||||
"title": {
|
|
||||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
|
||||||
"default": None,
|
|
||||||
"title": "Title",
|
|
||||||
},
|
|
||||||
"provider": {
|
|
||||||
"title": "Provider",
|
|
||||||
"type": "string",
|
|
||||||
**(
|
|
||||||
{"enum": providers}
|
|
||||||
if len(providers) > 1
|
|
||||||
else {"const": providers[0]}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"title": "Type",
|
|
||||||
"type": "string",
|
|
||||||
**(
|
|
||||||
{"enum": cred_types}
|
|
||||||
if len(cred_types) > 1
|
|
||||||
else {"const": cred_types[0]}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["id", "provider", "type"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add other (optional) field info items
|
|
||||||
field_schema.update(
|
|
||||||
field_info.model_dump(
|
|
||||||
by_alias=True,
|
|
||||||
exclude_defaults=True,
|
|
||||||
exclude={"provider", "supported_types"}, # already included above
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
||||||
# Ensure field schema is well-formed
|
|
||||||
CredentialsMetaInput.validate_credentials_field_schema(
|
|
||||||
field_schema, agg_field_key
|
|
||||||
)
|
|
||||||
|
|
||||||
properties[agg_field_key] = field_schema
|
|
||||||
if is_required:
|
|
||||||
required_fields.append(agg_field_key)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": properties,
|
|
||||||
"required": required_fields,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return create_model(
|
||||||
|
self.name.replace(" ", "") + "CredentialsInputSchema",
|
||||||
|
__base__=BlockSchema,
|
||||||
|
**fields, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
def aggregate_credentials_inputs(
|
def aggregate_credentials_inputs(
|
||||||
self,
|
self,
|
||||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
dict[aggregated_field_key, tuple(
|
dict[aggregated_field_key, tuple(
|
||||||
@@ -552,19 +455,13 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
(now includes discriminator_values from matching nodes)
|
(now includes discriminator_values from matching nodes)
|
||||||
set[(node_id, field_name)]: Node credentials fields that are
|
set[(node_id, field_name)]: Node credentials fields that are
|
||||||
compatible with this aggregated field spec
|
compatible with this aggregated field spec
|
||||||
bool: True if the field is required (any node has credentials_optional=False)
|
|
||||||
)]
|
)]
|
||||||
"""
|
"""
|
||||||
# First collect all credential field data with input defaults
|
# First collect all credential field data with input defaults
|
||||||
# Track (field_info, (node_id, field_name), is_required) for each credential field
|
node_credential_data = []
|
||||||
node_credential_data: list[tuple[CredentialsFieldInfo, tuple[str, str]]] = []
|
|
||||||
node_required_map: dict[str, bool] = {} # node_id -> is_required
|
|
||||||
|
|
||||||
for graph in [self] + self.sub_graphs:
|
for graph in [self] + self.sub_graphs:
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
# Track if this node requires credentials (credentials_optional=False means required)
|
|
||||||
node_required_map[node.id] = not node.credentials_optional
|
|
||||||
|
|
||||||
for (
|
for (
|
||||||
field_name,
|
field_name,
|
||||||
field_info,
|
field_info,
|
||||||
@@ -588,21 +485,37 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Combine credential field info (this will merge discriminator_values automatically)
|
# Combine credential field info (this will merge discriminator_values automatically)
|
||||||
combined = CredentialsFieldInfo.combine(*node_credential_data)
|
return CredentialsFieldInfo.combine(*node_credential_data)
|
||||||
|
|
||||||
# Add is_required flag to each aggregated field
|
|
||||||
# A field is required if ANY node using it has credentials_optional=False
|
class GraphModel(Graph):
|
||||||
return {
|
user_id: str
|
||||||
key: (
|
nodes: list[NodeModel] = [] # type: ignore
|
||||||
field_info,
|
|
||||||
node_field_pairs,
|
created_at: datetime
|
||||||
any(
|
|
||||||
node_required_map.get(node_id, True)
|
@property
|
||||||
for node_id, _ in node_field_pairs
|
def starting_nodes(self) -> list[NodeModel]:
|
||||||
),
|
outbound_nodes = {link.sink_id for link in self.links}
|
||||||
)
|
input_nodes = {
|
||||||
for key, (field_info, node_field_pairs) in combined.items()
|
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
||||||
}
|
}
|
||||||
|
return [
|
||||||
|
node
|
||||||
|
for node in self.nodes
|
||||||
|
if node.id not in outbound_nodes or node.id in input_nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
||||||
|
return cast(NodeModel, super().webhook_input_node)
|
||||||
|
|
||||||
|
def meta(self) -> "GraphMeta":
|
||||||
|
"""
|
||||||
|
Returns a GraphMeta object with metadata about the graph.
|
||||||
|
This is used to return metadata about the graph without exposing nodes and links.
|
||||||
|
"""
|
||||||
|
return GraphMeta.from_graph(self)
|
||||||
|
|
||||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||||
"""
|
"""
|
||||||
@@ -886,14 +799,13 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
if is_static_output_block(link.source_id):
|
if is_static_output_block(link.source_id):
|
||||||
link.is_static = True # Each value block output should be static.
|
link.is_static = True # Each value block output should be static.
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def from_db( # type: ignore[reportIncompatibleMethodOverride]
|
def from_db(
|
||||||
cls,
|
|
||||||
graph: AgentGraph,
|
graph: AgentGraph,
|
||||||
for_export: bool = False,
|
for_export: bool = False,
|
||||||
sub_graphs: list[AgentGraph] | None = None,
|
sub_graphs: list[AgentGraph] | None = None,
|
||||||
) -> Self:
|
) -> "GraphModel":
|
||||||
return cls(
|
return GraphModel(
|
||||||
id=graph.id,
|
id=graph.id,
|
||||||
user_id=graph.userId if not for_export else "",
|
user_id=graph.userId if not for_export else "",
|
||||||
version=graph.version,
|
version=graph.version,
|
||||||
@@ -919,28 +831,17 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def hide_nodes(self) -> "GraphModelWithoutNodes":
|
|
||||||
"""
|
|
||||||
Returns a copy of the `GraphModel` with nodes, links, and sub-graphs hidden
|
|
||||||
(excluded from serialization). They are still present in the model instance
|
|
||||||
so all computed fields (e.g. `credentials_input_schema`) still work.
|
|
||||||
"""
|
|
||||||
return GraphModelWithoutNodes.model_validate(self, from_attributes=True)
|
|
||||||
|
|
||||||
|
class GraphMeta(Graph):
|
||||||
|
user_id: str
|
||||||
|
|
||||||
class GraphModelWithoutNodes(GraphModel):
|
# Easy work-around to prevent exposing nodes and links in the API response
|
||||||
"""
|
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
|
||||||
GraphModel variant that excludes nodes, links, and sub-graphs from serialization.
|
links: list[Link] = Field(default=[], exclude=True)
|
||||||
|
|
||||||
Used in contexts like the store where exposing internal graph structure
|
@staticmethod
|
||||||
is not desired. Inherits all computed fields from GraphModel but marks
|
def from_graph(graph: GraphModel) -> "GraphMeta":
|
||||||
nodes and links as excluded from JSON output.
|
return GraphMeta(**graph.model_dump())
|
||||||
"""
|
|
||||||
|
|
||||||
nodes: list[NodeModel] = Field(default_factory=list, exclude=True)
|
|
||||||
links: list[Link] = Field(default_factory=list, exclude=True)
|
|
||||||
|
|
||||||
sub_graphs: list[BaseGraph] = Field(default_factory=list, exclude=True)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphsPaginated(BaseModel):
|
class GraphsPaginated(BaseModel):
|
||||||
@@ -1011,11 +912,21 @@ async def list_graphs_paginated(
|
|||||||
where=where_clause,
|
where=where_clause,
|
||||||
distinct=["id"],
|
distinct=["id"],
|
||||||
order={"version": "desc"},
|
order={"version": "desc"},
|
||||||
|
include=AGENT_GRAPH_INCLUDE,
|
||||||
skip=offset,
|
skip=offset,
|
||||||
take=page_size,
|
take=page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_models = [GraphMeta.from_db(graph) for graph in graphs]
|
graph_models: list[GraphMeta] = []
|
||||||
|
for graph in graphs:
|
||||||
|
try:
|
||||||
|
graph_meta = GraphModel.from_db(graph).meta()
|
||||||
|
# Trigger serialization to validate that the graph is well formed
|
||||||
|
graph_meta.model_dump()
|
||||||
|
graph_models.append(graph_meta)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing graph {graph.id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
return GraphsPaginated(
|
return GraphsPaginated(
|
||||||
graphs=graph_models,
|
graphs=graph_models,
|
||||||
|
|||||||
@@ -163,6 +163,7 @@ class User(BaseModel):
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from prisma.models import User as PrismaUser
|
from prisma.models import User as PrismaUser
|
||||||
|
|
||||||
|
from backend.data.block import BlockSchema
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -507,13 +508,15 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
||||||
return get_args(cls.model_fields["type"].annotation)
|
return get_args(cls.model_fields["type"].annotation)
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def validate_credentials_field_schema(
|
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
|
||||||
field_schema: dict[str, Any], field_name: str
|
|
||||||
):
|
|
||||||
"""Validates the schema of a credentials input field"""
|
"""Validates the schema of a credentials input field"""
|
||||||
|
field_name = next(
|
||||||
|
name for name, type in model.get_credentials_fields().items() if type is cls
|
||||||
|
)
|
||||||
|
field_schema = model.jsonschema()["properties"][field_name]
|
||||||
try:
|
try:
|
||||||
field_info = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
if "Field required [type=missing" not in str(e):
|
if "Field required [type=missing" not in str(e):
|
||||||
raise
|
raise
|
||||||
@@ -523,11 +526,11 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
f"{field_schema}"
|
f"{field_schema}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
providers = field_info.provider
|
providers = cls.allowed_providers()
|
||||||
if (
|
if (
|
||||||
providers is not None
|
providers is not None
|
||||||
and len(providers) > 1
|
and len(providers) > 1
|
||||||
and not field_info.discriminator
|
and not schema_extra.discriminator
|
||||||
):
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Multi-provider CredentialsField '{field_name}' "
|
f"Multi-provider CredentialsField '{field_name}' "
|
||||||
|
|||||||
@@ -373,7 +373,7 @@ def make_node_credentials_input_map(
|
|||||||
# Get aggregated credentials fields for the graph
|
# Get aggregated credentials fields for the graph
|
||||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||||
|
|
||||||
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
|
||||||
# Best-effort map: skip missing items
|
# Best-effort map: skip missing items
|
||||||
if graph_input_name not in graph_credentials_input:
|
if graph_input_name not in graph_credentials_input:
|
||||||
continue
|
continue
|
||||||
|
|||||||
6846
autogpt_platform/backend/poetry.lock
generated
6846
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -13,6 +13,7 @@ aio-pika = "^9.5.5"
|
|||||||
aiohttp = "^3.10.0"
|
aiohttp = "^3.10.0"
|
||||||
aiodns = "^3.5.0"
|
aiodns = "^3.5.0"
|
||||||
anthropic = "^0.59.0"
|
anthropic = "^0.59.0"
|
||||||
|
claude-agent-sdk = "^0.1.0"
|
||||||
apscheduler = "^3.11.1"
|
apscheduler = "^3.11.1"
|
||||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||||
@@ -21,7 +22,7 @@ cryptography = "^45.0"
|
|||||||
discord-py = "^2.5.2"
|
discord-py = "^2.5.2"
|
||||||
e2b-code-interpreter = "^1.5.2"
|
e2b-code-interpreter = "^1.5.2"
|
||||||
elevenlabs = "^1.50.0"
|
elevenlabs = "^1.50.0"
|
||||||
fastapi = "^0.128.0"
|
fastapi = "^0.116.1"
|
||||||
feedparser = "^6.0.11"
|
feedparser = "^6.0.11"
|
||||||
flake8 = "^7.3.0"
|
flake8 = "^7.3.0"
|
||||||
google-api-python-client = "^2.177.0"
|
google-api-python-client = "^2.177.0"
|
||||||
@@ -35,7 +36,7 @@ jinja2 = "^3.1.6"
|
|||||||
jsonref = "^1.1.0"
|
jsonref = "^1.1.0"
|
||||||
jsonschema = "^4.25.0"
|
jsonschema = "^4.25.0"
|
||||||
langfuse = "^3.11.0"
|
langfuse = "^3.11.0"
|
||||||
launchdarkly-server-sdk = "^9.14.1"
|
launchdarkly-server-sdk = "^9.12.0"
|
||||||
mem0ai = "^0.1.115"
|
mem0ai = "^0.1.115"
|
||||||
moviepy = "^2.1.2"
|
moviepy = "^2.1.2"
|
||||||
ollama = "^0.5.1"
|
ollama = "^0.5.1"
|
||||||
@@ -52,8 +53,8 @@ prometheus-client = "^0.22.1"
|
|||||||
prometheus-fastapi-instrumentator = "^7.0.0"
|
prometheus-fastapi-instrumentator = "^7.0.0"
|
||||||
psutil = "^7.0.0"
|
psutil = "^7.0.0"
|
||||||
psycopg2-binary = "^2.9.10"
|
psycopg2-binary = "^2.9.10"
|
||||||
pydantic = { extras = ["email"], version = "^2.12.5" }
|
pydantic = { extras = ["email"], version = "^2.11.7" }
|
||||||
pydantic-settings = "^2.12.0"
|
pydantic-settings = "^2.10.1"
|
||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.1.0"
|
pytest-asyncio = "^1.1.0"
|
||||||
python-dotenv = "^1.1.1"
|
python-dotenv = "^1.1.1"
|
||||||
@@ -65,11 +66,11 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal
|
|||||||
sqlalchemy = "^2.0.40"
|
sqlalchemy = "^2.0.40"
|
||||||
strenum = "^0.4.9"
|
strenum = "^0.4.9"
|
||||||
stripe = "^11.5.0"
|
stripe = "^11.5.0"
|
||||||
supabase = "2.27.2"
|
supabase = "2.17.0"
|
||||||
tenacity = "^9.1.2"
|
tenacity = "^9.1.2"
|
||||||
todoist-api-python = "^2.1.7"
|
todoist-api-python = "^2.1.7"
|
||||||
tweepy = "^4.16.0"
|
tweepy = "^4.16.0"
|
||||||
uvicorn = { extras = ["standard"], version = "^0.40.0" }
|
uvicorn = { extras = ["standard"], version = "^0.35.0" }
|
||||||
websockets = "^15.0"
|
websockets = "^15.0"
|
||||||
youtube-transcript-api = "^1.2.1"
|
youtube-transcript-api = "^1.2.1"
|
||||||
yt-dlp = "2025.12.08"
|
yt-dlp = "2025.12.08"
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
"credentials_input_schema": {
|
"credentials_input_schema": {
|
||||||
"properties": {},
|
"properties": {},
|
||||||
"required": [],
|
"required": [],
|
||||||
|
"title": "TestGraphCredentialsInputSchema",
|
||||||
"type": "object"
|
"type": "object"
|
||||||
},
|
},
|
||||||
"description": "A test graph",
|
"description": "A test graph",
|
||||||
|
|||||||
@@ -1,14 +1,34 @@
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"created_at": "2025-09-04T13:37:00",
|
"credentials_input_schema": {
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
"title": "TestGraphCredentialsInputSchema",
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"description": "A test graph",
|
"description": "A test graph",
|
||||||
"forked_from_id": null,
|
"forked_from_id": null,
|
||||||
"forked_from_version": null,
|
"forked_from_version": null,
|
||||||
|
"has_external_trigger": false,
|
||||||
|
"has_human_in_the_loop": false,
|
||||||
|
"has_sensitive_action": false,
|
||||||
"id": "graph-123",
|
"id": "graph-123",
|
||||||
|
"input_schema": {
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"instructions": null,
|
"instructions": null,
|
||||||
"is_active": true,
|
"is_active": true,
|
||||||
"name": "Test Graph",
|
"name": "Test Graph",
|
||||||
|
"output_schema": {
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"recommended_schedule_cron": null,
|
"recommended_schedule_cron": null,
|
||||||
|
"sub_graphs": [],
|
||||||
|
"trigger_setup_info": null,
|
||||||
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||||
"version": 1
|
"version": 1
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
||||||
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||||
import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput";
|
import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import { getSchemaDefaultCredentials } from "../../helpers";
|
import { getSchemaDefaultCredentials } from "../../helpers";
|
||||||
@@ -9,7 +9,7 @@ type Credential = CredentialsMetaInput | undefined;
|
|||||||
type Credentials = Record<string, Credential>;
|
type Credentials = Record<string, Credential>;
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
agent: GraphModel | null;
|
agent: GraphMeta | null;
|
||||||
siblingInputs?: Record<string, any>;
|
siblingInputs?: Record<string, any>;
|
||||||
onCredentialsChange: (
|
onCredentialsChange: (
|
||||||
credentials: Record<string, CredentialsMetaInput>,
|
credentials: Record<string, CredentialsMetaInput>,
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
||||||
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||||
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api/types";
|
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api/types";
|
||||||
|
|
||||||
export function getCredentialFields(
|
export function getCredentialFields(
|
||||||
agent: GraphModel | null,
|
agent: GraphMeta | null,
|
||||||
): AgentCredentialsFields {
|
): AgentCredentialsFields {
|
||||||
if (!agent) return {};
|
if (!agent) return {};
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ import type {
|
|||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
} from "@/lib/autogpt-server-api/types";
|
} from "@/lib/autogpt-server-api/types";
|
||||||
import type { InputValues } from "./types";
|
import type { InputValues } from "./types";
|
||||||
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||||
|
|
||||||
export function computeInitialAgentInputs(
|
export function computeInitialAgentInputs(
|
||||||
agent: GraphModel | null,
|
agent: GraphMeta | null,
|
||||||
existingInputs?: InputValues | null,
|
existingInputs?: InputValues | null,
|
||||||
): InputValues {
|
): InputValues {
|
||||||
const properties = agent?.input_schema?.properties || {};
|
const properties = agent?.input_schema?.properties || {};
|
||||||
@@ -29,7 +29,7 @@ export function computeInitialAgentInputs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
type IsRunDisabledParams = {
|
type IsRunDisabledParams = {
|
||||||
agent: GraphModel | null;
|
agent: GraphMeta | null;
|
||||||
isRunning: boolean;
|
isRunning: boolean;
|
||||||
agentInputs: InputValues | null | undefined;
|
agentInputs: InputValues | null | undefined;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -30,8 +30,6 @@ import {
|
|||||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||||
import { GraphMeta } from "@/lib/autogpt-server-api";
|
import { GraphMeta } from "@/lib/autogpt-server-api";
|
||||||
import jaro from "jaro-winkler";
|
import jaro from "jaro-winkler";
|
||||||
import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
|
||||||
import { okData } from "@/app/api/helpers";
|
|
||||||
|
|
||||||
type _Block = Omit<Block, "inputSchema" | "outputSchema"> & {
|
type _Block = Omit<Block, "inputSchema" | "outputSchema"> & {
|
||||||
uiKey?: string;
|
uiKey?: string;
|
||||||
@@ -109,8 +107,6 @@ export function BlocksControl({
|
|||||||
.filter((b) => b.uiType !== BlockUIType.AGENT)
|
.filter((b) => b.uiType !== BlockUIType.AGENT)
|
||||||
.sort((a, b) => a.name.localeCompare(b.name));
|
.sort((a, b) => a.name.localeCompare(b.name));
|
||||||
|
|
||||||
// Agent blocks are created from GraphMeta which doesn't include schemas.
|
|
||||||
// Schemas will be fetched on-demand when the block is actually added.
|
|
||||||
const agentBlockList = flows
|
const agentBlockList = flows
|
||||||
.map((flow): _Block => {
|
.map((flow): _Block => {
|
||||||
return {
|
return {
|
||||||
@@ -120,9 +116,8 @@ export function BlocksControl({
|
|||||||
`Ver.${flow.version}` +
|
`Ver.${flow.version}` +
|
||||||
(flow.description ? ` | ${flow.description}` : ""),
|
(flow.description ? ` | ${flow.description}` : ""),
|
||||||
categories: [{ category: "AGENT", description: "" }],
|
categories: [{ category: "AGENT", description: "" }],
|
||||||
// Empty schemas - will be populated when block is added
|
inputSchema: flow.input_schema,
|
||||||
inputSchema: { type: "object", properties: {} },
|
outputSchema: flow.output_schema,
|
||||||
outputSchema: { type: "object", properties: {} },
|
|
||||||
staticOutput: false,
|
staticOutput: false,
|
||||||
uiType: BlockUIType.AGENT,
|
uiType: BlockUIType.AGENT,
|
||||||
costs: [],
|
costs: [],
|
||||||
@@ -130,7 +125,8 @@ export function BlocksControl({
|
|||||||
hardcodedValues: {
|
hardcodedValues: {
|
||||||
graph_id: flow.id,
|
graph_id: flow.id,
|
||||||
graph_version: flow.version,
|
graph_version: flow.version,
|
||||||
// Schemas will be fetched on-demand when block is added
|
input_schema: flow.input_schema,
|
||||||
|
output_schema: flow.output_schema,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
})
|
})
|
||||||
@@ -186,37 +182,6 @@ export function BlocksControl({
|
|||||||
setSelectedCategory(null);
|
setSelectedCategory(null);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
// Handler to add a block, fetching graph data on-demand for agent blocks
|
|
||||||
const handleAddBlock = useCallback(
|
|
||||||
async (block: _Block & { notAvailable: string | null }) => {
|
|
||||||
if (block.notAvailable) return;
|
|
||||||
|
|
||||||
// For agent blocks, fetch the full graph to get schemas
|
|
||||||
if (block.uiType === BlockUIType.AGENT && block.hardcodedValues) {
|
|
||||||
const graphID = block.hardcodedValues.graph_id as string;
|
|
||||||
const graphVersion = block.hardcodedValues.graph_version as number;
|
|
||||||
const graphData = okData(
|
|
||||||
await getV1GetSpecificGraph(graphID, { version: graphVersion }),
|
|
||||||
);
|
|
||||||
|
|
||||||
if (graphData) {
|
|
||||||
addBlock(block.id, block.name, {
|
|
||||||
...block.hardcodedValues,
|
|
||||||
input_schema: graphData.input_schema,
|
|
||||||
output_schema: graphData.output_schema,
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
// Fallback: add without schemas (will be incomplete)
|
|
||||||
console.error("Failed to fetch graph data for agent block");
|
|
||||||
addBlock(block.id, block.name, block.hardcodedValues || {});
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
addBlock(block.id, block.name, block.hardcodedValues || {});
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[addBlock],
|
|
||||||
);
|
|
||||||
|
|
||||||
// Extract unique categories from blocks
|
// Extract unique categories from blocks
|
||||||
const categories = useMemo(() => {
|
const categories = useMemo(() => {
|
||||||
return Array.from(
|
return Array.from(
|
||||||
@@ -338,7 +303,10 @@ export function BlocksControl({
|
|||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
}}
|
}}
|
||||||
onClick={() => handleAddBlock(block)}
|
onClick={() =>
|
||||||
|
!block.notAvailable &&
|
||||||
|
addBlock(block.id, block.name, block?.hardcodedValues || {})
|
||||||
|
}
|
||||||
title={block.notAvailable ?? undefined}
|
title={block.notAvailable ?? undefined}
|
||||||
>
|
>
|
||||||
<div
|
<div
|
||||||
|
|||||||
@@ -29,17 +29,13 @@ import "@xyflow/react/dist/style.css";
|
|||||||
import { ConnectedEdge, CustomNode } from "../CustomNode/CustomNode";
|
import { ConnectedEdge, CustomNode } from "../CustomNode/CustomNode";
|
||||||
import "./flow.css";
|
import "./flow.css";
|
||||||
import {
|
import {
|
||||||
BlockIORootSchema,
|
|
||||||
BlockUIType,
|
BlockUIType,
|
||||||
formatEdgeID,
|
formatEdgeID,
|
||||||
GraphExecutionID,
|
GraphExecutionID,
|
||||||
GraphID,
|
GraphID,
|
||||||
GraphMeta,
|
GraphMeta,
|
||||||
LibraryAgent,
|
LibraryAgent,
|
||||||
SpecialBlockID,
|
|
||||||
} from "@/lib/autogpt-server-api";
|
} from "@/lib/autogpt-server-api";
|
||||||
import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
|
||||||
import { okData } from "@/app/api/helpers";
|
|
||||||
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
||||||
import { Key, storage } from "@/services/storage/local-storage";
|
import { Key, storage } from "@/services/storage/local-storage";
|
||||||
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils";
|
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils";
|
||||||
@@ -691,94 +687,8 @@ const FlowEditor: React.FC<{
|
|||||||
[getNode, updateNode, nodes],
|
[getNode, updateNode, nodes],
|
||||||
);
|
);
|
||||||
|
|
||||||
/* Shared helper to create and add a node */
|
|
||||||
const createAndAddNode = useCallback(
|
|
||||||
async (
|
|
||||||
blockID: string,
|
|
||||||
blockName: string,
|
|
||||||
hardcodedValues: Record<string, any>,
|
|
||||||
position: { x: number; y: number },
|
|
||||||
): Promise<CustomNode | null> => {
|
|
||||||
const nodeSchema = availableBlocks.find((node) => node.id === blockID);
|
|
||||||
if (!nodeSchema) {
|
|
||||||
console.error(`Schema not found for block ID: ${blockID}`);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
// For agent blocks, fetch the full graph to get schemas
|
|
||||||
let inputSchema: BlockIORootSchema = nodeSchema.inputSchema;
|
|
||||||
let outputSchema: BlockIORootSchema = nodeSchema.outputSchema;
|
|
||||||
let finalHardcodedValues = hardcodedValues;
|
|
||||||
|
|
||||||
if (blockID === SpecialBlockID.AGENT) {
|
|
||||||
const graphID = hardcodedValues.graph_id as string;
|
|
||||||
const graphVersion = hardcodedValues.graph_version as number;
|
|
||||||
const graphData = okData(
|
|
||||||
await getV1GetSpecificGraph(graphID, { version: graphVersion }),
|
|
||||||
);
|
|
||||||
|
|
||||||
if (graphData) {
|
|
||||||
inputSchema = graphData.input_schema as BlockIORootSchema;
|
|
||||||
outputSchema = graphData.output_schema as BlockIORootSchema;
|
|
||||||
finalHardcodedValues = {
|
|
||||||
...hardcodedValues,
|
|
||||||
input_schema: graphData.input_schema,
|
|
||||||
output_schema: graphData.output_schema,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
console.error("Failed to fetch graph data for agent block");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const newNode: CustomNode = {
|
|
||||||
id: nodeId.toString(),
|
|
||||||
type: "custom",
|
|
||||||
position,
|
|
||||||
data: {
|
|
||||||
blockType: blockName,
|
|
||||||
blockCosts: nodeSchema.costs || [],
|
|
||||||
title: `${blockName} ${nodeId}`,
|
|
||||||
description: nodeSchema.description,
|
|
||||||
categories: nodeSchema.categories,
|
|
||||||
inputSchema: inputSchema,
|
|
||||||
outputSchema: outputSchema,
|
|
||||||
hardcodedValues: finalHardcodedValues,
|
|
||||||
connections: [],
|
|
||||||
isOutputOpen: false,
|
|
||||||
block_id: blockID,
|
|
||||||
isOutputStatic: nodeSchema.staticOutput,
|
|
||||||
uiType: nodeSchema.uiType,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
addNodes(newNode);
|
|
||||||
setNodeId((prevId) => prevId + 1);
|
|
||||||
clearNodesStatusAndOutput();
|
|
||||||
|
|
||||||
history.push({
|
|
||||||
type: "ADD_NODE",
|
|
||||||
payload: { node: { ...newNode, ...newNode.data } },
|
|
||||||
undo: () => deleteElements({ nodes: [{ id: newNode.id }] }),
|
|
||||||
redo: () => addNodes(newNode),
|
|
||||||
});
|
|
||||||
|
|
||||||
return newNode;
|
|
||||||
},
|
|
||||||
[
|
|
||||||
availableBlocks,
|
|
||||||
nodeId,
|
|
||||||
addNodes,
|
|
||||||
deleteElements,
|
|
||||||
clearNodesStatusAndOutput,
|
|
||||||
],
|
|
||||||
);
|
|
||||||
|
|
||||||
const addNode = useCallback(
|
const addNode = useCallback(
|
||||||
async (
|
(blockId: string, nodeType: string, hardcodedValues: any = {}) => {
|
||||||
blockId: string,
|
|
||||||
nodeType: string,
|
|
||||||
hardcodedValues: Record<string, any> = {},
|
|
||||||
) => {
|
|
||||||
const nodeSchema = availableBlocks.find((node) => node.id === blockId);
|
const nodeSchema = availableBlocks.find((node) => node.id === blockId);
|
||||||
if (!nodeSchema) {
|
if (!nodeSchema) {
|
||||||
console.error(`Schema not found for block ID: ${blockId}`);
|
console.error(`Schema not found for block ID: ${blockId}`);
|
||||||
@@ -797,42 +707,73 @@ const FlowEditor: React.FC<{
|
|||||||
// Alternative: We could also use D3 force, Intersection for this (React flow Pro examples)
|
// Alternative: We could also use D3 force, Intersection for this (React flow Pro examples)
|
||||||
|
|
||||||
const { x, y } = getViewport();
|
const { x, y } = getViewport();
|
||||||
const position =
|
const viewportCoordinates =
|
||||||
nodeDimensions && Object.keys(nodeDimensions).length > 0
|
nodeDimensions && Object.keys(nodeDimensions).length > 0
|
||||||
? findNewlyAddedBlockCoordinates(
|
? // we will get all the dimension of nodes, then store
|
||||||
|
findNewlyAddedBlockCoordinates(
|
||||||
nodeDimensions,
|
nodeDimensions,
|
||||||
nodeSchema.uiType == BlockUIType.NOTE ? 300 : 500,
|
nodeSchema.uiType == BlockUIType.NOTE ? 300 : 500,
|
||||||
60,
|
60,
|
||||||
1.0,
|
1.0,
|
||||||
)
|
)
|
||||||
: {
|
: // we will get all the dimension of nodes, then store
|
||||||
|
{
|
||||||
x: window.innerWidth / 2 - x,
|
x: window.innerWidth / 2 - x,
|
||||||
y: window.innerHeight / 2 - y,
|
y: window.innerHeight / 2 - y,
|
||||||
};
|
};
|
||||||
|
|
||||||
const newNode = await createAndAddNode(
|
const newNode: CustomNode = {
|
||||||
blockId,
|
id: nodeId.toString(),
|
||||||
nodeType,
|
type: "custom",
|
||||||
hardcodedValues,
|
position: viewportCoordinates, // Set the position to the calculated viewport center
|
||||||
position,
|
data: {
|
||||||
);
|
blockType: nodeType,
|
||||||
if (!newNode) return;
|
blockCosts: nodeSchema.costs,
|
||||||
|
title: `${nodeType} ${nodeId}`,
|
||||||
|
description: nodeSchema.description,
|
||||||
|
categories: nodeSchema.categories,
|
||||||
|
inputSchema: nodeSchema.inputSchema,
|
||||||
|
outputSchema: nodeSchema.outputSchema,
|
||||||
|
hardcodedValues: hardcodedValues,
|
||||||
|
connections: [],
|
||||||
|
isOutputOpen: false,
|
||||||
|
block_id: blockId,
|
||||||
|
isOutputStatic: nodeSchema.staticOutput,
|
||||||
|
uiType: nodeSchema.uiType,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
addNodes(newNode);
|
||||||
|
setNodeId((prevId) => prevId + 1);
|
||||||
|
clearNodesStatusAndOutput(); // Clear status and output when a new node is added
|
||||||
|
|
||||||
setViewport(
|
setViewport(
|
||||||
{
|
{
|
||||||
x: -position.x * 0.8 + (window.innerWidth - 0.0) / 2,
|
// Rough estimate of the dimension of the node is: 500x400px.
|
||||||
y: -position.y * 0.8 + (window.innerHeight - 400) / 2,
|
// Though we skip shifting the X, considering the block menu side-bar.
|
||||||
|
x: -viewportCoordinates.x * 0.8 + (window.innerWidth - 0.0) / 2,
|
||||||
|
y: -viewportCoordinates.y * 0.8 + (window.innerHeight - 400) / 2,
|
||||||
zoom: 0.8,
|
zoom: 0.8,
|
||||||
},
|
},
|
||||||
{ duration: 500 },
|
{ duration: 500 },
|
||||||
);
|
);
|
||||||
|
|
||||||
|
history.push({
|
||||||
|
type: "ADD_NODE",
|
||||||
|
payload: { node: { ...newNode, ...newNode.data } },
|
||||||
|
undo: () => deleteElements({ nodes: [{ id: newNode.id }] }),
|
||||||
|
redo: () => addNodes(newNode),
|
||||||
|
});
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
|
nodeId,
|
||||||
getViewport,
|
getViewport,
|
||||||
setViewport,
|
setViewport,
|
||||||
availableBlocks,
|
availableBlocks,
|
||||||
|
addNodes,
|
||||||
nodeDimensions,
|
nodeDimensions,
|
||||||
createAndAddNode,
|
deleteElements,
|
||||||
|
clearNodesStatusAndOutput,
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -979,7 +920,7 @@ const FlowEditor: React.FC<{
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const onDrop = useCallback(
|
const onDrop = useCallback(
|
||||||
async (event: React.DragEvent) => {
|
(event: React.DragEvent) => {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
const blockData = event.dataTransfer.getData("application/reactflow");
|
const blockData = event.dataTransfer.getData("application/reactflow");
|
||||||
@@ -994,17 +935,62 @@ const FlowEditor: React.FC<{
|
|||||||
y: event.clientY,
|
y: event.clientY,
|
||||||
});
|
});
|
||||||
|
|
||||||
await createAndAddNode(
|
// Find the block schema
|
||||||
blockId,
|
const nodeSchema = availableBlocks.find((node) => node.id === blockId);
|
||||||
blockName,
|
if (!nodeSchema) {
|
||||||
hardcodedValues || {},
|
console.error(`Schema not found for block ID: ${blockId}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the new node at the drop position
|
||||||
|
const newNode: CustomNode = {
|
||||||
|
id: nodeId.toString(),
|
||||||
|
type: "custom",
|
||||||
position,
|
position,
|
||||||
);
|
data: {
|
||||||
|
blockType: blockName,
|
||||||
|
blockCosts: nodeSchema.costs || [],
|
||||||
|
title: `${blockName} ${nodeId}`,
|
||||||
|
description: nodeSchema.description,
|
||||||
|
categories: nodeSchema.categories,
|
||||||
|
inputSchema: nodeSchema.inputSchema,
|
||||||
|
outputSchema: nodeSchema.outputSchema,
|
||||||
|
hardcodedValues: hardcodedValues,
|
||||||
|
connections: [],
|
||||||
|
isOutputOpen: false,
|
||||||
|
block_id: blockId,
|
||||||
|
uiType: nodeSchema.uiType,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
history.push({
|
||||||
|
type: "ADD_NODE",
|
||||||
|
payload: { node: { ...newNode, ...newNode.data } },
|
||||||
|
undo: () => {
|
||||||
|
deleteElements({ nodes: [{ id: newNode.id } as any], edges: [] });
|
||||||
|
},
|
||||||
|
redo: () => {
|
||||||
|
addNodes([newNode]);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
addNodes([newNode]);
|
||||||
|
clearNodesStatusAndOutput();
|
||||||
|
|
||||||
|
setNodeId((prevId) => prevId + 1);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Failed to drop block:", error);
|
console.error("Failed to drop block:", error);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[screenToFlowPosition, createAndAddNode],
|
[
|
||||||
|
nodeId,
|
||||||
|
availableBlocks,
|
||||||
|
nodes,
|
||||||
|
edges,
|
||||||
|
addNodes,
|
||||||
|
screenToFlowPosition,
|
||||||
|
deleteElements,
|
||||||
|
clearNodesStatusAndOutput,
|
||||||
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
const buildContextValue: BuilderContextType = useMemo(
|
const buildContextValue: BuilderContextType = useMemo(
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ import { AgentRunDraftView } from "@/app/(platform)/library/agents/[id]/componen
|
|||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
import type {
|
import type {
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
Graph,
|
GraphMeta,
|
||||||
} from "@/lib/autogpt-server-api/types";
|
} from "@/lib/autogpt-server-api/types";
|
||||||
|
|
||||||
interface RunInputDialogProps {
|
interface RunInputDialogProps {
|
||||||
isOpen: boolean;
|
isOpen: boolean;
|
||||||
doClose: () => void;
|
doClose: () => void;
|
||||||
graph: Graph;
|
graph: GraphMeta;
|
||||||
doRun?: (
|
doRun?: (
|
||||||
inputs: Record<string, any>,
|
inputs: Record<string, any>,
|
||||||
credentialsInputs: Record<string, CredentialsMetaInput>,
|
credentialsInputs: Record<string, CredentialsMetaInput>,
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ import { CustomNodeData } from "@/app/(platform)/build/components/legacy-builder
|
|||||||
import {
|
import {
|
||||||
BlockUIType,
|
BlockUIType,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
Graph,
|
GraphMeta,
|
||||||
} from "@/lib/autogpt-server-api/types";
|
} from "@/lib/autogpt-server-api/types";
|
||||||
import RunnerOutputUI, { OutputNodeInfo } from "./RunnerOutputUI";
|
import RunnerOutputUI, { OutputNodeInfo } from "./RunnerOutputUI";
|
||||||
import { RunnerInputDialog } from "./RunnerInputUI";
|
import { RunnerInputDialog } from "./RunnerInputUI";
|
||||||
|
|
||||||
interface RunnerUIWrapperProps {
|
interface RunnerUIWrapperProps {
|
||||||
graph: Graph;
|
graph: GraphMeta;
|
||||||
nodes: Node<CustomNodeData>[];
|
nodes: Node<CustomNodeData>[];
|
||||||
graphExecutionError?: string | null;
|
graphExecutionError?: string | null;
|
||||||
saveAndRun: (
|
saveAndRun: (
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { GraphInputSchema } from "@/lib/autogpt-server-api";
|
import { GraphInputSchema } from "@/lib/autogpt-server-api";
|
||||||
import { GraphLike, IncompatibilityInfo } from "./types";
|
import { GraphMetaLike, IncompatibilityInfo } from "./types";
|
||||||
|
|
||||||
// Helper type for schema properties - the generated types are too loose
|
// Helper type for schema properties - the generated types are too loose
|
||||||
type SchemaProperties = Record<string, GraphInputSchema["properties"][string]>;
|
type SchemaProperties = Record<string, GraphInputSchema["properties"][string]>;
|
||||||
@@ -36,7 +36,7 @@ export function getSchemaRequired(schema: unknown): SchemaRequired {
|
|||||||
*/
|
*/
|
||||||
export function createUpdatedAgentNodeInputs(
|
export function createUpdatedAgentNodeInputs(
|
||||||
currentInputs: Record<string, unknown>,
|
currentInputs: Record<string, unknown>,
|
||||||
latestSubGraphVersion: GraphLike,
|
latestSubGraphVersion: GraphMetaLike,
|
||||||
): Record<string, unknown> {
|
): Record<string, unknown> {
|
||||||
return {
|
return {
|
||||||
...currentInputs,
|
...currentInputs,
|
||||||
|
|||||||
@@ -1,11 +1,7 @@
|
|||||||
import type {
|
import type { GraphMeta as LegacyGraphMeta } from "@/lib/autogpt-server-api";
|
||||||
Graph as LegacyGraph,
|
|
||||||
GraphMeta as LegacyGraphMeta,
|
|
||||||
} from "@/lib/autogpt-server-api";
|
|
||||||
import type { GraphModel as GeneratedGraph } from "@/app/api/__generated__/models/graphModel";
|
|
||||||
import type { GraphMeta as GeneratedGraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
import type { GraphMeta as GeneratedGraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
||||||
|
|
||||||
export type SubAgentUpdateInfo<T extends GraphLike = GraphLike> = {
|
export type SubAgentUpdateInfo<T extends GraphMetaLike = GraphMetaLike> = {
|
||||||
hasUpdate: boolean;
|
hasUpdate: boolean;
|
||||||
currentVersion: number;
|
currentVersion: number;
|
||||||
latestVersion: number;
|
latestVersion: number;
|
||||||
@@ -14,10 +10,7 @@ export type SubAgentUpdateInfo<T extends GraphLike = GraphLike> = {
|
|||||||
incompatibilities: IncompatibilityInfo | null;
|
incompatibilities: IncompatibilityInfo | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Union type for Graph (with schemas) that works with both legacy and new builder
|
// Union type for GraphMeta that works with both legacy and new builder
|
||||||
export type GraphLike = LegacyGraph | GeneratedGraph;
|
|
||||||
|
|
||||||
// Union type for GraphMeta (without schemas) for version detection
|
|
||||||
export type GraphMetaLike = LegacyGraphMeta | GeneratedGraphMeta;
|
export type GraphMetaLike = LegacyGraphMeta | GeneratedGraphMeta;
|
||||||
|
|
||||||
export type IncompatibilityInfo = {
|
export type IncompatibilityInfo = {
|
||||||
|
|||||||
@@ -1,11 +1,5 @@
|
|||||||
import { useMemo } from "react";
|
import { useMemo } from "react";
|
||||||
import type {
|
import { GraphInputSchema, GraphOutputSchema } from "@/lib/autogpt-server-api";
|
||||||
GraphInputSchema,
|
|
||||||
GraphOutputSchema,
|
|
||||||
} from "@/lib/autogpt-server-api";
|
|
||||||
import type { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
|
||||||
import { useGetV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
|
||||||
import { okData } from "@/app/api/helpers";
|
|
||||||
import { getEffectiveType } from "@/lib/utils";
|
import { getEffectiveType } from "@/lib/utils";
|
||||||
import { EdgeLike, getSchemaProperties, getSchemaRequired } from "./helpers";
|
import { EdgeLike, getSchemaProperties, getSchemaRequired } from "./helpers";
|
||||||
import {
|
import {
|
||||||
@@ -17,38 +11,26 @@ import {
|
|||||||
/**
|
/**
|
||||||
* Checks if a newer version of a sub-agent is available and determines compatibility
|
* Checks if a newer version of a sub-agent is available and determines compatibility
|
||||||
*/
|
*/
|
||||||
export function useSubAgentUpdate(
|
export function useSubAgentUpdate<T extends GraphMetaLike>(
|
||||||
nodeID: string,
|
nodeID: string,
|
||||||
graphID: string | undefined,
|
graphID: string | undefined,
|
||||||
graphVersion: number | undefined,
|
graphVersion: number | undefined,
|
||||||
currentInputSchema: GraphInputSchema | undefined,
|
currentInputSchema: GraphInputSchema | undefined,
|
||||||
currentOutputSchema: GraphOutputSchema | undefined,
|
currentOutputSchema: GraphOutputSchema | undefined,
|
||||||
connections: EdgeLike[],
|
connections: EdgeLike[],
|
||||||
availableGraphs: GraphMetaLike[],
|
availableGraphs: T[],
|
||||||
): SubAgentUpdateInfo<GraphModel> {
|
): SubAgentUpdateInfo<T> {
|
||||||
// Find the latest version of the same graph
|
// Find the latest version of the same graph
|
||||||
const latestGraphInfo = useMemo(() => {
|
const latestGraph = useMemo(() => {
|
||||||
if (!graphID) return null;
|
if (!graphID) return null;
|
||||||
return availableGraphs.find((graph) => graph.id === graphID) || null;
|
return availableGraphs.find((graph) => graph.id === graphID) || null;
|
||||||
}, [graphID, availableGraphs]);
|
}, [graphID, availableGraphs]);
|
||||||
|
|
||||||
// Check if there's a newer version available
|
// Check if there's an update available
|
||||||
const hasUpdate = useMemo(() => {
|
const hasUpdate = useMemo(() => {
|
||||||
if (!latestGraphInfo || graphVersion === undefined) return false;
|
if (!latestGraph || graphVersion === undefined) return false;
|
||||||
return latestGraphInfo.version! > graphVersion;
|
return latestGraph.version! > graphVersion;
|
||||||
}, [latestGraphInfo, graphVersion]);
|
}, [latestGraph, graphVersion]);
|
||||||
|
|
||||||
// Fetch full graph IF an update is detected
|
|
||||||
const { data: latestGraph } = useGetV1GetSpecificGraph(
|
|
||||||
graphID ?? "",
|
|
||||||
{ version: latestGraphInfo?.version },
|
|
||||||
{
|
|
||||||
query: {
|
|
||||||
enabled: hasUpdate && !!graphID && !!latestGraphInfo?.version,
|
|
||||||
select: okData,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
// Get connected input and output handles for this specific node
|
// Get connected input and output handles for this specific node
|
||||||
const connectedHandles = useMemo(() => {
|
const connectedHandles = useMemo(() => {
|
||||||
@@ -170,8 +152,8 @@ export function useSubAgentUpdate(
|
|||||||
return {
|
return {
|
||||||
hasUpdate,
|
hasUpdate,
|
||||||
currentVersion: graphVersion || 0,
|
currentVersion: graphVersion || 0,
|
||||||
latestVersion: latestGraphInfo?.version || 0,
|
latestVersion: latestGraph?.version || 0,
|
||||||
latestGraph: latestGraph || null,
|
latestGraph,
|
||||||
isCompatible: compatibilityResult.isCompatible,
|
isCompatible: compatibilityResult.isCompatible,
|
||||||
incompatibilities: compatibilityResult.incompatibilities,
|
incompatibilities: compatibilityResult.incompatibilities,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ interface GraphStore {
|
|||||||
outputSchema: Record<string, any> | null,
|
outputSchema: Record<string, any> | null,
|
||||||
) => void;
|
) => void;
|
||||||
|
|
||||||
// Available graphs; used for sub-graph updated version detection
|
// Available graphs; used for sub-graph updates
|
||||||
availableSubGraphs: GraphMeta[];
|
availableSubGraphs: GraphMeta[];
|
||||||
setAvailableSubGraphs: (graphs: GraphMeta[]) => void;
|
setAvailableSubGraphs: (graphs: GraphMeta[]) => void;
|
||||||
|
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import React, {
|
|||||||
import {
|
import {
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
CredentialsType,
|
CredentialsType,
|
||||||
Graph,
|
|
||||||
GraphExecutionID,
|
GraphExecutionID,
|
||||||
|
GraphMeta,
|
||||||
LibraryAgentPreset,
|
LibraryAgentPreset,
|
||||||
LibraryAgentPresetID,
|
LibraryAgentPresetID,
|
||||||
LibraryAgentPresetUpdatable,
|
LibraryAgentPresetUpdatable,
|
||||||
@@ -69,7 +69,7 @@ export function AgentRunDraftView({
|
|||||||
className,
|
className,
|
||||||
recommendedScheduleCron,
|
recommendedScheduleCron,
|
||||||
}: {
|
}: {
|
||||||
graph: Graph;
|
graph: GraphMeta;
|
||||||
agentActions?: ButtonAction[];
|
agentActions?: ButtonAction[];
|
||||||
recommendedScheduleCron?: string | null;
|
recommendedScheduleCron?: string | null;
|
||||||
doRun?: (
|
doRun?: (
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
import React, { useCallback, useMemo } from "react";
|
import React, { useCallback, useMemo } from "react";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
Graph,
|
|
||||||
GraphExecutionID,
|
GraphExecutionID,
|
||||||
|
GraphMeta,
|
||||||
Schedule,
|
Schedule,
|
||||||
ScheduleID,
|
ScheduleID,
|
||||||
} from "@/lib/autogpt-server-api";
|
} from "@/lib/autogpt-server-api";
|
||||||
@@ -35,7 +35,7 @@ export function AgentScheduleDetailsView({
|
|||||||
onForcedRun,
|
onForcedRun,
|
||||||
doDeleteSchedule,
|
doDeleteSchedule,
|
||||||
}: {
|
}: {
|
||||||
graph: Graph;
|
graph: GraphMeta;
|
||||||
schedule: Schedule;
|
schedule: Schedule;
|
||||||
agentActions: ButtonAction[];
|
agentActions: ButtonAction[];
|
||||||
onForcedRun: (runID: GraphExecutionID) => void;
|
onForcedRun: (runID: GraphExecutionID) => void;
|
||||||
|
|||||||
@@ -5629,9 +5629,7 @@
|
|||||||
"description": "Successful Response",
|
"description": "Successful Response",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": { "$ref": "#/components/schemas/GraphMeta" }
|
||||||
"$ref": "#/components/schemas/GraphModelWithoutNodes"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -6497,6 +6495,18 @@
|
|||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Recommended Schedule Cron"
|
"title": "Recommended Schedule Cron"
|
||||||
},
|
},
|
||||||
|
"nodes": {
|
||||||
|
"items": { "$ref": "#/components/schemas/Node" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Nodes",
|
||||||
|
"default": []
|
||||||
|
},
|
||||||
|
"links": {
|
||||||
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Links",
|
||||||
|
"default": []
|
||||||
|
},
|
||||||
"forked_from_id": {
|
"forked_from_id": {
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Forked From Id"
|
"title": "Forked From Id"
|
||||||
@@ -6504,22 +6514,11 @@
|
|||||||
"forked_from_version": {
|
"forked_from_version": {
|
||||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||||
"title": "Forked From Version"
|
"title": "Forked From Version"
|
||||||
},
|
|
||||||
"nodes": {
|
|
||||||
"items": { "$ref": "#/components/schemas/Node" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Nodes"
|
|
||||||
},
|
|
||||||
"links": {
|
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Links"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["name", "description"],
|
"required": ["name", "description"],
|
||||||
"title": "BaseGraph",
|
"title": "BaseGraph"
|
||||||
"description": "Graph with nodes, links, and computed I/O schema fields.\n\nUsed to represent sub-graphs within a `Graph`. Contains the full graph\nstructure including nodes and links, plus computed fields for schemas\nand trigger info. Does NOT include user_id or created_at (see GraphModel)."
|
|
||||||
},
|
},
|
||||||
"BaseGraph-Output": {
|
"BaseGraph-Output": {
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -6540,6 +6539,18 @@
|
|||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Recommended Schedule Cron"
|
"title": "Recommended Schedule Cron"
|
||||||
},
|
},
|
||||||
|
"nodes": {
|
||||||
|
"items": { "$ref": "#/components/schemas/Node" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Nodes",
|
||||||
|
"default": []
|
||||||
|
},
|
||||||
|
"links": {
|
||||||
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Links",
|
||||||
|
"default": []
|
||||||
|
},
|
||||||
"forked_from_id": {
|
"forked_from_id": {
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Forked From Id"
|
"title": "Forked From Id"
|
||||||
@@ -6548,16 +6559,6 @@
|
|||||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||||
"title": "Forked From Version"
|
"title": "Forked From Version"
|
||||||
},
|
},
|
||||||
"nodes": {
|
|
||||||
"items": { "$ref": "#/components/schemas/Node" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Nodes"
|
|
||||||
},
|
|
||||||
"links": {
|
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Links"
|
|
||||||
},
|
|
||||||
"input_schema": {
|
"input_schema": {
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -6604,8 +6605,7 @@
|
|||||||
"has_sensitive_action",
|
"has_sensitive_action",
|
||||||
"trigger_setup_info"
|
"trigger_setup_info"
|
||||||
],
|
],
|
||||||
"title": "BaseGraph",
|
"title": "BaseGraph"
|
||||||
"description": "Graph with nodes, links, and computed I/O schema fields.\n\nUsed to represent sub-graphs within a `Graph`. Contains the full graph\nstructure including nodes and links, plus computed fields for schemas\nand trigger info. Does NOT include user_id or created_at (see GraphModel)."
|
|
||||||
},
|
},
|
||||||
"BlockCategoryResponse": {
|
"BlockCategoryResponse": {
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -7399,6 +7399,18 @@
|
|||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Recommended Schedule Cron"
|
"title": "Recommended Schedule Cron"
|
||||||
},
|
},
|
||||||
|
"nodes": {
|
||||||
|
"items": { "$ref": "#/components/schemas/Node" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Nodes",
|
||||||
|
"default": []
|
||||||
|
},
|
||||||
|
"links": {
|
||||||
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Links",
|
||||||
|
"default": []
|
||||||
|
},
|
||||||
"forked_from_id": {
|
"forked_from_id": {
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Forked From Id"
|
"title": "Forked From Id"
|
||||||
@@ -7407,26 +7419,16 @@
|
|||||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||||
"title": "Forked From Version"
|
"title": "Forked From Version"
|
||||||
},
|
},
|
||||||
"nodes": {
|
|
||||||
"items": { "$ref": "#/components/schemas/Node" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Nodes"
|
|
||||||
},
|
|
||||||
"links": {
|
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Links"
|
|
||||||
},
|
|
||||||
"sub_graphs": {
|
"sub_graphs": {
|
||||||
"items": { "$ref": "#/components/schemas/BaseGraph-Input" },
|
"items": { "$ref": "#/components/schemas/BaseGraph-Input" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Sub Graphs"
|
"title": "Sub Graphs",
|
||||||
|
"default": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["name", "description"],
|
"required": ["name", "description"],
|
||||||
"title": "Graph",
|
"title": "Graph"
|
||||||
"description": "Creatable graph model used in API create/update endpoints."
|
|
||||||
},
|
},
|
||||||
"GraphExecution": {
|
"GraphExecution": {
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -7778,7 +7780,7 @@
|
|||||||
"GraphMeta": {
|
"GraphMeta": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"id": { "type": "string", "title": "Id" },
|
"id": { "type": "string", "title": "Id" },
|
||||||
"version": { "type": "integer", "title": "Version" },
|
"version": { "type": "integer", "title": "Version", "default": 1 },
|
||||||
"is_active": {
|
"is_active": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"title": "Is Active",
|
"title": "Is Active",
|
||||||
@@ -7802,24 +7804,68 @@
|
|||||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||||
"title": "Forked From Version"
|
"title": "Forked From Version"
|
||||||
},
|
},
|
||||||
|
"sub_graphs": {
|
||||||
|
"items": { "$ref": "#/components/schemas/BaseGraph-Output" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Sub Graphs",
|
||||||
|
"default": []
|
||||||
|
},
|
||||||
"user_id": { "type": "string", "title": "User Id" },
|
"user_id": { "type": "string", "title": "User Id" },
|
||||||
"created_at": {
|
"input_schema": {
|
||||||
"type": "string",
|
"additionalProperties": true,
|
||||||
"format": "date-time",
|
"type": "object",
|
||||||
"title": "Created At"
|
"title": "Input Schema",
|
||||||
|
"readOnly": true
|
||||||
|
},
|
||||||
|
"output_schema": {
|
||||||
|
"additionalProperties": true,
|
||||||
|
"type": "object",
|
||||||
|
"title": "Output Schema",
|
||||||
|
"readOnly": true
|
||||||
|
},
|
||||||
|
"has_external_trigger": {
|
||||||
|
"type": "boolean",
|
||||||
|
"title": "Has External Trigger",
|
||||||
|
"readOnly": true
|
||||||
|
},
|
||||||
|
"has_human_in_the_loop": {
|
||||||
|
"type": "boolean",
|
||||||
|
"title": "Has Human In The Loop",
|
||||||
|
"readOnly": true
|
||||||
|
},
|
||||||
|
"has_sensitive_action": {
|
||||||
|
"type": "boolean",
|
||||||
|
"title": "Has Sensitive Action",
|
||||||
|
"readOnly": true
|
||||||
|
},
|
||||||
|
"trigger_setup_info": {
|
||||||
|
"anyOf": [
|
||||||
|
{ "$ref": "#/components/schemas/GraphTriggerInfo" },
|
||||||
|
{ "type": "null" }
|
||||||
|
],
|
||||||
|
"readOnly": true
|
||||||
|
},
|
||||||
|
"credentials_input_schema": {
|
||||||
|
"additionalProperties": true,
|
||||||
|
"type": "object",
|
||||||
|
"title": "Credentials Input Schema",
|
||||||
|
"readOnly": true
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
"id",
|
|
||||||
"version",
|
|
||||||
"name",
|
"name",
|
||||||
"description",
|
"description",
|
||||||
"user_id",
|
"user_id",
|
||||||
"created_at"
|
"input_schema",
|
||||||
|
"output_schema",
|
||||||
|
"has_external_trigger",
|
||||||
|
"has_human_in_the_loop",
|
||||||
|
"has_sensitive_action",
|
||||||
|
"trigger_setup_info",
|
||||||
|
"credentials_input_schema"
|
||||||
],
|
],
|
||||||
"title": "GraphMeta",
|
"title": "GraphMeta"
|
||||||
"description": "Lightweight graph metadata model representing an existing graph from the database,\nfor use in listings and summaries.\n\nLacks `GraphModel`'s nodes, links, and expensive computed fields.\nUse for list endpoints where full graph data is not needed and performance matters."
|
|
||||||
},
|
},
|
||||||
"GraphModel": {
|
"GraphModel": {
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -7840,111 +7886,17 @@
|
|||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Recommended Schedule Cron"
|
"title": "Recommended Schedule Cron"
|
||||||
},
|
},
|
||||||
"forked_from_id": {
|
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
|
||||||
"title": "Forked From Id"
|
|
||||||
},
|
|
||||||
"forked_from_version": {
|
|
||||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
|
||||||
"title": "Forked From Version"
|
|
||||||
},
|
|
||||||
"user_id": { "type": "string", "title": "User Id" },
|
|
||||||
"created_at": {
|
|
||||||
"type": "string",
|
|
||||||
"format": "date-time",
|
|
||||||
"title": "Created At"
|
|
||||||
},
|
|
||||||
"nodes": {
|
"nodes": {
|
||||||
"items": { "$ref": "#/components/schemas/NodeModel" },
|
"items": { "$ref": "#/components/schemas/NodeModel" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Nodes"
|
"title": "Nodes",
|
||||||
|
"default": []
|
||||||
},
|
},
|
||||||
"links": {
|
"links": {
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Links"
|
"title": "Links",
|
||||||
},
|
"default": []
|
||||||
"sub_graphs": {
|
|
||||||
"items": { "$ref": "#/components/schemas/BaseGraph-Output" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Sub Graphs"
|
|
||||||
},
|
|
||||||
"input_schema": {
|
|
||||||
"additionalProperties": true,
|
|
||||||
"type": "object",
|
|
||||||
"title": "Input Schema",
|
|
||||||
"readOnly": true
|
|
||||||
},
|
|
||||||
"output_schema": {
|
|
||||||
"additionalProperties": true,
|
|
||||||
"type": "object",
|
|
||||||
"title": "Output Schema",
|
|
||||||
"readOnly": true
|
|
||||||
},
|
|
||||||
"has_external_trigger": {
|
|
||||||
"type": "boolean",
|
|
||||||
"title": "Has External Trigger",
|
|
||||||
"readOnly": true
|
|
||||||
},
|
|
||||||
"has_human_in_the_loop": {
|
|
||||||
"type": "boolean",
|
|
||||||
"title": "Has Human In The Loop",
|
|
||||||
"readOnly": true
|
|
||||||
},
|
|
||||||
"has_sensitive_action": {
|
|
||||||
"type": "boolean",
|
|
||||||
"title": "Has Sensitive Action",
|
|
||||||
"readOnly": true
|
|
||||||
},
|
|
||||||
"trigger_setup_info": {
|
|
||||||
"anyOf": [
|
|
||||||
{ "$ref": "#/components/schemas/GraphTriggerInfo" },
|
|
||||||
{ "type": "null" }
|
|
||||||
],
|
|
||||||
"readOnly": true
|
|
||||||
},
|
|
||||||
"credentials_input_schema": {
|
|
||||||
"additionalProperties": true,
|
|
||||||
"type": "object",
|
|
||||||
"title": "Credentials Input Schema",
|
|
||||||
"readOnly": true
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"type": "object",
|
|
||||||
"required": [
|
|
||||||
"name",
|
|
||||||
"description",
|
|
||||||
"user_id",
|
|
||||||
"created_at",
|
|
||||||
"input_schema",
|
|
||||||
"output_schema",
|
|
||||||
"has_external_trigger",
|
|
||||||
"has_human_in_the_loop",
|
|
||||||
"has_sensitive_action",
|
|
||||||
"trigger_setup_info",
|
|
||||||
"credentials_input_schema"
|
|
||||||
],
|
|
||||||
"title": "GraphModel",
|
|
||||||
"description": "Full graph model representing an existing graph from the database.\n\nThis is the primary model for working with persisted graphs. Includes all\ngraph data (nodes, links, sub_graphs) plus user ownership and timestamps.\nProvides computed fields (input_schema, output_schema, etc.) used during\nset-up (frontend) and execution (backend).\n\nInherits from:\n- `Graph`: provides structure (nodes, links, sub_graphs) and computed schemas\n- `GraphMeta`: provides user_id, created_at for database records"
|
|
||||||
},
|
|
||||||
"GraphModelWithoutNodes": {
|
|
||||||
"properties": {
|
|
||||||
"id": { "type": "string", "title": "Id" },
|
|
||||||
"version": { "type": "integer", "title": "Version", "default": 1 },
|
|
||||||
"is_active": {
|
|
||||||
"type": "boolean",
|
|
||||||
"title": "Is Active",
|
|
||||||
"default": true
|
|
||||||
},
|
|
||||||
"name": { "type": "string", "title": "Name" },
|
|
||||||
"description": { "type": "string", "title": "Description" },
|
|
||||||
"instructions": {
|
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
|
||||||
"title": "Instructions"
|
|
||||||
},
|
|
||||||
"recommended_schedule_cron": {
|
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
|
||||||
"title": "Recommended Schedule Cron"
|
|
||||||
},
|
},
|
||||||
"forked_from_id": {
|
"forked_from_id": {
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
@@ -7954,6 +7906,12 @@
|
|||||||
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
"anyOf": [{ "type": "integer" }, { "type": "null" }],
|
||||||
"title": "Forked From Version"
|
"title": "Forked From Version"
|
||||||
},
|
},
|
||||||
|
"sub_graphs": {
|
||||||
|
"items": { "$ref": "#/components/schemas/BaseGraph-Output" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Sub Graphs",
|
||||||
|
"default": []
|
||||||
|
},
|
||||||
"user_id": { "type": "string", "title": "User Id" },
|
"user_id": { "type": "string", "title": "User Id" },
|
||||||
"created_at": {
|
"created_at": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -8015,8 +7973,7 @@
|
|||||||
"trigger_setup_info",
|
"trigger_setup_info",
|
||||||
"credentials_input_schema"
|
"credentials_input_schema"
|
||||||
],
|
],
|
||||||
"title": "GraphModelWithoutNodes",
|
"title": "GraphModel"
|
||||||
"description": "GraphModel variant that excludes nodes, links, and sub-graphs from serialization.\n\nUsed in contexts like the store where exposing internal graph structure\nis not desired. Inherits all computed fields from GraphModel but marks\nnodes and links as excluded from JSON output."
|
|
||||||
},
|
},
|
||||||
"GraphSettings": {
|
"GraphSettings": {
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -8656,22 +8613,26 @@
|
|||||||
"input_default": {
|
"input_default": {
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"title": "Input Default"
|
"title": "Input Default",
|
||||||
|
"default": {}
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"title": "Metadata"
|
"title": "Metadata",
|
||||||
|
"default": {}
|
||||||
},
|
},
|
||||||
"input_links": {
|
"input_links": {
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Input Links"
|
"title": "Input Links",
|
||||||
|
"default": []
|
||||||
},
|
},
|
||||||
"output_links": {
|
"output_links": {
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Output Links"
|
"title": "Output Links",
|
||||||
|
"default": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -8751,22 +8712,26 @@
|
|||||||
"input_default": {
|
"input_default": {
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"title": "Input Default"
|
"title": "Input Default",
|
||||||
|
"default": {}
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"title": "Metadata"
|
"title": "Metadata",
|
||||||
|
"default": {}
|
||||||
},
|
},
|
||||||
"input_links": {
|
"input_links": {
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Input Links"
|
"title": "Input Links",
|
||||||
|
"default": []
|
||||||
},
|
},
|
||||||
"output_links": {
|
"output_links": {
|
||||||
"items": { "$ref": "#/components/schemas/Link" },
|
"items": { "$ref": "#/components/schemas/Link" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Output Links"
|
"title": "Output Links",
|
||||||
|
"default": []
|
||||||
},
|
},
|
||||||
"graph_id": { "type": "string", "title": "Graph Id" },
|
"graph_id": { "type": "string", "title": "Graph Id" },
|
||||||
"graph_version": { "type": "integer", "title": "Graph Version" },
|
"graph_version": { "type": "integer", "title": "Graph Version" },
|
||||||
@@ -12307,9 +12272,7 @@
|
|||||||
"title": "Location"
|
"title": "Location"
|
||||||
},
|
},
|
||||||
"msg": { "type": "string", "title": "Message" },
|
"msg": { "type": "string", "title": "Message" },
|
||||||
"type": { "type": "string", "title": "Error Type" },
|
"type": { "type": "string", "title": "Error Type" }
|
||||||
"input": { "title": "Input" },
|
|
||||||
"ctx": { "type": "object", "title": "Context" }
|
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["loc", "msg", "type"],
|
"required": ["loc", "msg", "type"],
|
||||||
|
|||||||
@@ -102,6 +102,18 @@ export function ChatMessage({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function handleClarificationAnswers(answers: Record<string, string>) {
|
||||||
|
if (onSendMessage) {
|
||||||
|
const contextMessage = Object.entries(answers)
|
||||||
|
.map(([keyword, answer]) => `${keyword}: ${answer}`)
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
onSendMessage(
|
||||||
|
`I have the answers to your questions:\n\n${contextMessage}\n\nPlease proceed with creating the agent.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const handleCopy = useCallback(
|
const handleCopy = useCallback(
|
||||||
async function handleCopy() {
|
async function handleCopy() {
|
||||||
if (message.type !== "message") return;
|
if (message.type !== "message") return;
|
||||||
@@ -150,22 +162,6 @@ export function ChatMessage({
|
|||||||
.slice(index + 1)
|
.slice(index + 1)
|
||||||
.some((m) => m.type === "message" && m.role === "user");
|
.some((m) => m.type === "message" && m.role === "user");
|
||||||
|
|
||||||
const handleClarificationAnswers = (answers: Record<string, string>) => {
|
|
||||||
if (onSendMessage) {
|
|
||||||
// Iterate over questions (preserves original order) instead of answers
|
|
||||||
const contextMessage = message.questions
|
|
||||||
.map((q) => {
|
|
||||||
const answer = answers[q.keyword] || "";
|
|
||||||
return `> ${q.question}\n\n${answer}`;
|
|
||||||
})
|
|
||||||
.join("\n\n");
|
|
||||||
|
|
||||||
onSendMessage(
|
|
||||||
`**Here are my answers:**\n\n${contextMessage}\n\nPlease proceed with creating the agent.`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ClarificationQuestionsWidget
|
<ClarificationQuestionsWidget
|
||||||
questions={message.questions}
|
questions={message.questions}
|
||||||
|
|||||||
@@ -19,13 +19,13 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
|||||||
if (timerRef.current === null) {
|
if (timerRef.current === null) {
|
||||||
timerRef.current = setTimeout(() => {
|
timerRef.current = setTimeout(() => {
|
||||||
setShowSlowLoader(true);
|
setShowSlowLoader(true);
|
||||||
}, 8000);
|
}, 3000);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (coffeeTimerRef.current === null) {
|
if (coffeeTimerRef.current === null) {
|
||||||
coffeeTimerRef.current = setTimeout(() => {
|
coffeeTimerRef.current = setTimeout(() => {
|
||||||
setShowCoffeeMessage(true);
|
setShowCoffeeMessage(true);
|
||||||
}, 10000);
|
}, 8000);
|
||||||
}
|
}
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
|
|||||||
@@ -362,14 +362,25 @@ export type GraphMeta = {
|
|||||||
user_id: UserID;
|
user_id: UserID;
|
||||||
version: number;
|
version: number;
|
||||||
is_active: boolean;
|
is_active: boolean;
|
||||||
created_at: Date;
|
|
||||||
name: string;
|
name: string;
|
||||||
description: string;
|
description: string;
|
||||||
instructions?: string | null;
|
instructions?: string | null;
|
||||||
recommended_schedule_cron: string | null;
|
recommended_schedule_cron: string | null;
|
||||||
forked_from_id?: GraphID | null;
|
forked_from_id?: GraphID | null;
|
||||||
forked_from_version?: number | null;
|
forked_from_version?: number | null;
|
||||||
};
|
input_schema: GraphInputSchema;
|
||||||
|
output_schema: GraphOutputSchema;
|
||||||
|
credentials_input_schema: CredentialsInputSchema;
|
||||||
|
} & (
|
||||||
|
| {
|
||||||
|
has_external_trigger: true;
|
||||||
|
trigger_setup_info: GraphTriggerInfo;
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
has_external_trigger: false;
|
||||||
|
trigger_setup_info: null;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
export type GraphID = Brand<string, "GraphID">;
|
export type GraphID = Brand<string, "GraphID">;
|
||||||
|
|
||||||
@@ -436,22 +447,11 @@ export type GraphTriggerInfo = {
|
|||||||
|
|
||||||
/* Mirror of backend/data/graph.py:Graph */
|
/* Mirror of backend/data/graph.py:Graph */
|
||||||
export type Graph = GraphMeta & {
|
export type Graph = GraphMeta & {
|
||||||
|
created_at: Date;
|
||||||
nodes: Node[];
|
nodes: Node[];
|
||||||
links: Link[];
|
links: Link[];
|
||||||
sub_graphs: Omit<Graph, "sub_graphs">[]; // Flattened sub-graphs
|
sub_graphs: Omit<Graph, "sub_graphs">[]; // Flattened sub-graphs
|
||||||
input_schema: GraphInputSchema;
|
};
|
||||||
output_schema: GraphOutputSchema;
|
|
||||||
credentials_input_schema: CredentialsInputSchema;
|
|
||||||
} & (
|
|
||||||
| {
|
|
||||||
has_external_trigger: true;
|
|
||||||
trigger_setup_info: GraphTriggerInfo;
|
|
||||||
}
|
|
||||||
| {
|
|
||||||
has_external_trigger: false;
|
|
||||||
trigger_setup_info: null;
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
export type GraphUpdateable = Omit<
|
export type GraphUpdateable = Omit<
|
||||||
Graph,
|
Graph,
|
||||||
|
|||||||
Reference in New Issue
Block a user