mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-28 08:28:00 -05:00
Compare commits
96 Commits
testing-cl
...
user-works
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef5120ad45 | ||
|
|
788713dfb5 | ||
|
|
704e2959ba | ||
|
|
cc95f1dbd6 | ||
|
|
7357c26c00 | ||
|
|
5d82a42b49 | ||
|
|
d2fca0adbd | ||
|
|
0acf868b18 | ||
|
|
d81e97b881 | ||
|
|
ef3dfb8af4 | ||
|
|
eb22180e6d | ||
|
|
f3d8d953f5 | ||
|
|
2b0afc348e | ||
|
|
e3a389ba00 | ||
|
|
f7c59b00d8 | ||
|
|
1feed23475 | ||
|
|
3a295b3192 | ||
|
|
4872eb3ccd | ||
|
|
f5d7b3f618 | ||
|
|
de9ef2366e | ||
|
|
0953983944 | ||
|
|
c4d83505c0 | ||
|
|
28e4da5b13 | ||
|
|
72b1542a43 | ||
|
|
5bbc3d55f0 | ||
|
|
87814bcdcb | ||
|
|
85c229dd6c | ||
|
|
270586751b | ||
|
|
fa1afd6a6d | ||
|
|
af2bcd900a | ||
|
|
8f7204484d | ||
|
|
5af4c60b8e | ||
|
|
49b67ccd94 | ||
|
|
efb2e2792d | ||
|
|
d51d811497 | ||
|
|
83f93d00f4 | ||
|
|
c132b6dfa5 | ||
|
|
7eb7b7186f | ||
|
|
4b58eac877 | ||
|
|
bae6be915f | ||
|
|
8f16d583a4 | ||
|
|
0b8c671a27 | ||
|
|
cb074b0076 | ||
|
|
f29dd34f51 | ||
|
|
581dc337f2 | ||
|
|
f8b041fd63 | ||
|
|
56248ae7b7 | ||
|
|
bec0157f9e | ||
|
|
57f44e166a | ||
|
|
2c678f2658 | ||
|
|
669e33d709 | ||
|
|
953e7a5afb | ||
|
|
e9c55ed5a3 | ||
|
|
0058cd3ba6 | ||
|
|
ce3b8fa8d8 | ||
|
|
ce67b7eca4 | ||
|
|
0e34c7e5c4 | ||
|
|
5c5dd160dd | ||
|
|
759248b7fe | ||
|
|
ca5758cce6 | ||
|
|
d40df5a8c8 | ||
|
|
0db228ed43 | ||
|
|
590f434d0a | ||
|
|
8f171a0537 | ||
|
|
c814a43465 | ||
|
|
5923041fe8 | ||
|
|
936a2d70db | ||
|
|
80c54b7f46 | ||
|
|
28caf01ca7 | ||
|
|
ea035224bc | ||
|
|
62813a1ea6 | ||
|
|
67405f7eb9 | ||
|
|
171ff6e776 | ||
|
|
349b1f9c79 | ||
|
|
277b0537e9 | ||
|
|
071b3bb5cd | ||
|
|
2134d777be | ||
|
|
962824c8af | ||
|
|
3e9d5d0d50 | ||
|
|
fac10c422b | ||
|
|
91c7896859 | ||
|
|
bab436231a | ||
|
|
859f3f8c06 | ||
|
|
d5c0f5b2df | ||
|
|
fbc2da36e6 | ||
|
|
75ecc4de92 | ||
|
|
f0c2503608 | ||
|
|
cfb7dc5aca | ||
|
|
9a6e17ff52 | ||
|
|
fb58827c61 | ||
|
|
595f3508c1 | ||
|
|
7892590b12 | ||
|
|
82d7134fc6 | ||
|
|
90466908a8 | ||
|
|
f9f984a8f4 | ||
|
|
fc87ed4e34 |
38
.github/workflows/platform-frontend-ci.yml
vendored
38
.github/workflows/platform-frontend-ci.yml
vendored
@@ -128,7 +128,7 @@ jobs:
|
|||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
exitOnceUploaded: true
|
exitOnceUploaded: true
|
||||||
|
|
||||||
test:
|
e2e_test:
|
||||||
runs-on: big-boi
|
runs-on: big-boi
|
||||||
needs: setup
|
needs: setup
|
||||||
strategy:
|
strategy:
|
||||||
@@ -258,3 +258,39 @@ jobs:
|
|||||||
- name: Print Final Docker Compose logs
|
- name: Print Final Docker Compose logs
|
||||||
if: always()
|
if: always()
|
||||||
run: docker compose -f ../docker-compose.yml logs
|
run: docker compose -f ../docker-compose.yml logs
|
||||||
|
|
||||||
|
integration_test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: setup
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: "22.18.0"
|
||||||
|
|
||||||
|
- name: Enable corepack
|
||||||
|
run: corepack enable
|
||||||
|
|
||||||
|
- name: Restore dependencies cache
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.pnpm-store
|
||||||
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||||
|
${{ runner.os }}-pnpm-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
|
- name: Generate API client
|
||||||
|
run: pnpm generate:api
|
||||||
|
|
||||||
|
- name: Run Integration Tests
|
||||||
|
run: pnpm test:unit
|
||||||
|
|||||||
@@ -194,6 +194,50 @@ ex: do the inputs and outputs tie well together?
|
|||||||
|
|
||||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||||
|
|
||||||
|
**Handling files in blocks with `store_media_file()`:**
|
||||||
|
|
||||||
|
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||||
|
|
||||||
|
| Format | Use When | Returns |
|
||||||
|
|--------|----------|---------|
|
||||||
|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||||
|
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||||
|
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||||
|
|
||||||
|
**Examples:**
|
||||||
|
```python
|
||||||
|
# INPUT: Need to process file locally with ffmpeg
|
||||||
|
local_path = await store_media_file(
|
||||||
|
file=input_data.video,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||||
|
|
||||||
|
# INPUT: Need to send to external API like Replicate
|
||||||
|
image_b64 = await store_media_file(
|
||||||
|
file=input_data.image,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_external_api",
|
||||||
|
)
|
||||||
|
# image_b64 = "..." - send to API
|
||||||
|
|
||||||
|
# OUTPUT: Returning result from block
|
||||||
|
result_url = await store_media_file(
|
||||||
|
file=generated_image_url,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", result_url
|
||||||
|
# In CoPilot: result_url = "workspace://abc123"
|
||||||
|
# In graphs: result_url = "data:image/png;base64,..."
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key points:**
|
||||||
|
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||||
|
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||||
|
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||||
|
|
||||||
**Modifying the API:**
|
**Modifying the API:**
|
||||||
|
|
||||||
1. Update route in `/backend/backend/server/routers/`
|
1. Update route in `/backend/backend/server/routers/`
|
||||||
|
|||||||
@@ -178,5 +178,10 @@ AYRSHARE_JWT_KEY=
|
|||||||
SMARTLEAD_API_KEY=
|
SMARTLEAD_API_KEY=
|
||||||
ZEROBOUNCE_API_KEY=
|
ZEROBOUNCE_API_KEY=
|
||||||
|
|
||||||
|
# PostHog Analytics
|
||||||
|
# Get API key from https://posthog.com - Project Settings > Project API Key
|
||||||
|
POSTHOG_API_KEY=
|
||||||
|
POSTHOG_HOST=https://eu.i.posthog.com
|
||||||
|
|
||||||
# Other Services
|
# Other Services
|
||||||
AUTOMOD_API_KEY=
|
AUTOMOD_API_KEY=
|
||||||
|
|||||||
@@ -86,6 +86,8 @@ async def execute_graph_block(
|
|||||||
obj = backend.data.block.get_block(block_id)
|
obj = backend.data.block.get_block(block_id)
|
||||||
if not obj:
|
if not obj:
|
||||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||||
|
if obj.disabled:
|
||||||
|
raise HTTPException(status_code=403, detail=f"Block #{block_id} is disabled.")
|
||||||
|
|
||||||
output = defaultdict(list)
|
output = defaultdict(list)
|
||||||
async for name, data in obj.execute(data):
|
async for name, data in obj.execute(data):
|
||||||
|
|||||||
@@ -33,9 +33,15 @@ class ChatConfig(BaseSettings):
|
|||||||
|
|
||||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||||
max_agent_runs: int = Field(default=3, description="Maximum number of agent runs")
|
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||||
max_agent_schedules: int = Field(
|
max_agent_schedules: int = Field(
|
||||||
default=3, description="Maximum number of agent schedules"
|
default=30, description="Maximum number of agent schedules"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Long-running operation configuration
|
||||||
|
long_running_operation_ttl: int = Field(
|
||||||
|
default=600,
|
||||||
|
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Langfuse Prompt Management Configuration
|
# Langfuse Prompt Management Configuration
|
||||||
|
|||||||
@@ -247,3 +247,45 @@ async def get_chat_session_message_count(session_id: str) -> int:
|
|||||||
"""Get the number of messages in a chat session."""
|
"""Get the number of messages in a chat session."""
|
||||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
async def update_tool_message_content(
|
||||||
|
session_id: str,
|
||||||
|
tool_call_id: str,
|
||||||
|
new_content: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Update the content of a tool message in chat history.
|
||||||
|
|
||||||
|
Used by background tasks to update pending operation messages with final results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The chat session ID.
|
||||||
|
tool_call_id: The tool call ID to find the message.
|
||||||
|
new_content: The new content to set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a message was updated, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await PrismaChatMessage.prisma().update_many(
|
||||||
|
where={
|
||||||
|
"sessionId": session_id,
|
||||||
|
"toolCallId": tool_call_id,
|
||||||
|
},
|
||||||
|
data={
|
||||||
|
"content": new_content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if result == 0:
|
||||||
|
logger.warning(
|
||||||
|
f"No message found to update for session {session_id}, "
|
||||||
|
f"tool_call_id {tool_call_id}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to update tool message for session {session_id}, "
|
||||||
|
f"tool_call_id {tool_call_id}: {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|||||||
@@ -295,6 +295,21 @@ async def cache_chat_session(session: ChatSession) -> None:
|
|||||||
await _cache_session(session)
|
await _cache_session(session)
|
||||||
|
|
||||||
|
|
||||||
|
async def invalidate_session_cache(session_id: str) -> None:
|
||||||
|
"""Invalidate a chat session from Redis cache.
|
||||||
|
|
||||||
|
Used by background tasks to ensure fresh data is loaded on next access.
|
||||||
|
This is best-effort - Redis failures are logged but don't fail the operation.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis_key = _get_session_cache_key(session_id)
|
||||||
|
async_redis = await get_redis_async()
|
||||||
|
await async_redis.delete(redis_key)
|
||||||
|
except Exception as e:
|
||||||
|
# Best-effort: log but don't fail - cache will expire naturally
|
||||||
|
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||||
"""Get a chat session from the database."""
|
"""Get a chat session from the database."""
|
||||||
prisma_session = await chat_db.get_chat_session(session_id)
|
prisma_session = await chat_db.get_chat_session(session_id)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class ResponseType(str, Enum):
|
|||||||
# Other
|
# Other
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
USAGE = "usage"
|
USAGE = "usage"
|
||||||
|
HEARTBEAT = "heartbeat"
|
||||||
|
|
||||||
|
|
||||||
class StreamBaseResponse(BaseModel):
|
class StreamBaseResponse(BaseModel):
|
||||||
@@ -142,3 +143,20 @@ class StreamError(StreamBaseResponse):
|
|||||||
details: dict[str, Any] | None = Field(
|
details: dict[str, Any] | None = Field(
|
||||||
default=None, description="Additional error details"
|
default=None, description="Additional error details"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamHeartbeat(StreamBaseResponse):
|
||||||
|
"""Heartbeat to keep SSE connection alive during long-running operations.
|
||||||
|
|
||||||
|
Uses SSE comment format (: comment) which is ignored by clients but keeps
|
||||||
|
the connection alive through proxies and load balancers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.HEARTBEAT
|
||||||
|
toolCallId: str | None = Field(
|
||||||
|
default=None, description="Tool call ID if heartbeat is for a specific tool"
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_sse(self) -> str:
|
||||||
|
"""Convert to SSE comment format to keep connection alive."""
|
||||||
|
return ": heartbeat\n\n"
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,10 @@
|
|||||||
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tracking import track_tool_called
|
||||||
|
|
||||||
from .add_understanding import AddUnderstandingTool
|
from .add_understanding import AddUnderstandingTool
|
||||||
from .agent_output import AgentOutputTool
|
from .agent_output import AgentOutputTool
|
||||||
@@ -16,10 +18,18 @@ from .get_doc_page import GetDocPageTool
|
|||||||
from .run_agent import RunAgentTool
|
from .run_agent import RunAgentTool
|
||||||
from .run_block import RunBlockTool
|
from .run_block import RunBlockTool
|
||||||
from .search_docs import SearchDocsTool
|
from .search_docs import SearchDocsTool
|
||||||
|
from .workspace_tools import (
|
||||||
|
DeleteWorkspaceFileTool,
|
||||||
|
ListWorkspaceFilesTool,
|
||||||
|
ReadWorkspaceFileTool,
|
||||||
|
WriteWorkspaceFileTool,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Single source of truth for all tools
|
# Single source of truth for all tools
|
||||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||||
"add_understanding": AddUnderstandingTool(),
|
"add_understanding": AddUnderstandingTool(),
|
||||||
@@ -33,6 +43,11 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
|||||||
"view_agent_output": AgentOutputTool(),
|
"view_agent_output": AgentOutputTool(),
|
||||||
"search_docs": SearchDocsTool(),
|
"search_docs": SearchDocsTool(),
|
||||||
"get_doc_page": GetDocPageTool(),
|
"get_doc_page": GetDocPageTool(),
|
||||||
|
# Workspace tools for CoPilot file operations
|
||||||
|
"list_workspace_files": ListWorkspaceFilesTool(),
|
||||||
|
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||||
|
"write_workspace_file": WriteWorkspaceFileTool(),
|
||||||
|
"delete_workspace_file": DeleteWorkspaceFileTool(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Export individual tool instances for backwards compatibility
|
# Export individual tool instances for backwards compatibility
|
||||||
@@ -45,6 +60,11 @@ tools: list[ChatCompletionToolParam] = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool(tool_name: str) -> BaseTool | None:
|
||||||
|
"""Get a tool instance by name."""
|
||||||
|
return TOOL_REGISTRY.get(tool_name)
|
||||||
|
|
||||||
|
|
||||||
async def execute_tool(
|
async def execute_tool(
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
parameters: dict[str, Any],
|
parameters: dict[str, Any],
|
||||||
@@ -53,7 +73,20 @@ async def execute_tool(
|
|||||||
tool_call_id: str,
|
tool_call_id: str,
|
||||||
) -> "StreamToolOutputAvailable":
|
) -> "StreamToolOutputAvailable":
|
||||||
"""Execute a tool by name."""
|
"""Execute a tool by name."""
|
||||||
tool = TOOL_REGISTRY.get(tool_name)
|
tool = get_tool(tool_name)
|
||||||
if not tool:
|
if not tool:
|
||||||
raise ValueError(f"Tool {tool_name} not found")
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
|
|
||||||
|
# Track tool call in PostHog
|
||||||
|
logger.info(
|
||||||
|
f"Tracking tool call: tool={tool_name}, user={user_id}, "
|
||||||
|
f"session={session.session_id}, call_id={tool_call_id}"
|
||||||
|
)
|
||||||
|
track_tool_called(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session.session_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
|
||||||
return await tool.execute(user_id, session, tool_call_id, **parameters)
|
return await tool.execute(user_id, session, tool_call_id, **parameters)
|
||||||
|
|||||||
@@ -3,8 +3,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langfuse import observe
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.understanding import (
|
from backend.data.understanding import (
|
||||||
BusinessUnderstandingInput,
|
BusinessUnderstandingInput,
|
||||||
@@ -61,7 +59,6 @@ and automations for the user's specific needs."""
|
|||||||
"""Requires authentication to store user-specific data."""
|
"""Requires authentication to store user-specific data."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@observe(as_type="tool", name="add_understanding")
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
|
|||||||
@@ -1,29 +1,28 @@
|
|||||||
"""Agent generator package - Creates agents from natural language."""
|
"""Agent generator package - Creates agents from natural language."""
|
||||||
|
|
||||||
from .core import (
|
from .core import (
|
||||||
apply_agent_patch,
|
AgentGeneratorNotConfiguredError,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
generate_agent,
|
generate_agent,
|
||||||
generate_agent_patch,
|
generate_agent_patch,
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
|
json_to_graph,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
)
|
)
|
||||||
from .fixer import apply_all_fixes
|
from .service import health_check as check_external_service_health
|
||||||
from .utils import get_blocks_info
|
from .service import is_external_service_configured
|
||||||
from .validator import validate_agent
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Core functions
|
# Core functions
|
||||||
"decompose_goal",
|
"decompose_goal",
|
||||||
"generate_agent",
|
"generate_agent",
|
||||||
"generate_agent_patch",
|
"generate_agent_patch",
|
||||||
"apply_agent_patch",
|
|
||||||
"save_agent_to_library",
|
"save_agent_to_library",
|
||||||
"get_agent_as_json",
|
"get_agent_as_json",
|
||||||
# Fixer
|
"json_to_graph",
|
||||||
"apply_all_fixes",
|
# Exceptions
|
||||||
# Validator
|
"AgentGeneratorNotConfiguredError",
|
||||||
"validate_agent",
|
# Service
|
||||||
# Utils
|
"is_external_service_configured",
|
||||||
"get_blocks_info",
|
"check_external_service_health",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
"""OpenRouter client configuration for agent generation."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
# Configuration - use OPEN_ROUTER_API_KEY for consistency with chat/config.py
|
|
||||||
OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY")
|
|
||||||
AGENT_GENERATOR_MODEL = os.getenv("AGENT_GENERATOR_MODEL", "anthropic/claude-opus-4.5")
|
|
||||||
|
|
||||||
# OpenRouter client (OpenAI-compatible API)
|
|
||||||
_client: AsyncOpenAI | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_client() -> AsyncOpenAI:
|
|
||||||
"""Get or create the OpenRouter client."""
|
|
||||||
global _client
|
|
||||||
if _client is None:
|
|
||||||
if not OPENROUTER_API_KEY:
|
|
||||||
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
|
||||||
_client = AsyncOpenAI(
|
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
api_key=OPENROUTER_API_KEY,
|
|
||||||
)
|
|
||||||
return _client
|
|
||||||
@@ -1,7 +1,5 @@
|
|||||||
"""Core agent generation functions."""
|
"""Core agent generation functions."""
|
||||||
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -9,13 +7,35 @@ from typing import Any
|
|||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.data.graph import Graph, Link, Node, create_graph
|
from backend.data.graph import Graph, Link, Node, create_graph
|
||||||
|
|
||||||
from .client import AGENT_GENERATOR_MODEL, get_client
|
from .service import (
|
||||||
from .prompts import DECOMPOSITION_PROMPT, GENERATION_PROMPT, PATCH_PROMPT
|
decompose_goal_external,
|
||||||
from .utils import get_block_summaries, parse_json_from_llm
|
generate_agent_external,
|
||||||
|
generate_agent_patch_external,
|
||||||
|
is_external_service_configured,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentGeneratorNotConfiguredError(Exception):
|
||||||
|
"""Raised when the external Agent Generator service is not configured."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _check_service_configured() -> None:
|
||||||
|
"""Check if the external Agent Generator service is configured.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the service is not configured.
|
||||||
|
"""
|
||||||
|
if not is_external_service_configured():
|
||||||
|
raise AgentGeneratorNotConfiguredError(
|
||||||
|
"Agent Generator service is not configured. "
|
||||||
|
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
||||||
"""Break down a goal into steps or return clarifying questions.
|
"""Break down a goal into steps or return clarifying questions.
|
||||||
|
|
||||||
@@ -28,40 +48,13 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
|
|||||||
- {"type": "clarifying_questions", "questions": [...]}
|
- {"type": "clarifying_questions", "questions": [...]}
|
||||||
- {"type": "instructions", "steps": [...]}
|
- {"type": "instructions", "steps": [...]}
|
||||||
Or None on error
|
Or None on error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
client = get_client()
|
_check_service_configured()
|
||||||
prompt = DECOMPOSITION_PROMPT.format(block_summaries=get_block_summaries())
|
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||||
|
return await decompose_goal_external(description, context)
|
||||||
full_description = description
|
|
||||||
if context:
|
|
||||||
full_description = f"{description}\n\nAdditional context:\n{context}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.chat.completions.create(
|
|
||||||
model=AGENT_GENERATOR_MODEL,
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": prompt},
|
|
||||||
{"role": "user", "content": full_description},
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
if content is None:
|
|
||||||
logger.error("LLM returned empty content for decomposition")
|
|
||||||
return None
|
|
||||||
|
|
||||||
result = parse_json_from_llm(content)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
logger.error(f"Failed to parse decomposition response: {content[:200]}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error decomposing goal: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
@@ -72,31 +65,14 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict or None on error
|
Agent JSON dict or None on error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
client = get_client()
|
_check_service_configured()
|
||||||
prompt = GENERATION_PROMPT.format(block_summaries=get_block_summaries())
|
logger.info("Calling external Agent Generator service for generate_agent")
|
||||||
|
result = await generate_agent_external(instructions)
|
||||||
try:
|
if result:
|
||||||
response = await client.chat.completions.create(
|
|
||||||
model=AGENT_GENERATOR_MODEL,
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": prompt},
|
|
||||||
{"role": "user", "content": json.dumps(instructions, indent=2)},
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
if content is None:
|
|
||||||
logger.error("LLM returned empty content for agent generation")
|
|
||||||
return None
|
|
||||||
|
|
||||||
result = parse_json_from_llm(content)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
logger.error(f"Failed to parse agent JSON: {content[:200]}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Ensure required fields
|
# Ensure required fields
|
||||||
if "id" not in result:
|
if "id" not in result:
|
||||||
result["id"] = str(uuid.uuid4())
|
result["id"] = str(uuid.uuid4())
|
||||||
@@ -104,13 +80,8 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
|||||||
result["version"] = 1
|
result["version"] = 1
|
||||||
if "is_active" not in result:
|
if "is_active" not in result:
|
||||||
result["is_active"] = True
|
result["is_active"] = True
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating agent: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||||
"""Convert agent JSON dict to Graph model.
|
"""Convert agent JSON dict to Graph model.
|
||||||
@@ -284,108 +255,23 @@ async def get_agent_as_json(
|
|||||||
async def generate_agent_patch(
|
async def generate_agent_patch(
|
||||||
update_request: str, current_agent: dict[str, Any]
|
update_request: str, current_agent: dict[str, Any]
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Generate a patch to update an existing agent.
|
"""Update an existing agent using natural language.
|
||||||
|
|
||||||
|
The external Agent Generator service handles:
|
||||||
|
- Generating the patch
|
||||||
|
- Applying the patch
|
||||||
|
- Fixing and validating the result
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Patch dict or clarifying questions, or None on error
|
Updated agent JSON, clarifying questions dict, or None on error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
client = get_client()
|
_check_service_configured()
|
||||||
prompt = PATCH_PROMPT.format(
|
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||||
current_agent=json.dumps(current_agent, indent=2),
|
return await generate_agent_patch_external(update_request, current_agent)
|
||||||
block_summaries=get_block_summaries(),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.chat.completions.create(
|
|
||||||
model=AGENT_GENERATOR_MODEL,
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": prompt},
|
|
||||||
{"role": "user", "content": update_request},
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
if content is None:
|
|
||||||
logger.error("LLM returned empty content for patch generation")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return parse_json_from_llm(content)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating patch: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def apply_agent_patch(
|
|
||||||
current_agent: dict[str, Any], patch: dict[str, Any]
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Apply a patch to an existing agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
current_agent: Current agent JSON
|
|
||||||
patch: Patch dict with operations
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated agent JSON
|
|
||||||
"""
|
|
||||||
agent = copy.deepcopy(current_agent)
|
|
||||||
patches = patch.get("patches", [])
|
|
||||||
|
|
||||||
for p in patches:
|
|
||||||
patch_type = p.get("type")
|
|
||||||
|
|
||||||
if patch_type == "modify":
|
|
||||||
node_id = p.get("node_id")
|
|
||||||
changes = p.get("changes", {})
|
|
||||||
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
if node["id"] == node_id:
|
|
||||||
_deep_update(node, changes)
|
|
||||||
logger.debug(f"Modified node {node_id}")
|
|
||||||
break
|
|
||||||
|
|
||||||
elif patch_type == "add":
|
|
||||||
new_nodes = p.get("new_nodes", [])
|
|
||||||
new_links = p.get("new_links", [])
|
|
||||||
|
|
||||||
agent["nodes"] = agent.get("nodes", []) + new_nodes
|
|
||||||
agent["links"] = agent.get("links", []) + new_links
|
|
||||||
logger.debug(f"Added {len(new_nodes)} nodes, {len(new_links)} links")
|
|
||||||
|
|
||||||
elif patch_type == "remove":
|
|
||||||
node_ids_to_remove = set(p.get("node_ids", []))
|
|
||||||
link_ids_to_remove = set(p.get("link_ids", []))
|
|
||||||
|
|
||||||
# Remove nodes
|
|
||||||
agent["nodes"] = [
|
|
||||||
n for n in agent.get("nodes", []) if n["id"] not in node_ids_to_remove
|
|
||||||
]
|
|
||||||
|
|
||||||
# Remove links (both explicit and those referencing removed nodes)
|
|
||||||
agent["links"] = [
|
|
||||||
link
|
|
||||||
for link in agent.get("links", [])
|
|
||||||
if link["id"] not in link_ids_to_remove
|
|
||||||
and link["source_id"] not in node_ids_to_remove
|
|
||||||
and link["sink_id"] not in node_ids_to_remove
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Removed {len(node_ids_to_remove)} nodes, {len(link_ids_to_remove)} links"
|
|
||||||
)
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def _deep_update(target: dict, source: dict) -> None:
|
|
||||||
"""Recursively update a dict with another dict."""
|
|
||||||
for key, value in source.items():
|
|
||||||
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
|
|
||||||
_deep_update(target[key], value)
|
|
||||||
else:
|
|
||||||
target[key] = value
|
|
||||||
|
|||||||
@@ -1,606 +0,0 @@
|
|||||||
"""Agent fixer - Fixes common LLM generation errors."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from .utils import (
|
|
||||||
ADDTODICTIONARY_BLOCK_ID,
|
|
||||||
ADDTOLIST_BLOCK_ID,
|
|
||||||
CODE_EXECUTION_BLOCK_ID,
|
|
||||||
CONDITION_BLOCK_ID,
|
|
||||||
CREATEDICT_BLOCK_ID,
|
|
||||||
CREATELIST_BLOCK_ID,
|
|
||||||
DATA_SAMPLING_BLOCK_ID,
|
|
||||||
DOUBLE_CURLY_BRACES_BLOCK_IDS,
|
|
||||||
GET_CURRENT_DATE_BLOCK_ID,
|
|
||||||
STORE_VALUE_BLOCK_ID,
|
|
||||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
|
||||||
get_blocks_info,
|
|
||||||
is_valid_uuid,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def fix_agent_ids(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix invalid UUIDs in agent and link IDs."""
|
|
||||||
# Fix agent ID
|
|
||||||
if not is_valid_uuid(agent.get("id", "")):
|
|
||||||
agent["id"] = str(uuid.uuid4())
|
|
||||||
logger.debug(f"Fixed agent ID: {agent['id']}")
|
|
||||||
|
|
||||||
# Fix node IDs
|
|
||||||
id_mapping = {} # Old ID -> New ID
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
if not is_valid_uuid(node.get("id", "")):
|
|
||||||
old_id = node.get("id", "")
|
|
||||||
new_id = str(uuid.uuid4())
|
|
||||||
id_mapping[old_id] = new_id
|
|
||||||
node["id"] = new_id
|
|
||||||
logger.debug(f"Fixed node ID: {old_id} -> {new_id}")
|
|
||||||
|
|
||||||
# Fix link IDs and update references
|
|
||||||
for link in agent.get("links", []):
|
|
||||||
if not is_valid_uuid(link.get("id", "")):
|
|
||||||
link["id"] = str(uuid.uuid4())
|
|
||||||
logger.debug(f"Fixed link ID: {link['id']}")
|
|
||||||
|
|
||||||
# Update source/sink IDs if they were remapped
|
|
||||||
if link.get("source_id") in id_mapping:
|
|
||||||
link["source_id"] = id_mapping[link["source_id"]]
|
|
||||||
if link.get("sink_id") in id_mapping:
|
|
||||||
link["sink_id"] = id_mapping[link["sink_id"]]
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_double_curly_braces(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix single curly braces to double in template blocks."""
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
if node.get("block_id") not in DOUBLE_CURLY_BRACES_BLOCK_IDS:
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_data = node.get("input_default", {})
|
|
||||||
for key in ("prompt", "format"):
|
|
||||||
if key in input_data and isinstance(input_data[key], str):
|
|
||||||
original = input_data[key]
|
|
||||||
# Fix simple variable references: {var} -> {{var}}
|
|
||||||
fixed = re.sub(
|
|
||||||
r"(?<!\{)\{([a-zA-Z_][a-zA-Z0-9_]*)\}(?!\})",
|
|
||||||
r"{{\1}}",
|
|
||||||
original,
|
|
||||||
)
|
|
||||||
if fixed != original:
|
|
||||||
input_data[key] = fixed
|
|
||||||
logger.debug(f"Fixed curly braces in {key}")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_storevalue_before_condition(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Add StoreValueBlock before ConditionBlock if needed for value2."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
|
|
||||||
# Find all ConditionBlock nodes
|
|
||||||
condition_node_ids = {
|
|
||||||
node["id"] for node in nodes if node.get("block_id") == CONDITION_BLOCK_ID
|
|
||||||
}
|
|
||||||
|
|
||||||
if not condition_node_ids:
|
|
||||||
return agent
|
|
||||||
|
|
||||||
new_nodes = []
|
|
||||||
new_links = []
|
|
||||||
processed_conditions = set()
|
|
||||||
|
|
||||||
for link in links:
|
|
||||||
sink_id = link.get("sink_id")
|
|
||||||
sink_name = link.get("sink_name")
|
|
||||||
|
|
||||||
# Check if this link goes to a ConditionBlock's value2
|
|
||||||
if sink_id in condition_node_ids and sink_name == "value2":
|
|
||||||
source_node = next(
|
|
||||||
(n for n in nodes if n["id"] == link.get("source_id")), None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Skip if source is already a StoreValueBlock
|
|
||||||
if source_node and source_node.get("block_id") == STORE_VALUE_BLOCK_ID:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Skip if we already processed this condition
|
|
||||||
if sink_id in processed_conditions:
|
|
||||||
continue
|
|
||||||
|
|
||||||
processed_conditions.add(sink_id)
|
|
||||||
|
|
||||||
# Create StoreValueBlock
|
|
||||||
store_node_id = str(uuid.uuid4())
|
|
||||||
store_node = {
|
|
||||||
"id": store_node_id,
|
|
||||||
"block_id": STORE_VALUE_BLOCK_ID,
|
|
||||||
"input_default": {"data": None},
|
|
||||||
"metadata": {"position": {"x": 0, "y": -100}},
|
|
||||||
}
|
|
||||||
new_nodes.append(store_node)
|
|
||||||
|
|
||||||
# Create link: original source -> StoreValueBlock
|
|
||||||
new_links.append(
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": link["source_id"],
|
|
||||||
"source_name": link["source_name"],
|
|
||||||
"sink_id": store_node_id,
|
|
||||||
"sink_name": "input",
|
|
||||||
"is_static": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update original link: StoreValueBlock -> ConditionBlock
|
|
||||||
link["source_id"] = store_node_id
|
|
||||||
link["source_name"] = "output"
|
|
||||||
|
|
||||||
logger.debug(f"Added StoreValueBlock before ConditionBlock {sink_id}")
|
|
||||||
|
|
||||||
if new_nodes:
|
|
||||||
agent["nodes"] = nodes + new_nodes
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_addtolist_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix AddToList blocks by adding prerequisite empty AddToList block.
|
|
||||||
|
|
||||||
When an AddToList block is found:
|
|
||||||
1. Checks if there's a CreateListBlock before it
|
|
||||||
2. Removes CreateListBlock if linked directly to AddToList
|
|
||||||
3. Adds an empty AddToList block before the original
|
|
||||||
4. Ensures the original has a self-referencing link
|
|
||||||
"""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
new_nodes = []
|
|
||||||
original_addtolist_ids = set()
|
|
||||||
nodes_to_remove = set()
|
|
||||||
links_to_remove = []
|
|
||||||
|
|
||||||
# First pass: identify CreateListBlock nodes to remove
|
|
||||||
for link in links:
|
|
||||||
source_node = next(
|
|
||||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
|
||||||
)
|
|
||||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
|
||||||
|
|
||||||
if (
|
|
||||||
source_node
|
|
||||||
and sink_node
|
|
||||||
and source_node.get("block_id") == CREATELIST_BLOCK_ID
|
|
||||||
and sink_node.get("block_id") == ADDTOLIST_BLOCK_ID
|
|
||||||
):
|
|
||||||
nodes_to_remove.add(source_node.get("id"))
|
|
||||||
links_to_remove.append(link)
|
|
||||||
logger.debug(f"Removing CreateListBlock {source_node.get('id')}")
|
|
||||||
|
|
||||||
# Second pass: process AddToList blocks
|
|
||||||
filtered_nodes = []
|
|
||||||
for node in nodes:
|
|
||||||
if node.get("id") in nodes_to_remove:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if node.get("block_id") == ADDTOLIST_BLOCK_ID:
|
|
||||||
original_addtolist_ids.add(node.get("id"))
|
|
||||||
node_id = node.get("id")
|
|
||||||
pos = node.get("metadata", {}).get("position", {"x": 0, "y": 0})
|
|
||||||
|
|
||||||
# Check if already has prerequisite
|
|
||||||
has_prereq = any(
|
|
||||||
link.get("sink_id") == node_id
|
|
||||||
and link.get("sink_name") == "list"
|
|
||||||
and link.get("source_name") == "updated_list"
|
|
||||||
for link in links
|
|
||||||
)
|
|
||||||
|
|
||||||
if not has_prereq:
|
|
||||||
# Remove links to "list" input (except self-reference)
|
|
||||||
for link in links:
|
|
||||||
if (
|
|
||||||
link.get("sink_id") == node_id
|
|
||||||
and link.get("sink_name") == "list"
|
|
||||||
and link.get("source_id") != node_id
|
|
||||||
and link not in links_to_remove
|
|
||||||
):
|
|
||||||
links_to_remove.append(link)
|
|
||||||
|
|
||||||
# Create prerequisite AddToList block
|
|
||||||
prereq_id = str(uuid.uuid4())
|
|
||||||
prereq_node = {
|
|
||||||
"id": prereq_id,
|
|
||||||
"block_id": ADDTOLIST_BLOCK_ID,
|
|
||||||
"input_default": {"list": [], "entry": None, "entries": []},
|
|
||||||
"metadata": {
|
|
||||||
"position": {"x": pos.get("x", 0) - 800, "y": pos.get("y", 0)}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
new_nodes.append(prereq_node)
|
|
||||||
|
|
||||||
# Link prerequisite to original
|
|
||||||
links.append(
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": prereq_id,
|
|
||||||
"source_name": "updated_list",
|
|
||||||
"sink_id": node_id,
|
|
||||||
"sink_name": "list",
|
|
||||||
"is_static": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
logger.debug(f"Added prerequisite AddToList block for {node_id}")
|
|
||||||
|
|
||||||
filtered_nodes.append(node)
|
|
||||||
|
|
||||||
# Remove marked links
|
|
||||||
filtered_links = [link for link in links if link not in links_to_remove]
|
|
||||||
|
|
||||||
# Add self-referencing links for original AddToList blocks
|
|
||||||
for node in filtered_nodes + new_nodes:
|
|
||||||
if (
|
|
||||||
node.get("block_id") == ADDTOLIST_BLOCK_ID
|
|
||||||
and node.get("id") in original_addtolist_ids
|
|
||||||
):
|
|
||||||
node_id = node.get("id")
|
|
||||||
has_self_ref = any(
|
|
||||||
link["source_id"] == node_id
|
|
||||||
and link["sink_id"] == node_id
|
|
||||||
and link["source_name"] == "updated_list"
|
|
||||||
and link["sink_name"] == "list"
|
|
||||||
for link in filtered_links
|
|
||||||
)
|
|
||||||
if not has_self_ref:
|
|
||||||
filtered_links.append(
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": node_id,
|
|
||||||
"source_name": "updated_list",
|
|
||||||
"sink_id": node_id,
|
|
||||||
"sink_name": "list",
|
|
||||||
"is_static": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
logger.debug(f"Added self-reference for AddToList {node_id}")
|
|
||||||
|
|
||||||
agent["nodes"] = filtered_nodes + new_nodes
|
|
||||||
agent["links"] = filtered_links
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_addtodictionary_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix AddToDictionary blocks by removing empty CreateDictionary nodes."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
nodes_to_remove = set()
|
|
||||||
links_to_remove = []
|
|
||||||
|
|
||||||
for link in links:
|
|
||||||
source_node = next(
|
|
||||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
|
||||||
)
|
|
||||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
|
||||||
|
|
||||||
if (
|
|
||||||
source_node
|
|
||||||
and sink_node
|
|
||||||
and source_node.get("block_id") == CREATEDICT_BLOCK_ID
|
|
||||||
and sink_node.get("block_id") == ADDTODICTIONARY_BLOCK_ID
|
|
||||||
):
|
|
||||||
nodes_to_remove.add(source_node.get("id"))
|
|
||||||
links_to_remove.append(link)
|
|
||||||
logger.debug(f"Removing CreateDictionary {source_node.get('id')}")
|
|
||||||
|
|
||||||
agent["nodes"] = [n for n in nodes if n.get("id") not in nodes_to_remove]
|
|
||||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_code_execution_output(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix CodeExecutionBlock output: change 'response' to 'stdout_logs'."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
|
|
||||||
for link in links:
|
|
||||||
source_node = next(
|
|
||||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
source_node
|
|
||||||
and source_node.get("block_id") == CODE_EXECUTION_BLOCK_ID
|
|
||||||
and link.get("source_name") == "response"
|
|
||||||
):
|
|
||||||
link["source_name"] = "stdout_logs"
|
|
||||||
logger.debug("Fixed CodeExecutionBlock output: response -> stdout_logs")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_data_sampling_sample_size(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix DataSamplingBlock by setting sample_size to 1 as default."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
links_to_remove = []
|
|
||||||
|
|
||||||
for node in nodes:
|
|
||||||
if node.get("block_id") == DATA_SAMPLING_BLOCK_ID:
|
|
||||||
node_id = node.get("id")
|
|
||||||
input_default = node.get("input_default", {})
|
|
||||||
|
|
||||||
# Remove links to sample_size
|
|
||||||
for link in links:
|
|
||||||
if (
|
|
||||||
link.get("sink_id") == node_id
|
|
||||||
and link.get("sink_name") == "sample_size"
|
|
||||||
):
|
|
||||||
links_to_remove.append(link)
|
|
||||||
|
|
||||||
# Set default
|
|
||||||
input_default["sample_size"] = 1
|
|
||||||
node["input_default"] = input_default
|
|
||||||
logger.debug(f"Fixed DataSamplingBlock {node_id} sample_size to 1")
|
|
||||||
|
|
||||||
if links_to_remove:
|
|
||||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_node_x_coordinates(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix node x-coordinates to ensure 800+ unit spacing between linked nodes."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
node_lookup = {n.get("id"): n for n in nodes}
|
|
||||||
|
|
||||||
for link in links:
|
|
||||||
source_id = link.get("source_id")
|
|
||||||
sink_id = link.get("sink_id")
|
|
||||||
|
|
||||||
source_node = node_lookup.get(source_id)
|
|
||||||
sink_node = node_lookup.get(sink_id)
|
|
||||||
|
|
||||||
if not source_node or not sink_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_pos = source_node.get("metadata", {}).get("position", {})
|
|
||||||
sink_pos = sink_node.get("metadata", {}).get("position", {})
|
|
||||||
|
|
||||||
source_x = source_pos.get("x", 0)
|
|
||||||
sink_x = sink_pos.get("x", 0)
|
|
||||||
|
|
||||||
if abs(sink_x - source_x) < 800:
|
|
||||||
new_x = source_x + 800
|
|
||||||
if "metadata" not in sink_node:
|
|
||||||
sink_node["metadata"] = {}
|
|
||||||
if "position" not in sink_node["metadata"]:
|
|
||||||
sink_node["metadata"]["position"] = {}
|
|
||||||
sink_node["metadata"]["position"]["x"] = new_x
|
|
||||||
logger.debug(f"Fixed node {sink_id} x: {sink_x} -> {new_x}")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_getcurrentdate_offset(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix GetCurrentDateBlock offset to ensure it's positive."""
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
if node.get("block_id") == GET_CURRENT_DATE_BLOCK_ID:
|
|
||||||
input_default = node.get("input_default", {})
|
|
||||||
if "offset" in input_default:
|
|
||||||
offset = input_default["offset"]
|
|
||||||
if isinstance(offset, (int, float)) and offset < 0:
|
|
||||||
input_default["offset"] = abs(offset)
|
|
||||||
logger.debug(f"Fixed offset: {offset} -> {abs(offset)}")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_ai_model_parameter(
|
|
||||||
agent: dict[str, Any],
|
|
||||||
blocks_info: list[dict[str, Any]],
|
|
||||||
default_model: str = "gpt-4o",
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Add default model parameter to AI blocks if missing."""
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
block_id = node.get("block_id")
|
|
||||||
block = block_map.get(block_id)
|
|
||||||
|
|
||||||
if not block:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if block has AI category
|
|
||||||
categories = block.get("categories", [])
|
|
||||||
is_ai_block = any(
|
|
||||||
cat.get("category") == "AI" for cat in categories if isinstance(cat, dict)
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_ai_block:
|
|
||||||
input_default = node.get("input_default", {})
|
|
||||||
if "model" not in input_default:
|
|
||||||
input_default["model"] = default_model
|
|
||||||
node["input_default"] = input_default
|
|
||||||
logger.debug(
|
|
||||||
f"Added model '{default_model}' to AI block {node.get('id')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_link_static_properties(
|
|
||||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Fix is_static property based on source block's staticOutput."""
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
|
||||||
|
|
||||||
for link in agent.get("links", []):
|
|
||||||
source_node = node_lookup.get(link.get("source_id"))
|
|
||||||
if not source_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_block = block_map.get(source_node.get("block_id"))
|
|
||||||
if not source_block:
|
|
||||||
continue
|
|
||||||
|
|
||||||
static_output = source_block.get("staticOutput", False)
|
|
||||||
if link.get("is_static") != static_output:
|
|
||||||
link["is_static"] = static_output
|
|
||||||
logger.debug(f"Fixed link {link.get('id')} is_static to {static_output}")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_data_type_mismatch(
|
|
||||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Fix data type mismatches by inserting UniversalTypeConverterBlock."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
node_lookup = {n.get("id"): n for n in nodes}
|
|
||||||
|
|
||||||
def get_property_type(schema: dict, name: str) -> str | None:
|
|
||||||
if "_#_" in name:
|
|
||||||
parent, child = name.split("_#_", 1)
|
|
||||||
parent_schema = schema.get(parent, {})
|
|
||||||
if "properties" in parent_schema:
|
|
||||||
return parent_schema["properties"].get(child, {}).get("type")
|
|
||||||
return None
|
|
||||||
return schema.get(name, {}).get("type")
|
|
||||||
|
|
||||||
def are_types_compatible(src: str, sink: str) -> bool:
|
|
||||||
if {src, sink} <= {"integer", "number"}:
|
|
||||||
return True
|
|
||||||
return src == sink
|
|
||||||
|
|
||||||
type_mapping = {
|
|
||||||
"string": "string",
|
|
||||||
"text": "string",
|
|
||||||
"integer": "number",
|
|
||||||
"number": "number",
|
|
||||||
"float": "number",
|
|
||||||
"boolean": "boolean",
|
|
||||||
"bool": "boolean",
|
|
||||||
"array": "list",
|
|
||||||
"list": "list",
|
|
||||||
"object": "dictionary",
|
|
||||||
"dict": "dictionary",
|
|
||||||
"dictionary": "dictionary",
|
|
||||||
}
|
|
||||||
|
|
||||||
new_links = []
|
|
||||||
nodes_to_add = []
|
|
||||||
|
|
||||||
for link in links:
|
|
||||||
source_node = node_lookup.get(link.get("source_id"))
|
|
||||||
sink_node = node_lookup.get(link.get("sink_id"))
|
|
||||||
|
|
||||||
if not source_node or not sink_node:
|
|
||||||
new_links.append(link)
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_block = block_map.get(source_node.get("block_id"))
|
|
||||||
sink_block = block_map.get(sink_node.get("block_id"))
|
|
||||||
|
|
||||||
if not source_block or not sink_block:
|
|
||||||
new_links.append(link)
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
|
||||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
|
||||||
|
|
||||||
source_type = get_property_type(source_outputs, link.get("source_name", ""))
|
|
||||||
sink_type = get_property_type(sink_inputs, link.get("sink_name", ""))
|
|
||||||
|
|
||||||
if (
|
|
||||||
source_type
|
|
||||||
and sink_type
|
|
||||||
and not are_types_compatible(source_type, sink_type)
|
|
||||||
):
|
|
||||||
# Insert type converter
|
|
||||||
converter_id = str(uuid.uuid4())
|
|
||||||
target_type = type_mapping.get(sink_type, sink_type)
|
|
||||||
|
|
||||||
converter_node = {
|
|
||||||
"id": converter_id,
|
|
||||||
"block_id": UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
|
||||||
"input_default": {"type": target_type},
|
|
||||||
"metadata": {"position": {"x": 0, "y": 100}},
|
|
||||||
}
|
|
||||||
nodes_to_add.append(converter_node)
|
|
||||||
|
|
||||||
# source -> converter
|
|
||||||
new_links.append(
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": link["source_id"],
|
|
||||||
"source_name": link["source_name"],
|
|
||||||
"sink_id": converter_id,
|
|
||||||
"sink_name": "value",
|
|
||||||
"is_static": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# converter -> sink
|
|
||||||
new_links.append(
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": converter_id,
|
|
||||||
"source_name": "value",
|
|
||||||
"sink_id": link["sink_id"],
|
|
||||||
"sink_name": link["sink_name"],
|
|
||||||
"is_static": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Inserted type converter: {source_type} -> {target_type}")
|
|
||||||
else:
|
|
||||||
new_links.append(link)
|
|
||||||
|
|
||||||
if nodes_to_add:
|
|
||||||
agent["nodes"] = nodes + nodes_to_add
|
|
||||||
agent["links"] = new_links
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def apply_all_fixes(
|
|
||||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Apply all fixes to an agent JSON.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent: Agent JSON dict
|
|
||||||
blocks_info: Optional list of block info dicts for advanced fixes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Fixed agent JSON
|
|
||||||
"""
|
|
||||||
# Basic fixes (no block info needed)
|
|
||||||
agent = fix_agent_ids(agent)
|
|
||||||
agent = fix_double_curly_braces(agent)
|
|
||||||
agent = fix_storevalue_before_condition(agent)
|
|
||||||
agent = fix_addtolist_blocks(agent)
|
|
||||||
agent = fix_addtodictionary_blocks(agent)
|
|
||||||
agent = fix_code_execution_output(agent)
|
|
||||||
agent = fix_data_sampling_sample_size(agent)
|
|
||||||
agent = fix_node_x_coordinates(agent)
|
|
||||||
agent = fix_getcurrentdate_offset(agent)
|
|
||||||
|
|
||||||
# Advanced fixes (require block info)
|
|
||||||
if blocks_info is None:
|
|
||||||
blocks_info = get_blocks_info()
|
|
||||||
|
|
||||||
agent = fix_ai_model_parameter(agent, blocks_info)
|
|
||||||
agent = fix_link_static_properties(agent, blocks_info)
|
|
||||||
agent = fix_data_type_mismatch(agent, blocks_info)
|
|
||||||
|
|
||||||
return agent
|
|
||||||
@@ -1,225 +0,0 @@
|
|||||||
"""Prompt templates for agent generation."""
|
|
||||||
|
|
||||||
DECOMPOSITION_PROMPT = """
|
|
||||||
You are an expert AutoGPT Workflow Decomposer. Your task is to analyze a user's high-level goal and break it down into a clear, step-by-step plan using the available blocks.
|
|
||||||
|
|
||||||
Each step should represent a distinct, automatable action suitable for execution by an AI automation system.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
FIRST: Analyze the user's goal and determine:
|
|
||||||
1) Design-time configuration (fixed settings that won't change per run)
|
|
||||||
2) Runtime inputs (values the agent's end-user will provide each time it runs)
|
|
||||||
|
|
||||||
For anything that can vary per run (email addresses, names, dates, search terms, etc.):
|
|
||||||
- DO NOT ask for the actual value
|
|
||||||
- Instead, define it as an Agent Input with a clear name, type, and description
|
|
||||||
|
|
||||||
Only ask clarifying questions about design-time config that affects how you build the workflow:
|
|
||||||
- Which external service to use (e.g., "Gmail vs Outlook", "Notion vs Google Docs")
|
|
||||||
- Required formats or structures (e.g., "CSV, JSON, or PDF output?")
|
|
||||||
- Business rules that must be hard-coded
|
|
||||||
|
|
||||||
IMPORTANT CLARIFICATIONS POLICY:
|
|
||||||
- Ask no more than five essential questions
|
|
||||||
- Do not ask for concrete values that can be provided at runtime as Agent Inputs
|
|
||||||
- Do not ask for API keys or credentials; the platform handles those directly
|
|
||||||
- If there is enough information to infer reasonable defaults, prefer to propose defaults
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
GUIDELINES:
|
|
||||||
1. List each step as a numbered item
|
|
||||||
2. Describe the action clearly and specify inputs/outputs
|
|
||||||
3. Ensure steps are in logical, sequential order
|
|
||||||
4. Mention block names naturally (e.g., "Use GetWeatherByLocationBlock to...")
|
|
||||||
5. Help the user reach their goal efficiently
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
RULES:
|
|
||||||
1. OUTPUT FORMAT: Only output either clarifying questions OR step-by-step instructions, not both
|
|
||||||
2. USE ONLY THE BLOCKS PROVIDED
|
|
||||||
3. ALL required_input fields must be provided
|
|
||||||
4. Data types of linked properties must match
|
|
||||||
5. Write expert-level prompts for AI-related blocks
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
CRITICAL BLOCK RESTRICTIONS:
|
|
||||||
1. AddToListBlock: Outputs updated list EVERY addition, not after all additions
|
|
||||||
2. SendEmailBlock: Draft the email for user review; set SMTP config based on email type
|
|
||||||
3. ConditionBlock: value2 is reference, value1 is contrast
|
|
||||||
4. CodeExecutionBlock: DO NOT USE - use AI blocks instead
|
|
||||||
5. ReadCsvBlock: Only use the 'rows' output, not 'row'
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
OUTPUT FORMAT:
|
|
||||||
|
|
||||||
If more information is needed:
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"type": "clarifying_questions",
|
|
||||||
"questions": [
|
|
||||||
{{
|
|
||||||
"question": "Which email provider should be used? (Gmail, Outlook, custom SMTP)",
|
|
||||||
"keyword": "email_provider",
|
|
||||||
"example": "Gmail"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
|
|
||||||
If ready to proceed:
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"type": "instructions",
|
|
||||||
"steps": [
|
|
||||||
{{
|
|
||||||
"step_number": 1,
|
|
||||||
"block_name": "AgentShortTextInputBlock",
|
|
||||||
"description": "Get the URL of the content to analyze.",
|
|
||||||
"inputs": [{{"name": "name", "value": "URL"}}],
|
|
||||||
"outputs": [{{"name": "result", "description": "The URL entered by user"}}]
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
AVAILABLE BLOCKS:
|
|
||||||
{block_summaries}
|
|
||||||
"""
|
|
||||||
|
|
||||||
GENERATION_PROMPT = """
|
|
||||||
You are an expert AI workflow builder. Generate a valid agent JSON from the given instructions.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
NODES:
|
|
||||||
Each node must include:
|
|
||||||
- `id`: Unique UUID v4 (e.g. `a8f5b1e2-c3d4-4e5f-8a9b-0c1d2e3f4a5b`)
|
|
||||||
- `block_id`: The block identifier (must match an Allowed Block)
|
|
||||||
- `input_default`: Dict of inputs (can be empty if no static inputs needed)
|
|
||||||
- `metadata`: Must contain:
|
|
||||||
- `position`: {{"x": number, "y": number}} - adjacent nodes should differ by 800+ in X
|
|
||||||
- `customized_name`: Clear name describing this block's purpose in the workflow
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
LINKS:
|
|
||||||
Each link connects a source node's output to a sink node's input:
|
|
||||||
- `id`: MUST be UUID v4 (NOT "link-1", "link-2", etc.)
|
|
||||||
- `source_id`: ID of the source node
|
|
||||||
- `source_name`: Output field name from the source block
|
|
||||||
- `sink_id`: ID of the sink node
|
|
||||||
- `sink_name`: Input field name on the sink block
|
|
||||||
- `is_static`: true only if source block has static_output: true
|
|
||||||
|
|
||||||
CRITICAL: All IDs must be valid UUID v4 format!
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
AGENT (GRAPH):
|
|
||||||
Wrap nodes and links in:
|
|
||||||
- `id`: UUID of the agent
|
|
||||||
- `name`: Short, generic name (avoid specific company names, URLs)
|
|
||||||
- `description`: Short, generic description
|
|
||||||
- `nodes`: List of all nodes
|
|
||||||
- `links`: List of all links
|
|
||||||
- `version`: 1
|
|
||||||
- `is_active`: true
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
TIPS:
|
|
||||||
- All required_input fields must be provided via input_default or a valid link
|
|
||||||
- Ensure consistent source_id and sink_id references
|
|
||||||
- Avoid dangling links
|
|
||||||
- Input/output pins must match block schemas
|
|
||||||
- Do not invent unknown block_ids
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
ALLOWED BLOCKS:
|
|
||||||
{block_summaries}
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
Generate the complete agent JSON. Output ONLY valid JSON, no explanation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATCH_PROMPT = """
|
|
||||||
You are an expert at modifying AutoGPT agent workflows. Given the current agent and a modification request, generate a JSON patch to update the agent.
|
|
||||||
|
|
||||||
CURRENT AGENT:
|
|
||||||
{current_agent}
|
|
||||||
|
|
||||||
AVAILABLE BLOCKS:
|
|
||||||
{block_summaries}
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
PATCH FORMAT:
|
|
||||||
Return a JSON object with the following structure:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"type": "patch",
|
|
||||||
"intent": "Brief description of what the patch does",
|
|
||||||
"patches": [
|
|
||||||
{{
|
|
||||||
"type": "modify",
|
|
||||||
"node_id": "uuid-of-node-to-modify",
|
|
||||||
"changes": {{
|
|
||||||
"input_default": {{"field": "new_value"}},
|
|
||||||
"metadata": {{"customized_name": "New Name"}}
|
|
||||||
}}
|
|
||||||
}},
|
|
||||||
{{
|
|
||||||
"type": "add",
|
|
||||||
"new_nodes": [
|
|
||||||
{{
|
|
||||||
"id": "new-uuid",
|
|
||||||
"block_id": "block-uuid",
|
|
||||||
"input_default": {{}},
|
|
||||||
"metadata": {{"position": {{"x": 0, "y": 0}}, "customized_name": "Name"}}
|
|
||||||
}}
|
|
||||||
],
|
|
||||||
"new_links": [
|
|
||||||
{{
|
|
||||||
"id": "link-uuid",
|
|
||||||
"source_id": "source-node-id",
|
|
||||||
"source_name": "output_field",
|
|
||||||
"sink_id": "sink-node-id",
|
|
||||||
"sink_name": "input_field"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}},
|
|
||||||
{{
|
|
||||||
"type": "remove",
|
|
||||||
"node_ids": ["uuid-of-node-to-remove"],
|
|
||||||
"link_ids": ["uuid-of-link-to-remove"]
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
|
|
||||||
If you need more information, return:
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"type": "clarifying_questions",
|
|
||||||
"questions": [
|
|
||||||
{{
|
|
||||||
"question": "What specific change do you want?",
|
|
||||||
"keyword": "change_type",
|
|
||||||
"example": "Add error handling"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
|
|
||||||
Generate the minimal patch needed. Output ONLY valid JSON.
|
|
||||||
"""
|
|
||||||
@@ -0,0 +1,269 @@
|
|||||||
|
"""External Agent Generator service client.
|
||||||
|
|
||||||
|
This module provides a client for communicating with the external Agent Generator
|
||||||
|
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
|
||||||
|
will delegate to the external service instead of using the built-in LLM-based implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_client: httpx.AsyncClient | None = None
|
||||||
|
_settings: Settings | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_settings() -> Settings:
|
||||||
|
"""Get or create settings singleton."""
|
||||||
|
global _settings
|
||||||
|
if _settings is None:
|
||||||
|
_settings = Settings()
|
||||||
|
return _settings
|
||||||
|
|
||||||
|
|
||||||
|
def is_external_service_configured() -> bool:
|
||||||
|
"""Check if external Agent Generator service is configured."""
|
||||||
|
settings = _get_settings()
|
||||||
|
return bool(settings.config.agentgenerator_host)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_base_url() -> str:
|
||||||
|
"""Get the base URL for the external service."""
|
||||||
|
settings = _get_settings()
|
||||||
|
host = settings.config.agentgenerator_host
|
||||||
|
port = settings.config.agentgenerator_port
|
||||||
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client() -> httpx.AsyncClient:
|
||||||
|
"""Get or create the HTTP client for the external service."""
|
||||||
|
global _client
|
||||||
|
if _client is None:
|
||||||
|
settings = _get_settings()
|
||||||
|
_client = httpx.AsyncClient(
|
||||||
|
base_url=_get_base_url(),
|
||||||
|
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
|
||||||
|
)
|
||||||
|
return _client
|
||||||
|
|
||||||
|
|
||||||
|
async def decompose_goal_external(
|
||||||
|
description: str, context: str = ""
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Call the external service to decompose a goal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description: Natural language goal description
|
||||||
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with either:
|
||||||
|
- {"type": "clarifying_questions", "questions": [...]}
|
||||||
|
- {"type": "instructions", "steps": [...]}
|
||||||
|
- {"type": "unachievable_goal", ...}
|
||||||
|
- {"type": "vague_goal", ...}
|
||||||
|
Or None on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
# Build the request payload
|
||||||
|
payload: dict[str, Any] = {"description": description}
|
||||||
|
if context:
|
||||||
|
# The external service uses user_instruction for additional context
|
||||||
|
payload["user_instruction"] = context
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post("/api/decompose-description", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
logger.error(f"External service returned error: {data.get('error')}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Map the response to the expected format
|
||||||
|
response_type = data.get("type")
|
||||||
|
if response_type == "instructions":
|
||||||
|
return {"type": "instructions", "steps": data.get("steps", [])}
|
||||||
|
elif response_type == "clarifying_questions":
|
||||||
|
return {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": data.get("questions", []),
|
||||||
|
}
|
||||||
|
elif response_type == "unachievable_goal":
|
||||||
|
return {
|
||||||
|
"type": "unachievable_goal",
|
||||||
|
"reason": data.get("reason"),
|
||||||
|
"suggested_goal": data.get("suggested_goal"),
|
||||||
|
}
|
||||||
|
elif response_type == "vague_goal":
|
||||||
|
return {
|
||||||
|
"type": "vague_goal",
|
||||||
|
"suggested_goal": data.get("suggested_goal"),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Unknown response type from external service: {response_type}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Request error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_agent_external(
|
||||||
|
instructions: dict[str, Any]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Call the external service to generate an agent from instructions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instructions: Structured instructions from decompose_goal
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent JSON dict or None on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/generate-agent", json={"instructions": instructions}
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
logger.error(f"External service returned error: {data.get('error')}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data.get("agent_json")
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Request error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_agent_patch_external(
|
||||||
|
update_request: str, current_agent: dict[str, Any]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Call the external service to generate a patch for an existing agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update_request: Natural language description of changes
|
||||||
|
current_agent: Current agent JSON
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated agent JSON, clarifying questions dict, or None on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/update-agent",
|
||||||
|
json={
|
||||||
|
"update_request": update_request,
|
||||||
|
"current_agent_json": current_agent,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
logger.error(f"External service returned error: {data.get('error')}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if it's clarifying questions
|
||||||
|
if data.get("type") == "clarifying_questions":
|
||||||
|
return {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": data.get("questions", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Otherwise return the updated agent JSON
|
||||||
|
return data.get("agent_json")
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Request error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||||
|
"""Get available blocks from the external service.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of block info dicts or None on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get("/api/blocks")
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
logger.error("External service returned error getting blocks")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data.get("blocks", [])
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"HTTP error getting blocks from external service: {e}")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Request error getting blocks from external service: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error getting blocks from external service: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def health_check() -> bool:
|
||||||
|
"""Check if the external service is healthy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if healthy, False otherwise
|
||||||
|
"""
|
||||||
|
if not is_external_service_configured():
|
||||||
|
return False
|
||||||
|
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get("/health")
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"External agent generator health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def close_client() -> None:
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
global _client
|
||||||
|
if _client is not None:
|
||||||
|
await _client.aclose()
|
||||||
|
_client = None
|
||||||
@@ -1,213 +0,0 @@
|
|||||||
"""Utilities for agent generation."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.data.block import get_blocks
|
|
||||||
|
|
||||||
# UUID validation regex
|
|
||||||
UUID_REGEX = re.compile(
|
|
||||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Block IDs for various fixes
|
|
||||||
STORE_VALUE_BLOCK_ID = "1ff065e9-88e8-4358-9d82-8dc91f622ba9"
|
|
||||||
CONDITION_BLOCK_ID = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
|
|
||||||
ADDTOLIST_BLOCK_ID = "aeb08fc1-2fc1-4141-bc8e-f758f183a822"
|
|
||||||
ADDTODICTIONARY_BLOCK_ID = "31d1064e-7446-4693-a7d4-65e5ca1180d1"
|
|
||||||
CREATELIST_BLOCK_ID = "a912d5c7-6e00-4542-b2a9-8034136930e4"
|
|
||||||
CREATEDICT_BLOCK_ID = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
|
|
||||||
CODE_EXECUTION_BLOCK_ID = "0b02b072-abe7-11ef-8372-fb5d162dd712"
|
|
||||||
DATA_SAMPLING_BLOCK_ID = "4a448883-71fa-49cf-91cf-70d793bd7d87"
|
|
||||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID = "95d1b990-ce13-4d88-9737-ba5c2070c97b"
|
|
||||||
GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
|
|
||||||
|
|
||||||
DOUBLE_CURLY_BRACES_BLOCK_IDS = [
|
|
||||||
"44f6c8ad-d75c-4ae1-8209-aad1c0326928", # FillTextTemplateBlock
|
|
||||||
"6ab085e2-20b3-4055-bc3e-08036e01eca6",
|
|
||||||
"90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
|
|
||||||
"363ae599-353e-4804-937e-b2ee3cef3da4", # AgentOutputBlock
|
|
||||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
|
||||||
"db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
|
|
||||||
"3a7c4b8d-6e2f-4a5d-b9c1-f8d23c5a9b0e",
|
|
||||||
"ed1ae7a0-b770-4089-b520-1f0005fad19a",
|
|
||||||
"a892b8d9-3e4e-4e9c-9c1e-75f8efcf1bfa",
|
|
||||||
"b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1",
|
|
||||||
"716a67b3-6760-42e7-86dc-18645c6e00fc",
|
|
||||||
"530cf046-2ce0-4854-ae2c-659db17c7a46",
|
|
||||||
"ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
|
||||||
"1f292d4a-41a4-4977-9684-7c8d560b9f91", # LLM blocks
|
|
||||||
"32a87eab-381e-4dd4-bdb8-4c47151be35a",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_uuid(value: str) -> bool:
|
|
||||||
"""Check if a string is a valid UUID v4."""
|
|
||||||
return isinstance(value, str) and UUID_REGEX.match(value) is not None
|
|
||||||
|
|
||||||
|
|
||||||
def _compact_schema(schema: dict) -> dict[str, str]:
|
|
||||||
"""Extract compact type info from a JSON schema properties dict.
|
|
||||||
|
|
||||||
Returns a dict of {field_name: type_string} for essential info only.
|
|
||||||
"""
|
|
||||||
props = schema.get("properties", {})
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
for name, prop in props.items():
|
|
||||||
# Skip internal/complex fields
|
|
||||||
if name.startswith("_"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get type string
|
|
||||||
type_str = prop.get("type", "any")
|
|
||||||
|
|
||||||
# Handle anyOf/oneOf (optional types)
|
|
||||||
if "anyOf" in prop:
|
|
||||||
types = [t.get("type", "?") for t in prop["anyOf"] if t.get("type")]
|
|
||||||
type_str = "|".join(types) if types else "any"
|
|
||||||
elif "allOf" in prop:
|
|
||||||
type_str = "object"
|
|
||||||
|
|
||||||
# Add array item type if present
|
|
||||||
if type_str == "array" and "items" in prop:
|
|
||||||
items = prop["items"]
|
|
||||||
if isinstance(items, dict):
|
|
||||||
item_type = items.get("type", "any")
|
|
||||||
type_str = f"array[{item_type}]"
|
|
||||||
|
|
||||||
result[name] = type_str
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_block_summaries(include_schemas: bool = True) -> str:
|
|
||||||
"""Generate compact block summaries for prompts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
include_schemas: Whether to include input/output type info
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted string of block summaries (compact format)
|
|
||||||
"""
|
|
||||||
blocks = get_blocks()
|
|
||||||
summaries = []
|
|
||||||
|
|
||||||
for block_id, block_cls in blocks.items():
|
|
||||||
block = block_cls()
|
|
||||||
name = block.name
|
|
||||||
desc = getattr(block, "description", "") or ""
|
|
||||||
|
|
||||||
# Truncate description
|
|
||||||
if len(desc) > 150:
|
|
||||||
desc = desc[:147] + "..."
|
|
||||||
|
|
||||||
if not include_schemas:
|
|
||||||
summaries.append(f"- {name} (id: {block_id}): {desc}")
|
|
||||||
else:
|
|
||||||
# Compact format with type info only
|
|
||||||
inputs = {}
|
|
||||||
outputs = {}
|
|
||||||
required = []
|
|
||||||
|
|
||||||
if hasattr(block, "input_schema"):
|
|
||||||
try:
|
|
||||||
schema = block.input_schema.jsonschema()
|
|
||||||
inputs = _compact_schema(schema)
|
|
||||||
required = schema.get("required", [])
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if hasattr(block, "output_schema"):
|
|
||||||
try:
|
|
||||||
schema = block.output_schema.jsonschema()
|
|
||||||
outputs = _compact_schema(schema)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Build compact line format
|
|
||||||
# Format: NAME (id): desc | in: {field:type, ...} [required] | out: {field:type}
|
|
||||||
in_str = ", ".join(f"{k}:{v}" for k, v in inputs.items())
|
|
||||||
out_str = ", ".join(f"{k}:{v}" for k, v in outputs.items())
|
|
||||||
req_str = f" req=[{','.join(required)}]" if required else ""
|
|
||||||
|
|
||||||
static = " [static]" if getattr(block, "static_output", False) else ""
|
|
||||||
|
|
||||||
line = f"- {name} (id: {block_id}): {desc}"
|
|
||||||
if in_str:
|
|
||||||
line += f"\n in: {{{in_str}}}{req_str}"
|
|
||||||
if out_str:
|
|
||||||
line += f"\n out: {{{out_str}}}{static}"
|
|
||||||
|
|
||||||
summaries.append(line)
|
|
||||||
|
|
||||||
return "\n".join(summaries)
|
|
||||||
|
|
||||||
|
|
||||||
def get_blocks_info() -> list[dict[str, Any]]:
|
|
||||||
"""Get block information with schemas for validation and fixing."""
|
|
||||||
blocks = get_blocks()
|
|
||||||
blocks_info = []
|
|
||||||
for block_id, block_cls in blocks.items():
|
|
||||||
block = block_cls()
|
|
||||||
blocks_info.append(
|
|
||||||
{
|
|
||||||
"id": block_id,
|
|
||||||
"name": block.name,
|
|
||||||
"description": getattr(block, "description", ""),
|
|
||||||
"categories": getattr(block, "categories", []),
|
|
||||||
"staticOutput": getattr(block, "static_output", False),
|
|
||||||
"inputSchema": (
|
|
||||||
block.input_schema.jsonschema()
|
|
||||||
if hasattr(block, "input_schema")
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
"outputSchema": (
|
|
||||||
block.output_schema.jsonschema()
|
|
||||||
if hasattr(block, "output_schema")
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return blocks_info
|
|
||||||
|
|
||||||
|
|
||||||
def parse_json_from_llm(text: str) -> dict[str, Any] | None:
|
|
||||||
"""Extract JSON from LLM response (handles markdown code blocks)."""
|
|
||||||
if not text:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Try fenced code block
|
|
||||||
match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE)
|
|
||||||
if match:
|
|
||||||
try:
|
|
||||||
return json.loads(match.group(1).strip())
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Try raw text
|
|
||||||
try:
|
|
||||||
return json.loads(text.strip())
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Try finding {...} span
|
|
||||||
start = text.find("{")
|
|
||||||
end = text.rfind("}")
|
|
||||||
if start != -1 and end > start:
|
|
||||||
try:
|
|
||||||
return json.loads(text[start : end + 1])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Try finding [...] span
|
|
||||||
start = text.find("[")
|
|
||||||
end = text.rfind("]")
|
|
||||||
if start != -1 and end > start:
|
|
||||||
try:
|
|
||||||
return json.loads(text[start : end + 1])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return None
|
|
||||||
@@ -1,279 +0,0 @@
|
|||||||
"""Agent validator - Validates agent structure and connections."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from .utils import get_blocks_info
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentValidator:
|
|
||||||
"""Validator for AutoGPT agents with detailed error reporting."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.errors: list[str] = []
|
|
||||||
|
|
||||||
def add_error(self, error: str) -> None:
|
|
||||||
"""Add an error message."""
|
|
||||||
self.errors.append(error)
|
|
||||||
|
|
||||||
def validate_block_existence(
|
|
||||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> bool:
|
|
||||||
"""Validate all block IDs exist in the blocks library."""
|
|
||||||
valid = True
|
|
||||||
valid_block_ids = {b.get("id") for b in blocks_info if b.get("id")}
|
|
||||||
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
block_id = node.get("block_id")
|
|
||||||
node_id = node.get("id")
|
|
||||||
|
|
||||||
if not block_id:
|
|
||||||
self.add_error(f"Node '{node_id}' is missing 'block_id' field.")
|
|
||||||
valid = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
if block_id not in valid_block_ids:
|
|
||||||
self.add_error(
|
|
||||||
f"Node '{node_id}' references block_id '{block_id}' which does not exist."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate_link_node_references(self, agent: dict[str, Any]) -> bool:
|
|
||||||
"""Validate all node IDs referenced in links exist."""
|
|
||||||
valid = True
|
|
||||||
valid_node_ids = {n.get("id") for n in agent.get("nodes", []) if n.get("id")}
|
|
||||||
|
|
||||||
for link in agent.get("links", []):
|
|
||||||
link_id = link.get("id", "Unknown")
|
|
||||||
source_id = link.get("source_id")
|
|
||||||
sink_id = link.get("sink_id")
|
|
||||||
|
|
||||||
if not source_id:
|
|
||||||
self.add_error(f"Link '{link_id}' is missing 'source_id'.")
|
|
||||||
valid = False
|
|
||||||
elif source_id not in valid_node_ids:
|
|
||||||
self.add_error(
|
|
||||||
f"Link '{link_id}' references non-existent source_id '{source_id}'."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
if not sink_id:
|
|
||||||
self.add_error(f"Link '{link_id}' is missing 'sink_id'.")
|
|
||||||
valid = False
|
|
||||||
elif sink_id not in valid_node_ids:
|
|
||||||
self.add_error(
|
|
||||||
f"Link '{link_id}' references non-existent sink_id '{sink_id}'."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate_required_inputs(
|
|
||||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> bool:
|
|
||||||
"""Validate required inputs are provided."""
|
|
||||||
valid = True
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
block_id = node.get("block_id")
|
|
||||||
block = block_map.get(block_id)
|
|
||||||
|
|
||||||
if not block:
|
|
||||||
continue
|
|
||||||
|
|
||||||
required_inputs = block.get("inputSchema", {}).get("required", [])
|
|
||||||
input_defaults = node.get("input_default", {})
|
|
||||||
node_id = node.get("id")
|
|
||||||
|
|
||||||
# Get linked inputs
|
|
||||||
linked_inputs = {
|
|
||||||
link["sink_name"]
|
|
||||||
for link in agent.get("links", [])
|
|
||||||
if link.get("sink_id") == node_id
|
|
||||||
}
|
|
||||||
|
|
||||||
for req_input in required_inputs:
|
|
||||||
if (
|
|
||||||
req_input not in input_defaults
|
|
||||||
and req_input not in linked_inputs
|
|
||||||
and req_input != "credentials"
|
|
||||||
):
|
|
||||||
block_name = block.get("name", "Unknown Block")
|
|
||||||
self.add_error(
|
|
||||||
f"Node '{node_id}' ({block_name}) is missing required input '{req_input}'."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate_data_type_compatibility(
|
|
||||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> bool:
|
|
||||||
"""Validate linked data types are compatible."""
|
|
||||||
valid = True
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
|
||||||
|
|
||||||
def get_type(schema: dict, name: str) -> str | None:
|
|
||||||
if "_#_" in name:
|
|
||||||
parent, child = name.split("_#_", 1)
|
|
||||||
parent_schema = schema.get(parent, {})
|
|
||||||
if "properties" in parent_schema:
|
|
||||||
return parent_schema["properties"].get(child, {}).get("type")
|
|
||||||
return None
|
|
||||||
return schema.get(name, {}).get("type")
|
|
||||||
|
|
||||||
def are_compatible(src: str, sink: str) -> bool:
|
|
||||||
if {src, sink} <= {"integer", "number"}:
|
|
||||||
return True
|
|
||||||
return src == sink
|
|
||||||
|
|
||||||
for link in agent.get("links", []):
|
|
||||||
source_node = node_lookup.get(link.get("source_id"))
|
|
||||||
sink_node = node_lookup.get(link.get("sink_id"))
|
|
||||||
|
|
||||||
if not source_node or not sink_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_block = block_map.get(source_node.get("block_id"))
|
|
||||||
sink_block = block_map.get(sink_node.get("block_id"))
|
|
||||||
|
|
||||||
if not source_block or not sink_block:
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
|
||||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
|
||||||
|
|
||||||
source_type = get_type(source_outputs, link.get("source_name", ""))
|
|
||||||
sink_type = get_type(sink_inputs, link.get("sink_name", ""))
|
|
||||||
|
|
||||||
if source_type and sink_type and not are_compatible(source_type, sink_type):
|
|
||||||
self.add_error(
|
|
||||||
f"Type mismatch: {source_block.get('name')} output '{link['source_name']}' "
|
|
||||||
f"({source_type}) -> {sink_block.get('name')} input '{link['sink_name']}' ({sink_type})."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate_nested_sink_links(
|
|
||||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> bool:
|
|
||||||
"""Validate nested sink links (with _#_ notation)."""
|
|
||||||
valid = True
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
|
||||||
|
|
||||||
for link in agent.get("links", []):
|
|
||||||
sink_name = link.get("sink_name", "")
|
|
||||||
|
|
||||||
if "_#_" in sink_name:
|
|
||||||
parent, child = sink_name.split("_#_", 1)
|
|
||||||
|
|
||||||
sink_node = node_lookup.get(link.get("sink_id"))
|
|
||||||
if not sink_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
block = block_map.get(sink_node.get("block_id"))
|
|
||||||
if not block:
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_props = block.get("inputSchema", {}).get("properties", {})
|
|
||||||
parent_schema = input_props.get(parent)
|
|
||||||
|
|
||||||
if not parent_schema:
|
|
||||||
self.add_error(
|
|
||||||
f"Invalid nested link '{sink_name}': parent '{parent}' not found."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not parent_schema.get("additionalProperties"):
|
|
||||||
if not (
|
|
||||||
isinstance(parent_schema, dict)
|
|
||||||
and "properties" in parent_schema
|
|
||||||
and child in parent_schema.get("properties", {})
|
|
||||||
):
|
|
||||||
self.add_error(
|
|
||||||
f"Invalid nested link '{sink_name}': child '{child}' not found in '{parent}'."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate_prompt_spaces(self, agent: dict[str, Any]) -> bool:
|
|
||||||
"""Validate prompts don't have spaces in template variables."""
|
|
||||||
valid = True
|
|
||||||
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
input_default = node.get("input_default", {})
|
|
||||||
prompt = input_default.get("prompt", "")
|
|
||||||
|
|
||||||
if not isinstance(prompt, str):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Find {{...}} with spaces
|
|
||||||
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt)
|
|
||||||
for match in matches:
|
|
||||||
content = match.group(1)
|
|
||||||
if " " in content:
|
|
||||||
self.add_error(
|
|
||||||
f"Node '{node.get('id')}' has spaces in template variable: "
|
|
||||||
f"'{{{{{content}}}}}' should be '{{{{{content.replace(' ', '_')}}}}}'."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate(
|
|
||||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
|
||||||
) -> tuple[bool, str | None]:
|
|
||||||
"""Run all validations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_valid, error_message)
|
|
||||||
"""
|
|
||||||
self.errors = []
|
|
||||||
|
|
||||||
if blocks_info is None:
|
|
||||||
blocks_info = get_blocks_info()
|
|
||||||
|
|
||||||
checks = [
|
|
||||||
self.validate_block_existence(agent, blocks_info),
|
|
||||||
self.validate_link_node_references(agent),
|
|
||||||
self.validate_required_inputs(agent, blocks_info),
|
|
||||||
self.validate_data_type_compatibility(agent, blocks_info),
|
|
||||||
self.validate_nested_sink_links(agent, blocks_info),
|
|
||||||
self.validate_prompt_spaces(agent),
|
|
||||||
]
|
|
||||||
|
|
||||||
all_passed = all(checks)
|
|
||||||
|
|
||||||
if all_passed:
|
|
||||||
logger.info("Agent validation successful")
|
|
||||||
return True, None
|
|
||||||
|
|
||||||
error_message = "Agent validation failed:\n"
|
|
||||||
for i, error in enumerate(self.errors, 1):
|
|
||||||
error_message += f"{i}. {error}\n"
|
|
||||||
|
|
||||||
logger.warning(f"Agent validation failed with {len(self.errors)} errors")
|
|
||||||
return False, error_message
|
|
||||||
|
|
||||||
|
|
||||||
def validate_agent(
|
|
||||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
|
||||||
) -> tuple[bool, str | None]:
|
|
||||||
"""Convenience function to validate an agent.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_valid, error_message)
|
|
||||||
"""
|
|
||||||
validator = AgentValidator()
|
|
||||||
return validator.validate(agent, blocks_info)
|
|
||||||
@@ -5,7 +5,6 @@ import re
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langfuse import observe
|
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
@@ -329,7 +328,6 @@ class AgentOutputTool(BaseTool):
|
|||||||
total_executions=len(available_executions) if available_executions else 1,
|
total_executions=len(available_executions) if available_executions else 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@observe(as_type="tool", name="view_agent_output")
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
|
|||||||
@@ -36,6 +36,16 @@ class BaseTool:
|
|||||||
"""Whether this tool requires authentication."""
|
"""Whether this tool requires authentication."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
"""Whether this tool is long-running and should execute in background.
|
||||||
|
|
||||||
|
Long-running tools (like agent generation) are executed via background
|
||||||
|
tasks to survive SSE disconnections. The result is persisted to chat
|
||||||
|
history and visible when the user refreshes.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
def as_openai_tool(self) -> ChatCompletionToolParam:
|
def as_openai_tool(self) -> ChatCompletionToolParam:
|
||||||
"""Convert to OpenAI tool format."""
|
"""Convert to OpenAI tool format."""
|
||||||
return ChatCompletionToolParam(
|
return ChatCompletionToolParam(
|
||||||
|
|||||||
@@ -3,17 +3,13 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langfuse import observe
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
apply_all_fixes,
|
AgentGeneratorNotConfiguredError,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
generate_agent,
|
generate_agent,
|
||||||
get_blocks_info,
|
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
validate_agent,
|
|
||||||
)
|
)
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .models import (
|
from .models import (
|
||||||
@@ -27,9 +23,6 @@ from .models import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Maximum retries for agent generation with validation feedback
|
|
||||||
MAX_GENERATION_RETRIES = 2
|
|
||||||
|
|
||||||
|
|
||||||
class CreateAgentTool(BaseTool):
|
class CreateAgentTool(BaseTool):
|
||||||
"""Tool for creating agents from natural language descriptions."""
|
"""Tool for creating agents from natural language descriptions."""
|
||||||
@@ -49,6 +42,10 @@ class CreateAgentTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
@@ -80,7 +77,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
"required": ["description"],
|
"required": ["description"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@observe(as_type="tool", name="create_agent")
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
@@ -91,9 +87,8 @@ class CreateAgentTool(BaseTool):
|
|||||||
|
|
||||||
Flow:
|
Flow:
|
||||||
1. Decompose the description into steps (may return clarifying questions)
|
1. Decompose the description into steps (may return clarifying questions)
|
||||||
2. Generate agent JSON from the steps
|
2. Generate agent JSON (external service handles fixing and validation)
|
||||||
3. Apply fixes to correct common LLM errors
|
3. Preview or save based on the save parameter
|
||||||
4. Preview or save based on the save parameter
|
|
||||||
"""
|
"""
|
||||||
description = kwargs.get("description", "").strip()
|
description = kwargs.get("description", "").strip()
|
||||||
context = kwargs.get("context", "")
|
context = kwargs.get("context", "")
|
||||||
@@ -110,18 +105,23 @@ class CreateAgentTool(BaseTool):
|
|||||||
# Step 1: Decompose goal into steps
|
# Step 1: Decompose goal into steps
|
||||||
try:
|
try:
|
||||||
decomposition_result = await decompose_goal(description, context)
|
decomposition_result = await decompose_goal(description, context)
|
||||||
except ValueError as e:
|
except AgentGeneratorNotConfiguredError:
|
||||||
# Handle missing API key or configuration errors
|
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=f"Agent generation is not configured: {str(e)}",
|
message=(
|
||||||
error="configuration_error",
|
"Agent generation is not available. "
|
||||||
|
"The Agent Generator service is not configured."
|
||||||
|
),
|
||||||
|
error="service_not_configured",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if decomposition_result is None:
|
if decomposition_result is None:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to analyze the goal. Please try rephrasing.",
|
message="Failed to analyze the goal. The agent generation service may be unavailable or timed out. Please try again.",
|
||||||
error="Decomposition failed",
|
error="decomposition_failed",
|
||||||
|
details={
|
||||||
|
"description": description[:100]
|
||||||
|
}, # Include context for debugging
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -171,63 +171,26 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 2: Generate agent JSON with retry on validation failure
|
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
||||||
blocks_info = get_blocks_info()
|
try:
|
||||||
agent_json = None
|
|
||||||
validation_errors = None
|
|
||||||
|
|
||||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
|
||||||
# Generate agent (include validation errors from previous attempt)
|
|
||||||
if attempt == 0:
|
|
||||||
agent_json = await generate_agent(decomposition_result)
|
agent_json = await generate_agent(decomposition_result)
|
||||||
else:
|
except AgentGeneratorNotConfiguredError:
|
||||||
# Retry with validation error feedback
|
|
||||||
logger.info(
|
|
||||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
|
||||||
)
|
|
||||||
retry_instructions = {
|
|
||||||
**decomposition_result,
|
|
||||||
"previous_errors": validation_errors,
|
|
||||||
"retry_instructions": (
|
|
||||||
"The previous generation had validation errors. "
|
|
||||||
"Please fix these issues in the new generation:\n"
|
|
||||||
f"{validation_errors}"
|
|
||||||
),
|
|
||||||
}
|
|
||||||
agent_json = await generate_agent(retry_instructions)
|
|
||||||
|
|
||||||
if agent_json is None:
|
|
||||||
if attempt == MAX_GENERATION_RETRIES:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to generate the agent. Please try again.",
|
|
||||||
error="Generation failed",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Step 3: Apply fixes to correct common errors
|
|
||||||
agent_json = apply_all_fixes(agent_json, blocks_info)
|
|
||||||
|
|
||||||
# Step 4: Validate the agent
|
|
||||||
is_valid, validation_errors = validate_agent(agent_json, blocks_info)
|
|
||||||
|
|
||||||
if is_valid:
|
|
||||||
logger.info(f"Agent generated successfully on attempt {attempt + 1}")
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attempt == MAX_GENERATION_RETRIES:
|
|
||||||
# Return error with validation details
|
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
f"Generated agent has validation errors after {MAX_GENERATION_RETRIES + 1} attempts. "
|
"Agent generation is not available. "
|
||||||
f"Please try rephrasing your request or simplify the workflow."
|
"The Agent Generator service is not configured."
|
||||||
),
|
),
|
||||||
error="validation_failed",
|
error="service_not_configured",
|
||||||
details={"validation_errors": validation_errors},
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if agent_json is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to generate the agent. The agent generation service may be unavailable or timed out. Please try again.",
|
||||||
|
error="generation_failed",
|
||||||
|
details={
|
||||||
|
"description": description[:100]
|
||||||
|
}, # Include context for debugging
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -236,7 +199,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
node_count = len(agent_json.get("nodes", []))
|
node_count = len(agent_json.get("nodes", []))
|
||||||
link_count = len(agent_json.get("links", []))
|
link_count = len(agent_json.get("links", []))
|
||||||
|
|
||||||
# Step 4: Preview or save
|
# Step 3: Preview or save
|
||||||
if not save:
|
if not save:
|
||||||
return AgentPreviewResponse(
|
return AgentPreviewResponse(
|
||||||
message=(
|
message=(
|
||||||
|
|||||||
@@ -3,18 +3,13 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langfuse import observe
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
apply_agent_patch,
|
AgentGeneratorNotConfiguredError,
|
||||||
apply_all_fixes,
|
|
||||||
generate_agent_patch,
|
generate_agent_patch,
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
get_blocks_info,
|
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
validate_agent,
|
|
||||||
)
|
)
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .models import (
|
from .models import (
|
||||||
@@ -28,9 +23,6 @@ from .models import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Maximum retries for patch generation with validation feedback
|
|
||||||
MAX_GENERATION_RETRIES = 2
|
|
||||||
|
|
||||||
|
|
||||||
class EditAgentTool(BaseTool):
|
class EditAgentTool(BaseTool):
|
||||||
"""Tool for editing existing agents using natural language."""
|
"""Tool for editing existing agents using natural language."""
|
||||||
@@ -43,13 +35,17 @@ class EditAgentTool(BaseTool):
|
|||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Edit an existing agent from the user's library using natural language. "
|
"Edit an existing agent from the user's library using natural language. "
|
||||||
"Generates a patch to update the agent while preserving unchanged parts."
|
"Generates updates to the agent while preserving unchanged parts."
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
@@ -87,7 +83,6 @@ class EditAgentTool(BaseTool):
|
|||||||
"required": ["agent_id", "changes"],
|
"required": ["agent_id", "changes"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@observe(as_type="tool", name="edit_agent")
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
@@ -98,9 +93,8 @@ class EditAgentTool(BaseTool):
|
|||||||
|
|
||||||
Flow:
|
Flow:
|
||||||
1. Fetch the current agent
|
1. Fetch the current agent
|
||||||
2. Generate a patch based on the requested changes
|
2. Generate updated agent (external service handles fixing and validation)
|
||||||
3. Apply the patch to create an updated agent
|
3. Preview or save based on the save parameter
|
||||||
4. Preview or save based on the save parameter
|
|
||||||
"""
|
"""
|
||||||
agent_id = kwargs.get("agent_id", "").strip()
|
agent_id = kwargs.get("agent_id", "").strip()
|
||||||
changes = kwargs.get("changes", "").strip()
|
changes = kwargs.get("changes", "").strip()
|
||||||
@@ -137,52 +131,30 @@ class EditAgentTool(BaseTool):
|
|||||||
if context:
|
if context:
|
||||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||||
|
|
||||||
# Step 2: Generate patch with retry on validation failure
|
# Step 2: Generate updated agent (external service handles fixing and validation)
|
||||||
blocks_info = get_blocks_info()
|
|
||||||
updated_agent = None
|
|
||||||
validation_errors = None
|
|
||||||
intent = "Applied requested changes"
|
|
||||||
|
|
||||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
|
||||||
# Generate patch (include validation errors from previous attempt)
|
|
||||||
try:
|
try:
|
||||||
if attempt == 0:
|
result = await generate_agent_patch(update_request, current_agent)
|
||||||
patch_result = await generate_agent_patch(
|
except AgentGeneratorNotConfiguredError:
|
||||||
update_request, current_agent
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Retry with validation error feedback
|
|
||||||
logger.info(
|
|
||||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
|
||||||
)
|
|
||||||
retry_request = (
|
|
||||||
f"{update_request}\n\n"
|
|
||||||
f"IMPORTANT: The previous edit had validation errors. "
|
|
||||||
f"Please fix these issues:\n{validation_errors}"
|
|
||||||
)
|
|
||||||
patch_result = await generate_agent_patch(
|
|
||||||
retry_request, current_agent
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
# Handle missing API key or configuration errors
|
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=f"Agent generation is not configured: {str(e)}",
|
message=(
|
||||||
error="configuration_error",
|
"Agent editing is not available. "
|
||||||
|
"The Agent Generator service is not configured."
|
||||||
|
),
|
||||||
|
error="service_not_configured",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if patch_result is None:
|
if result is None:
|
||||||
if attempt == MAX_GENERATION_RETRIES:
|
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to generate changes. Please try rephrasing.",
|
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
|
||||||
error="Patch generation failed",
|
error="update_generation_failed",
|
||||||
|
details={"agent_id": agent_id, "changes": changes[:100]},
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if LLM returned clarifying questions
|
# Check if LLM returned clarifying questions
|
||||||
if patch_result.get("type") == "clarifying_questions":
|
if result.get("type") == "clarifying_questions":
|
||||||
questions = patch_result.get("questions", [])
|
questions = result.get("questions", [])
|
||||||
return ClarificationNeededResponse(
|
return ClarificationNeededResponse(
|
||||||
message=(
|
message=(
|
||||||
"I need some more information about the changes. "
|
"I need some more information about the changes. "
|
||||||
@@ -199,59 +171,19 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 3: Apply patch and fixes
|
# Result is the updated agent JSON
|
||||||
try:
|
updated_agent = result
|
||||||
updated_agent = apply_agent_patch(current_agent, patch_result)
|
|
||||||
updated_agent = apply_all_fixes(updated_agent, blocks_info)
|
|
||||||
except Exception as e:
|
|
||||||
if attempt == MAX_GENERATION_RETRIES:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to apply changes: {str(e)}",
|
|
||||||
error="patch_apply_failed",
|
|
||||||
details={"exception": str(e)},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
validation_errors = str(e)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Step 4: Validate the updated agent
|
|
||||||
is_valid, validation_errors = validate_agent(updated_agent, blocks_info)
|
|
||||||
|
|
||||||
if is_valid:
|
|
||||||
logger.info(f"Agent edited successfully on attempt {attempt + 1}")
|
|
||||||
intent = patch_result.get("intent", "Applied requested changes")
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attempt == MAX_GENERATION_RETRIES:
|
|
||||||
# Return error with validation details
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Updated agent has validation errors after "
|
|
||||||
f"{MAX_GENERATION_RETRIES + 1} attempts. "
|
|
||||||
f"Please try rephrasing your request or simplify the changes."
|
|
||||||
),
|
|
||||||
error="validation_failed",
|
|
||||||
details={"validation_errors": validation_errors},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# At this point, updated_agent is guaranteed to be set (we return on all failure paths)
|
|
||||||
assert updated_agent is not None
|
|
||||||
|
|
||||||
agent_name = updated_agent.get("name", "Updated Agent")
|
agent_name = updated_agent.get("name", "Updated Agent")
|
||||||
agent_description = updated_agent.get("description", "")
|
agent_description = updated_agent.get("description", "")
|
||||||
node_count = len(updated_agent.get("nodes", []))
|
node_count = len(updated_agent.get("nodes", []))
|
||||||
link_count = len(updated_agent.get("links", []))
|
link_count = len(updated_agent.get("links", []))
|
||||||
|
|
||||||
# Step 5: Preview or save
|
# Step 3: Preview or save
|
||||||
if not save:
|
if not save:
|
||||||
return AgentPreviewResponse(
|
return AgentPreviewResponse(
|
||||||
message=(
|
message=(
|
||||||
f"I've updated the agent. Changes: {intent}. "
|
f"I've updated the agent. "
|
||||||
f"The agent now has {node_count} blocks. "
|
f"The agent now has {node_count} blocks. "
|
||||||
f"Review it and call edit_agent with save=true to save the changes."
|
f"Review it and call edit_agent with save=true to save the changes."
|
||||||
),
|
),
|
||||||
@@ -277,10 +209,7 @@ class EditAgentTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return AgentSavedResponse(
|
return AgentSavedResponse(
|
||||||
message=(
|
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
|
||||||
f"Updated agent '{created_graph.name}' has been saved to your library! "
|
|
||||||
f"Changes: {intent}"
|
|
||||||
),
|
|
||||||
agent_id=created_graph.id,
|
agent_id=created_graph.id,
|
||||||
agent_name=created_graph.name,
|
agent_name=created_graph.name,
|
||||||
library_agent_id=library_agent.id,
|
library_agent_id=library_agent.id,
|
||||||
|
|||||||
@@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langfuse import observe
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_search import search_agents
|
from .agent_search import search_agents
|
||||||
@@ -37,7 +35,6 @@ class FindAgentTool(BaseTool):
|
|||||||
"required": ["query"],
|
"required": ["query"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@observe(as_type="tool", name="find_agent")
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self, user_id: str | None, session: ChatSession, **kwargs
|
self, user_id: str | None, session: ChatSession, **kwargs
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langfuse import observe
|
|
||||||
from prisma.enums import ContentType
|
from prisma.enums import ContentType
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
@@ -56,7 +55,6 @@ class FindBlockTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@observe(as_type="tool", name="find_block")
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
@@ -109,7 +107,8 @@ class FindBlockTool(BaseTool):
|
|||||||
block_id = result["content_id"]
|
block_id = result["content_id"]
|
||||||
block = get_block(block_id)
|
block = get_block(block_id)
|
||||||
|
|
||||||
if block:
|
# Skip disabled blocks
|
||||||
|
if block and not block.disabled:
|
||||||
# Get input/output schemas
|
# Get input/output schemas
|
||||||
input_schema = {}
|
input_schema = {}
|
||||||
output_schema = {}
|
output_schema = {}
|
||||||
|
|||||||
@@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langfuse import observe
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_search import search_agents
|
from .agent_search import search_agents
|
||||||
@@ -43,7 +41,6 @@ class FindLibraryAgentTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@observe(as_type="tool", name="find_library_agent")
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self, user_id: str | None, session: ChatSession, **kwargs
|
self, user_id: str | None, session: ChatSession, **kwargs
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langfuse import observe
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
from backend.api.features.chat.tools.models import (
|
from backend.api.features.chat.tools.models import (
|
||||||
@@ -73,7 +71,6 @@ class GetDocPageTool(BaseTool):
|
|||||||
url_path = path.rsplit(".", 1)[0] if "." in path else path
|
url_path = path.rsplit(".", 1)[0] if "." in path else path
|
||||||
return f"{DOCS_BASE_URL}/{url_path}"
|
return f"{DOCS_BASE_URL}/{url_path}"
|
||||||
|
|
||||||
@observe(as_type="tool", name="get_doc_page")
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
|
|||||||
@@ -28,6 +28,16 @@ class ResponseType(str, Enum):
|
|||||||
BLOCK_OUTPUT = "block_output"
|
BLOCK_OUTPUT = "block_output"
|
||||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||||
DOC_PAGE = "doc_page"
|
DOC_PAGE = "doc_page"
|
||||||
|
# Workspace response types
|
||||||
|
WORKSPACE_FILE_LIST = "workspace_file_list"
|
||||||
|
WORKSPACE_FILE_CONTENT = "workspace_file_content"
|
||||||
|
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
|
||||||
|
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
|
||||||
|
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
|
||||||
|
# Long-running operation types
|
||||||
|
OPERATION_STARTED = "operation_started"
|
||||||
|
OPERATION_PENDING = "operation_pending"
|
||||||
|
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||||
|
|
||||||
|
|
||||||
# Base response model
|
# Base response model
|
||||||
@@ -334,3 +344,39 @@ class BlockOutputResponse(ToolResponseBase):
|
|||||||
block_name: str
|
block_name: str
|
||||||
outputs: dict[str, list[Any]]
|
outputs: dict[str, list[Any]]
|
||||||
success: bool = True
|
success: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
# Long-running operation models
|
||||||
|
class OperationStartedResponse(ToolResponseBase):
|
||||||
|
"""Response when a long-running operation has been started in the background.
|
||||||
|
|
||||||
|
This is returned immediately to the client while the operation continues
|
||||||
|
to execute. The user can close the tab and check back later.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||||
|
operation_id: str
|
||||||
|
tool_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class OperationPendingResponse(ToolResponseBase):
|
||||||
|
"""Response stored in chat history while a long-running operation is executing.
|
||||||
|
|
||||||
|
This is persisted to the database so users see a pending state when they
|
||||||
|
refresh before the operation completes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_PENDING
|
||||||
|
operation_id: str
|
||||||
|
tool_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class OperationInProgressResponse(ToolResponseBase):
|
||||||
|
"""Response when an operation is already in progress.
|
||||||
|
|
||||||
|
Returned for idempotency when the same tool_call_id is requested again
|
||||||
|
while the background task is still running.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||||
|
tool_call_id: str
|
||||||
|
|||||||
@@ -3,11 +3,14 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langfuse import observe
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from backend.api.features.chat.config import ChatConfig
|
from backend.api.features.chat.config import ChatConfig
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tracking import (
|
||||||
|
track_agent_run_success,
|
||||||
|
track_agent_scheduled,
|
||||||
|
)
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
@@ -155,7 +158,6 @@ class RunAgentTool(BaseTool):
|
|||||||
"""All operations require authentication."""
|
"""All operations require authentication."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@observe(as_type="tool", name="run_agent")
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
@@ -453,6 +455,16 @@ class RunAgentTool(BaseTool):
|
|||||||
session.successful_agent_runs.get(library_agent.graph_id, 0) + 1
|
session.successful_agent_runs.get(library_agent.graph_id, 0) + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Track in PostHog
|
||||||
|
track_agent_run_success(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
graph_id=library_agent.graph_id,
|
||||||
|
graph_name=library_agent.name,
|
||||||
|
execution_id=execution.id,
|
||||||
|
library_agent_id=library_agent.id,
|
||||||
|
)
|
||||||
|
|
||||||
library_agent_link = f"/library/agents/{library_agent.id}"
|
library_agent_link = f"/library/agents/{library_agent.id}"
|
||||||
return ExecutionStartedResponse(
|
return ExecutionStartedResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -534,6 +546,18 @@ class RunAgentTool(BaseTool):
|
|||||||
session.successful_agent_schedules.get(library_agent.graph_id, 0) + 1
|
session.successful_agent_schedules.get(library_agent.graph_id, 0) + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Track in PostHog
|
||||||
|
track_agent_scheduled(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
graph_id=library_agent.graph_id,
|
||||||
|
graph_name=library_agent.name,
|
||||||
|
schedule_id=result.id,
|
||||||
|
schedule_name=schedule_name,
|
||||||
|
cron=cron,
|
||||||
|
library_agent_id=library_agent.id,
|
||||||
|
)
|
||||||
|
|
||||||
library_agent_link = f"/library/agents/{library_agent.id}"
|
library_agent_link = f"/library/agents/{library_agent.id}"
|
||||||
return ExecutionStartedResponse(
|
return ExecutionStartedResponse(
|
||||||
message=(
|
message=(
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ def mock_embedding_functions():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent(setup_test_data):
|
async def test_run_agent(setup_test_data):
|
||||||
"""Test that the run_agent tool successfully executes an approved agent"""
|
"""Test that the run_agent tool successfully executes an approved agent"""
|
||||||
# Use test data from fixture
|
# Use test data from fixture
|
||||||
@@ -70,7 +70,7 @@ async def test_run_agent(setup_test_data):
|
|||||||
assert result_data["graph_name"] == "Test Agent"
|
assert result_data["graph_name"] == "Test Agent"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_missing_inputs(setup_test_data):
|
async def test_run_agent_missing_inputs(setup_test_data):
|
||||||
"""Test that the run_agent tool returns error when inputs are missing"""
|
"""Test that the run_agent tool returns error when inputs are missing"""
|
||||||
# Use test data from fixture
|
# Use test data from fixture
|
||||||
@@ -106,7 +106,7 @@ async def test_run_agent_missing_inputs(setup_test_data):
|
|||||||
assert "message" in result_data
|
assert "message" in result_data
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_invalid_agent_id(setup_test_data):
|
async def test_run_agent_invalid_agent_id(setup_test_data):
|
||||||
"""Test that the run_agent tool returns error for invalid agent ID"""
|
"""Test that the run_agent tool returns error for invalid agent ID"""
|
||||||
# Use test data from fixture
|
# Use test data from fixture
|
||||||
@@ -141,7 +141,7 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
||||||
"""Test that run_agent works with an agent requiring LLM credentials"""
|
"""Test that run_agent works with an agent requiring LLM credentials"""
|
||||||
# Use test data from fixture
|
# Use test data from fixture
|
||||||
@@ -185,7 +185,7 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
|||||||
assert result_data["graph_name"] == "LLM Test Agent"
|
assert result_data["graph_name"] == "LLM Test Agent"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_data):
|
async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_data):
|
||||||
"""Test that run_agent returns available inputs when called without inputs or use_defaults."""
|
"""Test that run_agent returns available inputs when called without inputs or use_defaults."""
|
||||||
user = setup_test_data["user"]
|
user = setup_test_data["user"]
|
||||||
@@ -219,7 +219,7 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da
|
|||||||
assert "inputs" in result_data["message"].lower()
|
assert "inputs" in result_data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_with_use_defaults(setup_test_data):
|
async def test_run_agent_with_use_defaults(setup_test_data):
|
||||||
"""Test that run_agent executes successfully with use_defaults=True."""
|
"""Test that run_agent executes successfully with use_defaults=True."""
|
||||||
user = setup_test_data["user"]
|
user = setup_test_data["user"]
|
||||||
@@ -251,7 +251,7 @@ async def test_run_agent_with_use_defaults(setup_test_data):
|
|||||||
assert result_data["graph_id"] == graph.id
|
assert result_data["graph_id"] == graph.id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
|
async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
|
||||||
"""Test that run_agent returns setup_requirements when credentials are missing."""
|
"""Test that run_agent returns setup_requirements when credentials are missing."""
|
||||||
user = setup_firecrawl_test_data["user"]
|
user = setup_firecrawl_test_data["user"]
|
||||||
@@ -285,7 +285,7 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
|
|||||||
assert len(setup_info["user_readiness"]["missing_credentials"]) > 0
|
assert len(setup_info["user_readiness"]["missing_credentials"]) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_invalid_slug_format(setup_test_data):
|
async def test_run_agent_invalid_slug_format(setup_test_data):
|
||||||
"""Test that run_agent returns error for invalid slug format (no slash)."""
|
"""Test that run_agent returns error for invalid slug format (no slash)."""
|
||||||
user = setup_test_data["user"]
|
user = setup_test_data["user"]
|
||||||
@@ -313,7 +313,7 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
|
|||||||
assert "username/agent-name" in result_data["message"]
|
assert "username/agent-name" in result_data["message"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_unauthenticated():
|
async def test_run_agent_unauthenticated():
|
||||||
"""Test that run_agent returns need_login for unauthenticated users."""
|
"""Test that run_agent returns need_login for unauthenticated users."""
|
||||||
tool = RunAgentTool()
|
tool = RunAgentTool()
|
||||||
@@ -340,7 +340,7 @@ async def test_run_agent_unauthenticated():
|
|||||||
assert "sign in" in result_data["message"].lower()
|
assert "sign in" in result_data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_schedule_without_cron(setup_test_data):
|
async def test_run_agent_schedule_without_cron(setup_test_data):
|
||||||
"""Test that run_agent returns error when scheduling without cron expression."""
|
"""Test that run_agent returns error when scheduling without cron expression."""
|
||||||
user = setup_test_data["user"]
|
user = setup_test_data["user"]
|
||||||
@@ -372,7 +372,7 @@ async def test_run_agent_schedule_without_cron(setup_test_data):
|
|||||||
assert "cron" in result_data["message"].lower()
|
assert "cron" in result_data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_run_agent_schedule_without_name(setup_test_data):
|
async def test_run_agent_schedule_without_name(setup_test_data):
|
||||||
"""Test that run_agent returns error when scheduling without schedule_name."""
|
"""Test that run_agent returns error when scheduling without schedule_name."""
|
||||||
user = setup_test_data["user"]
|
user = setup_test_data["user"]
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
"""Tool for executing blocks directly."""
|
"""Tool for executing blocks directly."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langfuse import observe
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.block import get_block
|
from backend.data.block import get_block
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
@@ -130,7 +130,6 @@ class RunBlockTool(BaseTool):
|
|||||||
|
|
||||||
return matched_credentials, missing_credentials
|
return matched_credentials, missing_credentials
|
||||||
|
|
||||||
@observe(as_type="tool", name="run_block")
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
@@ -179,6 +178,11 @@ class RunBlockTool(BaseTool):
|
|||||||
message=f"Block '{block_id}' not found",
|
message=f"Block '{block_id}' not found",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
if block.disabled:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Block '{block_id}' is disabled",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||||
|
|
||||||
@@ -221,11 +225,48 @@ class RunBlockTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Fetch actual credentials and prepare kwargs for block execution
|
# Get or create user's workspace for CoPilot file operations
|
||||||
# Create execution context with defaults (blocks may require it)
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
|
||||||
|
# Generate synthetic IDs for CoPilot context
|
||||||
|
# Each chat session is treated as its own agent with one continuous run
|
||||||
|
# This means:
|
||||||
|
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
|
||||||
|
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
|
||||||
|
# - node_exec_id = unique per block execution
|
||||||
|
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
||||||
|
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
||||||
|
synthetic_node_id = f"copilot-node-{block_id}"
|
||||||
|
synthetic_node_exec_id = (
|
||||||
|
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create unified execution context with all required fields
|
||||||
|
execution_context = ExecutionContext(
|
||||||
|
# Execution identity
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=synthetic_graph_id,
|
||||||
|
graph_exec_id=synthetic_graph_exec_id,
|
||||||
|
graph_version=1, # Versions are 1-indexed
|
||||||
|
node_id=synthetic_node_id,
|
||||||
|
node_exec_id=synthetic_node_exec_id,
|
||||||
|
# Workspace with session scoping
|
||||||
|
workspace_id=workspace.id,
|
||||||
|
session_id=session.session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare kwargs for block execution
|
||||||
|
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||||
exec_kwargs: dict[str, Any] = {
|
exec_kwargs: dict[str, Any] = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"execution_context": ExecutionContext(),
|
"execution_context": execution_context,
|
||||||
|
# Legacy: individual kwargs for blocks not yet using execution_context
|
||||||
|
"workspace_id": workspace.id,
|
||||||
|
"graph_exec_id": synthetic_graph_exec_id,
|
||||||
|
"node_exec_id": synthetic_node_exec_id,
|
||||||
|
"node_id": synthetic_node_id,
|
||||||
|
"graph_version": 1, # Versions are 1-indexed
|
||||||
|
"graph_id": synthetic_graph_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
for field_name, cred_meta in matched_credentials.items():
|
for field_name, cred_meta in matched_credentials.items():
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langfuse import observe
|
|
||||||
from prisma.enums import ContentType
|
from prisma.enums import ContentType
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
@@ -88,7 +87,6 @@ class SearchDocsTool(BaseTool):
|
|||||||
url_path = path.rsplit(".", 1)[0] if "." in path else path
|
url_path = path.rsplit(".", 1)[0] if "." in path else path
|
||||||
return f"{DOCS_BASE_URL}/{url_path}"
|
return f"{DOCS_BASE_URL}/{url_path}"
|
||||||
|
|
||||||
@observe(as_type="tool", name="search_docs")
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
|
|||||||
@@ -0,0 +1,625 @@
|
|||||||
|
"""CoPilot tools for workspace file operations."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from prisma.enums import WorkspaceFileSource
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.data.workspace import get_or_create_workspace
|
||||||
|
from backend.util.settings import Config
|
||||||
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
from backend.util.workspace import WorkspaceManager
|
||||||
|
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileInfoData(BaseModel):
|
||||||
|
"""Data model for workspace file information (not a response itself)."""
|
||||||
|
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
size_bytes: int
|
||||||
|
source: str
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileListResponse(ToolResponseBase):
|
||||||
|
"""Response containing list of workspace files."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
|
||||||
|
files: list[WorkspaceFileInfoData]
|
||||||
|
total_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileContentResponse(ToolResponseBase):
|
||||||
|
"""Response containing workspace file content (legacy, for small text files)."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
content_base64: str
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
||||||
|
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
size_bytes: int
|
||||||
|
download_url: str
|
||||||
|
preview: str | None = None # First 500 chars for text files
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceWriteResponse(ToolResponseBase):
|
||||||
|
"""Response after writing a file to workspace."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
size_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceDeleteResponse(ToolResponseBase):
|
||||||
|
"""Response after deleting a file from workspace."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
|
||||||
|
file_id: str
|
||||||
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
|
class ListWorkspaceFilesTool(BaseTool):
|
||||||
|
"""Tool for listing files in user's workspace."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "list_workspace_files"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"List files in the user's workspace. "
|
||||||
|
"Returns file names, paths, sizes, and metadata. "
|
||||||
|
"Optionally filter by path prefix."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path_prefix": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional path prefix to filter files "
|
||||||
|
"(e.g., '/documents/' to list only files in documents folder). "
|
||||||
|
"By default, only files from the current session are listed."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of files to return (default 50, max 100)",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 100,
|
||||||
|
},
|
||||||
|
"include_all_sessions": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"If true, list files from all sessions. "
|
||||||
|
"Default is false (only current session's files)."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||||
|
limit = min(kwargs.get("limit", 50), 100)
|
||||||
|
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
files = await manager.list_files(
|
||||||
|
path=path_prefix,
|
||||||
|
limit=limit,
|
||||||
|
include_all_sessions=include_all_sessions,
|
||||||
|
)
|
||||||
|
total = await manager.get_file_count(
|
||||||
|
path=path_prefix,
|
||||||
|
include_all_sessions=include_all_sessions,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_infos = [
|
||||||
|
WorkspaceFileInfoData(
|
||||||
|
file_id=f.id,
|
||||||
|
name=f.name,
|
||||||
|
path=f.path,
|
||||||
|
mime_type=f.mimeType,
|
||||||
|
size_bytes=f.sizeBytes,
|
||||||
|
source=f.source,
|
||||||
|
)
|
||||||
|
for f in files
|
||||||
|
]
|
||||||
|
|
||||||
|
scope_msg = "all sessions" if include_all_sessions else "current session"
|
||||||
|
return WorkspaceFileListResponse(
|
||||||
|
files=file_infos,
|
||||||
|
total_count=total,
|
||||||
|
message=f"Found {len(files)} files in workspace ({scope_msg})",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing workspace files: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to list workspace files: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadWorkspaceFileTool(BaseTool):
|
||||||
|
"""Tool for reading file content from workspace."""
|
||||||
|
|
||||||
|
# Size threshold for returning full content vs metadata+URL
|
||||||
|
# Files larger than this return metadata with download URL to prevent context bloat
|
||||||
|
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||||
|
# Preview size for text files
|
||||||
|
PREVIEW_SIZE = 500
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "read_workspace_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Read a file from the user's workspace. "
|
||||||
|
"Specify either file_id or path to identify the file. "
|
||||||
|
"For small text files, returns content directly. "
|
||||||
|
"For large or binary files, returns metadata and a download URL. "
|
||||||
|
"Paths are scoped to the current session by default. "
|
||||||
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file's unique ID (from list_workspace_files)",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||||
|
"Scoped to current session by default."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"force_download_url": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"If true, always return metadata+URL instead of inline content. "
|
||||||
|
"Default is false (auto-selects based on file size/type)."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [], # At least one must be provided
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _is_text_mime_type(self, mime_type: str) -> bool:
|
||||||
|
"""Check if the MIME type is a text-based type."""
|
||||||
|
text_types = [
|
||||||
|
"text/",
|
||||||
|
"application/json",
|
||||||
|
"application/xml",
|
||||||
|
"application/javascript",
|
||||||
|
"application/x-python",
|
||||||
|
"application/x-sh",
|
||||||
|
]
|
||||||
|
return any(mime_type.startswith(t) for t in text_types)
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id: Optional[str] = kwargs.get("file_id")
|
||||||
|
path: Optional[str] = kwargs.get("path")
|
||||||
|
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||||
|
|
||||||
|
if not file_id and not path:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide either file_id or path",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
# Get file info
|
||||||
|
if file_id:
|
||||||
|
file_info = await manager.get_file_info(file_id)
|
||||||
|
if file_info is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found: {file_id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
target_file_id = file_id
|
||||||
|
else:
|
||||||
|
# path is guaranteed to be non-None here due to the check above
|
||||||
|
assert path is not None
|
||||||
|
file_info = await manager.get_file_info_by_path(path)
|
||||||
|
if file_info is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found at path: {path}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
target_file_id = file_info.id
|
||||||
|
|
||||||
|
# Decide whether to return inline content or metadata+URL
|
||||||
|
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
|
||||||
|
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
||||||
|
|
||||||
|
# Return inline content for small text files (unless force_download_url)
|
||||||
|
if is_small_file and is_text_file and not force_download_url:
|
||||||
|
content = await manager.read_file_by_id(target_file_id)
|
||||||
|
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||||
|
|
||||||
|
return WorkspaceFileContentResponse(
|
||||||
|
file_id=file_info.id,
|
||||||
|
name=file_info.name,
|
||||||
|
path=file_info.path,
|
||||||
|
mime_type=file_info.mimeType,
|
||||||
|
content_base64=content_b64,
|
||||||
|
message=f"Successfully read file: {file_info.name}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return metadata + workspace:// reference for large or binary files
|
||||||
|
# This prevents context bloat (100KB file = ~133KB as base64)
|
||||||
|
# Use workspace:// format so frontend urlTransform can add proxy prefix
|
||||||
|
download_url = f"workspace://{target_file_id}"
|
||||||
|
|
||||||
|
# Generate preview for text files
|
||||||
|
preview: str | None = None
|
||||||
|
if is_text_file:
|
||||||
|
try:
|
||||||
|
content = await manager.read_file_by_id(target_file_id)
|
||||||
|
preview_text = content[: self.PREVIEW_SIZE].decode(
|
||||||
|
"utf-8", errors="replace"
|
||||||
|
)
|
||||||
|
if len(content) > self.PREVIEW_SIZE:
|
||||||
|
preview_text += "..."
|
||||||
|
preview = preview_text
|
||||||
|
except Exception:
|
||||||
|
pass # Preview is optional
|
||||||
|
|
||||||
|
return WorkspaceFileMetadataResponse(
|
||||||
|
file_id=file_info.id,
|
||||||
|
name=file_info.name,
|
||||||
|
path=file_info.path,
|
||||||
|
mime_type=file_info.mimeType,
|
||||||
|
size_bytes=file_info.sizeBytes,
|
||||||
|
download_url=download_url,
|
||||||
|
preview=preview,
|
||||||
|
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading workspace file: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to read workspace file: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteWorkspaceFileTool(BaseTool):
|
||||||
|
"""Tool for writing files to workspace."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "write_workspace_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Write or create a file in the user's workspace. "
|
||||||
|
"Provide the content as a base64-encoded string. "
|
||||||
|
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||||
|
"Files are saved to the current session's folder by default. "
|
||||||
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"filename": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Name for the file (e.g., 'report.pdf')",
|
||||||
|
},
|
||||||
|
"content_base64": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Base64-encoded file content",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional virtual path where to save the file "
|
||||||
|
"(e.g., '/documents/report.pdf'). "
|
||||||
|
"Defaults to '/{filename}'. Scoped to current session."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"mime_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional MIME type of the file. "
|
||||||
|
"Auto-detected from filename if not provided."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"overwrite": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["filename", "content_base64"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
filename: str = kwargs.get("filename", "")
|
||||||
|
content_b64: str = kwargs.get("content_base64", "")
|
||||||
|
path: Optional[str] = kwargs.get("path")
|
||||||
|
mime_type: Optional[str] = kwargs.get("mime_type")
|
||||||
|
overwrite: bool = kwargs.get("overwrite", False)
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide a filename",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not content_b64:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide content_base64",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode content
|
||||||
|
try:
|
||||||
|
content = base64.b64decode(content_b64)
|
||||||
|
except Exception:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Invalid base64-encoded content",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check size
|
||||||
|
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||||
|
if len(content) > max_file_size:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Virus scan
|
||||||
|
await scan_content_safe(content, filename=filename)
|
||||||
|
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
file_record = await manager.write_file(
|
||||||
|
content=content,
|
||||||
|
filename=filename,
|
||||||
|
path=path,
|
||||||
|
mime_type=mime_type,
|
||||||
|
source=WorkspaceFileSource.COPILOT,
|
||||||
|
source_session_id=session.session_id,
|
||||||
|
overwrite=overwrite,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkspaceWriteResponse(
|
||||||
|
file_id=file_record.id,
|
||||||
|
name=file_record.name,
|
||||||
|
path=file_record.path,
|
||||||
|
size_bytes=file_record.sizeBytes,
|
||||||
|
message=f"Successfully wrote file: {file_record.name}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error writing workspace file: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to write workspace file: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteWorkspaceFileTool(BaseTool):
|
||||||
|
"""Tool for deleting files from workspace."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "delete_workspace_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Delete a file from the user's workspace. "
|
||||||
|
"Specify either file_id or path to identify the file. "
|
||||||
|
"Paths are scoped to the current session by default. "
|
||||||
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file's unique ID (from list_workspace_files)",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||||
|
"Scoped to current session by default."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [], # At least one must be provided
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id: Optional[str] = kwargs.get("file_id")
|
||||||
|
path: Optional[str] = kwargs.get("path")
|
||||||
|
|
||||||
|
if not file_id and not path:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide either file_id or path",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
# Determine the file_id to delete
|
||||||
|
target_file_id: str
|
||||||
|
if file_id:
|
||||||
|
target_file_id = file_id
|
||||||
|
else:
|
||||||
|
# path is guaranteed to be non-None here due to the check above
|
||||||
|
assert path is not None
|
||||||
|
file_info = await manager.get_file_info_by_path(path)
|
||||||
|
if file_info is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found at path: {path}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
target_file_id = file_info.id
|
||||||
|
|
||||||
|
success = await manager.delete_file(target_file_id)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found: {target_file_id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkspaceDeleteResponse(
|
||||||
|
file_id=target_file_id,
|
||||||
|
success=True,
|
||||||
|
message="File deleted successfully",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to delete workspace file: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
250
autogpt_platform/backend/backend/api/features/chat/tracking.py
Normal file
250
autogpt_platform/backend/backend/api/features/chat/tracking.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
"""PostHog analytics tracking for the chat system."""
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from posthog import Posthog
|
||||||
|
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
# PostHog client instance (lazily initialized)
|
||||||
|
_posthog_client: Posthog | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _shutdown_posthog() -> None:
|
||||||
|
"""Flush and shutdown PostHog client on process exit."""
|
||||||
|
if _posthog_client is not None:
|
||||||
|
_posthog_client.flush()
|
||||||
|
_posthog_client.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
atexit.register(_shutdown_posthog)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_posthog_client() -> Posthog | None:
|
||||||
|
"""Get or create the PostHog client instance."""
|
||||||
|
global _posthog_client
|
||||||
|
if _posthog_client is not None:
|
||||||
|
return _posthog_client
|
||||||
|
|
||||||
|
if not settings.secrets.posthog_api_key:
|
||||||
|
logger.debug("PostHog API key not configured, analytics disabled")
|
||||||
|
return None
|
||||||
|
|
||||||
|
_posthog_client = Posthog(
|
||||||
|
settings.secrets.posthog_api_key,
|
||||||
|
host=settings.secrets.posthog_host,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"PostHog client initialized with host: {settings.secrets.posthog_host}"
|
||||||
|
)
|
||||||
|
return _posthog_client
|
||||||
|
|
||||||
|
|
||||||
|
def _get_base_properties() -> dict[str, Any]:
|
||||||
|
"""Get base properties included in all events."""
|
||||||
|
return {
|
||||||
|
"environment": settings.config.app_env.value,
|
||||||
|
"source": "chat_copilot",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def track_user_message(
|
||||||
|
user_id: str | None,
|
||||||
|
session_id: str,
|
||||||
|
message_length: int,
|
||||||
|
) -> None:
|
||||||
|
"""Track when a user sends a message in chat.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID (or None for anonymous)
|
||||||
|
session_id: The chat session ID
|
||||||
|
message_length: Length of the user's message
|
||||||
|
"""
|
||||||
|
client = _get_posthog_client()
|
||||||
|
if not client:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
properties = {
|
||||||
|
**_get_base_properties(),
|
||||||
|
"session_id": session_id,
|
||||||
|
"message_length": message_length,
|
||||||
|
}
|
||||||
|
client.capture(
|
||||||
|
distinct_id=user_id or f"anonymous_{session_id}",
|
||||||
|
event="copilot_message_sent",
|
||||||
|
properties=properties,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to track user message: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def track_tool_called(
|
||||||
|
user_id: str | None,
|
||||||
|
session_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
tool_call_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Track when a tool is called in chat.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID (or None for anonymous)
|
||||||
|
session_id: The chat session ID
|
||||||
|
tool_name: Name of the tool being called
|
||||||
|
tool_call_id: Unique ID of the tool call
|
||||||
|
"""
|
||||||
|
client = _get_posthog_client()
|
||||||
|
if not client:
|
||||||
|
logger.info("PostHog client not available for tool tracking")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
properties = {
|
||||||
|
**_get_base_properties(),
|
||||||
|
"session_id": session_id,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
}
|
||||||
|
distinct_id = user_id or f"anonymous_{session_id}"
|
||||||
|
logger.info(
|
||||||
|
f"Sending copilot_tool_called event to PostHog: distinct_id={distinct_id}, "
|
||||||
|
f"tool_name={tool_name}"
|
||||||
|
)
|
||||||
|
client.capture(
|
||||||
|
distinct_id=distinct_id,
|
||||||
|
event="copilot_tool_called",
|
||||||
|
properties=properties,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to track tool call: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def track_agent_run_success(
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
graph_name: str,
|
||||||
|
execution_id: str,
|
||||||
|
library_agent_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Track when an agent is successfully run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
session_id: The chat session ID
|
||||||
|
graph_id: ID of the agent graph
|
||||||
|
graph_name: Name of the agent
|
||||||
|
execution_id: ID of the execution
|
||||||
|
library_agent_id: ID of the library agent
|
||||||
|
"""
|
||||||
|
client = _get_posthog_client()
|
||||||
|
if not client:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
properties = {
|
||||||
|
**_get_base_properties(),
|
||||||
|
"session_id": session_id,
|
||||||
|
"graph_id": graph_id,
|
||||||
|
"graph_name": graph_name,
|
||||||
|
"execution_id": execution_id,
|
||||||
|
"library_agent_id": library_agent_id,
|
||||||
|
}
|
||||||
|
client.capture(
|
||||||
|
distinct_id=user_id,
|
||||||
|
event="copilot_agent_run_success",
|
||||||
|
properties=properties,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to track agent run: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def track_agent_scheduled(
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
graph_name: str,
|
||||||
|
schedule_id: str,
|
||||||
|
schedule_name: str,
|
||||||
|
cron: str,
|
||||||
|
library_agent_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Track when an agent is successfully scheduled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
session_id: The chat session ID
|
||||||
|
graph_id: ID of the agent graph
|
||||||
|
graph_name: Name of the agent
|
||||||
|
schedule_id: ID of the schedule
|
||||||
|
schedule_name: Name of the schedule
|
||||||
|
cron: Cron expression for the schedule
|
||||||
|
library_agent_id: ID of the library agent
|
||||||
|
"""
|
||||||
|
client = _get_posthog_client()
|
||||||
|
if not client:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
properties = {
|
||||||
|
**_get_base_properties(),
|
||||||
|
"session_id": session_id,
|
||||||
|
"graph_id": graph_id,
|
||||||
|
"graph_name": graph_name,
|
||||||
|
"schedule_id": schedule_id,
|
||||||
|
"schedule_name": schedule_name,
|
||||||
|
"cron": cron,
|
||||||
|
"library_agent_id": library_agent_id,
|
||||||
|
}
|
||||||
|
client.capture(
|
||||||
|
distinct_id=user_id,
|
||||||
|
event="copilot_agent_scheduled",
|
||||||
|
properties=properties,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to track agent schedule: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def track_trigger_setup(
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
graph_name: str,
|
||||||
|
trigger_type: str,
|
||||||
|
library_agent_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Track when a trigger is set up for an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
session_id: The chat session ID
|
||||||
|
graph_id: ID of the agent graph
|
||||||
|
graph_name: Name of the agent
|
||||||
|
trigger_type: Type of trigger (e.g., 'webhook')
|
||||||
|
library_agent_id: ID of the library agent
|
||||||
|
"""
|
||||||
|
client = _get_posthog_client()
|
||||||
|
if not client:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
properties = {
|
||||||
|
**_get_base_properties(),
|
||||||
|
"session_id": session_id,
|
||||||
|
"graph_id": graph_id,
|
||||||
|
"graph_name": graph_name,
|
||||||
|
"trigger_type": trigger_type,
|
||||||
|
"library_agent_id": library_agent_id,
|
||||||
|
}
|
||||||
|
client.capture(
|
||||||
|
distinct_id=user_id,
|
||||||
|
event="copilot_trigger_setup",
|
||||||
|
properties=properties,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to track trigger setup: {e}")
|
||||||
@@ -23,6 +23,7 @@ class PendingHumanReviewModel(BaseModel):
|
|||||||
id: Unique identifier for the review record
|
id: Unique identifier for the review record
|
||||||
user_id: ID of the user who must perform the review
|
user_id: ID of the user who must perform the review
|
||||||
node_exec_id: ID of the node execution that created this review
|
node_exec_id: ID of the node execution that created this review
|
||||||
|
node_id: ID of the node definition (for grouping reviews from same node)
|
||||||
graph_exec_id: ID of the graph execution containing the node
|
graph_exec_id: ID of the graph execution containing the node
|
||||||
graph_id: ID of the graph template being executed
|
graph_id: ID of the graph template being executed
|
||||||
graph_version: Version number of the graph template
|
graph_version: Version number of the graph template
|
||||||
@@ -37,6 +38,10 @@ class PendingHumanReviewModel(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
node_exec_id: str = Field(description="Node execution ID (primary key)")
|
node_exec_id: str = Field(description="Node execution ID (primary key)")
|
||||||
|
node_id: str = Field(
|
||||||
|
description="Node definition ID (for grouping)",
|
||||||
|
default="", # Temporary default for test compatibility
|
||||||
|
)
|
||||||
user_id: str = Field(description="User ID associated with the review")
|
user_id: str = Field(description="User ID associated with the review")
|
||||||
graph_exec_id: str = Field(description="Graph execution ID")
|
graph_exec_id: str = Field(description="Graph execution ID")
|
||||||
graph_id: str = Field(description="Graph ID")
|
graph_id: str = Field(description="Graph ID")
|
||||||
@@ -66,7 +71,9 @@ class PendingHumanReviewModel(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db(cls, review: "PendingHumanReview") -> "PendingHumanReviewModel":
|
def from_db(
|
||||||
|
cls, review: "PendingHumanReview", node_id: str
|
||||||
|
) -> "PendingHumanReviewModel":
|
||||||
"""
|
"""
|
||||||
Convert a database model to a response model.
|
Convert a database model to a response model.
|
||||||
|
|
||||||
@@ -74,9 +81,14 @@ class PendingHumanReviewModel(BaseModel):
|
|||||||
payload, instructions, and editable flag.
|
payload, instructions, and editable flag.
|
||||||
|
|
||||||
Handles invalid data gracefully by using safe defaults.
|
Handles invalid data gracefully by using safe defaults.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
review: Database review object
|
||||||
|
node_id: Node definition ID (fetched from NodeExecution)
|
||||||
"""
|
"""
|
||||||
return cls(
|
return cls(
|
||||||
node_exec_id=review.nodeExecId,
|
node_exec_id=review.nodeExecId,
|
||||||
|
node_id=node_id,
|
||||||
user_id=review.userId,
|
user_id=review.userId,
|
||||||
graph_exec_id=review.graphExecId,
|
graph_exec_id=review.graphExecId,
|
||||||
graph_id=review.graphId,
|
graph_id=review.graphId,
|
||||||
@@ -107,6 +119,13 @@ class ReviewItem(BaseModel):
|
|||||||
reviewed_data: SafeJsonData | None = Field(
|
reviewed_data: SafeJsonData | None = Field(
|
||||||
None, description="Optional edited data (ignored if approved=False)"
|
None, description="Optional edited data (ignored if approved=False)"
|
||||||
)
|
)
|
||||||
|
auto_approve_future: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"If true and this review is approved, future executions of this same "
|
||||||
|
"block (node) will be automatically approved. This only affects approved reviews."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("reviewed_data")
|
@field_validator("reviewed_data")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -174,6 +193,9 @@ class ReviewRequest(BaseModel):
|
|||||||
This request must include ALL pending reviews for a graph execution.
|
This request must include ALL pending reviews for a graph execution.
|
||||||
Each review will be either approved (with optional data modifications)
|
Each review will be either approved (with optional data modifications)
|
||||||
or rejected (data ignored). The execution will resume only after ALL reviews are processed.
|
or rejected (data ignored). The execution will resume only after ALL reviews are processed.
|
||||||
|
|
||||||
|
Each review item can individually specify whether to auto-approve future executions
|
||||||
|
of the same block via the `auto_approve_future` field on ReviewItem.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
reviews: List[ReviewItem] = Field(
|
reviews: List[ReviewItem] = Field(
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,17 +1,27 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import Any, List
|
||||||
|
|
||||||
import autogpt_libs.auth as autogpt_auth_lib
|
import autogpt_libs.auth as autogpt_auth_lib
|
||||||
from fastapi import APIRouter, HTTPException, Query, Security, status
|
from fastapi import APIRouter, HTTPException, Query, Security, status
|
||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
|
|
||||||
from backend.data.execution import get_graph_execution_meta
|
from backend.data.execution import (
|
||||||
|
ExecutionContext,
|
||||||
|
ExecutionStatus,
|
||||||
|
get_graph_execution_meta,
|
||||||
|
)
|
||||||
|
from backend.data.graph import get_graph_settings
|
||||||
from backend.data.human_review import (
|
from backend.data.human_review import (
|
||||||
|
create_auto_approval_record,
|
||||||
get_pending_reviews_for_execution,
|
get_pending_reviews_for_execution,
|
||||||
get_pending_reviews_for_user,
|
get_pending_reviews_for_user,
|
||||||
|
get_reviews_by_node_exec_ids,
|
||||||
has_pending_reviews_for_graph_exec,
|
has_pending_reviews_for_graph_exec,
|
||||||
process_all_reviews_for_execution,
|
process_all_reviews_for_execution,
|
||||||
)
|
)
|
||||||
|
from backend.data.model import USER_TIMEZONE_NOT_SET
|
||||||
|
from backend.data.user import get_user_by_id
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
|
|
||||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||||
@@ -127,17 +137,70 @@ async def process_review_action(
|
|||||||
detail="At least one review must be provided",
|
detail="At least one review must be provided",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build review decisions map
|
# Batch fetch all requested reviews (regardless of status for idempotent handling)
|
||||||
|
reviews_map = await get_reviews_by_node_exec_ids(
|
||||||
|
list(all_request_node_ids), user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate all reviews were found (must exist, any status is OK for now)
|
||||||
|
missing_ids = all_request_node_ids - set(reviews_map.keys())
|
||||||
|
if missing_ids:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Review(s) not found: {', '.join(missing_ids)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate all reviews belong to the same execution
|
||||||
|
graph_exec_ids = {review.graph_exec_id for review in reviews_map.values()}
|
||||||
|
if len(graph_exec_ids) > 1:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="All reviews in a single request must belong to the same execution.",
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_exec_id = next(iter(graph_exec_ids))
|
||||||
|
|
||||||
|
# Validate execution status before processing reviews
|
||||||
|
graph_exec_meta = await get_graph_execution_meta(
|
||||||
|
user_id=user_id, execution_id=graph_exec_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not graph_exec_meta:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Graph execution #{graph_exec_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only allow processing reviews if execution is paused for review
|
||||||
|
# or incomplete (partial execution with some reviews already processed)
|
||||||
|
if graph_exec_meta.status not in (
|
||||||
|
ExecutionStatus.REVIEW,
|
||||||
|
ExecutionStatus.INCOMPLETE,
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
|
||||||
|
f"Reviews can only be processed when execution is paused (REVIEW status). "
|
||||||
|
f"Current status: {graph_exec_meta.status}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build review decisions map and track which reviews requested auto-approval
|
||||||
|
# Auto-approved reviews use original data (no modifications allowed)
|
||||||
review_decisions = {}
|
review_decisions = {}
|
||||||
|
auto_approve_requests = {} # Map node_exec_id -> auto_approve_future flag
|
||||||
|
|
||||||
for review in request.reviews:
|
for review in request.reviews:
|
||||||
review_status = (
|
review_status = (
|
||||||
ReviewStatus.APPROVED if review.approved else ReviewStatus.REJECTED
|
ReviewStatus.APPROVED if review.approved else ReviewStatus.REJECTED
|
||||||
)
|
)
|
||||||
|
# If this review requested auto-approval, don't allow data modifications
|
||||||
|
reviewed_data = None if review.auto_approve_future else review.reviewed_data
|
||||||
review_decisions[review.node_exec_id] = (
|
review_decisions[review.node_exec_id] = (
|
||||||
review_status,
|
review_status,
|
||||||
review.reviewed_data,
|
reviewed_data,
|
||||||
review.message,
|
review.message,
|
||||||
)
|
)
|
||||||
|
auto_approve_requests[review.node_exec_id] = review.auto_approve_future
|
||||||
|
|
||||||
# Process all reviews
|
# Process all reviews
|
||||||
updated_reviews = await process_all_reviews_for_execution(
|
updated_reviews = await process_all_reviews_for_execution(
|
||||||
@@ -145,6 +208,87 @@ async def process_review_action(
|
|||||||
review_decisions=review_decisions,
|
review_decisions=review_decisions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create auto-approval records for approved reviews that requested it
|
||||||
|
# Deduplicate by node_id to avoid race conditions when multiple reviews
|
||||||
|
# for the same node are processed in parallel
|
||||||
|
async def create_auto_approval_for_node(
|
||||||
|
node_id: str, review_result
|
||||||
|
) -> tuple[str, bool]:
|
||||||
|
"""
|
||||||
|
Create auto-approval record for a node.
|
||||||
|
Returns (node_id, success) tuple for tracking failures.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await create_auto_approval_record(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=review_result.graph_exec_id,
|
||||||
|
graph_id=review_result.graph_id,
|
||||||
|
graph_version=review_result.graph_version,
|
||||||
|
node_id=node_id,
|
||||||
|
payload=review_result.payload,
|
||||||
|
)
|
||||||
|
return (node_id, True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create auto-approval record for node {node_id}",
|
||||||
|
exc_info=e,
|
||||||
|
)
|
||||||
|
return (node_id, False)
|
||||||
|
|
||||||
|
# Collect node_exec_ids that need auto-approval
|
||||||
|
node_exec_ids_needing_auto_approval = [
|
||||||
|
node_exec_id
|
||||||
|
for node_exec_id, review_result in updated_reviews.items()
|
||||||
|
if review_result.status == ReviewStatus.APPROVED
|
||||||
|
and auto_approve_requests.get(node_exec_id, False)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Batch-fetch node executions to get node_ids
|
||||||
|
nodes_needing_auto_approval: dict[str, Any] = {}
|
||||||
|
if node_exec_ids_needing_auto_approval:
|
||||||
|
from backend.data.execution import get_node_executions
|
||||||
|
|
||||||
|
node_execs = await get_node_executions(
|
||||||
|
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||||
|
)
|
||||||
|
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
|
||||||
|
|
||||||
|
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||||
|
node_exec = node_exec_map.get(node_exec_id)
|
||||||
|
if node_exec:
|
||||||
|
review_result = updated_reviews[node_exec_id]
|
||||||
|
# Use the first approved review for this node (deduplicate by node_id)
|
||||||
|
if node_exec.node_id not in nodes_needing_auto_approval:
|
||||||
|
nodes_needing_auto_approval[node_exec.node_id] = review_result
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create auto-approval record for {node_exec_id}: "
|
||||||
|
f"Node execution not found. This may indicate a race condition "
|
||||||
|
f"or data inconsistency."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
||||||
|
auto_approval_results = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
create_auto_approval_for_node(node_id, review_result)
|
||||||
|
for node_id, review_result in nodes_needing_auto_approval.items()
|
||||||
|
],
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count auto-approval failures
|
||||||
|
auto_approval_failed_count = 0
|
||||||
|
for result in auto_approval_results:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
# Unexpected exception during auto-approval creation
|
||||||
|
auto_approval_failed_count += 1
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected exception during auto-approval creation: {result}"
|
||||||
|
)
|
||||||
|
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
||||||
|
# Auto-approval creation failed (returned False)
|
||||||
|
auto_approval_failed_count += 1
|
||||||
|
|
||||||
# Count results
|
# Count results
|
||||||
approved_count = sum(
|
approved_count = sum(
|
||||||
1
|
1
|
||||||
@@ -157,30 +301,53 @@ async def process_review_action(
|
|||||||
if review.status == ReviewStatus.REJECTED
|
if review.status == ReviewStatus.REJECTED
|
||||||
)
|
)
|
||||||
|
|
||||||
# Resume execution if we processed some reviews
|
# Resume execution only if ALL pending reviews for this execution have been processed
|
||||||
if updated_reviews:
|
if updated_reviews:
|
||||||
# Get graph execution ID from any processed review
|
|
||||||
first_review = next(iter(updated_reviews.values()))
|
|
||||||
graph_exec_id = first_review.graph_exec_id
|
|
||||||
|
|
||||||
# Check if any pending reviews remain for this execution
|
|
||||||
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
||||||
|
|
||||||
if not still_has_pending:
|
if not still_has_pending:
|
||||||
# Resume execution
|
# Get the graph_id from any processed review
|
||||||
|
first_review = next(iter(updated_reviews.values()))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Fetch user and settings to build complete execution context
|
||||||
|
user = await get_user_by_id(user_id)
|
||||||
|
settings = await get_graph_settings(
|
||||||
|
user_id=user_id, graph_id=first_review.graph_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Preserve user's timezone preference when resuming execution
|
||||||
|
user_timezone = (
|
||||||
|
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||||
|
)
|
||||||
|
|
||||||
|
execution_context = ExecutionContext(
|
||||||
|
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||||
|
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||||
|
user_timezone=user_timezone,
|
||||||
|
)
|
||||||
|
|
||||||
await add_graph_execution(
|
await add_graph_execution(
|
||||||
graph_id=first_review.graph_id,
|
graph_id=first_review.graph_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
graph_exec_id=graph_exec_id,
|
graph_exec_id=graph_exec_id,
|
||||||
|
execution_context=execution_context,
|
||||||
)
|
)
|
||||||
logger.info(f"Resumed execution {graph_exec_id}")
|
logger.info(f"Resumed execution {graph_exec_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to resume execution {graph_exec_id}: {str(e)}")
|
logger.error(f"Failed to resume execution {graph_exec_id}: {str(e)}")
|
||||||
|
|
||||||
|
# Build error message if auto-approvals failed
|
||||||
|
error_message = None
|
||||||
|
if auto_approval_failed_count > 0:
|
||||||
|
error_message = (
|
||||||
|
f"{auto_approval_failed_count} auto-approval setting(s) could not be saved. "
|
||||||
|
f"You may need to manually approve these reviews in future executions."
|
||||||
|
)
|
||||||
|
|
||||||
return ReviewResponse(
|
return ReviewResponse(
|
||||||
approved_count=approved_count,
|
approved_count=approved_count,
|
||||||
rejected_count=rejected_count,
|
rejected_count=rejected_count,
|
||||||
failed_count=0,
|
failed_count=auto_approval_failed_count,
|
||||||
error=None,
|
error=error_message,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -583,7 +583,13 @@ async def update_library_agent(
|
|||||||
)
|
)
|
||||||
update_fields["isDeleted"] = is_deleted
|
update_fields["isDeleted"] = is_deleted
|
||||||
if settings is not None:
|
if settings is not None:
|
||||||
update_fields["settings"] = SafeJson(settings.model_dump())
|
existing_agent = await get_library_agent(id=library_agent_id, user_id=user_id)
|
||||||
|
current_settings_dict = (
|
||||||
|
existing_agent.settings.model_dump() if existing_agent.settings else {}
|
||||||
|
)
|
||||||
|
new_settings = settings.model_dump(exclude_unset=True)
|
||||||
|
merged_settings = {**current_settings_dict, **new_settings}
|
||||||
|
update_fields["settings"] = SafeJson(merged_settings)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# If graph_version is provided, update to that specific version
|
# If graph_version is provided, update to that specific version
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from typing import AsyncGenerator
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||||
from prisma.enums import APIKeyPermission
|
from prisma.enums import APIKeyPermission
|
||||||
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
|
from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken
|
||||||
@@ -38,13 +39,13 @@ keysmith = APIKeySmith()
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(scope="session")
|
||||||
def test_user_id() -> str:
|
def test_user_id() -> str:
|
||||||
"""Test user ID for OAuth tests."""
|
"""Test user ID for OAuth tests."""
|
||||||
return str(uuid.uuid4())
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||||
async def test_user(server, test_user_id: str):
|
async def test_user(server, test_user_id: str):
|
||||||
"""Create a test user in the database."""
|
"""Create a test user in the database."""
|
||||||
await PrismaUser.prisma().create(
|
await PrismaUser.prisma().create(
|
||||||
@@ -67,7 +68,7 @@ async def test_user(server, test_user_id: str):
|
|||||||
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def test_oauth_app(test_user: str):
|
async def test_oauth_app(test_user: str):
|
||||||
"""Create a test OAuth application in the database."""
|
"""Create a test OAuth application in the database."""
|
||||||
app_id = str(uuid.uuid4())
|
app_id = str(uuid.uuid4())
|
||||||
@@ -122,7 +123,7 @@ def pkce_credentials() -> tuple[str, str]:
|
|||||||
return generate_pkce()
|
return generate_pkce()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, None]:
|
async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||||
"""
|
"""
|
||||||
Create an async HTTP client that talks directly to the FastAPI app.
|
Create an async HTTP client that talks directly to the FastAPI app.
|
||||||
@@ -287,7 +288,7 @@ async def test_authorize_invalid_client_returns_error(
|
|||||||
assert query_params["error"][0] == "invalid_client"
|
assert query_params["error"][0] == "invalid_client"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def inactive_oauth_app(test_user: str):
|
async def inactive_oauth_app(test_user: str):
|
||||||
"""Create an inactive test OAuth application in the database."""
|
"""Create an inactive test OAuth application in the database."""
|
||||||
app_id = str(uuid.uuid4())
|
app_id = str(uuid.uuid4())
|
||||||
@@ -1004,7 +1005,7 @@ async def test_token_refresh_revoked(
|
|||||||
assert "revoked" in response.json()["detail"].lower()
|
assert "revoked" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest_asyncio.fixture
|
||||||
async def other_oauth_app(test_user: str):
|
async def other_oauth_app(test_user: str):
|
||||||
"""Create a second OAuth application for cross-app tests."""
|
"""Create a second OAuth application for cross-app tests."""
|
||||||
app_id = str(uuid.uuid4())
|
app_id = str(uuid.uuid4())
|
||||||
|
|||||||
@@ -188,6 +188,10 @@ class BlockHandler(ContentHandler):
|
|||||||
try:
|
try:
|
||||||
block_instance = block_cls()
|
block_instance = block_cls()
|
||||||
|
|
||||||
|
# Skip disabled blocks - they shouldn't be indexed
|
||||||
|
if block_instance.disabled:
|
||||||
|
continue
|
||||||
|
|
||||||
# Build searchable text from block metadata
|
# Build searchable text from block metadata
|
||||||
parts = []
|
parts = []
|
||||||
if hasattr(block_instance, "name") and block_instance.name:
|
if hasattr(block_instance, "name") and block_instance.name:
|
||||||
@@ -248,12 +252,19 @@ class BlockHandler(ContentHandler):
|
|||||||
from backend.data.block import get_blocks
|
from backend.data.block import get_blocks
|
||||||
|
|
||||||
all_blocks = get_blocks()
|
all_blocks = get_blocks()
|
||||||
total_blocks = len(all_blocks)
|
|
||||||
|
# Filter out disabled blocks - they're not indexed
|
||||||
|
enabled_block_ids = [
|
||||||
|
block_id
|
||||||
|
for block_id, block_cls in all_blocks.items()
|
||||||
|
if not block_cls().disabled
|
||||||
|
]
|
||||||
|
total_blocks = len(enabled_block_ids)
|
||||||
|
|
||||||
if total_blocks == 0:
|
if total_blocks == 0:
|
||||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||||
|
|
||||||
block_ids = list(all_blocks.keys())
|
block_ids = enabled_block_ids
|
||||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||||
|
|
||||||
embedded_result = await query_raw_with_schema(
|
embedded_result = await query_raw_with_schema(
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ async def test_block_handler_get_missing_items(mocker):
|
|||||||
mock_block_instance.name = "Calculator Block"
|
mock_block_instance.name = "Calculator Block"
|
||||||
mock_block_instance.description = "Performs calculations"
|
mock_block_instance.description = "Performs calculations"
|
||||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||||
|
mock_block_instance.disabled = False
|
||||||
mock_block_instance.input_schema.model_json_schema.return_value = {
|
mock_block_instance.input_schema.model_json_schema.return_value = {
|
||||||
"properties": {"expression": {"description": "Math expression to evaluate"}}
|
"properties": {"expression": {"description": "Math expression to evaluate"}}
|
||||||
}
|
}
|
||||||
@@ -116,11 +117,18 @@ async def test_block_handler_get_stats(mocker):
|
|||||||
"""Test BlockHandler returns correct stats."""
|
"""Test BlockHandler returns correct stats."""
|
||||||
handler = BlockHandler()
|
handler = BlockHandler()
|
||||||
|
|
||||||
# Mock get_blocks
|
# Mock get_blocks - each block class returns an instance with disabled=False
|
||||||
|
def make_mock_block_class():
|
||||||
|
mock_class = MagicMock()
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_instance.disabled = False
|
||||||
|
mock_class.return_value = mock_instance
|
||||||
|
return mock_class
|
||||||
|
|
||||||
mock_blocks = {
|
mock_blocks = {
|
||||||
"block-1": MagicMock(),
|
"block-1": make_mock_block_class(),
|
||||||
"block-2": MagicMock(),
|
"block-2": make_mock_block_class(),
|
||||||
"block-3": MagicMock(),
|
"block-3": make_mock_block_class(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mock embedded count query (2 blocks have embeddings)
|
# Mock embedded count query (2 blocks have embeddings)
|
||||||
@@ -309,6 +317,7 @@ async def test_block_handler_handles_missing_attributes():
|
|||||||
mock_block_class = MagicMock()
|
mock_block_class = MagicMock()
|
||||||
mock_block_instance = MagicMock()
|
mock_block_instance = MagicMock()
|
||||||
mock_block_instance.name = "Minimal Block"
|
mock_block_instance.name = "Minimal Block"
|
||||||
|
mock_block_instance.disabled = False
|
||||||
# No description, categories, or schema
|
# No description, categories, or schema
|
||||||
del mock_block_instance.description
|
del mock_block_instance.description
|
||||||
del mock_block_instance.categories
|
del mock_block_instance.categories
|
||||||
@@ -342,6 +351,7 @@ async def test_block_handler_skips_failed_blocks():
|
|||||||
good_instance.name = "Good Block"
|
good_instance.name = "Good Block"
|
||||||
good_instance.description = "Works fine"
|
good_instance.description = "Works fine"
|
||||||
good_instance.categories = []
|
good_instance.categories = []
|
||||||
|
good_instance.disabled = False
|
||||||
good_block.return_value = good_instance
|
good_block.return_value = good_instance
|
||||||
|
|
||||||
bad_block = MagicMock()
|
bad_block = MagicMock()
|
||||||
|
|||||||
@@ -1552,7 +1552,7 @@ async def review_store_submission(
|
|||||||
|
|
||||||
# Generate embedding for approved listing (blocking - admin operation)
|
# Generate embedding for approved listing (blocking - admin operation)
|
||||||
# Inside transaction: if embedding fails, entire transaction rolls back
|
# Inside transaction: if embedding fails, entire transaction rolls back
|
||||||
embedding_success = await ensure_embedding(
|
await ensure_embedding(
|
||||||
version_id=store_listing_version_id,
|
version_id=store_listing_version_id,
|
||||||
name=store_listing_version.name,
|
name=store_listing_version.name,
|
||||||
description=store_listing_version.description,
|
description=store_listing_version.description,
|
||||||
@@ -1560,12 +1560,6 @@ async def review_store_submission(
|
|||||||
categories=store_listing_version.categories or [],
|
categories=store_listing_version.categories or [],
|
||||||
tx=tx,
|
tx=tx,
|
||||||
)
|
)
|
||||||
if not embedding_success:
|
|
||||||
raise ValueError(
|
|
||||||
f"Failed to generate embedding for listing {store_listing_version_id}. "
|
|
||||||
"This is likely due to OpenAI API being unavailable. "
|
|
||||||
"Please try again later or contact support if the issue persists."
|
|
||||||
)
|
|
||||||
|
|
||||||
await prisma.models.StoreListing.prisma(tx).update(
|
await prisma.models.StoreListing.prisma(tx).update(
|
||||||
where={"id": store_listing_version.StoreListing.id},
|
where={"id": store_listing_version.StoreListing.id},
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ from backend.util.json import dumps
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# OpenAI embedding model configuration
|
# OpenAI embedding model configuration
|
||||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||||
# Embedding dimension for the model above
|
# Embedding dimension for the model above
|
||||||
@@ -63,18 +62,15 @@ def build_searchable_text(
|
|||||||
return " ".join(parts)
|
return " ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
async def generate_embedding(text: str) -> list[float] | None:
|
async def generate_embedding(text: str) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Generate embedding for text using OpenAI API.
|
Generate embedding for text using OpenAI API.
|
||||||
|
|
||||||
Returns None if embedding generation fails.
|
Raises exceptions on failure - caller should handle.
|
||||||
Fail-fast: no retries to maintain consistency with approval flow.
|
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
client = get_openai_client()
|
client = get_openai_client()
|
||||||
if not client:
|
if not client:
|
||||||
logger.error("openai_internal_api_key not set, cannot generate embedding")
|
raise RuntimeError("openai_internal_api_key not set, cannot generate embedding")
|
||||||
return None
|
|
||||||
|
|
||||||
# Truncate text to token limit using tiktoken
|
# Truncate text to token limit using tiktoken
|
||||||
# Character-based truncation is insufficient because token ratios vary by content type
|
# Character-based truncation is insufficient because token ratios vary by content type
|
||||||
@@ -103,10 +99,6 @@ async def generate_embedding(text: str) -> list[float] | None:
|
|||||||
)
|
)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to generate embedding: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def store_embedding(
|
async def store_embedding(
|
||||||
version_id: str,
|
version_id: str,
|
||||||
@@ -144,8 +136,9 @@ async def store_content_embedding(
|
|||||||
|
|
||||||
New function for unified content embedding storage.
|
New function for unified content embedding storage.
|
||||||
Uses raw SQL since Prisma doesn't natively support pgvector.
|
Uses raw SQL since Prisma doesn't natively support pgvector.
|
||||||
|
|
||||||
|
Raises exceptions on failure - caller should handle.
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
client = tx if tx else prisma.get_client()
|
client = tx if tx else prisma.get_client()
|
||||||
|
|
||||||
# Convert embedding to PostgreSQL vector format
|
# Convert embedding to PostgreSQL vector format
|
||||||
@@ -183,10 +176,6 @@ async def store_content_embedding(
|
|||||||
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
||||||
"""
|
"""
|
||||||
@@ -217,8 +206,9 @@ async def get_content_embedding(
|
|||||||
|
|
||||||
New function for unified content embedding retrieval.
|
New function for unified content embedding retrieval.
|
||||||
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
|
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
|
||||||
|
|
||||||
|
Raises exceptions on failure - caller should handle.
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
result = await query_raw_with_schema(
|
result = await query_raw_with_schema(
|
||||||
"""
|
"""
|
||||||
SELECT
|
SELECT
|
||||||
@@ -242,10 +232,6 @@ async def get_content_embedding(
|
|||||||
return result[0]
|
return result[0]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def ensure_embedding(
|
async def ensure_embedding(
|
||||||
version_id: str,
|
version_id: str,
|
||||||
@@ -272,9 +258,10 @@ async def ensure_embedding(
|
|||||||
tx: Optional transaction client
|
tx: Optional transaction client
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if embedding exists/was created, False on failure
|
True if embedding exists/was created
|
||||||
|
|
||||||
|
Raises exceptions on failure - caller should handle.
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
# Check if embedding already exists
|
# Check if embedding already exists
|
||||||
if not force:
|
if not force:
|
||||||
existing = await get_embedding(version_id)
|
existing = await get_embedding(version_id)
|
||||||
@@ -283,15 +270,10 @@ async def ensure_embedding(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# Build searchable text for embedding
|
# Build searchable text for embedding
|
||||||
searchable_text = build_searchable_text(
|
searchable_text = build_searchable_text(name, description, sub_heading, categories)
|
||||||
name, description, sub_heading, categories
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate new embedding
|
# Generate new embedding
|
||||||
embedding = await generate_embedding(searchable_text)
|
embedding = await generate_embedding(searchable_text)
|
||||||
if embedding is None:
|
|
||||||
logger.warning(f"Could not generate embedding for version {version_id}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Store the embedding with metadata using new function
|
# Store the embedding with metadata using new function
|
||||||
metadata = {
|
metadata = {
|
||||||
@@ -309,10 +291,6 @@ async def ensure_embedding(
|
|||||||
tx=tx,
|
tx=tx,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_embedding(version_id: str) -> bool:
|
async def delete_embedding(version_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -521,6 +499,24 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
|||||||
success = sum(1 for result in results if result is True)
|
success = sum(1 for result in results if result is True)
|
||||||
failed = len(results) - success
|
failed = len(results) - success
|
||||||
|
|
||||||
|
# Aggregate unique errors to avoid Sentry spam
|
||||||
|
if failed > 0:
|
||||||
|
# Group errors by type and message
|
||||||
|
error_summary: dict[str, int] = {}
|
||||||
|
for result in results:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
error_key = f"{type(result).__name__}: {str(result)}"
|
||||||
|
error_summary[error_key] = error_summary.get(error_key, 0) + 1
|
||||||
|
|
||||||
|
# Log aggregated error summary
|
||||||
|
error_details = ", ".join(
|
||||||
|
f"{error} ({count}x)" for error, count in error_summary.items()
|
||||||
|
)
|
||||||
|
logger.error(
|
||||||
|
f"{content_type.value}: {failed}/{len(results)} embeddings failed. "
|
||||||
|
f"Errors: {error_details}"
|
||||||
|
)
|
||||||
|
|
||||||
results_by_type[content_type.value] = {
|
results_by_type[content_type.value] = {
|
||||||
"processed": len(missing_items),
|
"processed": len(missing_items),
|
||||||
"success": success,
|
"success": success,
|
||||||
@@ -557,11 +553,12 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def embed_query(query: str) -> list[float] | None:
|
async def embed_query(query: str) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Generate embedding for a search query.
|
Generate embedding for a search query.
|
||||||
|
|
||||||
Same as generate_embedding but with clearer intent.
|
Same as generate_embedding but with clearer intent.
|
||||||
|
Raises exceptions on failure - caller should handle.
|
||||||
"""
|
"""
|
||||||
return await generate_embedding(query)
|
return await generate_embedding(query)
|
||||||
|
|
||||||
@@ -594,25 +591,19 @@ async def ensure_content_embedding(
|
|||||||
tx: Optional transaction client
|
tx: Optional transaction client
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if embedding exists/was created, False on failure
|
True if embedding exists/was created
|
||||||
|
|
||||||
|
Raises exceptions on failure - caller should handle.
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
# Check if embedding already exists
|
# Check if embedding already exists
|
||||||
if not force:
|
if not force:
|
||||||
existing = await get_content_embedding(content_type, content_id, user_id)
|
existing = await get_content_embedding(content_type, content_id, user_id)
|
||||||
if existing and existing.get("embedding"):
|
if existing and existing.get("embedding"):
|
||||||
logger.debug(
|
logger.debug(f"Embedding for {content_type}:{content_id} already exists")
|
||||||
f"Embedding for {content_type}:{content_id} already exists"
|
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Generate new embedding
|
# Generate new embedding
|
||||||
embedding = await generate_embedding(searchable_text)
|
embedding = await generate_embedding(searchable_text)
|
||||||
if embedding is None:
|
|
||||||
logger.warning(
|
|
||||||
f"Could not generate embedding for {content_type}:{content_id}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Store the embedding
|
# Store the embedding
|
||||||
return await store_content_embedding(
|
return await store_content_embedding(
|
||||||
@@ -625,10 +616,6 @@ async def ensure_content_embedding(
|
|||||||
tx=tx,
|
tx=tx,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -854,9 +841,8 @@ async def semantic_search(
|
|||||||
limit = 100
|
limit = 100
|
||||||
|
|
||||||
# Generate query embedding
|
# Generate query embedding
|
||||||
|
try:
|
||||||
query_embedding = await embed_query(query)
|
query_embedding = await embed_query(query)
|
||||||
|
|
||||||
if query_embedding is not None:
|
|
||||||
# Semantic search with embeddings
|
# Semantic search with embeddings
|
||||||
embedding_str = embedding_to_vector_string(query_embedding)
|
embedding_str = embedding_to_vector_string(query_embedding)
|
||||||
|
|
||||||
@@ -907,7 +893,6 @@ async def semantic_search(
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
results = await query_raw_with_schema(sql, *params)
|
results = await query_raw_with_schema(sql, *params)
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
@@ -920,11 +905,9 @@ async def semantic_search(
|
|||||||
for row in results
|
for row in results
|
||||||
]
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Semantic search failed: {e}")
|
logger.warning(f"Semantic search failed, falling back to lexical search: {e}")
|
||||||
# Fall through to lexical search below
|
|
||||||
|
|
||||||
# Fallback to lexical search if embeddings unavailable
|
# Fallback to lexical search if embeddings unavailable
|
||||||
logger.warning("Falling back to lexical search (embeddings unavailable)")
|
|
||||||
|
|
||||||
params_lexical: list[Any] = [limit]
|
params_lexical: list[Any] = [limit]
|
||||||
user_filter = ""
|
user_filter = ""
|
||||||
|
|||||||
@@ -298,7 +298,9 @@ async def test_schema_handling_error_cases():
|
|||||||
mock_client.execute_raw.side_effect = Exception("Database error")
|
mock_client.execute_raw.side_effect = Exception("Database error")
|
||||||
mock_get_client.return_value = mock_client
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
result = await embeddings.store_content_embedding(
|
# Should raise exception on error
|
||||||
|
with pytest.raises(Exception, match="Database error"):
|
||||||
|
await embeddings.store_content_embedding(
|
||||||
content_type=ContentType.STORE_AGENT,
|
content_type=ContentType.STORE_AGENT,
|
||||||
content_id="test-id",
|
content_id="test-id",
|
||||||
embedding=[0.1] * EMBEDDING_DIM,
|
embedding=[0.1] * EMBEDDING_DIM,
|
||||||
@@ -307,9 +309,6 @@ async def test_schema_handling_error_cases():
|
|||||||
user_id=None,
|
user_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should return False on error, not raise
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, "-v", "-s"])
|
pytest.main([__file__, "-v", "-s"])
|
||||||
|
|||||||
@@ -80,9 +80,8 @@ async def test_generate_embedding_no_api_key():
|
|||||||
) as mock_get_client:
|
) as mock_get_client:
|
||||||
mock_get_client.return_value = None
|
mock_get_client.return_value = None
|
||||||
|
|
||||||
result = await embeddings.generate_embedding("test text")
|
with pytest.raises(RuntimeError, match="openai_internal_api_key not set"):
|
||||||
|
await embeddings.generate_embedding("test text")
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
@@ -97,9 +96,8 @@ async def test_generate_embedding_api_error():
|
|||||||
) as mock_get_client:
|
) as mock_get_client:
|
||||||
mock_get_client.return_value = mock_client
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
result = await embeddings.generate_embedding("test text")
|
with pytest.raises(Exception, match="API Error"):
|
||||||
|
await embeddings.generate_embedding("test text")
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
@@ -173,12 +171,11 @@ async def test_store_embedding_database_error(mocker):
|
|||||||
|
|
||||||
embedding = [0.1, 0.2, 0.3]
|
embedding = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
result = await embeddings.store_embedding(
|
with pytest.raises(Exception, match="Database error"):
|
||||||
|
await embeddings.store_embedding(
|
||||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_get_embedding_success():
|
async def test_get_embedding_success():
|
||||||
@@ -277,9 +274,10 @@ async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
|
|||||||
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
|
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
|
||||||
"""Test ensure_embedding when generation fails."""
|
"""Test ensure_embedding when generation fails."""
|
||||||
mock_get.return_value = None
|
mock_get.return_value = None
|
||||||
mock_generate.return_value = None
|
mock_generate.side_effect = Exception("Generation failed")
|
||||||
|
|
||||||
result = await embeddings.ensure_embedding(
|
with pytest.raises(Exception, match="Generation failed"):
|
||||||
|
await embeddings.ensure_embedding(
|
||||||
version_id="test-id",
|
version_id="test-id",
|
||||||
name="Test",
|
name="Test",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
@@ -287,8 +285,6 @@ async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
|
|||||||
categories=["test"],
|
categories=["test"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_get_embedding_stats():
|
async def test_get_embedding_stats():
|
||||||
|
|||||||
@@ -186,13 +186,12 @@ async def unified_hybrid_search(
|
|||||||
|
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
# Generate query embedding
|
# Generate query embedding with graceful degradation
|
||||||
|
try:
|
||||||
query_embedding = await embed_query(query)
|
query_embedding = await embed_query(query)
|
||||||
|
except Exception as e:
|
||||||
# Graceful degradation if embedding unavailable
|
|
||||||
if query_embedding is None or not query_embedding:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to generate query embedding - falling back to lexical-only search. "
|
f"Failed to generate query embedding - falling back to lexical-only search: {e}. "
|
||||||
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
|
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
|
||||||
)
|
)
|
||||||
query_embedding = [0.0] * EMBEDDING_DIM
|
query_embedding = [0.0] * EMBEDDING_DIM
|
||||||
@@ -464,13 +463,12 @@ async def hybrid_search(
|
|||||||
|
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
# Generate query embedding
|
# Generate query embedding with graceful degradation
|
||||||
|
try:
|
||||||
query_embedding = await embed_query(query)
|
query_embedding = await embed_query(query)
|
||||||
|
except Exception as e:
|
||||||
# Graceful degradation
|
|
||||||
if query_embedding is None or not query_embedding:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to generate query embedding - falling back to lexical-only search."
|
f"Failed to generate query embedding - falling back to lexical-only search: {e}"
|
||||||
)
|
)
|
||||||
query_embedding = [0.0] * EMBEDDING_DIM
|
query_embedding = [0.0] * EMBEDDING_DIM
|
||||||
total_non_semantic = (
|
total_non_semantic = (
|
||||||
|
|||||||
@@ -172,8 +172,8 @@ async def test_hybrid_search_without_embeddings():
|
|||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||||
) as mock_query:
|
) as mock_query:
|
||||||
# Simulate embedding failure
|
# Simulate embedding failure by raising exception
|
||||||
mock_embed.return_value = None
|
mock_embed.side_effect = Exception("Embedding generation failed")
|
||||||
mock_query.return_value = mock_results
|
mock_query.return_value = mock_results
|
||||||
|
|
||||||
# Should NOT raise - graceful degradation
|
# Should NOT raise - graceful degradation
|
||||||
@@ -613,7 +613,9 @@ async def test_unified_hybrid_search_graceful_degradation():
|
|||||||
"backend.api.features.store.hybrid_search.embed_query"
|
"backend.api.features.store.hybrid_search.embed_query"
|
||||||
) as mock_embed:
|
) as mock_embed:
|
||||||
mock_query.return_value = mock_results
|
mock_query.return_value = mock_results
|
||||||
mock_embed.return_value = None # Embedding failure
|
mock_embed.side_effect = Exception(
|
||||||
|
"Embedding generation failed"
|
||||||
|
) # Embedding failure
|
||||||
|
|
||||||
# Should NOT raise - graceful degradation
|
# Should NOT raise - graceful degradation
|
||||||
results, total = await unified_hybrid_search(
|
results, total = await unified_hybrid_search(
|
||||||
|
|||||||
@@ -265,9 +265,13 @@ async def get_onboarding_agents(
|
|||||||
"/onboarding/enabled",
|
"/onboarding/enabled",
|
||||||
summary="Is onboarding enabled",
|
summary="Is onboarding enabled",
|
||||||
tags=["onboarding", "public"],
|
tags=["onboarding", "public"],
|
||||||
dependencies=[Security(requires_user)],
|
|
||||||
)
|
)
|
||||||
async def is_onboarding_enabled() -> bool:
|
async def is_onboarding_enabled(
|
||||||
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
) -> bool:
|
||||||
|
# If chat is enabled for user, skip legacy onboarding
|
||||||
|
if await is_feature_enabled(Flag.CHAT, user_id, False):
|
||||||
|
return False
|
||||||
return await onboarding_enabled()
|
return await onboarding_enabled()
|
||||||
|
|
||||||
|
|
||||||
@@ -364,6 +368,8 @@ async def execute_graph_block(
|
|||||||
obj = get_block(block_id)
|
obj = get_block(block_id)
|
||||||
if not obj:
|
if not obj:
|
||||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||||
|
if obj.disabled:
|
||||||
|
raise HTTPException(status_code=403, detail=f"Block #{block_id} is disabled.")
|
||||||
|
|
||||||
user = await get_user_by_id(user_id)
|
user = await get_user_by_id(user_id)
|
||||||
if not user:
|
if not user:
|
||||||
|
|||||||
@@ -138,6 +138,7 @@ def test_execute_graph_block(
|
|||||||
"""Test execute block endpoint"""
|
"""Test execute block endpoint"""
|
||||||
# Mock block
|
# Mock block
|
||||||
mock_block = Mock()
|
mock_block = Mock()
|
||||||
|
mock_block.disabled = False
|
||||||
|
|
||||||
async def mock_execute(*args, **kwargs):
|
async def mock_execute(*args, **kwargs):
|
||||||
yield "output1", {"data": "result1"}
|
yield "output1", {"data": "result1"}
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
# Workspace API feature module
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
"""
|
||||||
|
Workspace API routes for managing user file storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Annotated
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from backend.data.workspace import get_workspace, get_workspace_file
|
||||||
|
from backend.util.workspace_storage import get_workspace_storage
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_filename_for_header(filename: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize filename for Content-Disposition header to prevent header injection.
|
||||||
|
|
||||||
|
Removes/replaces characters that could break the header or inject new headers.
|
||||||
|
Uses RFC5987 encoding for non-ASCII characters.
|
||||||
|
"""
|
||||||
|
# Remove CR, LF, and null bytes (header injection prevention)
|
||||||
|
sanitized = re.sub(r"[\r\n\x00]", "", filename)
|
||||||
|
# Escape quotes
|
||||||
|
sanitized = sanitized.replace('"', '\\"')
|
||||||
|
# For non-ASCII, use RFC5987 filename* parameter
|
||||||
|
# Check if filename has non-ASCII characters
|
||||||
|
try:
|
||||||
|
sanitized.encode("ascii")
|
||||||
|
return f'attachment; filename="{sanitized}"'
|
||||||
|
except UnicodeEncodeError:
|
||||||
|
# Use RFC5987 encoding for UTF-8 filenames
|
||||||
|
encoded = quote(sanitized, safe="")
|
||||||
|
return f"attachment; filename*=UTF-8''{encoded}"
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = fastapi.APIRouter(
|
||||||
|
dependencies=[fastapi.Security(requires_user)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_streaming_response(content: bytes, file) -> Response:
|
||||||
|
"""Create a streaming response for file content."""
|
||||||
|
return Response(
|
||||||
|
content=content,
|
||||||
|
media_type=file.mimeType,
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||||
|
"Content-Length": str(len(content)),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_file_download_response(file) -> Response:
|
||||||
|
"""
|
||||||
|
Create a download response for a workspace file.
|
||||||
|
|
||||||
|
Handles both local storage (direct streaming) and GCS (signed URL redirect
|
||||||
|
with fallback to streaming).
|
||||||
|
"""
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
|
||||||
|
# For local storage, stream the file directly
|
||||||
|
if file.storagePath.startswith("local://"):
|
||||||
|
content = await storage.retrieve(file.storagePath)
|
||||||
|
return _create_streaming_response(content, file)
|
||||||
|
|
||||||
|
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||||
|
try:
|
||||||
|
url = await storage.get_download_url(file.storagePath, expires_in=300)
|
||||||
|
# If we got back an API path (fallback), stream directly instead
|
||||||
|
if url.startswith("/api/"):
|
||||||
|
content = await storage.retrieve(file.storagePath)
|
||||||
|
return _create_streaming_response(content, file)
|
||||||
|
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||||
|
except Exception as e:
|
||||||
|
# Log the signed URL failure with context
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get signed URL for file {file.id} "
|
||||||
|
f"(storagePath={file.storagePath}): {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# Fall back to streaming directly from GCS
|
||||||
|
try:
|
||||||
|
content = await storage.retrieve(file.storagePath)
|
||||||
|
return _create_streaming_response(content, file)
|
||||||
|
except Exception as fallback_error:
|
||||||
|
logger.error(
|
||||||
|
f"Fallback streaming also failed for file {file.id} "
|
||||||
|
f"(storagePath={file.storagePath}): {fallback_error}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/files/{file_id}/download",
|
||||||
|
summary="Download file by ID",
|
||||||
|
)
|
||||||
|
async def download_file(
|
||||||
|
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||||
|
file_id: str,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Download a file by its ID.
|
||||||
|
|
||||||
|
Returns the file content directly or redirects to a signed URL for GCS.
|
||||||
|
"""
|
||||||
|
workspace = await get_workspace(user_id)
|
||||||
|
if workspace is None:
|
||||||
|
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
|
||||||
|
|
||||||
|
file = await get_workspace_file(file_id, workspace.id)
|
||||||
|
if file is None:
|
||||||
|
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
|
return await _create_file_download_response(file)
|
||||||
@@ -32,6 +32,7 @@ import backend.api.features.postmark.postmark
|
|||||||
import backend.api.features.store.model
|
import backend.api.features.store.model
|
||||||
import backend.api.features.store.routes
|
import backend.api.features.store.routes
|
||||||
import backend.api.features.v1
|
import backend.api.features.v1
|
||||||
|
import backend.api.features.workspace.routes as workspace_routes
|
||||||
import backend.data.block
|
import backend.data.block
|
||||||
import backend.data.db
|
import backend.data.db
|
||||||
import backend.data.graph
|
import backend.data.graph
|
||||||
@@ -52,6 +53,7 @@ from backend.util.exceptions import (
|
|||||||
)
|
)
|
||||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||||
from backend.util.service import UnhealthyServiceError
|
from backend.util.service import UnhealthyServiceError
|
||||||
|
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||||
|
|
||||||
from .external.fastapi_app import external_api
|
from .external.fastapi_app import external_api
|
||||||
from .features.analytics import router as analytics_router
|
from .features.analytics import router as analytics_router
|
||||||
@@ -124,6 +126,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error shutting down cloud storage handler: {e}")
|
logger.warning(f"Error shutting down cloud storage handler: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await shutdown_workspace_storage()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error shutting down workspace storage: {e}")
|
||||||
|
|
||||||
await backend.data.db.disconnect()
|
await backend.data.db.disconnect()
|
||||||
|
|
||||||
|
|
||||||
@@ -315,6 +322,11 @@ app.include_router(
|
|||||||
tags=["v2", "chat"],
|
tags=["v2", "chat"],
|
||||||
prefix="/api/chat",
|
prefix="/api/chat",
|
||||||
)
|
)
|
||||||
|
app.include_router(
|
||||||
|
workspace_routes.router,
|
||||||
|
tags=["v2", "workspace"],
|
||||||
|
prefix="/api/workspace",
|
||||||
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.oauth.router,
|
backend.api.features.oauth.router,
|
||||||
tags=["oauth"],
|
tags=["oauth"],
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -117,11 +118,13 @@ class AIImageCustomizerBlock(Block):
|
|||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_output=[
|
test_output=[
|
||||||
("image_url", "https://replicate.delivery/generated-image.jpg"),
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
|
# Use data URI to avoid HTTP requests during tests
|
||||||
"run_model": lambda *args, **kwargs: MediaFileType(
|
"run_model": lambda *args, **kwargs: MediaFileType(
|
||||||
"https://replicate.delivery/generated-image.jpg"
|
""
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -132,8 +135,7 @@ class AIImageCustomizerBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
@@ -141,10 +143,9 @@ class AIImageCustomizerBlock(Block):
|
|||||||
processed_images = await asyncio.gather(
|
processed_images = await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
store_media_file(
|
store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=img,
|
file=img,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True,
|
return_format="for_external_api", # Get content for Replicate API
|
||||||
)
|
)
|
||||||
for img in input_data.images
|
for img in input_data.images
|
||||||
)
|
)
|
||||||
@@ -158,7 +159,14 @@ class AIImageCustomizerBlock(Block):
|
|||||||
aspect_ratio=input_data.aspect_ratio.value,
|
aspect_ratio=input_data.aspect_ratio.value,
|
||||||
output_format=input_data.output_format.value,
|
output_format=input_data.output_format.value,
|
||||||
)
|
)
|
||||||
yield "image_url", result
|
|
||||||
|
# Store the generated image to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=result,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", stored_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "error", str(e)
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from replicate.client import Client as ReplicateClient
|
|||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -13,6 +14,8 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.file import store_media_file
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
|
|
||||||
class ImageSize(str, Enum):
|
class ImageSize(str, Enum):
|
||||||
@@ -165,11 +168,13 @@ class AIImageGeneratorBlock(Block):
|
|||||||
test_output=[
|
test_output=[
|
||||||
(
|
(
|
||||||
"image_url",
|
"image_url",
|
||||||
"https://replicate.delivery/generated-image.webp",
|
# Test output is a data URI since we now store images
|
||||||
|
lambda x: x.startswith(""
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -318,11 +323,24 @@ class AIImageGeneratorBlock(Block):
|
|||||||
style_text = style_map.get(style, "")
|
style_text = style_map.get(style, "")
|
||||||
return f"{style_text} of" if style_text else ""
|
return f"{style_text} of" if style_text else ""
|
||||||
|
|
||||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
url = await self.generate_image(input_data, credentials)
|
url = await self.generate_image(input_data, credentials)
|
||||||
if url:
|
if url:
|
||||||
yield "image_url", url
|
# Store the generated image to the user's workspace/execution folder
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", stored_url
|
||||||
else:
|
else:
|
||||||
yield "error", "Image generation returned an empty result."
|
yield "error", "Image generation returned an empty result."
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -21,7 +22,9 @@ from backend.data.model import (
|
|||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.exceptions import BlockExecutionError
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
@@ -271,7 +274,10 @@ class AIShortformVideoCreatorBlock(Block):
|
|||||||
"voice": Voice.LILY,
|
"voice": Voice.LILY,
|
||||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||||
},
|
},
|
||||||
test_output=("video_url", "https://example.com/video.mp4"),
|
test_output=(
|
||||||
|
"video_url",
|
||||||
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
|
),
|
||||||
test_mock={
|
test_mock={
|
||||||
"create_webhook": lambda *args, **kwargs: (
|
"create_webhook": lambda *args, **kwargs: (
|
||||||
"test_uuid",
|
"test_uuid",
|
||||||
@@ -280,15 +286,21 @@ class AIShortformVideoCreatorBlock(Block):
|
|||||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||||
"check_video_status": lambda *args, **kwargs: {
|
"check_video_status": lambda *args, **kwargs: {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"videoUrl": "https://example.com/video.mp4",
|
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
|
# Use data URI to avoid HTTP requests during tests
|
||||||
|
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Create a new Webhook.site URL
|
# Create a new Webhook.site URL
|
||||||
webhook_token, webhook_url = await self.create_webhook()
|
webhook_token, webhook_url = await self.create_webhook()
|
||||||
@@ -340,7 +352,13 @@ class AIShortformVideoCreatorBlock(Block):
|
|||||||
)
|
)
|
||||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||||
logger.debug(f"Video ready: {video_url}")
|
logger.debug(f"Video ready: {video_url}")
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
|
|
||||||
|
|
||||||
class AIAdMakerVideoCreatorBlock(Block):
|
class AIAdMakerVideoCreatorBlock(Block):
|
||||||
@@ -447,7 +465,10 @@ class AIAdMakerVideoCreatorBlock(Block):
|
|||||||
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
test_output=("video_url", "https://example.com/ad.mp4"),
|
test_output=(
|
||||||
|
"video_url",
|
||||||
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
|
),
|
||||||
test_mock={
|
test_mock={
|
||||||
"create_webhook": lambda *args, **kwargs: (
|
"create_webhook": lambda *args, **kwargs: (
|
||||||
"test_uuid",
|
"test_uuid",
|
||||||
@@ -456,14 +477,21 @@ class AIAdMakerVideoCreatorBlock(Block):
|
|||||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||||
"check_video_status": lambda *args, **kwargs: {
|
"check_video_status": lambda *args, **kwargs: {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"videoUrl": "https://example.com/ad.mp4",
|
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
|
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
webhook_token, webhook_url = await self.create_webhook()
|
webhook_token, webhook_url = await self.create_webhook()
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -531,7 +559,13 @@ class AIAdMakerVideoCreatorBlock(Block):
|
|||||||
raise RuntimeError("Failed to create video: No project ID returned")
|
raise RuntimeError("Failed to create video: No project ID returned")
|
||||||
|
|
||||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
|
|
||||||
|
|
||||||
class AIScreenshotToVideoAdBlock(Block):
|
class AIScreenshotToVideoAdBlock(Block):
|
||||||
@@ -626,7 +660,10 @@ class AIScreenshotToVideoAdBlock(Block):
|
|||||||
"script": "Amazing numbers!",
|
"script": "Amazing numbers!",
|
||||||
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
||||||
},
|
},
|
||||||
test_output=("video_url", "https://example.com/screenshot.mp4"),
|
test_output=(
|
||||||
|
"video_url",
|
||||||
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
|
),
|
||||||
test_mock={
|
test_mock={
|
||||||
"create_webhook": lambda *args, **kwargs: (
|
"create_webhook": lambda *args, **kwargs: (
|
||||||
"test_uuid",
|
"test_uuid",
|
||||||
@@ -635,14 +672,21 @@ class AIScreenshotToVideoAdBlock(Block):
|
|||||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||||
"check_video_status": lambda *args, **kwargs: {
|
"check_video_status": lambda *args, **kwargs: {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"videoUrl": "https://example.com/screenshot.mp4",
|
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
|
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
webhook_token, webhook_url = await self.create_webhook()
|
webhook_token, webhook_url = await self.create_webhook()
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -710,4 +754,10 @@ class AIScreenshotToVideoAdBlock(Block):
|
|||||||
raise RuntimeError("Failed to create video: No project ID returned")
|
raise RuntimeError("Failed to create video: No project ID returned")
|
||||||
|
|
||||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
@@ -17,6 +18,8 @@ from backend.sdk import (
|
|||||||
Requests,
|
Requests,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
|
from backend.util.file import store_media_file
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
from ._config import bannerbear
|
from ._config import bannerbear
|
||||||
|
|
||||||
@@ -135,15 +138,17 @@ class BannerbearTextOverlayBlock(Block):
|
|||||||
},
|
},
|
||||||
test_output=[
|
test_output=[
|
||||||
("success", True),
|
("success", True),
|
||||||
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
("uid", "test-uid-123"),
|
("uid", "test-uid-123"),
|
||||||
("status", "completed"),
|
("status", "completed"),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
|
# Use data URI to avoid HTTP requests during tests
|
||||||
"_make_api_request": lambda *args, **kwargs: {
|
"_make_api_request": lambda *args, **kwargs: {
|
||||||
"uid": "test-uid-123",
|
"uid": "test-uid-123",
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
|
"image_url": "",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -177,7 +182,12 @@ class BannerbearTextOverlayBlock(Block):
|
|||||||
raise Exception(error_msg)
|
raise Exception(error_msg)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Build the modifications array
|
# Build the modifications array
|
||||||
modifications = []
|
modifications = []
|
||||||
@@ -234,6 +244,18 @@ class BannerbearTextOverlayBlock(Block):
|
|||||||
|
|
||||||
# Synchronous request - image should be ready
|
# Synchronous request - image should be ready
|
||||||
yield "success", True
|
yield "success", True
|
||||||
yield "image_url", data.get("image_url", "")
|
|
||||||
|
# Store the generated image to workspace for persistence
|
||||||
|
image_url = data.get("image_url", "")
|
||||||
|
if image_url:
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(image_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", stored_url
|
||||||
|
else:
|
||||||
|
yield "image_url", ""
|
||||||
|
|
||||||
yield "uid", data.get("uid", "")
|
yield "uid", data.get("uid", "")
|
||||||
yield "status", data.get("status", "completed")
|
yield "status", data.get("status", "completed")
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.type import MediaFileType, convert
|
from backend.util.type import MediaFileType, convert
|
||||||
@@ -17,10 +18,10 @@ from backend.util.type import MediaFileType, convert
|
|||||||
class FileStoreBlock(Block):
|
class FileStoreBlock(Block):
|
||||||
class Input(BlockSchemaInput):
|
class Input(BlockSchemaInput):
|
||||||
file_in: MediaFileType = SchemaField(
|
file_in: MediaFileType = SchemaField(
|
||||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
description="The file to download and store. Can be a URL (https://...), data URI, or local path."
|
||||||
)
|
)
|
||||||
base_64: bool = SchemaField(
|
base_64: bool = SchemaField(
|
||||||
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
|
description="Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks).",
|
||||||
default=False,
|
default=False,
|
||||||
advanced=True,
|
advanced=True,
|
||||||
title="Produce Base64 Output",
|
title="Produce Base64 Output",
|
||||||
@@ -28,13 +29,18 @@ class FileStoreBlock(Block):
|
|||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
file_out: MediaFileType = SchemaField(
|
file_out: MediaFileType = SchemaField(
|
||||||
description="The relative path to the stored file in the temporary directory."
|
description="Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks."
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
|
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
|
||||||
description="Stores the input file in the temporary directory.",
|
description=(
|
||||||
|
"Downloads and stores a file from a URL, data URI, or local path. "
|
||||||
|
"Use this to fetch images, documents, or other files for processing. "
|
||||||
|
"In CoPilot: saves to workspace (use list_workspace_files to see it). "
|
||||||
|
"In graphs: outputs a data URI to pass to other blocks."
|
||||||
|
),
|
||||||
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
|
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
|
||||||
input_schema=FileStoreBlock.Input,
|
input_schema=FileStoreBlock.Input,
|
||||||
output_schema=FileStoreBlock.Output,
|
output_schema=FileStoreBlock.Output,
|
||||||
@@ -45,15 +51,18 @@ class FileStoreBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
# Determine return format based on user preference
|
||||||
|
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
|
||||||
|
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
|
||||||
|
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
|
||||||
|
|
||||||
yield "file_out", await store_media_file(
|
yield "file_out", await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.file_in,
|
file=input_data.file_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=input_data.base_64,
|
return_format=return_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -116,6 +125,7 @@ class PrintToConsoleBlock(Block):
|
|||||||
input_schema=PrintToConsoleBlock.Input,
|
input_schema=PrintToConsoleBlock.Input,
|
||||||
output_schema=PrintToConsoleBlock.Output,
|
output_schema=PrintToConsoleBlock.Output,
|
||||||
test_input={"text": "Hello, World!"},
|
test_input={"text": "Hello, World!"},
|
||||||
|
is_sensitive_action=True,
|
||||||
test_output=[
|
test_output=[
|
||||||
("output", "Hello, World!"),
|
("output", "Hello, World!"),
|
||||||
("status", "printed"),
|
("status", "printed"),
|
||||||
|
|||||||
659
autogpt_platform/backend/backend/blocks/claude_code.py
Normal file
659
autogpt_platform/backend/backend/blocks/claude_code.py
Normal file
@@ -0,0 +1,659 @@
|
|||||||
|
import json
|
||||||
|
import shlex
|
||||||
|
import uuid
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from e2b import AsyncSandbox as BaseAsyncSandbox
|
||||||
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.model import (
|
||||||
|
APIKeyCredentials,
|
||||||
|
CredentialsField,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
SchemaField,
|
||||||
|
)
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeCodeExecutionError(Exception):
|
||||||
|
"""Exception raised when Claude Code execution fails.
|
||||||
|
|
||||||
|
Carries the sandbox_id so it can be returned to the user for cleanup
|
||||||
|
when dispose_sandbox=False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, sandbox_id: str = ""):
|
||||||
|
super().__init__(message)
|
||||||
|
self.sandbox_id = sandbox_id
|
||||||
|
|
||||||
|
|
||||||
|
# Test credentials for E2B
|
||||||
|
TEST_E2B_CREDENTIALS = APIKeyCredentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="e2b",
|
||||||
|
api_key=SecretStr("mock-e2b-api-key"),
|
||||||
|
title="Mock E2B API key",
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
TEST_E2B_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_E2B_CREDENTIALS.provider,
|
||||||
|
"id": TEST_E2B_CREDENTIALS.id,
|
||||||
|
"type": TEST_E2B_CREDENTIALS.type,
|
||||||
|
"title": TEST_E2B_CREDENTIALS.title,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test credentials for Anthropic
|
||||||
|
TEST_ANTHROPIC_CREDENTIALS = APIKeyCredentials(
|
||||||
|
id="2e568a2b-b2ea-475a-8564-9a676bf31c56",
|
||||||
|
provider="anthropic",
|
||||||
|
api_key=SecretStr("mock-anthropic-api-key"),
|
||||||
|
title="Mock Anthropic API key",
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
TEST_ANTHROPIC_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_ANTHROPIC_CREDENTIALS.provider,
|
||||||
|
"id": TEST_ANTHROPIC_CREDENTIALS.id,
|
||||||
|
"type": TEST_ANTHROPIC_CREDENTIALS.type,
|
||||||
|
"title": TEST_ANTHROPIC_CREDENTIALS.title,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeCodeBlock(Block):
|
||||||
|
"""
|
||||||
|
Execute tasks using Claude Code (Anthropic's AI coding assistant) in an E2B sandbox.
|
||||||
|
|
||||||
|
Claude Code can create files, install tools, run commands, and perform complex
|
||||||
|
coding tasks autonomously within a secure sandbox environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Use base template - we'll install Claude Code ourselves for latest version
|
||||||
|
DEFAULT_TEMPLATE = "base"
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
e2b_credentials: CredentialsMetaInput[
|
||||||
|
Literal[ProviderName.E2B], Literal["api_key"]
|
||||||
|
] = CredentialsField(
|
||||||
|
description=(
|
||||||
|
"API key for the E2B platform to create the sandbox. "
|
||||||
|
"Get one on the [e2b website](https://e2b.dev/docs)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
anthropic_credentials: CredentialsMetaInput[
|
||||||
|
Literal[ProviderName.ANTHROPIC], Literal["api_key"]
|
||||||
|
] = CredentialsField(
|
||||||
|
description=(
|
||||||
|
"API key for Anthropic to power Claude Code. "
|
||||||
|
"Get one at [Anthropic's website](https://console.anthropic.com)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt: str = SchemaField(
|
||||||
|
description=(
|
||||||
|
"The task or instruction for Claude Code to execute. "
|
||||||
|
"Claude Code can create files, install packages, run commands, "
|
||||||
|
"and perform complex coding tasks."
|
||||||
|
),
|
||||||
|
placeholder="Create a hello world index.html file",
|
||||||
|
default="",
|
||||||
|
advanced=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
timeout: int = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Sandbox timeout in seconds. Claude Code tasks can take "
|
||||||
|
"a while, so set this appropriately for your task complexity. "
|
||||||
|
"Note: This only applies when creating a new sandbox. "
|
||||||
|
"When reconnecting to an existing sandbox via sandbox_id, "
|
||||||
|
"the original timeout is retained."
|
||||||
|
),
|
||||||
|
default=300, # 5 minutes default
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
setup_commands: list[str] = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Optional shell commands to run before executing Claude Code. "
|
||||||
|
"Useful for installing dependencies or setting up the environment."
|
||||||
|
),
|
||||||
|
default_factory=list,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
working_directory: str = SchemaField(
|
||||||
|
description="Working directory for Claude Code to operate in.",
|
||||||
|
default="/home/user",
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Session/continuation support
|
||||||
|
session_id: str = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Session ID to resume a previous conversation. "
|
||||||
|
"Leave empty for a new conversation. "
|
||||||
|
"Use the session_id from a previous run to continue that conversation."
|
||||||
|
),
|
||||||
|
default="",
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
sandbox_id: str = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Sandbox ID to reconnect to an existing sandbox. "
|
||||||
|
"Required when resuming a session (along with session_id). "
|
||||||
|
"Use the sandbox_id from a previous run where dispose_sandbox was False."
|
||||||
|
),
|
||||||
|
default="",
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_history: str = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Previous conversation history to continue from. "
|
||||||
|
"Use this to restore context on a fresh sandbox if the previous one timed out. "
|
||||||
|
"Pass the conversation_history output from a previous run."
|
||||||
|
),
|
||||||
|
default="",
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
dispose_sandbox: bool = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Whether to dispose of the sandbox immediately after execution. "
|
||||||
|
"Set to False if you want to continue the conversation later "
|
||||||
|
"(you'll need both sandbox_id and session_id from the output)."
|
||||||
|
),
|
||||||
|
default=True,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class FileOutput(BaseModel):
|
||||||
|
"""A file extracted from the sandbox."""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
relative_path: str # Path relative to working directory (for GitHub, etc.)
|
||||||
|
name: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
response: str = SchemaField(
|
||||||
|
description="The output/response from Claude Code execution"
|
||||||
|
)
|
||||||
|
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
|
||||||
|
description=(
|
||||||
|
"List of text files created/modified by Claude Code during this execution. "
|
||||||
|
"Each file has 'path', 'relative_path', 'name', and 'content' fields."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conversation_history: str = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Full conversation history including this turn. "
|
||||||
|
"Pass this to conversation_history input to continue on a fresh sandbox "
|
||||||
|
"if the previous sandbox timed out."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session_id: str = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Session ID for this conversation. "
|
||||||
|
"Pass this back along with sandbox_id to continue the conversation."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sandbox_id: Optional[str] = SchemaField(
|
||||||
|
description=(
|
||||||
|
"ID of the sandbox instance. "
|
||||||
|
"Pass this back along with session_id to continue the conversation. "
|
||||||
|
"This is None if dispose_sandbox was True (sandbox was disposed)."
|
||||||
|
),
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
error: str = SchemaField(description="Error message if execution failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="4e34f4a5-9b89-4326-ba77-2dd6750b7194",
|
||||||
|
description=(
|
||||||
|
"Execute tasks using Claude Code in an E2B sandbox. "
|
||||||
|
"Claude Code can create files, install tools, run commands, "
|
||||||
|
"and perform complex coding tasks autonomously."
|
||||||
|
),
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.AI},
|
||||||
|
input_schema=ClaudeCodeBlock.Input,
|
||||||
|
output_schema=ClaudeCodeBlock.Output,
|
||||||
|
test_credentials={
|
||||||
|
"e2b_credentials": TEST_E2B_CREDENTIALS,
|
||||||
|
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS,
|
||||||
|
},
|
||||||
|
test_input={
|
||||||
|
"e2b_credentials": TEST_E2B_CREDENTIALS_INPUT,
|
||||||
|
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS_INPUT,
|
||||||
|
"prompt": "Create a hello world HTML file",
|
||||||
|
"timeout": 300,
|
||||||
|
"setup_commands": [],
|
||||||
|
"working_directory": "/home/user",
|
||||||
|
"session_id": "",
|
||||||
|
"sandbox_id": "",
|
||||||
|
"conversation_history": "",
|
||||||
|
"dispose_sandbox": True,
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
("response", "Created index.html with hello world content"),
|
||||||
|
(
|
||||||
|
"files",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"path": "/home/user/index.html",
|
||||||
|
"relative_path": "index.html",
|
||||||
|
"name": "index.html",
|
||||||
|
"content": "<html>Hello World</html>",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"conversation_history",
|
||||||
|
"User: Create a hello world HTML file\n"
|
||||||
|
"Claude: Created index.html with hello world content",
|
||||||
|
),
|
||||||
|
("session_id", str),
|
||||||
|
("sandbox_id", None), # None because dispose_sandbox=True in test_input
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"execute_claude_code": lambda *args, **kwargs: (
|
||||||
|
"Created index.html with hello world content", # response
|
||||||
|
[
|
||||||
|
ClaudeCodeBlock.FileOutput(
|
||||||
|
path="/home/user/index.html",
|
||||||
|
relative_path="index.html",
|
||||||
|
name="index.html",
|
||||||
|
content="<html>Hello World</html>",
|
||||||
|
)
|
||||||
|
], # files
|
||||||
|
"User: Create a hello world HTML file\n"
|
||||||
|
"Claude: Created index.html with hello world content", # conversation_history
|
||||||
|
"test-session-id", # session_id
|
||||||
|
"sandbox_id", # sandbox_id
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute_claude_code(
|
||||||
|
self,
|
||||||
|
e2b_api_key: str,
|
||||||
|
anthropic_api_key: str,
|
||||||
|
prompt: str,
|
||||||
|
timeout: int,
|
||||||
|
setup_commands: list[str],
|
||||||
|
working_directory: str,
|
||||||
|
session_id: str,
|
||||||
|
existing_sandbox_id: str,
|
||||||
|
conversation_history: str,
|
||||||
|
dispose_sandbox: bool,
|
||||||
|
) -> tuple[str, list["ClaudeCodeBlock.FileOutput"], str, str, str]:
|
||||||
|
"""
|
||||||
|
Execute Claude Code in an E2B sandbox.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (response, files, conversation_history, session_id, sandbox_id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Validate that sandbox_id is provided when resuming a session
|
||||||
|
if session_id and not existing_sandbox_id:
|
||||||
|
raise ValueError(
|
||||||
|
"sandbox_id is required when resuming a session with session_id. "
|
||||||
|
"The session state is stored in the original sandbox. "
|
||||||
|
"If the sandbox has timed out, use conversation_history instead "
|
||||||
|
"to restore context on a fresh sandbox."
|
||||||
|
)
|
||||||
|
|
||||||
|
sandbox = None
|
||||||
|
sandbox_id = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Either reconnect to existing sandbox or create a new one
|
||||||
|
if existing_sandbox_id:
|
||||||
|
# Reconnect to existing sandbox for conversation continuation
|
||||||
|
sandbox = await BaseAsyncSandbox.connect(
|
||||||
|
sandbox_id=existing_sandbox_id,
|
||||||
|
api_key=e2b_api_key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create new sandbox
|
||||||
|
sandbox = await BaseAsyncSandbox.create(
|
||||||
|
template=self.DEFAULT_TEMPLATE,
|
||||||
|
api_key=e2b_api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
envs={"ANTHROPIC_API_KEY": anthropic_api_key},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Install Claude Code from npm (ensures we get the latest version)
|
||||||
|
install_result = await sandbox.commands.run(
|
||||||
|
"npm install -g @anthropic-ai/claude-code@latest",
|
||||||
|
timeout=120, # 2 min timeout for install
|
||||||
|
)
|
||||||
|
if install_result.exit_code != 0:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to install Claude Code: {install_result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run any user-provided setup commands
|
||||||
|
for cmd in setup_commands:
|
||||||
|
setup_result = await sandbox.commands.run(cmd)
|
||||||
|
if setup_result.exit_code != 0:
|
||||||
|
raise Exception(
|
||||||
|
f"Setup command failed: {cmd}\n"
|
||||||
|
f"Exit code: {setup_result.exit_code}\n"
|
||||||
|
f"Stdout: {setup_result.stdout}\n"
|
||||||
|
f"Stderr: {setup_result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture sandbox_id immediately after creation/connection
|
||||||
|
# so it's available for error recovery if dispose_sandbox=False
|
||||||
|
sandbox_id = sandbox.sandbox_id
|
||||||
|
|
||||||
|
# Generate or use provided session ID
|
||||||
|
current_session_id = session_id if session_id else str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Build base Claude flags
|
||||||
|
base_flags = "-p --dangerously-skip-permissions --output-format json"
|
||||||
|
|
||||||
|
# Add conversation history context if provided (for fresh sandbox continuation)
|
||||||
|
history_flag = ""
|
||||||
|
if conversation_history and not session_id:
|
||||||
|
# Inject previous conversation as context via system prompt
|
||||||
|
# Use consistent escaping via _escape_prompt helper
|
||||||
|
escaped_history = self._escape_prompt(
|
||||||
|
f"Previous conversation context: {conversation_history}"
|
||||||
|
)
|
||||||
|
history_flag = f" --append-system-prompt {escaped_history}"
|
||||||
|
|
||||||
|
# Build Claude command based on whether we're resuming or starting new
|
||||||
|
# Use shlex.quote for working_directory and session IDs to prevent injection
|
||||||
|
safe_working_dir = shlex.quote(working_directory)
|
||||||
|
if session_id:
|
||||||
|
# Resuming existing session (sandbox still alive)
|
||||||
|
safe_session_id = shlex.quote(session_id)
|
||||||
|
claude_command = (
|
||||||
|
f"cd {safe_working_dir} && "
|
||||||
|
f"echo {self._escape_prompt(prompt)} | "
|
||||||
|
f"claude --resume {safe_session_id} {base_flags}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# New session with specific ID
|
||||||
|
safe_current_session_id = shlex.quote(current_session_id)
|
||||||
|
claude_command = (
|
||||||
|
f"cd {safe_working_dir} && "
|
||||||
|
f"echo {self._escape_prompt(prompt)} | "
|
||||||
|
f"claude --session-id {safe_current_session_id} {base_flags}{history_flag}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture timestamp before running Claude Code to filter files later
|
||||||
|
# Capture timestamp 1 second in the past to avoid race condition with file creation
|
||||||
|
timestamp_result = await sandbox.commands.run(
|
||||||
|
"date -u -d '1 second ago' +%Y-%m-%dT%H:%M:%S"
|
||||||
|
)
|
||||||
|
if timestamp_result.exit_code != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to capture timestamp: {timestamp_result.stderr}"
|
||||||
|
)
|
||||||
|
start_timestamp = (
|
||||||
|
timestamp_result.stdout.strip() if timestamp_result.stdout else None
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await sandbox.commands.run(
|
||||||
|
claude_command,
|
||||||
|
timeout=0, # No command timeout - let sandbox timeout handle it
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for command failure
|
||||||
|
if result.exit_code != 0:
|
||||||
|
error_msg = result.stderr or result.stdout or "Unknown error"
|
||||||
|
raise Exception(
|
||||||
|
f"Claude Code command failed with exit code {result.exit_code}:\n"
|
||||||
|
f"{error_msg}"
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_output = result.stdout or ""
|
||||||
|
|
||||||
|
# Parse JSON output to extract response and build conversation history
|
||||||
|
response = ""
|
||||||
|
new_conversation_history = conversation_history or ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# The JSON output contains the result
|
||||||
|
output_data = json.loads(raw_output)
|
||||||
|
response = output_data.get("result", raw_output)
|
||||||
|
|
||||||
|
# Build conversation history entry
|
||||||
|
turn_entry = f"User: {prompt}\nClaude: {response}"
|
||||||
|
if new_conversation_history:
|
||||||
|
new_conversation_history = (
|
||||||
|
f"{new_conversation_history}\n\n{turn_entry}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_conversation_history = turn_entry
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If not valid JSON, use raw output
|
||||||
|
response = raw_output
|
||||||
|
turn_entry = f"User: {prompt}\nClaude: {response}"
|
||||||
|
if new_conversation_history:
|
||||||
|
new_conversation_history = (
|
||||||
|
f"{new_conversation_history}\n\n{turn_entry}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_conversation_history = turn_entry
|
||||||
|
|
||||||
|
# Extract files created/modified during this run
|
||||||
|
files = await self._extract_files(
|
||||||
|
sandbox, working_directory, start_timestamp
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
response,
|
||||||
|
files,
|
||||||
|
new_conversation_history,
|
||||||
|
current_session_id,
|
||||||
|
sandbox_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Wrap exception with sandbox_id so caller can access/cleanup
|
||||||
|
# the preserved sandbox when dispose_sandbox=False
|
||||||
|
raise ClaudeCodeExecutionError(str(e), sandbox_id) from e
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if dispose_sandbox and sandbox:
|
||||||
|
await sandbox.kill()
|
||||||
|
|
||||||
|
async def _extract_files(
|
||||||
|
self,
|
||||||
|
sandbox: BaseAsyncSandbox,
|
||||||
|
working_directory: str,
|
||||||
|
since_timestamp: str | None = None,
|
||||||
|
) -> list["ClaudeCodeBlock.FileOutput"]:
|
||||||
|
"""
|
||||||
|
Extract text files created/modified during this Claude Code execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox: The E2B sandbox instance
|
||||||
|
working_directory: Directory to search for files
|
||||||
|
since_timestamp: ISO timestamp - only return files modified after this time
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of FileOutput objects with path, relative_path, name, and content
|
||||||
|
"""
|
||||||
|
files: list[ClaudeCodeBlock.FileOutput] = []
|
||||||
|
|
||||||
|
# Text file extensions we can safely read as text
|
||||||
|
text_extensions = {
|
||||||
|
".txt",
|
||||||
|
".md",
|
||||||
|
".html",
|
||||||
|
".htm",
|
||||||
|
".css",
|
||||||
|
".js",
|
||||||
|
".ts",
|
||||||
|
".jsx",
|
||||||
|
".tsx",
|
||||||
|
".json",
|
||||||
|
".xml",
|
||||||
|
".yaml",
|
||||||
|
".yml",
|
||||||
|
".toml",
|
||||||
|
".ini",
|
||||||
|
".cfg",
|
||||||
|
".conf",
|
||||||
|
".py",
|
||||||
|
".rb",
|
||||||
|
".php",
|
||||||
|
".java",
|
||||||
|
".c",
|
||||||
|
".cpp",
|
||||||
|
".h",
|
||||||
|
".hpp",
|
||||||
|
".cs",
|
||||||
|
".go",
|
||||||
|
".rs",
|
||||||
|
".swift",
|
||||||
|
".kt",
|
||||||
|
".scala",
|
||||||
|
".sh",
|
||||||
|
".bash",
|
||||||
|
".zsh",
|
||||||
|
".sql",
|
||||||
|
".graphql",
|
||||||
|
".env",
|
||||||
|
".gitignore",
|
||||||
|
".dockerfile",
|
||||||
|
"Dockerfile",
|
||||||
|
".vue",
|
||||||
|
".svelte",
|
||||||
|
".astro",
|
||||||
|
".mdx",
|
||||||
|
".rst",
|
||||||
|
".tex",
|
||||||
|
".csv",
|
||||||
|
".log",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# List files recursively using find command
|
||||||
|
# Exclude node_modules and .git directories, but allow hidden files
|
||||||
|
# like .env and .gitignore (they're filtered by text_extensions later)
|
||||||
|
# Filter by timestamp to only get files created/modified during this run
|
||||||
|
safe_working_dir = shlex.quote(working_directory)
|
||||||
|
timestamp_filter = ""
|
||||||
|
if since_timestamp:
|
||||||
|
timestamp_filter = f"-newermt {shlex.quote(since_timestamp)} "
|
||||||
|
find_result = await sandbox.commands.run(
|
||||||
|
f"find {safe_working_dir} -type f "
|
||||||
|
f"{timestamp_filter}"
|
||||||
|
f"-not -path '*/node_modules/*' "
|
||||||
|
f"-not -path '*/.git/*' "
|
||||||
|
f"2>/dev/null"
|
||||||
|
)
|
||||||
|
|
||||||
|
if find_result.stdout:
|
||||||
|
for file_path in find_result.stdout.strip().split("\n"):
|
||||||
|
if not file_path:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if it's a text file we can read
|
||||||
|
is_text = any(
|
||||||
|
file_path.endswith(ext) for ext in text_extensions
|
||||||
|
) or file_path.endswith("Dockerfile")
|
||||||
|
|
||||||
|
if is_text:
|
||||||
|
try:
|
||||||
|
content = await sandbox.files.read(file_path)
|
||||||
|
# Handle bytes or string
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
content = content.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
# Extract filename from path
|
||||||
|
file_name = file_path.split("/")[-1]
|
||||||
|
|
||||||
|
# Calculate relative path by stripping working directory
|
||||||
|
relative_path = file_path
|
||||||
|
if file_path.startswith(working_directory):
|
||||||
|
relative_path = file_path[len(working_directory) :]
|
||||||
|
# Remove leading slash if present
|
||||||
|
if relative_path.startswith("/"):
|
||||||
|
relative_path = relative_path[1:]
|
||||||
|
|
||||||
|
files.append(
|
||||||
|
ClaudeCodeBlock.FileOutput(
|
||||||
|
path=file_path,
|
||||||
|
relative_path=relative_path,
|
||||||
|
name=file_name,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Skip files that can't be read
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# If file extraction fails, return empty results
|
||||||
|
pass
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
|
def _escape_prompt(self, prompt: str) -> str:
|
||||||
|
"""Escape the prompt for safe shell execution."""
|
||||||
|
# Use single quotes and escape any single quotes in the prompt
|
||||||
|
escaped = prompt.replace("'", "'\"'\"'")
|
||||||
|
return f"'{escaped}'"
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
e2b_credentials: APIKeyCredentials,
|
||||||
|
anthropic_credentials: APIKeyCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
(
|
||||||
|
response,
|
||||||
|
files,
|
||||||
|
conversation_history,
|
||||||
|
session_id,
|
||||||
|
sandbox_id,
|
||||||
|
) = await self.execute_claude_code(
|
||||||
|
e2b_api_key=e2b_credentials.api_key.get_secret_value(),
|
||||||
|
anthropic_api_key=anthropic_credentials.api_key.get_secret_value(),
|
||||||
|
prompt=input_data.prompt,
|
||||||
|
timeout=input_data.timeout,
|
||||||
|
setup_commands=input_data.setup_commands,
|
||||||
|
working_directory=input_data.working_directory,
|
||||||
|
session_id=input_data.session_id,
|
||||||
|
existing_sandbox_id=input_data.sandbox_id,
|
||||||
|
conversation_history=input_data.conversation_history,
|
||||||
|
dispose_sandbox=input_data.dispose_sandbox,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "response", response
|
||||||
|
# Always yield files (empty list if none) to match Output schema
|
||||||
|
yield "files", [f.model_dump() for f in files]
|
||||||
|
# Always yield conversation_history so user can restore context on fresh sandbox
|
||||||
|
yield "conversation_history", conversation_history
|
||||||
|
# Always yield session_id so user can continue conversation
|
||||||
|
yield "session_id", session_id
|
||||||
|
# Always yield sandbox_id (None if disposed) to match Output schema
|
||||||
|
yield "sandbox_id", sandbox_id if not input_data.dispose_sandbox else None
|
||||||
|
|
||||||
|
except ClaudeCodeExecutionError as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
# If sandbox was preserved (dispose_sandbox=False), yield sandbox_id
|
||||||
|
# so user can reconnect to or clean up the orphaned sandbox
|
||||||
|
if not input_data.dispose_sandbox and e.sandbox_id:
|
||||||
|
yield "sandbox_id", e.sandbox_id
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
@@ -15,6 +15,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import APIKeyCredentials, SchemaField
|
from backend.data.model import APIKeyCredentials, SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
@@ -666,8 +667,7 @@ class SendDiscordFileBlock(Block):
|
|||||||
file: MediaFileType,
|
file: MediaFileType,
|
||||||
filename: str,
|
filename: str,
|
||||||
message_content: str,
|
message_content: str,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
intents.guilds = True
|
intents.guilds = True
|
||||||
@@ -731,10 +731,9 @@ class SendDiscordFileBlock(Block):
|
|||||||
# Local file path - read from stored media file
|
# Local file path - read from stored media file
|
||||||
# This would be a path from a previous block's output
|
# This would be a path from a previous block's output
|
||||||
stored_file = await store_media_file(
|
stored_file = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=file,
|
file=file,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True, # Get as data URI
|
return_format="for_external_api", # Get content to send to Discord
|
||||||
)
|
)
|
||||||
# Now process as data URI
|
# Now process as data URI
|
||||||
header, encoded = stored_file.split(",", 1)
|
header, encoded = stored_file.split(",", 1)
|
||||||
@@ -781,8 +780,7 @@ class SendDiscordFileBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
@@ -793,8 +791,7 @@ class SendDiscordFileBlock(Block):
|
|||||||
file=input_data.file,
|
file=input_data.file,
|
||||||
filename=input_data.filename,
|
filename=input_data.filename,
|
||||||
message_content=input_data.message_content,
|
message_content=input_data.message_content,
|
||||||
graph_exec_id=graph_exec_id,
|
execution_context=execution_context,
|
||||||
user_id=user_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "status", result.get("status", "Unknown error")
|
yield "status", result.get("status", "Unknown error")
|
||||||
|
|||||||
@@ -17,8 +17,11 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import ClientResponseError, Requests
|
from backend.util.request import ClientResponseError, Requests
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -64,9 +67,13 @@ class AIVideoGeneratorBlock(Block):
|
|||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
test_output=[("video_url", "https://fal.media/files/example/video.mp4")],
|
test_output=[
|
||||||
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("video_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
"generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4"
|
# Use data URI to avoid HTTP requests during tests
|
||||||
|
"generate_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -208,11 +215,22 @@ class AIVideoGeneratorBlock(Block):
|
|||||||
raise RuntimeError(f"API request failed: {str(e)}")
|
raise RuntimeError(f"API request failed: {str(e)}")
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: FalCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: FalCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
video_url = await self.generate_video(input_data, credentials)
|
video_url = await self.generate_video(input_data, credentials)
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
yield "error", error_message
|
yield "error", error_message
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -121,10 +122,12 @@ class AIImageEditorBlock(Block):
|
|||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_output=[
|
test_output=[
|
||||||
("output_image", "https://replicate.com/output/edited-image.png"),
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("output_image", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
"run_model": lambda *args, **kwargs: "https://replicate.com/output/edited-image.png",
|
# Use data URI to avoid HTTP requests during tests
|
||||||
|
"run_model": lambda *args, **kwargs: "",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
@@ -134,8 +137,7 @@ class AIImageEditorBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
result = await self.run_model(
|
result = await self.run_model(
|
||||||
@@ -144,20 +146,25 @@ class AIImageEditorBlock(Block):
|
|||||||
prompt=input_data.prompt,
|
prompt=input_data.prompt,
|
||||||
input_image_b64=(
|
input_image_b64=(
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.input_image,
|
file=input_data.input_image,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True,
|
return_format="for_external_api", # Get content for Replicate API
|
||||||
)
|
)
|
||||||
if input_data.input_image
|
if input_data.input_image
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
aspect_ratio=input_data.aspect_ratio.value,
|
aspect_ratio=input_data.aspect_ratio.value,
|
||||||
seed=input_data.seed,
|
seed=input_data.seed,
|
||||||
user_id=user_id,
|
user_id=execution_context.user_id or "",
|
||||||
graph_exec_id=graph_exec_id,
|
graph_exec_id=execution_context.graph_exec_id or "",
|
||||||
)
|
)
|
||||||
yield "output_image", result
|
# Store the generated image to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=result,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "output_image", stored_url
|
||||||
|
|
||||||
async def run_model(
|
async def run_model(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
@@ -95,8 +96,7 @@ def _make_mime_text(
|
|||||||
|
|
||||||
async def create_mime_message(
|
async def create_mime_message(
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
||||||
|
|
||||||
@@ -117,12 +117,12 @@ async def create_mime_message(
|
|||||||
if input_data.attachments:
|
if input_data.attachments:
|
||||||
for attach in input_data.attachments:
|
for attach in input_data.attachments:
|
||||||
local_path = await store_media_file(
|
local_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=attach,
|
file=attach,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
part = MIMEBase("application", "octet-stream")
|
part = MIMEBase("application", "octet-stream")
|
||||||
with open(abs_path, "rb") as f:
|
with open(abs_path, "rb") as f:
|
||||||
part.set_payload(f.read())
|
part.set_payload(f.read())
|
||||||
@@ -582,27 +582,25 @@ class GmailSendBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
result = await self._send_email(
|
result = await self._send_email(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "result", result
|
yield "result", result
|
||||||
|
|
||||||
async def _send_email(
|
async def _send_email(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not input_data.to or not input_data.subject or not input_data.body:
|
if not input_data.to or not input_data.subject or not input_data.body:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"At least one recipient, subject, and body are required for sending an email"
|
"At least one recipient, subject, and body are required for sending an email"
|
||||||
)
|
)
|
||||||
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
|
raw_message = await create_mime_message(input_data, execution_context)
|
||||||
sent_message = await asyncio.to_thread(
|
sent_message = await asyncio.to_thread(
|
||||||
lambda: service.users()
|
lambda: service.users()
|
||||||
.messages()
|
.messages()
|
||||||
@@ -692,30 +690,28 @@ class GmailCreateDraftBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
result = await self._create_draft(
|
result = await self._create_draft(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "result", GmailDraftResult(
|
yield "result", GmailDraftResult(
|
||||||
id=result["id"], message_id=result["message"]["id"], status="draft_created"
|
id=result["id"], message_id=result["message"]["id"], status="draft_created"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _create_draft(
|
async def _create_draft(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not input_data.to or not input_data.subject:
|
if not input_data.to or not input_data.subject:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"At least one recipient and subject are required for creating a draft"
|
"At least one recipient and subject are required for creating a draft"
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
|
raw_message = await create_mime_message(input_data, execution_context)
|
||||||
draft = await asyncio.to_thread(
|
draft = await asyncio.to_thread(
|
||||||
lambda: service.users()
|
lambda: service.users()
|
||||||
.drafts()
|
.drafts()
|
||||||
@@ -1100,7 +1096,7 @@ class GmailGetThreadBlock(GmailBase):
|
|||||||
|
|
||||||
|
|
||||||
async def _build_reply_message(
|
async def _build_reply_message(
|
||||||
service, input_data, graph_exec_id: str, user_id: str
|
service, input_data, execution_context: ExecutionContext
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Builds a reply MIME message for Gmail threads.
|
Builds a reply MIME message for Gmail threads.
|
||||||
@@ -1190,12 +1186,12 @@ async def _build_reply_message(
|
|||||||
# Handle attachments
|
# Handle attachments
|
||||||
for attach in input_data.attachments:
|
for attach in input_data.attachments:
|
||||||
local_path = await store_media_file(
|
local_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=attach,
|
file=attach,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
part = MIMEBase("application", "octet-stream")
|
part = MIMEBase("application", "octet-stream")
|
||||||
with open(abs_path, "rb") as f:
|
with open(abs_path, "rb") as f:
|
||||||
part.set_payload(f.read())
|
part.set_payload(f.read())
|
||||||
@@ -1311,16 +1307,14 @@ class GmailReplyBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
message = await self._reply(
|
message = await self._reply(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "messageId", message["id"]
|
yield "messageId", message["id"]
|
||||||
yield "threadId", message.get("threadId", input_data.threadId)
|
yield "threadId", message.get("threadId", input_data.threadId)
|
||||||
@@ -1343,11 +1337,11 @@ class GmailReplyBlock(GmailBase):
|
|||||||
yield "email", email
|
yield "email", email
|
||||||
|
|
||||||
async def _reply(
|
async def _reply(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# Build the reply message using the shared helper
|
# Build the reply message using the shared helper
|
||||||
raw, thread_id = await _build_reply_message(
|
raw, thread_id = await _build_reply_message(
|
||||||
service, input_data, graph_exec_id, user_id
|
service, input_data, execution_context
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send the message
|
# Send the message
|
||||||
@@ -1441,16 +1435,14 @@ class GmailDraftReplyBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
draft = await self._create_draft_reply(
|
draft = await self._create_draft_reply(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "draftId", draft["id"]
|
yield "draftId", draft["id"]
|
||||||
yield "messageId", draft["message"]["id"]
|
yield "messageId", draft["message"]["id"]
|
||||||
@@ -1458,11 +1450,11 @@ class GmailDraftReplyBlock(GmailBase):
|
|||||||
yield "status", "draft_created"
|
yield "status", "draft_created"
|
||||||
|
|
||||||
async def _create_draft_reply(
|
async def _create_draft_reply(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# Build the reply message using the shared helper
|
# Build the reply message using the shared helper
|
||||||
raw, thread_id = await _build_reply_message(
|
raw, thread_id = await _build_reply_message(
|
||||||
service, input_data, graph_exec_id, user_id
|
service, input_data, execution_context
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create draft with proper thread association
|
# Create draft with proper thread association
|
||||||
@@ -1629,23 +1621,21 @@ class GmailForwardBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
result = await self._forward_message(
|
result = await self._forward_message(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "messageId", result["id"]
|
yield "messageId", result["id"]
|
||||||
yield "threadId", result.get("threadId", "")
|
yield "threadId", result.get("threadId", "")
|
||||||
yield "status", "forwarded"
|
yield "status", "forwarded"
|
||||||
|
|
||||||
async def _forward_message(
|
async def _forward_message(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not input_data.to:
|
if not input_data.to:
|
||||||
raise ValueError("At least one recipient is required for forwarding")
|
raise ValueError("At least one recipient is required for forwarding")
|
||||||
@@ -1727,12 +1717,12 @@ To: {original_to}
|
|||||||
# Add any additional attachments
|
# Add any additional attachments
|
||||||
for attach in input_data.additionalAttachments:
|
for attach in input_data.additionalAttachments:
|
||||||
local_path = await store_media_file(
|
local_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=attach,
|
file=attach,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
part = MIMEBase("application", "octet-stream")
|
part = MIMEBase("application", "octet-stream")
|
||||||
with open(abs_path, "rb") as f:
|
with open(abs_path, "rb") as f:
|
||||||
part.set_payload(f.read())
|
part.set_payload(f.read())
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from typing import Any, Optional
|
|||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.data.execution import ExecutionContext, ExecutionStatus
|
from backend.data.execution import ExecutionStatus
|
||||||
from backend.data.human_review import ReviewResult
|
from backend.data.human_review import ReviewResult
|
||||||
from backend.executor.manager import async_update_node_execution_status
|
from backend.executor.manager import async_update_node_execution_status
|
||||||
from backend.util.clients import get_database_manager_async_client
|
from backend.util.clients import get_database_manager_async_client
|
||||||
@@ -28,6 +28,11 @@ class ReviewDecision(BaseModel):
|
|||||||
class HITLReviewHelper:
|
class HITLReviewHelper:
|
||||||
"""Helper class for Human-In-The-Loop review operations."""
|
"""Helper class for Human-In-The-Loop review operations."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def check_approval(**kwargs) -> Optional[ReviewResult]:
|
||||||
|
"""Check if there's an existing approval for this node execution."""
|
||||||
|
return await get_database_manager_async_client().check_approval(**kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
|
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
|
||||||
"""Create or retrieve a human review from the database."""
|
"""Create or retrieve a human review from the database."""
|
||||||
@@ -55,11 +60,11 @@ class HITLReviewHelper:
|
|||||||
async def _handle_review_request(
|
async def _handle_review_request(
|
||||||
input_data: Any,
|
input_data: Any,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
node_id: str,
|
||||||
node_exec_id: str,
|
node_exec_id: str,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
graph_version: int,
|
graph_version: int,
|
||||||
execution_context: ExecutionContext,
|
|
||||||
block_name: str = "Block",
|
block_name: str = "Block",
|
||||||
editable: bool = False,
|
editable: bool = False,
|
||||||
) -> Optional[ReviewResult]:
|
) -> Optional[ReviewResult]:
|
||||||
@@ -69,11 +74,11 @@ class HITLReviewHelper:
|
|||||||
Args:
|
Args:
|
||||||
input_data: The input data to be reviewed
|
input_data: The input data to be reviewed
|
||||||
user_id: ID of the user requesting the review
|
user_id: ID of the user requesting the review
|
||||||
|
node_id: ID of the node in the graph definition
|
||||||
node_exec_id: ID of the node execution
|
node_exec_id: ID of the node execution
|
||||||
graph_exec_id: ID of the graph execution
|
graph_exec_id: ID of the graph execution
|
||||||
graph_id: ID of the graph
|
graph_id: ID of the graph
|
||||||
graph_version: Version of the graph
|
graph_version: Version of the graph
|
||||||
execution_context: Current execution context
|
|
||||||
block_name: Name of the block requesting review
|
block_name: Name of the block requesting review
|
||||||
editable: Whether the reviewer can edit the data
|
editable: Whether the reviewer can edit the data
|
||||||
|
|
||||||
@@ -83,15 +88,41 @@ class HITLReviewHelper:
|
|||||||
Raises:
|
Raises:
|
||||||
Exception: If review creation or status update fails
|
Exception: If review creation or status update fails
|
||||||
"""
|
"""
|
||||||
# Skip review if safe mode is disabled - return auto-approved result
|
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
|
||||||
if not execution_context.human_in_the_loop_safe_mode:
|
# are handled by the caller:
|
||||||
|
# - HITL blocks check human_in_the_loop_safe_mode in their run() method
|
||||||
|
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
|
||||||
|
# This function only handles checking for existing approvals.
|
||||||
|
|
||||||
|
# Check if this node has already been approved (normal or auto-approval)
|
||||||
|
if approval_result := await HITLReviewHelper.check_approval(
|
||||||
|
node_exec_id=node_exec_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
node_id=node_id,
|
||||||
|
user_id=user_id,
|
||||||
|
input_data=input_data,
|
||||||
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
|
f"Block {block_name} skipping review for node {node_exec_id} - "
|
||||||
|
f"found existing approval"
|
||||||
|
)
|
||||||
|
# Return a new ReviewResult with the current node_exec_id but approved status
|
||||||
|
# For auto-approvals, always use current input_data
|
||||||
|
# For normal approvals, use approval_result.data unless it's None
|
||||||
|
is_auto_approval = approval_result.node_exec_id != node_exec_id
|
||||||
|
approved_data = (
|
||||||
|
input_data
|
||||||
|
if is_auto_approval
|
||||||
|
else (
|
||||||
|
approval_result.data
|
||||||
|
if approval_result.data is not None
|
||||||
|
else input_data
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return ReviewResult(
|
return ReviewResult(
|
||||||
data=input_data,
|
data=approved_data,
|
||||||
status=ReviewStatus.APPROVED,
|
status=ReviewStatus.APPROVED,
|
||||||
message="Auto-approved (safe mode disabled)",
|
message=approval_result.message,
|
||||||
processed=True,
|
processed=True,
|
||||||
node_exec_id=node_exec_id,
|
node_exec_id=node_exec_id,
|
||||||
)
|
)
|
||||||
@@ -103,7 +134,7 @@ class HITLReviewHelper:
|
|||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
graph_version=graph_version,
|
graph_version=graph_version,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
message=f"Review required for {block_name} execution",
|
message=block_name, # Use block_name directly as the message
|
||||||
editable=editable,
|
editable=editable,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -129,11 +160,11 @@ class HITLReviewHelper:
|
|||||||
async def handle_review_decision(
|
async def handle_review_decision(
|
||||||
input_data: Any,
|
input_data: Any,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
node_id: str,
|
||||||
node_exec_id: str,
|
node_exec_id: str,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
graph_version: int,
|
graph_version: int,
|
||||||
execution_context: ExecutionContext,
|
|
||||||
block_name: str = "Block",
|
block_name: str = "Block",
|
||||||
editable: bool = False,
|
editable: bool = False,
|
||||||
) -> Optional[ReviewDecision]:
|
) -> Optional[ReviewDecision]:
|
||||||
@@ -143,11 +174,11 @@ class HITLReviewHelper:
|
|||||||
Args:
|
Args:
|
||||||
input_data: The input data to be reviewed
|
input_data: The input data to be reviewed
|
||||||
user_id: ID of the user requesting the review
|
user_id: ID of the user requesting the review
|
||||||
|
node_id: ID of the node in the graph definition
|
||||||
node_exec_id: ID of the node execution
|
node_exec_id: ID of the node execution
|
||||||
graph_exec_id: ID of the graph execution
|
graph_exec_id: ID of the graph execution
|
||||||
graph_id: ID of the graph
|
graph_id: ID of the graph
|
||||||
graph_version: Version of the graph
|
graph_version: Version of the graph
|
||||||
execution_context: Current execution context
|
|
||||||
block_name: Name of the block requesting review
|
block_name: Name of the block requesting review
|
||||||
editable: Whether the reviewer can edit the data
|
editable: Whether the reviewer can edit the data
|
||||||
|
|
||||||
@@ -158,11 +189,11 @@ class HITLReviewHelper:
|
|||||||
review_result = await HITLReviewHelper._handle_review_request(
|
review_result = await HITLReviewHelper._handle_review_request(
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
node_id=node_id,
|
||||||
node_exec_id=node_exec_id,
|
node_exec_id=node_exec_id,
|
||||||
graph_exec_id=graph_exec_id,
|
graph_exec_id=graph_exec_id,
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
graph_version=graph_version,
|
graph_version=graph_version,
|
||||||
execution_context=execution_context,
|
|
||||||
block_name=block_name,
|
block_name=block_name,
|
||||||
editable=editable,
|
editable=editable,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
@@ -116,10 +117,9 @@ class SendWebRequestBlock(Block):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _prepare_files(
|
async def _prepare_files(
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
files_name: str,
|
files_name: str,
|
||||||
files: list[MediaFileType],
|
files: list[MediaFileType],
|
||||||
user_id: str,
|
|
||||||
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
|
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
|
||||||
"""
|
"""
|
||||||
Prepare files for the request by storing them and reading their content.
|
Prepare files for the request by storing them and reading their content.
|
||||||
@@ -127,11 +127,16 @@ class SendWebRequestBlock(Block):
|
|||||||
(files_name, (filename, BytesIO, mime_type))
|
(files_name, (filename, BytesIO, mime_type))
|
||||||
"""
|
"""
|
||||||
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
if graph_exec_id is None:
|
||||||
|
raise ValueError("graph_exec_id is required for file operations")
|
||||||
|
|
||||||
for media in files:
|
for media in files:
|
||||||
# Normalise to a list so we can repeat the same key
|
# Normalise to a list so we can repeat the same key
|
||||||
rel_path = await store_media_file(
|
rel_path = await store_media_file(
|
||||||
graph_exec_id, media, user_id, return_content=False
|
file=media,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, rel_path)
|
abs_path = get_exec_file_path(graph_exec_id, rel_path)
|
||||||
async with aiofiles.open(abs_path, "rb") as f:
|
async with aiofiles.open(abs_path, "rb") as f:
|
||||||
@@ -143,7 +148,7 @@ class SendWebRequestBlock(Block):
|
|||||||
return files_payload
|
return files_payload
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs
|
self, input_data: Input, *, execution_context: ExecutionContext, **kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# ─── Parse/normalise body ────────────────────────────────────
|
# ─── Parse/normalise body ────────────────────────────────────
|
||||||
body = input_data.body
|
body = input_data.body
|
||||||
@@ -174,7 +179,7 @@ class SendWebRequestBlock(Block):
|
|||||||
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
||||||
if use_files:
|
if use_files:
|
||||||
files_payload = await self._prepare_files(
|
files_payload = await self._prepare_files(
|
||||||
graph_exec_id, input_data.files_name, input_data.files, user_id
|
execution_context, input_data.files_name, input_data.files
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enforce body format rules
|
# Enforce body format rules
|
||||||
@@ -238,9 +243,8 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
credentials: HostScopedCredentials,
|
credentials: HostScopedCredentials,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
||||||
@@ -271,6 +275,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
|||||||
|
|
||||||
# Use parent class run method
|
# Use parent class run method
|
||||||
async for output_name, output_data in super().run(
|
async for output_name, output_data in super().run(
|
||||||
base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs
|
base_input, execution_context=execution_context, **kwargs
|
||||||
):
|
):
|
||||||
yield output_name, output_data
|
yield output_name, output_data
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ class HumanInTheLoopBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
node_id: str,
|
||||||
node_exec_id: str,
|
node_exec_id: str,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
@@ -115,12 +116,12 @@ class HumanInTheLoopBlock(Block):
|
|||||||
decision = await self.handle_review_decision(
|
decision = await self.handle_review_decision(
|
||||||
input_data=input_data.data,
|
input_data=input_data.data,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
node_id=node_id,
|
||||||
node_exec_id=node_exec_id,
|
node_exec_id=node_exec_id,
|
||||||
graph_exec_id=graph_exec_id,
|
graph_exec_id=graph_exec_id,
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
graph_version=graph_version,
|
graph_version=graph_version,
|
||||||
execution_context=execution_context,
|
block_name=input_data.name, # Use user-provided name instead of block type
|
||||||
block_name=self.name,
|
|
||||||
editable=input_data.editable,
|
editable=input_data.editable,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.mock import MockObject
|
from backend.util.mock import MockObject
|
||||||
@@ -462,18 +463,23 @@ class AgentFileInputBlock(AgentInputBlock):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
if not input_data.value:
|
if not input_data.value:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Determine return format based on user preference
|
||||||
|
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
|
||||||
|
# for_local_processing: returns local file path
|
||||||
|
return_format = (
|
||||||
|
"for_external_api" if input_data.base_64 else "for_local_processing"
|
||||||
|
)
|
||||||
|
|
||||||
yield "result", await store_media_file(
|
yield "result", await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.value,
|
file=input_data.value,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=input_data.base_64,
|
return_format=return_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Literal, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
from moviepy.video.fx.Loop import Loop
|
from moviepy.video.fx.Loop import Loop
|
||||||
@@ -13,6 +13,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
@@ -46,18 +47,19 @@ class MediaDurationBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# 1) Store the input media locally
|
# 1) Store the input media locally
|
||||||
local_media_path = await store_media_file(
|
local_media_path = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.media_in,
|
file=input_data.media_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=False,
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
media_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_media_path
|
||||||
)
|
)
|
||||||
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
|
|
||||||
|
|
||||||
# 2) Load the clip
|
# 2) Load the clip
|
||||||
if input_data.is_video:
|
if input_data.is_video:
|
||||||
@@ -88,10 +90,6 @@ class LoopVideoBlock(Block):
|
|||||||
default=None,
|
default=None,
|
||||||
ge=1,
|
ge=1,
|
||||||
)
|
)
|
||||||
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
|
||||||
description="How to return the output video. Either a relative path or base64 data URI.",
|
|
||||||
default="file_path",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
video_out: str = SchemaField(
|
video_out: str = SchemaField(
|
||||||
@@ -111,17 +109,19 @@ class LoopVideoBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
node_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
graph_exec_id: str,
|
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
# 1) Store the input video locally
|
# 1) Store the input video locally
|
||||||
local_video_path = await store_media_file(
|
local_video_path = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.video_in,
|
file=input_data.video_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=False,
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||||
|
|
||||||
@@ -149,12 +149,11 @@ class LoopVideoBlock(Block):
|
|||||||
looped_clip = looped_clip.with_audio(clip.audio)
|
looped_clip = looped_clip.with_audio(clip.audio)
|
||||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||||
|
|
||||||
# Return as data URI
|
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
video_out = await store_media_file(
|
video_out = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=output_filename,
|
file=output_filename,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=input_data.output_return_type == "data_uri",
|
return_format="for_block_output",
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "video_out", video_out
|
yield "video_out", video_out
|
||||||
@@ -177,10 +176,6 @@ class AddAudioToVideoBlock(Block):
|
|||||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||||
default=1.0,
|
default=1.0,
|
||||||
)
|
)
|
||||||
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
|
||||||
description="Return the final output as a relative path or base64 data URI.",
|
|
||||||
default="file_path",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
video_out: MediaFileType = SchemaField(
|
video_out: MediaFileType = SchemaField(
|
||||||
@@ -200,23 +195,24 @@ class AddAudioToVideoBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
node_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
graph_exec_id: str,
|
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
# 1) Store the inputs locally
|
# 1) Store the inputs locally
|
||||||
local_video_path = await store_media_file(
|
local_video_path = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.video_in,
|
file=input_data.video_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=False,
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
local_audio_path = await store_media_file(
|
local_audio_path = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.audio_in,
|
file=input_data.audio_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=False,
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
||||||
@@ -240,12 +236,11 @@ class AddAudioToVideoBlock(Block):
|
|||||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||||
|
|
||||||
# 5) Return either path or data URI
|
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
video_out = await store_media_file(
|
video_out = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=output_filename,
|
file=output_filename,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=input_data.output_return_type == "data_uri",
|
return_format="for_block_output",
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "video_out", video_out
|
yield "video_out", video_out
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -112,8 +113,7 @@ class ScreenshotWebPageBlock(Block):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def take_screenshot(
|
async def take_screenshot(
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
url: str,
|
url: str,
|
||||||
viewport_width: int,
|
viewport_width: int,
|
||||||
viewport_height: int,
|
viewport_height: int,
|
||||||
@@ -155,12 +155,11 @@ class ScreenshotWebPageBlock(Block):
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"image": await store_media_file(
|
"image": await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=MediaFileType(
|
file=MediaFileType(
|
||||||
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
|
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
|
||||||
),
|
),
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True,
|
return_format="for_block_output",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,15 +168,13 @@ class ScreenshotWebPageBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
screenshot_data = await self.take_screenshot(
|
screenshot_data = await self.take_screenshot(
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
graph_exec_id=graph_exec_id,
|
execution_context=execution_context,
|
||||||
user_id=user_id,
|
|
||||||
url=input_data.url,
|
url=input_data.url,
|
||||||
viewport_width=input_data.viewport_width,
|
viewport_width=input_data.viewport_width,
|
||||||
viewport_height=input_data.viewport_height,
|
viewport_height=input_data.viewport_height,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import ContributorDetails, SchemaField
|
from backend.data.model import ContributorDetails, SchemaField
|
||||||
from backend.util.file import get_exec_file_path, store_media_file
|
from backend.util.file import get_exec_file_path, store_media_file
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
@@ -98,7 +99,7 @@ class ReadSpreadsheetBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
import csv
|
import csv
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
@@ -106,14 +107,16 @@ class ReadSpreadsheetBlock(Block):
|
|||||||
# Determine data source - prefer file_input if provided, otherwise use contents
|
# Determine data source - prefer file_input if provided, otherwise use contents
|
||||||
if input_data.file_input:
|
if input_data.file_input:
|
||||||
stored_file_path = await store_media_file(
|
stored_file_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.file_input,
|
file=input_data.file_input,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get full file path
|
# Get full file path
|
||||||
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
file_path = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, stored_file_path
|
||||||
|
)
|
||||||
if not Path(file_path).exists():
|
if not Path(file_path).exists():
|
||||||
raise ValueError(f"File does not exist: {file_path}")
|
raise ValueError(f"File does not exist: {file_path}")
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -17,7 +18,9 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
@@ -102,7 +105,7 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
test_output=[
|
test_output=[
|
||||||
(
|
(
|
||||||
"video_url",
|
"video_url",
|
||||||
"https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
@@ -110,9 +113,10 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
|
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
|
||||||
"status": "created",
|
"status": "created",
|
||||||
},
|
},
|
||||||
|
# Use data URI to avoid HTTP requests during tests
|
||||||
"get_clip_status": lambda *args, **kwargs: {
|
"get_clip_status": lambda *args, **kwargs: {
|
||||||
"status": "done",
|
"status": "done",
|
||||||
"result_url": "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
|
"result_url": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -138,7 +142,12 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Create the clip
|
# Create the clip
|
||||||
payload = {
|
payload = {
|
||||||
@@ -165,7 +174,14 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
for _ in range(input_data.max_polling_attempts):
|
for _ in range(input_data.max_polling_attempts):
|
||||||
status_response = await self.get_clip_status(credentials.api_key, clip_id)
|
status_response = await self.get_clip_status(credentials.api_key, clip_id)
|
||||||
if status_response["status"] == "done":
|
if status_response["status"] == "done":
|
||||||
yield "video_url", status_response["result_url"]
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
video_url = status_response["result_url"]
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
return
|
return
|
||||||
elif status_response["status"] == "error":
|
elif status_response["status"] == "error":
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from backend.blocks.iteration import StepThroughItemsBlock
|
|||||||
from backend.blocks.llm import AITextSummarizerBlock
|
from backend.blocks.llm import AITextSummarizerBlock
|
||||||
from backend.blocks.text import ExtractTextInformationBlock
|
from backend.blocks.text import ExtractTextInformationBlock
|
||||||
from backend.blocks.xml_parser import XMLParserBlock
|
from backend.blocks.xml_parser import XMLParserBlock
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
@@ -233,9 +234,12 @@ class TestStoreMediaFileSecurity:
|
|||||||
|
|
||||||
with pytest.raises(ValueError, match="File too large"):
|
with pytest.raises(ValueError, match="File too large"):
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id="test",
|
|
||||||
file=MediaFileType(large_data_uri),
|
file=MediaFileType(large_data_uri),
|
||||||
|
execution_context=ExecutionContext(
|
||||||
user_id="test_user",
|
user_id="test_user",
|
||||||
|
graph_exec_id="test",
|
||||||
|
),
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("backend.util.file.Path")
|
@patch("backend.util.file.Path")
|
||||||
@@ -270,9 +274,12 @@ class TestStoreMediaFileSecurity:
|
|||||||
# Should raise an error when directory size exceeds limit
|
# Should raise an error when directory size exceeds limit
|
||||||
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
|
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id="test",
|
|
||||||
file=MediaFileType(
|
file=MediaFileType(
|
||||||
"data:text/plain;base64,dGVzdA=="
|
"data:text/plain;base64,dGVzdA=="
|
||||||
), # Small test file
|
), # Small test file
|
||||||
|
execution_context=ExecutionContext(
|
||||||
user_id="test_user",
|
user_id="test_user",
|
||||||
|
graph_exec_id="test",
|
||||||
|
),
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,10 +11,22 @@ from backend.blocks.http import (
|
|||||||
HttpMethod,
|
HttpMethod,
|
||||||
SendAuthenticatedWebRequestBlock,
|
SendAuthenticatedWebRequestBlock,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import HostScopedCredentials
|
from backend.data.model import HostScopedCredentials
|
||||||
from backend.util.request import Response
|
from backend.util.request import Response
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_context(
|
||||||
|
graph_exec_id: str = "test-exec-id",
|
||||||
|
user_id: str = "test-user-id",
|
||||||
|
) -> ExecutionContext:
|
||||||
|
"""Helper to create test ExecutionContext."""
|
||||||
|
return ExecutionContext(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestHttpBlockWithHostScopedCredentials:
|
class TestHttpBlockWithHostScopedCredentials:
|
||||||
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
||||||
|
|
||||||
@@ -105,8 +117,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=exact_match_credentials,
|
credentials=exact_match_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -161,8 +172,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=wildcard_credentials,
|
credentials=wildcard_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -208,8 +218,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=non_matching_credentials,
|
credentials=non_matching_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -258,8 +267,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=exact_match_credentials,
|
credentials=exact_match_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -318,8 +326,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=auto_discovered_creds, # Execution manager found these
|
credentials=auto_discovered_creds, # Execution manager found these
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -382,8 +389,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=multi_header_creds,
|
credentials=multi_header_creds,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -471,8 +477,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=test_creds,
|
credentials=test_creds,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util import json, text
|
from backend.util import json, text
|
||||||
from backend.util.file import get_exec_file_path, store_media_file
|
from backend.util.file import get_exec_file_path, store_media_file
|
||||||
@@ -444,18 +445,21 @@ class FileReadBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Store the media file properly (handles URLs, data URIs, etc.)
|
# Store the media file properly (handles URLs, data URIs, etc.)
|
||||||
stored_file_path = await store_media_file(
|
stored_file_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.file_input,
|
file=input_data.file_input,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get full file path
|
# Get full file path (graph_exec_id validated by store_media_file above)
|
||||||
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
if not execution_context.graph_exec_id:
|
||||||
|
raise ValueError("execution_context.graph_exec_id is required")
|
||||||
|
file_path = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, stored_file_path
|
||||||
|
)
|
||||||
|
|
||||||
if not Path(file_path).exists():
|
if not Path(file_path).exists():
|
||||||
raise ValueError(f"File does not exist: {file_path}")
|
raise ValueError(f"File does not exist: {file_path}")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest_asyncio
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from backend.util.logging import configure_logging
|
from backend.util.logging import configure_logging
|
||||||
@@ -19,7 +19,7 @@ if not os.getenv("PRISMA_DEBUG"):
|
|||||||
prisma_logger.setLevel(logging.INFO)
|
prisma_logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||||
async def server():
|
async def server():
|
||||||
from backend.util.test import SpinTestServer
|
from backend.util.test import SpinTestServer
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ async def server():
|
|||||||
yield server
|
yield server
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
||||||
async def graph_cleanup(server):
|
async def graph_cleanup(server):
|
||||||
created_graph_ids = []
|
created_graph_ids = []
|
||||||
original_create_graph = server.agent_server.test_create_graph
|
original_create_graph = server.agent_server.test_create_graph
|
||||||
|
|||||||
@@ -441,6 +441,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
static_output: bool = False,
|
static_output: bool = False,
|
||||||
block_type: BlockType = BlockType.STANDARD,
|
block_type: BlockType = BlockType.STANDARD,
|
||||||
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
|
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
|
||||||
|
is_sensitive_action: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the block with the given schema.
|
Initialize the block with the given schema.
|
||||||
@@ -473,8 +474,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
self.static_output = static_output
|
self.static_output = static_output
|
||||||
self.block_type = block_type
|
self.block_type = block_type
|
||||||
self.webhook_config = webhook_config
|
self.webhook_config = webhook_config
|
||||||
|
self.is_sensitive_action = is_sensitive_action
|
||||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||||
self.is_sensitive_action: bool = False
|
|
||||||
|
|
||||||
if self.webhook_config:
|
if self.webhook_config:
|
||||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||||
@@ -622,6 +623,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
input_data: BlockInput,
|
input_data: BlockInput,
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
node_id: str,
|
||||||
node_exec_id: str,
|
node_exec_id: str,
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
@@ -648,11 +650,11 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
|||||||
decision = await HITLReviewHelper.handle_review_decision(
|
decision = await HITLReviewHelper.handle_review_decision(
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
node_id=node_id,
|
||||||
node_exec_id=node_exec_id,
|
node_exec_id=node_exec_id,
|
||||||
graph_exec_id=graph_exec_id,
|
graph_exec_id=graph_exec_id,
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
graph_version=graph_version,
|
graph_version=graph_version,
|
||||||
execution_context=execution_context,
|
|
||||||
block_name=self.name,
|
block_name=self.name,
|
||||||
editable=True,
|
editable=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -83,12 +83,29 @@ class ExecutionContext(BaseModel):
|
|||||||
|
|
||||||
model_config = {"extra": "ignore"}
|
model_config = {"extra": "ignore"}
|
||||||
|
|
||||||
|
# Execution identity
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
graph_id: Optional[str] = None
|
||||||
|
graph_exec_id: Optional[str] = None
|
||||||
|
graph_version: Optional[int] = None
|
||||||
|
node_id: Optional[str] = None
|
||||||
|
node_exec_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Safety settings
|
||||||
human_in_the_loop_safe_mode: bool = True
|
human_in_the_loop_safe_mode: bool = True
|
||||||
sensitive_action_safe_mode: bool = False
|
sensitive_action_safe_mode: bool = False
|
||||||
|
|
||||||
|
# User settings
|
||||||
user_timezone: str = "UTC"
|
user_timezone: str = "UTC"
|
||||||
|
|
||||||
|
# Execution hierarchy
|
||||||
root_execution_id: Optional[str] = None
|
root_execution_id: Optional[str] = None
|
||||||
parent_execution_id: Optional[str] = None
|
parent_execution_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Workspace
|
||||||
|
workspace_id: Optional[str] = None
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# -------------------------- Models -------------------------- #
|
# -------------------------- Models -------------------------- #
|
||||||
|
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ Handles all database operations for pending human reviews.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
from prisma.models import PendingHumanReview
|
from prisma.models import AgentNodeExecution, PendingHumanReview
|
||||||
from prisma.types import PendingHumanReviewUpdateInput
|
from prisma.types import PendingHumanReviewUpdateInput
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -17,8 +17,12 @@ from backend.api.features.executions.review.model import (
|
|||||||
PendingHumanReviewModel,
|
PendingHumanReviewModel,
|
||||||
SafeJsonData,
|
SafeJsonData,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import get_graph_execution_meta
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -32,6 +36,125 @@ class ReviewResult(BaseModel):
|
|||||||
node_exec_id: str
|
node_exec_id: str
|
||||||
|
|
||||||
|
|
||||||
|
def get_auto_approve_key(graph_exec_id: str, node_id: str) -> str:
|
||||||
|
"""Generate the special nodeExecId key for auto-approval records."""
|
||||||
|
return f"auto_approve_{graph_exec_id}_{node_id}"
|
||||||
|
|
||||||
|
|
||||||
|
async def check_approval(
|
||||||
|
node_exec_id: str,
|
||||||
|
graph_exec_id: str,
|
||||||
|
node_id: str,
|
||||||
|
user_id: str,
|
||||||
|
input_data: SafeJsonData | None = None,
|
||||||
|
) -> Optional[ReviewResult]:
|
||||||
|
"""
|
||||||
|
Check if there's an existing approval for this node execution.
|
||||||
|
|
||||||
|
Checks both:
|
||||||
|
1. Normal approval by node_exec_id (previous run of the same node execution)
|
||||||
|
2. Auto-approval by special key pattern "auto_approve_{graph_exec_id}_{node_id}"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_exec_id: ID of the node execution
|
||||||
|
graph_exec_id: ID of the graph execution
|
||||||
|
node_id: ID of the node definition (not execution)
|
||||||
|
user_id: ID of the user (for data isolation)
|
||||||
|
input_data: Current input data (used for auto-approvals to avoid stale data)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReviewResult if approval found (either normal or auto), None otherwise
|
||||||
|
"""
|
||||||
|
auto_approve_key = get_auto_approve_key(graph_exec_id, node_id)
|
||||||
|
|
||||||
|
# Check for either normal approval or auto-approval in a single query
|
||||||
|
existing_review = await PendingHumanReview.prisma().find_first(
|
||||||
|
where={
|
||||||
|
"OR": [
|
||||||
|
{"nodeExecId": node_exec_id},
|
||||||
|
{"nodeExecId": auto_approve_key},
|
||||||
|
],
|
||||||
|
"status": ReviewStatus.APPROVED,
|
||||||
|
"userId": user_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_review:
|
||||||
|
is_auto_approval = existing_review.nodeExecId == auto_approve_key
|
||||||
|
logger.info(
|
||||||
|
f"Found {'auto-' if is_auto_approval else ''}approval for node {node_id} "
|
||||||
|
f"(exec: {node_exec_id}) in execution {graph_exec_id}"
|
||||||
|
)
|
||||||
|
# For auto-approvals, use current input_data to avoid replaying stale payload
|
||||||
|
# For normal approvals, use the stored payload (which may have been edited)
|
||||||
|
return ReviewResult(
|
||||||
|
data=(
|
||||||
|
input_data
|
||||||
|
if is_auto_approval and input_data is not None
|
||||||
|
else existing_review.payload
|
||||||
|
),
|
||||||
|
status=ReviewStatus.APPROVED,
|
||||||
|
message=(
|
||||||
|
"Auto-approved (user approved all future actions for this node)"
|
||||||
|
if is_auto_approval
|
||||||
|
else existing_review.reviewMessage or ""
|
||||||
|
),
|
||||||
|
processed=True,
|
||||||
|
node_exec_id=existing_review.nodeExecId,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def create_auto_approval_record(
|
||||||
|
user_id: str,
|
||||||
|
graph_exec_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
node_id: str,
|
||||||
|
payload: SafeJsonData,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create an auto-approval record for a node in this execution.
|
||||||
|
|
||||||
|
This is stored as a PendingHumanReview with a special nodeExecId pattern
|
||||||
|
and status=APPROVED, so future executions of the same node can skip review.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the graph execution doesn't belong to the user
|
||||||
|
"""
|
||||||
|
# Validate that the graph execution belongs to this user (defense in depth)
|
||||||
|
graph_exec = await get_graph_execution_meta(
|
||||||
|
user_id=user_id, execution_id=graph_exec_id
|
||||||
|
)
|
||||||
|
if not graph_exec:
|
||||||
|
raise ValueError(
|
||||||
|
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
auto_approve_key = get_auto_approve_key(graph_exec_id, node_id)
|
||||||
|
|
||||||
|
await PendingHumanReview.prisma().upsert(
|
||||||
|
where={"nodeExecId": auto_approve_key},
|
||||||
|
data={
|
||||||
|
"create": {
|
||||||
|
"nodeExecId": auto_approve_key,
|
||||||
|
"userId": user_id,
|
||||||
|
"graphExecId": graph_exec_id,
|
||||||
|
"graphId": graph_id,
|
||||||
|
"graphVersion": graph_version,
|
||||||
|
"payload": SafeJson(payload),
|
||||||
|
"instructions": "Auto-approval record",
|
||||||
|
"editable": False,
|
||||||
|
"status": ReviewStatus.APPROVED,
|
||||||
|
"processed": True,
|
||||||
|
"reviewedAt": datetime.now(timezone.utc),
|
||||||
|
},
|
||||||
|
"update": {}, # Already exists, no update needed
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_or_create_human_review(
|
async def get_or_create_human_review(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
node_exec_id: str,
|
node_exec_id: str,
|
||||||
@@ -108,6 +231,89 @@ async def get_or_create_human_review(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_pending_review_by_node_exec_id(
|
||||||
|
node_exec_id: str, user_id: str
|
||||||
|
) -> Optional["PendingHumanReviewModel"]:
|
||||||
|
"""
|
||||||
|
Get a pending review by its node execution ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_exec_id: The node execution ID to look up
|
||||||
|
user_id: User ID for authorization (only returns if review belongs to this user)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The pending review if found and belongs to user, None otherwise
|
||||||
|
"""
|
||||||
|
review = await PendingHumanReview.prisma().find_first(
|
||||||
|
where={
|
||||||
|
"nodeExecId": node_exec_id,
|
||||||
|
"userId": user_id,
|
||||||
|
"status": ReviewStatus.WAITING,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not review:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Local import to avoid event loop conflicts in tests
|
||||||
|
from backend.data.execution import get_node_execution
|
||||||
|
|
||||||
|
node_exec = await get_node_execution(review.nodeExecId)
|
||||||
|
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||||
|
return PendingHumanReviewModel.from_db(review, node_id=node_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_reviews_by_node_exec_ids(
|
||||||
|
node_exec_ids: list[str], user_id: str
|
||||||
|
) -> dict[str, "PendingHumanReviewModel"]:
|
||||||
|
"""
|
||||||
|
Get multiple reviews by their node execution IDs regardless of status.
|
||||||
|
|
||||||
|
Unlike get_pending_reviews_by_node_exec_ids, this returns reviews in any status
|
||||||
|
(WAITING, APPROVED, REJECTED). Used for validation in idempotent operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_exec_ids: List of node execution IDs to look up
|
||||||
|
user_id: User ID for authorization (only returns reviews belonging to this user)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping node_exec_id -> PendingHumanReviewModel for found reviews
|
||||||
|
"""
|
||||||
|
if not node_exec_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
reviews = await PendingHumanReview.prisma().find_many(
|
||||||
|
where={
|
||||||
|
"nodeExecId": {"in": node_exec_ids},
|
||||||
|
"userId": user_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not reviews:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Batch fetch all node executions to avoid N+1 queries
|
||||||
|
node_exec_ids_to_fetch = [review.nodeExecId for review in reviews]
|
||||||
|
node_execs = await AgentNodeExecution.prisma().find_many(
|
||||||
|
where={"id": {"in": node_exec_ids_to_fetch}},
|
||||||
|
include={"Node": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mapping from node_exec_id to node_id
|
||||||
|
node_exec_id_to_node_id = {
|
||||||
|
node_exec.id: node_exec.agentNodeId for node_exec in node_execs
|
||||||
|
}
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for review in reviews:
|
||||||
|
node_id = node_exec_id_to_node_id.get(review.nodeExecId, review.nodeExecId)
|
||||||
|
result[review.nodeExecId] = PendingHumanReviewModel.from_db(
|
||||||
|
review, node_id=node_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool:
|
async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a graph execution has any pending reviews.
|
Check if a graph execution has any pending reviews.
|
||||||
@@ -137,8 +343,11 @@ async def get_pending_reviews_for_user(
|
|||||||
page_size: Number of reviews per page
|
page_size: Number of reviews per page
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of pending review models
|
List of pending review models with node_id included
|
||||||
"""
|
"""
|
||||||
|
# Local import to avoid event loop conflicts in tests
|
||||||
|
from backend.data.execution import get_node_execution
|
||||||
|
|
||||||
# Calculate offset for pagination
|
# Calculate offset for pagination
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
@@ -149,7 +358,14 @@ async def get_pending_reviews_for_user(
|
|||||||
take=page_size,
|
take=page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [PendingHumanReviewModel.from_db(review) for review in reviews]
|
# Fetch node_id for each review from NodeExecution
|
||||||
|
result = []
|
||||||
|
for review in reviews:
|
||||||
|
node_exec = await get_node_execution(review.nodeExecId)
|
||||||
|
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||||
|
result.append(PendingHumanReviewModel.from_db(review, node_id=node_id))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def get_pending_reviews_for_execution(
|
async def get_pending_reviews_for_execution(
|
||||||
@@ -163,8 +379,11 @@ async def get_pending_reviews_for_execution(
|
|||||||
user_id: User ID for security validation
|
user_id: User ID for security validation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of pending review models
|
List of pending review models with node_id included
|
||||||
"""
|
"""
|
||||||
|
# Local import to avoid event loop conflicts in tests
|
||||||
|
from backend.data.execution import get_node_execution
|
||||||
|
|
||||||
reviews = await PendingHumanReview.prisma().find_many(
|
reviews = await PendingHumanReview.prisma().find_many(
|
||||||
where={
|
where={
|
||||||
"userId": user_id,
|
"userId": user_id,
|
||||||
@@ -174,7 +393,14 @@ async def get_pending_reviews_for_execution(
|
|||||||
order={"createdAt": "asc"},
|
order={"createdAt": "asc"},
|
||||||
)
|
)
|
||||||
|
|
||||||
return [PendingHumanReviewModel.from_db(review) for review in reviews]
|
# Fetch node_id for each review from NodeExecution
|
||||||
|
result = []
|
||||||
|
for review in reviews:
|
||||||
|
node_exec = await get_node_execution(review.nodeExecId)
|
||||||
|
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||||
|
result.append(PendingHumanReviewModel.from_db(review, node_id=node_id))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def process_all_reviews_for_execution(
|
async def process_all_reviews_for_execution(
|
||||||
@@ -183,38 +409,68 @@ async def process_all_reviews_for_execution(
|
|||||||
) -> dict[str, PendingHumanReviewModel]:
|
) -> dict[str, PendingHumanReviewModel]:
|
||||||
"""Process all pending reviews for an execution with approve/reject decisions.
|
"""Process all pending reviews for an execution with approve/reject decisions.
|
||||||
|
|
||||||
|
Handles race conditions gracefully: if a review was already processed with the
|
||||||
|
same decision by a concurrent request, it's treated as success rather than error.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: User ID for ownership validation
|
user_id: User ID for ownership validation
|
||||||
review_decisions: Map of node_exec_id -> (status, reviewed_data, message)
|
review_decisions: Map of node_exec_id -> (status, reviewed_data, message)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict of node_exec_id -> updated review model
|
Dict of node_exec_id -> updated review model (includes already-processed reviews)
|
||||||
"""
|
"""
|
||||||
if not review_decisions:
|
if not review_decisions:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
node_exec_ids = list(review_decisions.keys())
|
node_exec_ids = list(review_decisions.keys())
|
||||||
|
|
||||||
# Get all reviews for validation
|
# Get all reviews (both WAITING and already processed) for the user
|
||||||
reviews = await PendingHumanReview.prisma().find_many(
|
all_reviews = await PendingHumanReview.prisma().find_many(
|
||||||
where={
|
where={
|
||||||
"nodeExecId": {"in": node_exec_ids},
|
"nodeExecId": {"in": node_exec_ids},
|
||||||
"userId": user_id,
|
"userId": user_id,
|
||||||
"status": ReviewStatus.WAITING,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate all reviews can be processed
|
# Separate into pending and already-processed reviews
|
||||||
if len(reviews) != len(node_exec_ids):
|
reviews_to_process = []
|
||||||
missing_ids = set(node_exec_ids) - {review.nodeExecId for review in reviews}
|
already_processed = []
|
||||||
|
for review in all_reviews:
|
||||||
|
if review.status == ReviewStatus.WAITING:
|
||||||
|
reviews_to_process.append(review)
|
||||||
|
else:
|
||||||
|
already_processed.append(review)
|
||||||
|
|
||||||
|
# Check for truly missing reviews (not found at all)
|
||||||
|
found_ids = {review.nodeExecId for review in all_reviews}
|
||||||
|
missing_ids = set(node_exec_ids) - found_ids
|
||||||
|
if missing_ids:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Reviews not found, access denied, or not in WAITING status: {', '.join(missing_ids)}"
|
f"Reviews not found or access denied: {', '.join(missing_ids)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create parallel update tasks
|
# Validate already-processed reviews have compatible status (same decision)
|
||||||
|
# This handles race conditions where another request processed the same reviews
|
||||||
|
for review in already_processed:
|
||||||
|
requested_status = review_decisions[review.nodeExecId][0]
|
||||||
|
if review.status != requested_status:
|
||||||
|
raise ValueError(
|
||||||
|
f"Review {review.nodeExecId} was already processed with status "
|
||||||
|
f"{review.status}, cannot change to {requested_status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log if we're handling a race condition (some reviews already processed)
|
||||||
|
if already_processed:
|
||||||
|
already_processed_ids = [r.nodeExecId for r in already_processed]
|
||||||
|
logger.info(
|
||||||
|
f"Race condition handled: {len(already_processed)} review(s) already "
|
||||||
|
f"processed by concurrent request: {already_processed_ids}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create parallel update tasks for reviews that still need processing
|
||||||
update_tasks = []
|
update_tasks = []
|
||||||
|
|
||||||
for review in reviews:
|
for review in reviews_to_process:
|
||||||
new_status, reviewed_data, message = review_decisions[review.nodeExecId]
|
new_status, reviewed_data, message = review_decisions[review.nodeExecId]
|
||||||
has_data_changes = reviewed_data is not None and reviewed_data != review.payload
|
has_data_changes = reviewed_data is not None and reviewed_data != review.payload
|
||||||
|
|
||||||
@@ -239,16 +495,27 @@ async def process_all_reviews_for_execution(
|
|||||||
update_tasks.append(task)
|
update_tasks.append(task)
|
||||||
|
|
||||||
# Execute all updates in parallel and get updated reviews
|
# Execute all updates in parallel and get updated reviews
|
||||||
updated_reviews = await asyncio.gather(*update_tasks)
|
updated_reviews = await asyncio.gather(*update_tasks) if update_tasks else []
|
||||||
|
|
||||||
# Note: Execution resumption is now handled at the API layer after ALL reviews
|
# Note: Execution resumption is now handled at the API layer after ALL reviews
|
||||||
# for an execution are processed (both approved and rejected)
|
# for an execution are processed (both approved and rejected)
|
||||||
|
|
||||||
# Return as dict for easy access
|
# Fetch node_id for each review and return as dict for easy access
|
||||||
return {
|
# Local import to avoid event loop conflicts in tests
|
||||||
review.nodeExecId: PendingHumanReviewModel.from_db(review)
|
from backend.data.execution import get_node_execution
|
||||||
for review in updated_reviews
|
|
||||||
}
|
# Combine updated reviews with already-processed ones (for idempotent response)
|
||||||
|
all_result_reviews = list(updated_reviews) + already_processed
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for review in all_result_reviews:
|
||||||
|
node_exec = await get_node_execution(review.nodeExecId)
|
||||||
|
node_id = node_exec.node_id if node_exec else review.nodeExecId
|
||||||
|
result[review.nodeExecId] = PendingHumanReviewModel.from_db(
|
||||||
|
review, node_id=node_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def update_review_processed_status(node_exec_id: str, processed: bool) -> None:
|
async def update_review_processed_status(node_exec_id: str, processed: bool) -> None:
|
||||||
@@ -256,3 +523,44 @@ async def update_review_processed_status(node_exec_id: str, processed: bool) ->
|
|||||||
await PendingHumanReview.prisma().update(
|
await PendingHumanReview.prisma().update(
|
||||||
where={"nodeExecId": node_exec_id}, data={"processed": processed}
|
where={"nodeExecId": node_exec_id}, data={"processed": processed}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def cancel_pending_reviews_for_execution(graph_exec_id: str, user_id: str) -> int:
|
||||||
|
"""
|
||||||
|
Cancel all pending reviews for a graph execution (e.g., when execution is stopped).
|
||||||
|
|
||||||
|
Marks all WAITING reviews as REJECTED with a message indicating the execution was stopped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph_exec_id: The graph execution ID
|
||||||
|
user_id: User ID who owns the execution (for security validation)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of reviews cancelled
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the graph execution doesn't belong to the user
|
||||||
|
"""
|
||||||
|
# Validate user ownership before cancelling reviews
|
||||||
|
graph_exec = await get_graph_execution_meta(
|
||||||
|
user_id=user_id, execution_id=graph_exec_id
|
||||||
|
)
|
||||||
|
if not graph_exec:
|
||||||
|
raise ValueError(
|
||||||
|
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await PendingHumanReview.prisma().update_many(
|
||||||
|
where={
|
||||||
|
"graphExecId": graph_exec_id,
|
||||||
|
"userId": user_id,
|
||||||
|
"status": ReviewStatus.WAITING,
|
||||||
|
},
|
||||||
|
data={
|
||||||
|
"status": ReviewStatus.REJECTED,
|
||||||
|
"reviewMessage": "Execution was stopped by user",
|
||||||
|
"processed": True,
|
||||||
|
"reviewedAt": datetime.now(timezone.utc),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ def sample_db_review():
|
|||||||
return mock_review
|
return mock_review
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_get_or_create_human_review_new(
|
async def test_get_or_create_human_review_new(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
sample_db_review,
|
sample_db_review,
|
||||||
@@ -46,8 +46,8 @@ async def test_get_or_create_human_review_new(
|
|||||||
sample_db_review.status = ReviewStatus.WAITING
|
sample_db_review.status = ReviewStatus.WAITING
|
||||||
sample_db_review.processed = False
|
sample_db_review.processed = False
|
||||||
|
|
||||||
mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||||
mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||||
|
|
||||||
result = await get_or_create_human_review(
|
result = await get_or_create_human_review(
|
||||||
user_id="test-user-123",
|
user_id="test-user-123",
|
||||||
@@ -64,7 +64,7 @@ async def test_get_or_create_human_review_new(
|
|||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_get_or_create_human_review_approved(
|
async def test_get_or_create_human_review_approved(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
sample_db_review,
|
sample_db_review,
|
||||||
@@ -75,8 +75,8 @@ async def test_get_or_create_human_review_approved(
|
|||||||
sample_db_review.processed = False
|
sample_db_review.processed = False
|
||||||
sample_db_review.reviewMessage = "Looks good"
|
sample_db_review.reviewMessage = "Looks good"
|
||||||
|
|
||||||
mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||||
mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||||
|
|
||||||
result = await get_or_create_human_review(
|
result = await get_or_create_human_review(
|
||||||
user_id="test-user-123",
|
user_id="test-user-123",
|
||||||
@@ -96,7 +96,7 @@ async def test_get_or_create_human_review_approved(
|
|||||||
assert result.message == "Looks good"
|
assert result.message == "Looks good"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_has_pending_reviews_for_graph_exec_true(
|
async def test_has_pending_reviews_for_graph_exec_true(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
):
|
):
|
||||||
@@ -109,7 +109,7 @@ async def test_has_pending_reviews_for_graph_exec_true(
|
|||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_has_pending_reviews_for_graph_exec_false(
|
async def test_has_pending_reviews_for_graph_exec_false(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
):
|
):
|
||||||
@@ -122,7 +122,7 @@ async def test_has_pending_reviews_for_graph_exec_false(
|
|||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_get_pending_reviews_for_user(
|
async def test_get_pending_reviews_for_user(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
sample_db_review,
|
sample_db_review,
|
||||||
@@ -131,10 +131,19 @@ async def test_get_pending_reviews_for_user(
|
|||||||
mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||||
mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review])
|
mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review])
|
||||||
|
|
||||||
|
# Mock get_node_execution to return node with node_id (async function)
|
||||||
|
mock_node_exec = Mock()
|
||||||
|
mock_node_exec.node_id = "test_node_def_789"
|
||||||
|
mocker.patch(
|
||||||
|
"backend.data.execution.get_node_execution",
|
||||||
|
new=AsyncMock(return_value=mock_node_exec),
|
||||||
|
)
|
||||||
|
|
||||||
result = await get_pending_reviews_for_user("test_user", page=2, page_size=10)
|
result = await get_pending_reviews_for_user("test_user", page=2, page_size=10)
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0].node_exec_id == "test_node_123"
|
assert result[0].node_exec_id == "test_node_123"
|
||||||
|
assert result[0].node_id == "test_node_def_789"
|
||||||
|
|
||||||
# Verify pagination parameters
|
# Verify pagination parameters
|
||||||
call_args = mock_find_many.return_value.find_many.call_args
|
call_args = mock_find_many.return_value.find_many.call_args
|
||||||
@@ -142,7 +151,7 @@ async def test_get_pending_reviews_for_user(
|
|||||||
assert call_args.kwargs["take"] == 10
|
assert call_args.kwargs["take"] == 10
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_get_pending_reviews_for_execution(
|
async def test_get_pending_reviews_for_execution(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
sample_db_review,
|
sample_db_review,
|
||||||
@@ -151,12 +160,21 @@ async def test_get_pending_reviews_for_execution(
|
|||||||
mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||||
mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review])
|
mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review])
|
||||||
|
|
||||||
|
# Mock get_node_execution to return node with node_id (async function)
|
||||||
|
mock_node_exec = Mock()
|
||||||
|
mock_node_exec.node_id = "test_node_def_789"
|
||||||
|
mocker.patch(
|
||||||
|
"backend.data.execution.get_node_execution",
|
||||||
|
new=AsyncMock(return_value=mock_node_exec),
|
||||||
|
)
|
||||||
|
|
||||||
result = await get_pending_reviews_for_execution(
|
result = await get_pending_reviews_for_execution(
|
||||||
"test_graph_exec_456", "test-user-123"
|
"test_graph_exec_456", "test-user-123"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0].graph_exec_id == "test_graph_exec_456"
|
assert result[0].graph_exec_id == "test_graph_exec_456"
|
||||||
|
assert result[0].node_id == "test_node_def_789"
|
||||||
|
|
||||||
# Verify it filters by execution and user
|
# Verify it filters by execution and user
|
||||||
call_args = mock_find_many.return_value.find_many.call_args
|
call_args = mock_find_many.return_value.find_many.call_args
|
||||||
@@ -166,7 +184,7 @@ async def test_get_pending_reviews_for_execution(
|
|||||||
assert where_clause["status"] == ReviewStatus.WAITING
|
assert where_clause["status"] == ReviewStatus.WAITING
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_process_all_reviews_for_execution_success(
|
async def test_process_all_reviews_for_execution_success(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
sample_db_review,
|
sample_db_review,
|
||||||
@@ -201,6 +219,14 @@ async def test_process_all_reviews_for_execution_success(
|
|||||||
new=AsyncMock(return_value=[updated_review]),
|
new=AsyncMock(return_value=[updated_review]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mock get_node_execution to return node with node_id (async function)
|
||||||
|
mock_node_exec = Mock()
|
||||||
|
mock_node_exec.node_id = "test_node_def_789"
|
||||||
|
mocker.patch(
|
||||||
|
"backend.data.execution.get_node_execution",
|
||||||
|
new=AsyncMock(return_value=mock_node_exec),
|
||||||
|
)
|
||||||
|
|
||||||
result = await process_all_reviews_for_execution(
|
result = await process_all_reviews_for_execution(
|
||||||
user_id="test-user-123",
|
user_id="test-user-123",
|
||||||
review_decisions={
|
review_decisions={
|
||||||
@@ -211,9 +237,10 @@ async def test_process_all_reviews_for_execution_success(
|
|||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert "test_node_123" in result
|
assert "test_node_123" in result
|
||||||
assert result["test_node_123"].status == ReviewStatus.APPROVED
|
assert result["test_node_123"].status == ReviewStatus.APPROVED
|
||||||
|
assert result["test_node_123"].node_id == "test_node_def_789"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_process_all_reviews_for_execution_validation_errors(
|
async def test_process_all_reviews_for_execution_validation_errors(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
):
|
):
|
||||||
@@ -233,7 +260,7 @@ async def test_process_all_reviews_for_execution_validation_errors(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_process_all_reviews_edit_permission_error(
|
async def test_process_all_reviews_edit_permission_error(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
sample_db_review,
|
sample_db_review,
|
||||||
@@ -259,7 +286,7 @@ async def test_process_all_reviews_edit_permission_error(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_process_all_reviews_mixed_approval_rejection(
|
async def test_process_all_reviews_mixed_approval_rejection(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
sample_db_review,
|
sample_db_review,
|
||||||
@@ -329,6 +356,14 @@ async def test_process_all_reviews_mixed_approval_rejection(
|
|||||||
new=AsyncMock(return_value=[approved_review, rejected_review]),
|
new=AsyncMock(return_value=[approved_review, rejected_review]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mock get_node_execution to return node with node_id (async function)
|
||||||
|
mock_node_exec = Mock()
|
||||||
|
mock_node_exec.node_id = "test_node_def_789"
|
||||||
|
mocker.patch(
|
||||||
|
"backend.data.execution.get_node_execution",
|
||||||
|
new=AsyncMock(return_value=mock_node_exec),
|
||||||
|
)
|
||||||
|
|
||||||
result = await process_all_reviews_for_execution(
|
result = await process_all_reviews_for_execution(
|
||||||
user_id="test-user-123",
|
user_id="test-user-123",
|
||||||
review_decisions={
|
review_decisions={
|
||||||
@@ -340,3 +375,5 @@ async def test_process_all_reviews_mixed_approval_rejection(
|
|||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert "test_node_123" in result
|
assert "test_node_123" in result
|
||||||
assert "test_node_456" in result
|
assert "test_node_456" in result
|
||||||
|
assert result["test_node_123"].node_id == "test_node_def_789"
|
||||||
|
assert result["test_node_456"].node_id == "test_node_def_789"
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ FrontendOnboardingStep = Literal[
|
|||||||
OnboardingStep.AGENT_NEW_RUN,
|
OnboardingStep.AGENT_NEW_RUN,
|
||||||
OnboardingStep.AGENT_INPUT,
|
OnboardingStep.AGENT_INPUT,
|
||||||
OnboardingStep.CONGRATS,
|
OnboardingStep.CONGRATS,
|
||||||
|
OnboardingStep.VISIT_COPILOT,
|
||||||
OnboardingStep.MARKETPLACE_VISIT,
|
OnboardingStep.MARKETPLACE_VISIT,
|
||||||
OnboardingStep.BUILDER_OPEN,
|
OnboardingStep.BUILDER_OPEN,
|
||||||
]
|
]
|
||||||
@@ -122,6 +123,9 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
|||||||
async def _reward_user(user_id: str, onboarding: UserOnboarding, step: OnboardingStep):
|
async def _reward_user(user_id: str, onboarding: UserOnboarding, step: OnboardingStep):
|
||||||
reward = 0
|
reward = 0
|
||||||
match step:
|
match step:
|
||||||
|
# Welcome bonus for visiting copilot ($5 = 500 credits)
|
||||||
|
case OnboardingStep.VISIT_COPILOT:
|
||||||
|
reward = 500
|
||||||
# Reward user when they clicked New Run during onboarding
|
# Reward user when they clicked New Run during onboarding
|
||||||
# This is because they need credits before scheduling a run (next step)
|
# This is because they need credits before scheduling a run (next step)
|
||||||
# This is seen as a reward for the GET_RESULTS step in the wallet
|
# This is seen as a reward for the GET_RESULTS step in the wallet
|
||||||
|
|||||||
285
autogpt_platform/backend/backend/data/workspace.py
Normal file
285
autogpt_platform/backend/backend/data/workspace.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
"""
|
||||||
|
Database CRUD operations for User Workspace.
|
||||||
|
|
||||||
|
This module provides functions for managing user workspaces and workspace files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from prisma.enums import WorkspaceFileSource
|
||||||
|
from prisma.models import UserWorkspace, UserWorkspaceFile
|
||||||
|
|
||||||
|
from backend.util.json import SafeJson
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
|
||||||
|
"""
|
||||||
|
Get user's workspace, creating one if it doesn't exist.
|
||||||
|
|
||||||
|
Uses upsert to handle race conditions when multiple concurrent requests
|
||||||
|
attempt to create a workspace for the same user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspace instance
|
||||||
|
"""
|
||||||
|
workspace = await UserWorkspace.prisma().upsert(
|
||||||
|
where={"userId": user_id},
|
||||||
|
data={
|
||||||
|
"create": {"userId": user_id},
|
||||||
|
"update": {}, # No updates needed if exists
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return workspace
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
|
||||||
|
"""
|
||||||
|
Get user's workspace if it exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspace instance or None
|
||||||
|
"""
|
||||||
|
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
||||||
|
|
||||||
|
|
||||||
|
async def create_workspace_file(
|
||||||
|
workspace_id: str,
|
||||||
|
file_id: str,
|
||||||
|
name: str,
|
||||||
|
path: str,
|
||||||
|
storage_path: str,
|
||||||
|
mime_type: str,
|
||||||
|
size_bytes: int,
|
||||||
|
checksum: Optional[str] = None,
|
||||||
|
source: WorkspaceFileSource = WorkspaceFileSource.UPLOAD,
|
||||||
|
source_exec_id: Optional[str] = None,
|
||||||
|
source_session_id: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
) -> UserWorkspaceFile:
|
||||||
|
"""
|
||||||
|
Create a new workspace file record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
file_id: The file ID (same as used in storage path for consistency)
|
||||||
|
name: User-visible filename
|
||||||
|
path: Virtual path (e.g., "/documents/report.pdf")
|
||||||
|
storage_path: Actual storage path (GCS or local)
|
||||||
|
mime_type: MIME type of the file
|
||||||
|
size_bytes: File size in bytes
|
||||||
|
checksum: Optional SHA256 checksum
|
||||||
|
source: How the file was created
|
||||||
|
source_exec_id: Graph execution ID if from execution
|
||||||
|
source_session_id: Chat session ID if from CoPilot
|
||||||
|
metadata: Optional additional metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created UserWorkspaceFile instance
|
||||||
|
"""
|
||||||
|
# Normalize path to start with /
|
||||||
|
if not path.startswith("/"):
|
||||||
|
path = f"/{path}"
|
||||||
|
|
||||||
|
file = await UserWorkspaceFile.prisma().create(
|
||||||
|
data={
|
||||||
|
"id": file_id,
|
||||||
|
"workspaceId": workspace_id,
|
||||||
|
"name": name,
|
||||||
|
"path": path,
|
||||||
|
"storagePath": storage_path,
|
||||||
|
"mimeType": mime_type,
|
||||||
|
"sizeBytes": size_bytes,
|
||||||
|
"checksum": checksum,
|
||||||
|
"source": source,
|
||||||
|
"sourceExecId": source_exec_id,
|
||||||
|
"sourceSessionId": source_session_id,
|
||||||
|
"metadata": SafeJson(metadata or {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created workspace file {file.id} at path {path} "
|
||||||
|
f"in workspace {workspace_id}"
|
||||||
|
)
|
||||||
|
return file
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_file(
|
||||||
|
file_id: str,
|
||||||
|
workspace_id: Optional[str] = None,
|
||||||
|
) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Get a workspace file by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file ID
|
||||||
|
workspace_id: Optional workspace ID for validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspaceFile instance or None
|
||||||
|
"""
|
||||||
|
where_clause: dict = {"id": file_id, "isDeleted": False}
|
||||||
|
if workspace_id:
|
||||||
|
where_clause["workspaceId"] = workspace_id
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_file_by_path(
|
||||||
|
workspace_id: str,
|
||||||
|
path: str,
|
||||||
|
) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Get a workspace file by its virtual path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
path: Virtual path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspaceFile instance or None
|
||||||
|
"""
|
||||||
|
# Normalize path
|
||||||
|
if not path.startswith("/"):
|
||||||
|
path = f"/{path}"
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().find_first(
|
||||||
|
where={
|
||||||
|
"workspaceId": workspace_id,
|
||||||
|
"path": path,
|
||||||
|
"isDeleted": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_workspace_files(
|
||||||
|
workspace_id: str,
|
||||||
|
path_prefix: Optional[str] = None,
|
||||||
|
include_deleted: bool = False,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
List files in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
path_prefix: Optional path prefix to filter (e.g., "/documents/")
|
||||||
|
include_deleted: Whether to include soft-deleted files
|
||||||
|
limit: Maximum number of files to return
|
||||||
|
offset: Number of files to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of UserWorkspaceFile instances
|
||||||
|
"""
|
||||||
|
where_clause: dict = {"workspaceId": workspace_id}
|
||||||
|
|
||||||
|
if not include_deleted:
|
||||||
|
where_clause["isDeleted"] = False
|
||||||
|
|
||||||
|
if path_prefix:
|
||||||
|
# Normalize prefix
|
||||||
|
if not path_prefix.startswith("/"):
|
||||||
|
path_prefix = f"/{path_prefix}"
|
||||||
|
where_clause["path"] = {"startswith": path_prefix}
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().find_many(
|
||||||
|
where=where_clause,
|
||||||
|
order={"createdAt": "desc"},
|
||||||
|
take=limit,
|
||||||
|
skip=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def count_workspace_files(
|
||||||
|
workspace_id: str,
|
||||||
|
path_prefix: Optional[str] = None,
|
||||||
|
include_deleted: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Count files in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
path_prefix: Optional path prefix to filter (e.g., "/sessions/abc123/")
|
||||||
|
include_deleted: Whether to include soft-deleted files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of files
|
||||||
|
"""
|
||||||
|
where_clause: dict = {"workspaceId": workspace_id}
|
||||||
|
if not include_deleted:
|
||||||
|
where_clause["isDeleted"] = False
|
||||||
|
|
||||||
|
if path_prefix:
|
||||||
|
# Normalize prefix
|
||||||
|
if not path_prefix.startswith("/"):
|
||||||
|
path_prefix = f"/{path_prefix}"
|
||||||
|
where_clause["path"] = {"startswith": path_prefix}
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().count(where=where_clause)
|
||||||
|
|
||||||
|
|
||||||
|
async def soft_delete_workspace_file(
|
||||||
|
file_id: str,
|
||||||
|
workspace_id: Optional[str] = None,
|
||||||
|
) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Soft-delete a workspace file.
|
||||||
|
|
||||||
|
The path is modified to include a deletion timestamp to free up the original
|
||||||
|
path for new files while preserving the record for potential recovery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file ID
|
||||||
|
workspace_id: Optional workspace ID for validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated UserWorkspaceFile instance or None if not found
|
||||||
|
"""
|
||||||
|
# First verify the file exists and belongs to workspace
|
||||||
|
file = await get_workspace_file(file_id, workspace_id)
|
||||||
|
if file is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
deleted_at = datetime.now(timezone.utc)
|
||||||
|
# Modify path to free up the unique constraint for new files at original path
|
||||||
|
# Format: {original_path}__deleted__{timestamp}
|
||||||
|
deleted_path = f"{file.path}__deleted__{int(deleted_at.timestamp())}"
|
||||||
|
|
||||||
|
updated = await UserWorkspaceFile.prisma().update(
|
||||||
|
where={"id": file_id},
|
||||||
|
data={
|
||||||
|
"isDeleted": True,
|
||||||
|
"deletedAt": deleted_at,
|
||||||
|
"path": deleted_path,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Soft-deleted workspace file {file_id}")
|
||||||
|
return updated
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_total_size(workspace_id: str) -> int:
|
||||||
|
"""
|
||||||
|
Get the total size of all files in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total size in bytes
|
||||||
|
"""
|
||||||
|
files = await list_workspace_files(workspace_id)
|
||||||
|
return sum(file.sizeBytes for file in files)
|
||||||
@@ -50,6 +50,8 @@ from backend.data.graph import (
|
|||||||
validate_graph_execution_permissions,
|
validate_graph_execution_permissions,
|
||||||
)
|
)
|
||||||
from backend.data.human_review import (
|
from backend.data.human_review import (
|
||||||
|
cancel_pending_reviews_for_execution,
|
||||||
|
check_approval,
|
||||||
get_or_create_human_review,
|
get_or_create_human_review,
|
||||||
has_pending_reviews_for_graph_exec,
|
has_pending_reviews_for_graph_exec,
|
||||||
update_review_processed_status,
|
update_review_processed_status,
|
||||||
@@ -190,6 +192,8 @@ class DatabaseManager(AppService):
|
|||||||
get_user_notification_preference = _(get_user_notification_preference)
|
get_user_notification_preference = _(get_user_notification_preference)
|
||||||
|
|
||||||
# Human In The Loop
|
# Human In The Loop
|
||||||
|
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
||||||
|
check_approval = _(check_approval)
|
||||||
get_or_create_human_review = _(get_or_create_human_review)
|
get_or_create_human_review = _(get_or_create_human_review)
|
||||||
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
||||||
update_review_processed_status = _(update_review_processed_status)
|
update_review_processed_status = _(update_review_processed_status)
|
||||||
@@ -313,6 +317,8 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
set_execution_kv_data = d.set_execution_kv_data
|
set_execution_kv_data = d.set_execution_kv_data
|
||||||
|
|
||||||
# Human In The Loop
|
# Human In The Loop
|
||||||
|
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
||||||
|
check_approval = d.check_approval
|
||||||
get_or_create_human_review = d.get_or_create_human_review
|
get_or_create_human_review = d.get_or_create_human_review
|
||||||
update_review_processed_status = d.update_review_processed_status
|
update_review_processed_status = d.update_review_processed_status
|
||||||
|
|
||||||
|
|||||||
@@ -236,7 +236,14 @@ async def execute_node(
|
|||||||
input_size = len(input_data_str)
|
input_size = len(input_data_str)
|
||||||
log_metadata.debug("Executed node with input", input=input_data_str)
|
log_metadata.debug("Executed node with input", input=input_data_str)
|
||||||
|
|
||||||
|
# Create node-specific execution context to avoid race conditions
|
||||||
|
# (multiple nodes can execute concurrently and would otherwise mutate shared state)
|
||||||
|
execution_context = execution_context.model_copy(
|
||||||
|
update={"node_id": node_id, "node_exec_id": node_exec_id}
|
||||||
|
)
|
||||||
|
|
||||||
# Inject extra execution arguments for the blocks via kwargs
|
# Inject extra execution arguments for the blocks via kwargs
|
||||||
|
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||||
extra_exec_kwargs: dict = {
|
extra_exec_kwargs: dict = {
|
||||||
"graph_id": graph_id,
|
"graph_id": graph_id,
|
||||||
"graph_version": graph_version,
|
"graph_version": graph_version,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel, JsonValue, ValidationError
|
|||||||
|
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
|
from backend.data import human_review as human_review_db
|
||||||
from backend.data import onboarding as onboarding_db
|
from backend.data import onboarding as onboarding_db
|
||||||
from backend.data import user as user_db
|
from backend.data import user as user_db
|
||||||
from backend.data.block import (
|
from backend.data.block import (
|
||||||
@@ -749,9 +750,27 @@ async def stop_graph_execution(
|
|||||||
if graph_exec.status in [
|
if graph_exec.status in [
|
||||||
ExecutionStatus.QUEUED,
|
ExecutionStatus.QUEUED,
|
||||||
ExecutionStatus.INCOMPLETE,
|
ExecutionStatus.INCOMPLETE,
|
||||||
|
ExecutionStatus.REVIEW,
|
||||||
]:
|
]:
|
||||||
# If the graph is still on the queue, we can prevent them from being executed
|
# If the graph is queued/incomplete/paused for review, terminate immediately
|
||||||
# by setting the status to TERMINATED.
|
# No need to wait for executor since it's not actively running
|
||||||
|
|
||||||
|
# If graph is in REVIEW status, clean up pending reviews before terminating
|
||||||
|
if graph_exec.status == ExecutionStatus.REVIEW:
|
||||||
|
# Use human_review_db if Prisma connected, else database manager
|
||||||
|
review_db = (
|
||||||
|
human_review_db
|
||||||
|
if prisma.is_connected()
|
||||||
|
else get_database_manager_async_client()
|
||||||
|
)
|
||||||
|
# Mark all pending reviews as rejected/cancelled
|
||||||
|
cancelled_count = await review_db.cancel_pending_reviews_for_execution(
|
||||||
|
graph_exec_id, user_id
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Cancelled {cancelled_count} pending review(s) for stopped execution {graph_exec_id}"
|
||||||
|
)
|
||||||
|
|
||||||
graph_exec.status = ExecutionStatus.TERMINATED
|
graph_exec.status = ExecutionStatus.TERMINATED
|
||||||
|
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
@@ -873,11 +892,19 @@ async def add_graph_execution(
|
|||||||
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
||||||
|
|
||||||
execution_context = ExecutionContext(
|
execution_context = ExecutionContext(
|
||||||
|
# Execution identity
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_exec_id=graph_exec.id,
|
||||||
|
graph_version=graph_exec.graph_version,
|
||||||
|
# Safety settings
|
||||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||||
|
# User settings
|
||||||
user_timezone=(
|
user_timezone=(
|
||||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||||
),
|
),
|
||||||
|
# Execution hierarchy
|
||||||
root_execution_id=graph_exec.id,
|
root_execution_id=graph_exec.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -887,9 +914,28 @@ async def add_graph_execution(
|
|||||||
nodes_to_skip=nodes_to_skip,
|
nodes_to_skip=nodes_to_skip,
|
||||||
execution_context=execution_context,
|
execution_context=execution_context,
|
||||||
)
|
)
|
||||||
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
logger.info(f"Queueing execution {graph_exec.id}")
|
||||||
|
|
||||||
|
# Update execution status to QUEUED BEFORE publishing to prevent race condition
|
||||||
|
# where two concurrent requests could both publish the same execution
|
||||||
|
updated_exec = await edb.update_graph_execution_stats(
|
||||||
|
graph_exec_id=graph_exec.id,
|
||||||
|
status=ExecutionStatus.QUEUED,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the status update succeeded (prevents duplicate queueing in race conditions)
|
||||||
|
# If another request already updated the status, this execution will not be QUEUED
|
||||||
|
if not updated_exec or updated_exec.status != ExecutionStatus.QUEUED:
|
||||||
|
logger.warning(
|
||||||
|
f"Skipping queue publish for execution {graph_exec.id} - "
|
||||||
|
f"status update failed or execution already queued by another request"
|
||||||
|
)
|
||||||
|
return graph_exec
|
||||||
|
|
||||||
|
graph_exec.status = ExecutionStatus.QUEUED
|
||||||
|
|
||||||
# Publish to execution queue for executor to pick up
|
# Publish to execution queue for executor to pick up
|
||||||
|
# This happens AFTER status update to ensure only one request publishes
|
||||||
exec_queue = await get_async_execution_queue()
|
exec_queue = await get_async_execution_queue()
|
||||||
await exec_queue.publish_message(
|
await exec_queue.publish_message(
|
||||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||||
@@ -897,13 +943,6 @@ async def add_graph_execution(
|
|||||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||||
)
|
)
|
||||||
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
||||||
|
|
||||||
# Update execution status to QUEUED
|
|
||||||
graph_exec.status = ExecutionStatus.QUEUED
|
|
||||||
await edb.update_graph_execution_stats(
|
|
||||||
graph_exec_id=graph_exec.id,
|
|
||||||
status=graph_exec.status,
|
|
||||||
)
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
err = str(e) or type(e).__name__
|
err = str(e) or type(e).__name__
|
||||||
if not graph_exec:
|
if not graph_exec:
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import pytest
|
|||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from backend.data.dynamic_fields import merge_execution_input, parse_execution_output
|
from backend.data.dynamic_fields import merge_execution_input, parse_execution_output
|
||||||
|
from backend.data.execution import ExecutionStatus
|
||||||
from backend.util.mock import MockObject
|
from backend.util.mock import MockObject
|
||||||
|
|
||||||
|
|
||||||
@@ -346,6 +347,8 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
|||||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||||
mock_graph_exec.id = "execution-id-123"
|
mock_graph_exec.id = "execution-id-123"
|
||||||
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
||||||
|
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||||
|
mock_graph_exec.graph_version = graph_version
|
||||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||||
|
|
||||||
# Mock the queue and event bus
|
# Mock the queue and event bus
|
||||||
@@ -432,6 +435,9 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
|||||||
# Create a second mock execution for the sanity check
|
# Create a second mock execution for the sanity check
|
||||||
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||||
mock_graph_exec_2.id = "execution-id-456"
|
mock_graph_exec_2.id = "execution-id-456"
|
||||||
|
mock_graph_exec_2.node_executions = []
|
||||||
|
mock_graph_exec_2.status = ExecutionStatus.QUEUED
|
||||||
|
mock_graph_exec_2.graph_version = graph_version
|
||||||
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
|
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||||
|
|
||||||
# Reset mocks and set up for second call
|
# Reset mocks and set up for second call
|
||||||
@@ -611,6 +617,8 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
|||||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||||
mock_graph_exec.id = "execution-id-123"
|
mock_graph_exec.id = "execution-id-123"
|
||||||
mock_graph_exec.node_executions = []
|
mock_graph_exec.node_executions = []
|
||||||
|
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||||
|
mock_graph_exec.graph_version = graph_version
|
||||||
|
|
||||||
# Track what's passed to to_graph_execution_entry
|
# Track what's passed to to_graph_execution_entry
|
||||||
captured_kwargs = {}
|
captured_kwargs = {}
|
||||||
@@ -670,3 +678,232 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
|||||||
# Verify nodes_to_skip was passed to to_graph_execution_entry
|
# Verify nodes_to_skip was passed to to_graph_execution_entry
|
||||||
assert "nodes_to_skip" in captured_kwargs
|
assert "nodes_to_skip" in captured_kwargs
|
||||||
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_graph_execution_in_review_status_cancels_pending_reviews(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""Test that stopping an execution in REVIEW status cancels pending reviews."""
|
||||||
|
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||||
|
from backend.executor.utils import stop_graph_execution
|
||||||
|
|
||||||
|
user_id = "test-user"
|
||||||
|
graph_exec_id = "test-exec-123"
|
||||||
|
|
||||||
|
# Mock graph execution in REVIEW status
|
||||||
|
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||||
|
mock_graph_exec.id = graph_exec_id
|
||||||
|
mock_graph_exec.status = ExecutionStatus.REVIEW
|
||||||
|
|
||||||
|
# Mock dependencies
|
||||||
|
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||||
|
mock_queue_client = mocker.AsyncMock()
|
||||||
|
mock_get_queue.return_value = mock_queue_client
|
||||||
|
|
||||||
|
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||||
|
mock_prisma.is_connected.return_value = True
|
||||||
|
|
||||||
|
mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db")
|
||||||
|
mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock(
|
||||||
|
return_value=2 # 2 reviews cancelled
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_execution_db = mocker.patch("backend.executor.utils.execution_db")
|
||||||
|
mock_execution_db.get_graph_execution_meta = mocker.AsyncMock(
|
||||||
|
return_value=mock_graph_exec
|
||||||
|
)
|
||||||
|
mock_execution_db.update_graph_execution_stats = mocker.AsyncMock()
|
||||||
|
|
||||||
|
mock_get_event_bus = mocker.patch(
|
||||||
|
"backend.executor.utils.get_async_execution_event_bus"
|
||||||
|
)
|
||||||
|
mock_event_bus = mocker.MagicMock()
|
||||||
|
mock_event_bus.publish = mocker.AsyncMock()
|
||||||
|
mock_get_event_bus.return_value = mock_event_bus
|
||||||
|
|
||||||
|
mock_get_child_executions = mocker.patch(
|
||||||
|
"backend.executor.utils._get_child_executions"
|
||||||
|
)
|
||||||
|
mock_get_child_executions.return_value = [] # No children
|
||||||
|
|
||||||
|
# Call stop_graph_execution with timeout to allow status check
|
||||||
|
await stop_graph_execution(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
wait_timeout=1.0, # Wait to allow status check
|
||||||
|
cascade=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify pending reviews were cancelled
|
||||||
|
mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with(
|
||||||
|
graph_exec_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify execution status was updated to TERMINATED
|
||||||
|
mock_execution_db.update_graph_execution_stats.assert_called_once()
|
||||||
|
call_kwargs = mock_execution_db.update_graph_execution_stats.call_args[1]
|
||||||
|
assert call_kwargs["graph_exec_id"] == graph_exec_id
|
||||||
|
assert call_kwargs["status"] == ExecutionStatus.TERMINATED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_graph_execution_with_database_manager_when_prisma_disconnected(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""Test that stop uses database manager when Prisma is not connected."""
|
||||||
|
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||||
|
from backend.executor.utils import stop_graph_execution
|
||||||
|
|
||||||
|
user_id = "test-user"
|
||||||
|
graph_exec_id = "test-exec-456"
|
||||||
|
|
||||||
|
# Mock graph execution in REVIEW status
|
||||||
|
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||||
|
mock_graph_exec.id = graph_exec_id
|
||||||
|
mock_graph_exec.status = ExecutionStatus.REVIEW
|
||||||
|
|
||||||
|
# Mock dependencies
|
||||||
|
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||||
|
mock_queue_client = mocker.AsyncMock()
|
||||||
|
mock_get_queue.return_value = mock_queue_client
|
||||||
|
|
||||||
|
# Prisma is NOT connected
|
||||||
|
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||||
|
mock_prisma.is_connected.return_value = False
|
||||||
|
|
||||||
|
# Mock database manager client
|
||||||
|
mock_get_db_manager = mocker.patch(
|
||||||
|
"backend.executor.utils.get_database_manager_async_client"
|
||||||
|
)
|
||||||
|
mock_db_manager = mocker.AsyncMock()
|
||||||
|
mock_db_manager.get_graph_execution_meta = mocker.AsyncMock(
|
||||||
|
return_value=mock_graph_exec
|
||||||
|
)
|
||||||
|
mock_db_manager.cancel_pending_reviews_for_execution = mocker.AsyncMock(
|
||||||
|
return_value=3 # 3 reviews cancelled
|
||||||
|
)
|
||||||
|
mock_db_manager.update_graph_execution_stats = mocker.AsyncMock()
|
||||||
|
mock_get_db_manager.return_value = mock_db_manager
|
||||||
|
|
||||||
|
mock_get_event_bus = mocker.patch(
|
||||||
|
"backend.executor.utils.get_async_execution_event_bus"
|
||||||
|
)
|
||||||
|
mock_event_bus = mocker.MagicMock()
|
||||||
|
mock_event_bus.publish = mocker.AsyncMock()
|
||||||
|
mock_get_event_bus.return_value = mock_event_bus
|
||||||
|
|
||||||
|
mock_get_child_executions = mocker.patch(
|
||||||
|
"backend.executor.utils._get_child_executions"
|
||||||
|
)
|
||||||
|
mock_get_child_executions.return_value = [] # No children
|
||||||
|
|
||||||
|
# Call stop_graph_execution with timeout
|
||||||
|
await stop_graph_execution(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
wait_timeout=1.0,
|
||||||
|
cascade=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify database manager was used for cancel_pending_reviews
|
||||||
|
mock_db_manager.cancel_pending_reviews_for_execution.assert_called_once_with(
|
||||||
|
graph_exec_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify execution status was updated via database manager
|
||||||
|
mock_db_manager.update_graph_execution_stats.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_graph_execution_cascades_to_child_with_reviews(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
):
|
||||||
|
"""Test that stopping parent execution cascades to children and cancels their reviews."""
|
||||||
|
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||||
|
from backend.executor.utils import stop_graph_execution
|
||||||
|
|
||||||
|
user_id = "test-user"
|
||||||
|
parent_exec_id = "parent-exec"
|
||||||
|
child_exec_id = "child-exec"
|
||||||
|
|
||||||
|
# Mock parent execution in RUNNING status
|
||||||
|
mock_parent_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||||
|
mock_parent_exec.id = parent_exec_id
|
||||||
|
mock_parent_exec.status = ExecutionStatus.RUNNING
|
||||||
|
|
||||||
|
# Mock child execution in REVIEW status
|
||||||
|
mock_child_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||||
|
mock_child_exec.id = child_exec_id
|
||||||
|
mock_child_exec.status = ExecutionStatus.REVIEW
|
||||||
|
|
||||||
|
# Mock dependencies
|
||||||
|
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||||
|
mock_queue_client = mocker.AsyncMock()
|
||||||
|
mock_get_queue.return_value = mock_queue_client
|
||||||
|
|
||||||
|
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||||
|
mock_prisma.is_connected.return_value = True
|
||||||
|
|
||||||
|
mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db")
|
||||||
|
mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock(
|
||||||
|
return_value=1 # 1 child review cancelled
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock execution_db to return different status based on which execution is queried
|
||||||
|
mock_execution_db = mocker.patch("backend.executor.utils.execution_db")
|
||||||
|
|
||||||
|
# Track call count to simulate status transition
|
||||||
|
call_count = {"count": 0}
|
||||||
|
|
||||||
|
async def get_exec_meta_side_effect(execution_id, user_id):
|
||||||
|
call_count["count"] += 1
|
||||||
|
if execution_id == parent_exec_id:
|
||||||
|
# After a few calls (child processing happens), transition parent to TERMINATED
|
||||||
|
# This simulates the executor service processing the stop request
|
||||||
|
if call_count["count"] > 3:
|
||||||
|
mock_parent_exec.status = ExecutionStatus.TERMINATED
|
||||||
|
return mock_parent_exec
|
||||||
|
elif execution_id == child_exec_id:
|
||||||
|
return mock_child_exec
|
||||||
|
return None
|
||||||
|
|
||||||
|
mock_execution_db.get_graph_execution_meta = mocker.AsyncMock(
|
||||||
|
side_effect=get_exec_meta_side_effect
|
||||||
|
)
|
||||||
|
mock_execution_db.update_graph_execution_stats = mocker.AsyncMock()
|
||||||
|
|
||||||
|
mock_get_event_bus = mocker.patch(
|
||||||
|
"backend.executor.utils.get_async_execution_event_bus"
|
||||||
|
)
|
||||||
|
mock_event_bus = mocker.MagicMock()
|
||||||
|
mock_event_bus.publish = mocker.AsyncMock()
|
||||||
|
mock_get_event_bus.return_value = mock_event_bus
|
||||||
|
|
||||||
|
# Mock _get_child_executions to return the child
|
||||||
|
mock_get_child_executions = mocker.patch(
|
||||||
|
"backend.executor.utils._get_child_executions"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_children_side_effect(parent_id):
|
||||||
|
if parent_id == parent_exec_id:
|
||||||
|
return [mock_child_exec]
|
||||||
|
return []
|
||||||
|
|
||||||
|
mock_get_child_executions.side_effect = get_children_side_effect
|
||||||
|
|
||||||
|
# Call stop_graph_execution on parent with cascade=True
|
||||||
|
await stop_graph_execution(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=parent_exec_id,
|
||||||
|
wait_timeout=1.0,
|
||||||
|
cascade=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify child reviews were cancelled
|
||||||
|
mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with(
|
||||||
|
child_exec_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify both parent and child status updates
|
||||||
|
assert mock_execution_db.update_graph_execution_stats.call_count >= 1
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import aiohttp
|
|||||||
from gcloud.aio import storage as async_gcs_storage
|
from gcloud.aio import storage as async_gcs_storage
|
||||||
from google.cloud import storage as gcs_storage
|
from google.cloud import storage as gcs_storage
|
||||||
|
|
||||||
|
from backend.util.gcs_utils import download_with_fresh_session, generate_signed_url
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -251,7 +252,7 @@ class CloudStorageHandler:
|
|||||||
f"in_task: {current_task is not None}"
|
f"in_task: {current_task is not None}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse bucket and blob name from path
|
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
|
||||||
parts = path.split("/", 1)
|
parts = path.split("/", 1)
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
raise ValueError(f"Invalid GCS path: {path}")
|
raise ValueError(f"Invalid GCS path: {path}")
|
||||||
@@ -261,50 +262,19 @@ class CloudStorageHandler:
|
|||||||
# Authorization check
|
# Authorization check
|
||||||
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
||||||
|
|
||||||
# Use a fresh client for each download to avoid session issues
|
|
||||||
# This is less efficient but more reliable with the executor's event loop
|
|
||||||
logger.info("[CloudStorage] Creating fresh GCS client for download")
|
|
||||||
|
|
||||||
# Create a new session specifically for this download
|
|
||||||
session = aiohttp.ClientSession(
|
|
||||||
connector=aiohttp.TCPConnector(limit=10, force_close=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
async_client = None
|
|
||||||
try:
|
|
||||||
# Create a new GCS client with the fresh session
|
|
||||||
async_client = async_gcs_storage.Storage(session=session)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
|
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Download content using the fresh client
|
try:
|
||||||
content = await async_client.download(bucket_name, blob_name)
|
content = await download_with_fresh_session(bucket_name, blob_name)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CloudStorage] GCS download successful - size: {len(content)} bytes"
|
f"[CloudStorage] GCS download successful - size: {len(content)} bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clean up
|
|
||||||
await async_client.close()
|
|
||||||
await session.close()
|
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Always try to clean up
|
|
||||||
if async_client is not None:
|
|
||||||
try:
|
|
||||||
await async_client.close()
|
|
||||||
except Exception as cleanup_error:
|
|
||||||
logger.warning(
|
|
||||||
f"[CloudStorage] Error closing GCS client: {cleanup_error}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await session.close()
|
|
||||||
except Exception as cleanup_error:
|
|
||||||
logger.warning(f"[CloudStorage] Error closing session: {cleanup_error}")
|
|
||||||
|
|
||||||
# Log the specific error for debugging
|
# Log the specific error for debugging
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[CloudStorage] GCS download failed - error: {str(e)}, "
|
f"[CloudStorage] GCS download failed - error: {str(e)}, "
|
||||||
@@ -319,10 +289,6 @@ class CloudStorageHandler:
|
|||||||
f"current_task: {current_task}, "
|
f"current_task: {current_task}, "
|
||||||
f"bucket: {bucket_name}, blob: redacted for privacy"
|
f"bucket: {bucket_name}, blob: redacted for privacy"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert gcloud-aio exceptions to standard ones
|
|
||||||
if "404" in str(e) or "Not Found" in str(e):
|
|
||||||
raise FileNotFoundError(f"File not found: gcs://{path}")
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _validate_file_access(
|
def _validate_file_access(
|
||||||
@@ -445,8 +411,7 @@ class CloudStorageHandler:
|
|||||||
graph_exec_id: str | None = None,
|
graph_exec_id: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate signed URL for GCS with authorization."""
|
"""Generate signed URL for GCS with authorization."""
|
||||||
|
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
|
||||||
# Parse bucket and blob name from path
|
|
||||||
parts = path.split("/", 1)
|
parts = path.split("/", 1)
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
raise ValueError(f"Invalid GCS path: {path}")
|
raise ValueError(f"Invalid GCS path: {path}")
|
||||||
@@ -456,21 +421,11 @@ class CloudStorageHandler:
|
|||||||
# Authorization check
|
# Authorization check
|
||||||
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
||||||
|
|
||||||
# Use sync client for signed URLs since gcloud-aio doesn't support them
|
|
||||||
sync_client = self._get_sync_gcs_client()
|
sync_client = self._get_sync_gcs_client()
|
||||||
bucket = sync_client.bucket(bucket_name)
|
return await generate_signed_url(
|
||||||
blob = bucket.blob(blob_name)
|
sync_client, bucket_name, blob_name, expiration_hours * 3600
|
||||||
|
|
||||||
# Generate signed URL asynchronously using sync client
|
|
||||||
url = await asyncio.to_thread(
|
|
||||||
blob.generate_signed_url,
|
|
||||||
version="v4",
|
|
||||||
expiration=datetime.now(timezone.utc) + timedelta(hours=expiration_hours),
|
|
||||||
method="GET",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return url
|
|
||||||
|
|
||||||
async def delete_expired_files(self, provider: str = "gcs") -> int:
|
async def delete_expired_files(self, provider: str = "gcs") -> int:
|
||||||
"""
|
"""
|
||||||
Delete files that have passed their expiration time.
|
Delete files that have passed their expiration time.
|
||||||
|
|||||||
@@ -5,13 +5,28 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Literal
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from prisma.enums import WorkspaceFileSource
|
||||||
|
|
||||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
from backend.util.settings import Config
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
|
||||||
|
# Return format options for store_media_file
|
||||||
|
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||||
|
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
||||||
|
# - "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
|
||||||
|
MediaReturnFormat = Literal[
|
||||||
|
"for_local_processing", "for_external_api", "for_block_output"
|
||||||
|
]
|
||||||
|
|
||||||
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
|
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
|
||||||
|
|
||||||
# Maximum filename length (conservative limit for most filesystems)
|
# Maximum filename length (conservative limit for most filesystems)
|
||||||
@@ -67,42 +82,56 @@ def clean_exec_files(graph_exec_id: str, file: str = "") -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def store_media_file(
|
async def store_media_file(
|
||||||
graph_exec_id: str,
|
|
||||||
file: MediaFileType,
|
file: MediaFileType,
|
||||||
user_id: str,
|
execution_context: "ExecutionContext",
|
||||||
return_content: bool = False,
|
*,
|
||||||
|
return_format: MediaReturnFormat,
|
||||||
) -> MediaFileType:
|
) -> MediaFileType:
|
||||||
"""
|
"""
|
||||||
Safely handle 'file' (a data URI, a URL, or a local path relative to {temp}/exec_file/{exec_id}),
|
Safely handle 'file' (a data URI, a URL, a workspace:// reference, or a local path
|
||||||
placing or verifying it under:
|
relative to {temp}/exec_file/{exec_id}), placing or verifying it under:
|
||||||
{tempdir}/exec_file/{exec_id}/...
|
{tempdir}/exec_file/{exec_id}/...
|
||||||
|
|
||||||
If 'return_content=True', return a data URI (data:<mime>;base64,<content>).
|
For each MediaFileType input:
|
||||||
Otherwise, returns the file media path relative to the exec_id folder.
|
- Data URI: decode and store locally
|
||||||
|
- URL: download and store locally
|
||||||
|
- workspace:// reference: read from workspace, store locally
|
||||||
|
- Local path: verify it exists in exec_file directory
|
||||||
|
|
||||||
For each MediaFileType type:
|
Return format options:
|
||||||
- Data URI:
|
- "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||||
-> decode and store in a new random file in that folder
|
- "for_external_api": Returns data URI (base64) - use when sending to external APIs
|
||||||
- URL:
|
- "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
|
||||||
-> download and store in that folder
|
|
||||||
- Local path:
|
|
||||||
-> interpret as relative to that folder; verify it exists
|
|
||||||
(no copying, as it's presumably already there).
|
|
||||||
We realpath-check so no symlink or '..' can escape the folder.
|
|
||||||
|
|
||||||
|
:param file: Data URI, URL, workspace://, or local (relative) path.
|
||||||
:param graph_exec_id: The unique ID of the graph execution.
|
:param execution_context: ExecutionContext with user_id, graph_exec_id, workspace_id.
|
||||||
:param file: Data URI, URL, or local (relative) path.
|
:param return_format: What to return: "for_local_processing", "for_external_api", or "for_block_output".
|
||||||
:param return_content: If True, return a data URI of the file content.
|
:return: The requested result based on return_format.
|
||||||
If False, return the *relative* path inside the exec_id folder.
|
|
||||||
:return: The requested result: data URI or relative path of the media.
|
|
||||||
"""
|
"""
|
||||||
|
# Extract values from execution_context
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
user_id = execution_context.user_id
|
||||||
|
|
||||||
|
if not graph_exec_id:
|
||||||
|
raise ValueError("execution_context.graph_exec_id is required")
|
||||||
|
if not user_id:
|
||||||
|
raise ValueError("execution_context.user_id is required")
|
||||||
|
|
||||||
|
# Create workspace_manager if we have workspace_id (with session scoping)
|
||||||
|
# Import here to avoid circular import (file.py → workspace.py → data → blocks → file.py)
|
||||||
|
from backend.util.workspace import WorkspaceManager
|
||||||
|
|
||||||
|
workspace_manager: WorkspaceManager | None = None
|
||||||
|
if execution_context.workspace_id:
|
||||||
|
workspace_manager = WorkspaceManager(
|
||||||
|
user_id, execution_context.workspace_id, execution_context.session_id
|
||||||
|
)
|
||||||
# Build base path
|
# Build base path
|
||||||
base_path = Path(get_exec_file_path(graph_exec_id, ""))
|
base_path = Path(get_exec_file_path(graph_exec_id, ""))
|
||||||
base_path.mkdir(parents=True, exist_ok=True)
|
base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Security fix: Add disk space limits to prevent DoS
|
# Security fix: Add disk space limits to prevent DoS
|
||||||
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB per file
|
MAX_FILE_SIZE_BYTES = Config().max_file_size_mb * 1024 * 1024
|
||||||
MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory
|
MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory
|
||||||
|
|
||||||
# Check total disk usage in base_path
|
# Check total disk usage in base_path
|
||||||
@@ -142,9 +171,57 @@ async def store_media_file(
|
|||||||
"""
|
"""
|
||||||
return str(absolute_path.relative_to(base))
|
return str(absolute_path.relative_to(base))
|
||||||
|
|
||||||
# Check if this is a cloud storage path
|
# Get cloud storage handler for checking cloud paths
|
||||||
cloud_storage = await get_cloud_storage_handler()
|
cloud_storage = await get_cloud_storage_handler()
|
||||||
if cloud_storage.is_cloud_path(file):
|
|
||||||
|
# Track if the input came from workspace (don't re-save it)
|
||||||
|
is_from_workspace = file.startswith("workspace://")
|
||||||
|
|
||||||
|
# Check if this is a workspace file reference
|
||||||
|
if is_from_workspace:
|
||||||
|
if workspace_manager is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Workspace file reference requires workspace context. "
|
||||||
|
"This file type is only available in CoPilot sessions."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse workspace reference
|
||||||
|
# workspace://abc123 - by file ID
|
||||||
|
# workspace:///path/to/file.txt - by virtual path
|
||||||
|
file_ref = file[12:] # Remove "workspace://"
|
||||||
|
|
||||||
|
if file_ref.startswith("/"):
|
||||||
|
# Path reference
|
||||||
|
workspace_content = await workspace_manager.read_file(file_ref)
|
||||||
|
file_info = await workspace_manager.get_file_info_by_path(file_ref)
|
||||||
|
filename = sanitize_filename(
|
||||||
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# ID reference
|
||||||
|
workspace_content = await workspace_manager.read_file_by_id(file_ref)
|
||||||
|
file_info = await workspace_manager.get_file_info(file_ref)
|
||||||
|
filename = sanitize_filename(
|
||||||
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||||
|
except OSError as e:
|
||||||
|
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||||
|
|
||||||
|
# Check file size limit
|
||||||
|
if len(workspace_content) > MAX_FILE_SIZE_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large: {len(workspace_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Virus scan the workspace content before writing locally
|
||||||
|
await scan_content_safe(workspace_content, filename=filename)
|
||||||
|
target_path.write_bytes(workspace_content)
|
||||||
|
|
||||||
|
# Check if this is a cloud storage path
|
||||||
|
elif cloud_storage.is_cloud_path(file):
|
||||||
# Download from cloud storage and store locally
|
# Download from cloud storage and store locally
|
||||||
cloud_content = await cloud_storage.retrieve_file(
|
cloud_content = await cloud_storage.retrieve_file(
|
||||||
file, user_id=user_id, graph_exec_id=graph_exec_id
|
file, user_id=user_id, graph_exec_id=graph_exec_id
|
||||||
@@ -159,9 +236,9 @@ async def store_media_file(
|
|||||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||||
|
|
||||||
# Check file size limit
|
# Check file size limit
|
||||||
if len(cloud_content) > MAX_FILE_SIZE:
|
if len(cloud_content) > MAX_FILE_SIZE_BYTES:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE} bytes"
|
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Virus scan the cloud content before writing locally
|
# Virus scan the cloud content before writing locally
|
||||||
@@ -189,9 +266,9 @@ async def store_media_file(
|
|||||||
content = base64.b64decode(b64_content)
|
content = base64.b64decode(b64_content)
|
||||||
|
|
||||||
# Check file size limit
|
# Check file size limit
|
||||||
if len(content) > MAX_FILE_SIZE:
|
if len(content) > MAX_FILE_SIZE_BYTES:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE} bytes"
|
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Virus scan the base64 content before writing
|
# Virus scan the base64 content before writing
|
||||||
@@ -199,23 +276,31 @@ async def store_media_file(
|
|||||||
target_path.write_bytes(content)
|
target_path.write_bytes(content)
|
||||||
|
|
||||||
elif file.startswith(("http://", "https://")):
|
elif file.startswith(("http://", "https://")):
|
||||||
# URL
|
# URL - download first to get Content-Type header
|
||||||
|
resp = await Requests().get(file)
|
||||||
|
|
||||||
|
# Check file size limit
|
||||||
|
if len(resp.content) > MAX_FILE_SIZE_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract filename from URL path
|
||||||
parsed_url = urlparse(file)
|
parsed_url = urlparse(file)
|
||||||
filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}")
|
filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}")
|
||||||
|
|
||||||
|
# If filename lacks extension, add one from Content-Type header
|
||||||
|
if "." not in filename:
|
||||||
|
content_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
|
||||||
|
if content_type:
|
||||||
|
ext = _extension_from_mime(content_type)
|
||||||
|
filename = f"{filename}{ext}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||||
|
|
||||||
# Download and save
|
|
||||||
resp = await Requests().get(file)
|
|
||||||
|
|
||||||
# Check file size limit
|
|
||||||
if len(resp.content) > MAX_FILE_SIZE:
|
|
||||||
raise ValueError(
|
|
||||||
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE} bytes"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Virus scan the downloaded content before writing
|
# Virus scan the downloaded content before writing
|
||||||
await scan_content_safe(resp.content, filename=filename)
|
await scan_content_safe(resp.content, filename=filename)
|
||||||
target_path.write_bytes(resp.content)
|
target_path.write_bytes(resp.content)
|
||||||
@@ -230,12 +315,46 @@ async def store_media_file(
|
|||||||
if not target_path.is_file():
|
if not target_path.is_file():
|
||||||
raise ValueError(f"Local file does not exist: {target_path}")
|
raise ValueError(f"Local file does not exist: {target_path}")
|
||||||
|
|
||||||
# Return result
|
# Return based on requested format
|
||||||
if return_content:
|
if return_format == "for_local_processing":
|
||||||
return MediaFileType(_file_to_data_uri(target_path))
|
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
||||||
else:
|
# Returns: relative path in exec_file directory (e.g., "image.png")
|
||||||
return MediaFileType(_strip_base_prefix(target_path, base_path))
|
return MediaFileType(_strip_base_prefix(target_path, base_path))
|
||||||
|
|
||||||
|
elif return_format == "for_external_api":
|
||||||
|
# Use when sending content to external APIs that need base64
|
||||||
|
# Returns: data URI (e.g., "...")
|
||||||
|
return MediaFileType(_file_to_data_uri(target_path))
|
||||||
|
|
||||||
|
elif return_format == "for_block_output":
|
||||||
|
# Use when returning output from a block to user/next block
|
||||||
|
# Returns: workspace:// ref (CoPilot) or data URI (graph execution)
|
||||||
|
if workspace_manager is None:
|
||||||
|
# No workspace available (graph execution without CoPilot)
|
||||||
|
# Fallback to data URI so the content can still be used/displayed
|
||||||
|
return MediaFileType(_file_to_data_uri(target_path))
|
||||||
|
|
||||||
|
# Don't re-save if input was already from workspace
|
||||||
|
if is_from_workspace:
|
||||||
|
# Return original workspace reference
|
||||||
|
return MediaFileType(file)
|
||||||
|
|
||||||
|
# Save new content to workspace
|
||||||
|
content = target_path.read_bytes()
|
||||||
|
filename = target_path.name
|
||||||
|
|
||||||
|
file_record = await workspace_manager.write_file(
|
||||||
|
content=content,
|
||||||
|
filename=filename,
|
||||||
|
source=WorkspaceFileSource.COPILOT,
|
||||||
|
source_session_id=execution_context.session_id,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
return MediaFileType(f"workspace://{file_record.id}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid return_format: {return_format}")
|
||||||
|
|
||||||
|
|
||||||
def get_dir_size(path: Path) -> int:
|
def get_dir_size(path: Path) -> int:
|
||||||
"""Get total size of directory."""
|
"""Get total size of directory."""
|
||||||
|
|||||||
@@ -7,10 +7,22 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_context(
|
||||||
|
graph_exec_id: str = "test-exec-123",
|
||||||
|
user_id: str = "test-user-123",
|
||||||
|
) -> ExecutionContext:
|
||||||
|
"""Helper to create test ExecutionContext."""
|
||||||
|
return ExecutionContext(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestFileCloudIntegration:
|
class TestFileCloudIntegration:
|
||||||
"""Test cases for cloud storage integration in file utilities."""
|
"""Test cases for cloud storage integration in file utilities."""
|
||||||
|
|
||||||
@@ -70,10 +82,9 @@ class TestFileCloudIntegration:
|
|||||||
mock_path_class.side_effect = path_constructor
|
mock_path_class.side_effect = path_constructor
|
||||||
|
|
||||||
result = await store_media_file(
|
result = await store_media_file(
|
||||||
graph_exec_id,
|
file=MediaFileType(cloud_path),
|
||||||
MediaFileType(cloud_path),
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
"test-user-123",
|
return_format="for_local_processing",
|
||||||
return_content=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify cloud storage operations
|
# Verify cloud storage operations
|
||||||
@@ -144,10 +155,9 @@ class TestFileCloudIntegration:
|
|||||||
mock_path_obj.name = "image.png"
|
mock_path_obj.name = "image.png"
|
||||||
with patch("backend.util.file.Path", return_value=mock_path_obj):
|
with patch("backend.util.file.Path", return_value=mock_path_obj):
|
||||||
result = await store_media_file(
|
result = await store_media_file(
|
||||||
graph_exec_id,
|
file=MediaFileType(cloud_path),
|
||||||
MediaFileType(cloud_path),
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
"test-user-123",
|
return_format="for_external_api",
|
||||||
return_content=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify result is a data URI
|
# Verify result is a data URI
|
||||||
@@ -198,10 +208,9 @@ class TestFileCloudIntegration:
|
|||||||
mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt")
|
mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt")
|
||||||
|
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id,
|
file=MediaFileType(data_uri),
|
||||||
MediaFileType(data_uri),
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
"test-user-123",
|
return_format="for_local_processing",
|
||||||
return_content=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify cloud handler was checked but not used for retrieval
|
# Verify cloud handler was checked but not used for retrieval
|
||||||
@@ -234,5 +243,7 @@ class TestFileCloudIntegration:
|
|||||||
FileNotFoundError, match="File not found in cloud storage"
|
FileNotFoundError, match="File not found in cloud storage"
|
||||||
):
|
):
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id, MediaFileType(cloud_path), "test-user-123"
|
file=MediaFileType(cloud_path),
|
||||||
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|||||||
160
autogpt_platform/backend/backend/util/gcs_utils.py
Normal file
160
autogpt_platform/backend/backend/util/gcs_utils.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
"""
|
||||||
|
Shared GCS utilities for workspace and cloud storage backends.
|
||||||
|
|
||||||
|
This module provides common functionality for working with Google Cloud Storage,
|
||||||
|
including path parsing, client management, and signed URL generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from gcloud.aio import storage as async_gcs_storage
|
||||||
|
from google.cloud import storage as gcs_storage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_gcs_path(path: str) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Parse a GCS path in the format 'gcs://bucket/blob' to (bucket, blob).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: GCS path string (e.g., "gcs://my-bucket/path/to/file")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (bucket_name, blob_name)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the path format is invalid
|
||||||
|
"""
|
||||||
|
if not path.startswith("gcs://"):
|
||||||
|
raise ValueError(f"Invalid GCS path: {path}")
|
||||||
|
|
||||||
|
path_without_prefix = path[6:] # Remove "gcs://"
|
||||||
|
parts = path_without_prefix.split("/", 1)
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise ValueError(f"Invalid GCS path format: {path}")
|
||||||
|
|
||||||
|
return parts[0], parts[1]
|
||||||
|
|
||||||
|
|
||||||
|
class GCSClientManager:
|
||||||
|
"""
|
||||||
|
Manages async and sync GCS clients with lazy initialization.
|
||||||
|
|
||||||
|
This class provides a unified way to manage GCS client lifecycle,
|
||||||
|
supporting both async operations (uploads, downloads) and sync
|
||||||
|
operations that require service account credentials (signed URLs).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._async_client: Optional[async_gcs_storage.Storage] = None
|
||||||
|
self._sync_client: Optional[gcs_storage.Client] = None
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
|
||||||
|
async def get_async_client(self) -> async_gcs_storage.Storage:
|
||||||
|
"""
|
||||||
|
Get or create async GCS client.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Async GCS storage client
|
||||||
|
"""
|
||||||
|
if self._async_client is None:
|
||||||
|
self._session = aiohttp.ClientSession(
|
||||||
|
connector=aiohttp.TCPConnector(limit=100, force_close=False)
|
||||||
|
)
|
||||||
|
self._async_client = async_gcs_storage.Storage(session=self._session)
|
||||||
|
return self._async_client
|
||||||
|
|
||||||
|
def get_sync_client(self) -> gcs_storage.Client:
|
||||||
|
"""
|
||||||
|
Get or create sync GCS client (used for signed URLs).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sync GCS storage client
|
||||||
|
"""
|
||||||
|
if self._sync_client is None:
|
||||||
|
self._sync_client = gcs_storage.Client()
|
||||||
|
return self._sync_client
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close all client connections."""
|
||||||
|
if self._async_client is not None:
|
||||||
|
try:
|
||||||
|
await self._async_client.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing GCS client: {e}")
|
||||||
|
self._async_client = None
|
||||||
|
|
||||||
|
if self._session is not None:
|
||||||
|
try:
|
||||||
|
await self._session.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing session: {e}")
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
|
||||||
|
async def download_with_fresh_session(bucket: str, blob: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Download file content using a fresh session.
|
||||||
|
|
||||||
|
This approach avoids event loop issues that can occur when reusing
|
||||||
|
sessions across different async contexts (e.g., in executors).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bucket: GCS bucket name
|
||||||
|
blob: Blob path within the bucket
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content as bytes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the file doesn't exist
|
||||||
|
"""
|
||||||
|
session = aiohttp.ClientSession(
|
||||||
|
connector=aiohttp.TCPConnector(limit=10, force_close=True)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
client = async_gcs_storage.Storage(session=session)
|
||||||
|
content = await client.download(bucket, blob)
|
||||||
|
await client.close()
|
||||||
|
return content
|
||||||
|
except Exception as e:
|
||||||
|
if "404" in str(e) or "Not Found" in str(e):
|
||||||
|
raise FileNotFoundError(f"File not found: gcs://{bucket}/{blob}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_signed_url(
|
||||||
|
sync_client: gcs_storage.Client,
|
||||||
|
bucket_name: str,
|
||||||
|
blob_name: str,
|
||||||
|
expires_in: int,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a signed URL for temporary access to a GCS file.
|
||||||
|
|
||||||
|
Uses asyncio.to_thread() to run the sync operation without blocking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sync_client: Sync GCS client with service account credentials
|
||||||
|
bucket_name: GCS bucket name
|
||||||
|
blob_name: Blob path within the bucket
|
||||||
|
expires_in: URL expiration time in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Signed URL string
|
||||||
|
"""
|
||||||
|
bucket = sync_client.bucket(bucket_name)
|
||||||
|
blob = bucket.blob(blob_name)
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
blob.generate_signed_url,
|
||||||
|
version="v4",
|
||||||
|
expiration=datetime.now(timezone.utc) + timedelta(seconds=expires_in),
|
||||||
|
method="GET",
|
||||||
|
)
|
||||||
@@ -263,6 +263,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
description="The name of the Google Cloud Storage bucket for media files",
|
description="The name of the Google Cloud Storage bucket for media files",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
workspace_storage_dir: str = Field(
|
||||||
|
default="",
|
||||||
|
description="Local directory for workspace file storage when GCS is not configured. "
|
||||||
|
"If empty, defaults to {app_data}/workspaces. Used for self-hosted deployments.",
|
||||||
|
)
|
||||||
|
|
||||||
reddit_user_agent: str = Field(
|
reddit_user_agent: str = Field(
|
||||||
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
|
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
|
||||||
description="The user agent for the Reddit API",
|
description="The user agent for the Reddit API",
|
||||||
@@ -350,6 +356,19 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
description="Whether to mark failed scans as clean or not",
|
description="Whether to mark failed scans as clean or not",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agentgenerator_host: str = Field(
|
||||||
|
default="",
|
||||||
|
description="The host for the Agent Generator service (empty to use built-in)",
|
||||||
|
)
|
||||||
|
agentgenerator_port: int = Field(
|
||||||
|
default=8000,
|
||||||
|
description="The port for the Agent Generator service",
|
||||||
|
)
|
||||||
|
agentgenerator_timeout: int = Field(
|
||||||
|
default=600,
|
||||||
|
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
|
||||||
|
)
|
||||||
|
|
||||||
enable_example_blocks: bool = Field(
|
enable_example_blocks: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to enable example blocks in production",
|
description="Whether to enable example blocks in production",
|
||||||
@@ -376,6 +395,13 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
description="Maximum file size in MB for file uploads (1-1024 MB)",
|
description="Maximum file size in MB for file uploads (1-1024 MB)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
max_file_size_mb: int = Field(
|
||||||
|
default=100,
|
||||||
|
ge=1,
|
||||||
|
le=1024,
|
||||||
|
description="Maximum file size in MB for workspace files (1-1024 MB)",
|
||||||
|
)
|
||||||
|
|
||||||
# AutoMod configuration
|
# AutoMod configuration
|
||||||
automod_enabled: bool = Field(
|
automod_enabled: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
@@ -666,6 +692,12 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|||||||
default="https://cloud.langfuse.com", description="Langfuse host URL"
|
default="https://cloud.langfuse.com", description="Langfuse host URL"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# PostHog analytics
|
||||||
|
posthog_api_key: str = Field(default="", description="PostHog API key")
|
||||||
|
posthog_host: str = Field(
|
||||||
|
default="https://eu.i.posthog.com", description="PostHog host URL"
|
||||||
|
)
|
||||||
|
|
||||||
# Add more secret fields as needed
|
# Add more secret fields as needed
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
env_file=".env",
|
env_file=".env",
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
@@ -58,6 +59,11 @@ class SpinTestServer:
|
|||||||
self.db_api.__exit__(exc_type, exc_val, exc_tb)
|
self.db_api.__exit__(exc_type, exc_val, exc_tb)
|
||||||
self.notif_manager.__exit__(exc_type, exc_val, exc_tb)
|
self.notif_manager.__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
|
# Give services time to fully shut down
|
||||||
|
# This prevents event loop issues where services haven't fully cleaned up
|
||||||
|
# before the next test starts
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
def setup_dependency_overrides(self):
|
def setup_dependency_overrides(self):
|
||||||
# Override get_user_id for testing
|
# Override get_user_id for testing
|
||||||
self.agent_server.set_test_dependency_overrides(
|
self.agent_server.set_test_dependency_overrides(
|
||||||
@@ -134,14 +140,29 @@ async def execute_block_test(block: Block):
|
|||||||
setattr(block, mock_name, mock_obj)
|
setattr(block, mock_name, mock_obj)
|
||||||
|
|
||||||
# Populate credentials argument(s)
|
# Populate credentials argument(s)
|
||||||
|
# Generate IDs for execution context
|
||||||
|
graph_id = str(uuid.uuid4())
|
||||||
|
node_id = str(uuid.uuid4())
|
||||||
|
graph_exec_id = str(uuid.uuid4())
|
||||||
|
node_exec_id = str(uuid.uuid4())
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
graph_version = 1 # Default version for tests
|
||||||
|
|
||||||
extra_exec_kwargs: dict = {
|
extra_exec_kwargs: dict = {
|
||||||
"graph_id": str(uuid.uuid4()),
|
"graph_id": graph_id,
|
||||||
"node_id": str(uuid.uuid4()),
|
"node_id": node_id,
|
||||||
"graph_exec_id": str(uuid.uuid4()),
|
"graph_exec_id": graph_exec_id,
|
||||||
"node_exec_id": str(uuid.uuid4()),
|
"node_exec_id": node_exec_id,
|
||||||
"user_id": str(uuid.uuid4()),
|
"user_id": user_id,
|
||||||
"graph_version": 1, # Default version for tests
|
"graph_version": graph_version,
|
||||||
"execution_context": ExecutionContext(),
|
"execution_context": ExecutionContext(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
graph_version=graph_version,
|
||||||
|
node_id=node_id,
|
||||||
|
node_exec_id=node_exec_id,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
input_model = cast(type[BlockSchema], block.input_schema)
|
input_model = cast(type[BlockSchema], block.input_schema)
|
||||||
|
|
||||||
|
|||||||
432
autogpt_platform/backend/backend/util/workspace.py
Normal file
432
autogpt_platform/backend/backend/util/workspace.py
Normal file
@@ -0,0 +1,432 @@
|
|||||||
|
"""
|
||||||
|
WorkspaceManager for managing user workspace file operations.
|
||||||
|
|
||||||
|
This module provides a high-level interface for workspace file operations,
|
||||||
|
combining the storage backend and database layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from prisma.enums import WorkspaceFileSource
|
||||||
|
from prisma.errors import UniqueViolationError
|
||||||
|
from prisma.models import UserWorkspaceFile
|
||||||
|
|
||||||
|
from backend.data.workspace import (
|
||||||
|
count_workspace_files,
|
||||||
|
create_workspace_file,
|
||||||
|
get_workspace_file,
|
||||||
|
get_workspace_file_by_path,
|
||||||
|
list_workspace_files,
|
||||||
|
soft_delete_workspace_file,
|
||||||
|
)
|
||||||
|
from backend.util.settings import Config
|
||||||
|
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceManager:
|
||||||
|
"""
|
||||||
|
Manages workspace file operations.
|
||||||
|
|
||||||
|
Combines storage backend operations with database record management.
|
||||||
|
Supports session-scoped file segmentation where files are stored in
|
||||||
|
session-specific virtual paths: /sessions/{session_id}/{filename}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, user_id: str, workspace_id: str, session_id: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize WorkspaceManager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
session_id: Optional session ID for session-scoped file access
|
||||||
|
"""
|
||||||
|
self.user_id = user_id
|
||||||
|
self.workspace_id = workspace_id
|
||||||
|
self.session_id = session_id
|
||||||
|
# Session path prefix for file isolation
|
||||||
|
self.session_path = f"/sessions/{session_id}" if session_id else ""
|
||||||
|
|
||||||
|
def _resolve_path(self, path: str) -> str:
|
||||||
|
"""
|
||||||
|
Resolve a path, defaulting to session folder if session_id is set.
|
||||||
|
|
||||||
|
Cross-session access is allowed by explicitly using /sessions/other-session-id/...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Virtual path (e.g., "/file.txt" or "/sessions/abc123/file.txt")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved path with session prefix if applicable
|
||||||
|
"""
|
||||||
|
# If path explicitly references a session folder, use it as-is
|
||||||
|
if path.startswith("/sessions/"):
|
||||||
|
return path
|
||||||
|
|
||||||
|
# If we have a session context, prepend session path
|
||||||
|
if self.session_path:
|
||||||
|
# Normalize the path
|
||||||
|
if not path.startswith("/"):
|
||||||
|
path = f"/{path}"
|
||||||
|
return f"{self.session_path}{path}"
|
||||||
|
|
||||||
|
# No session context, use path as-is
|
||||||
|
return path if path.startswith("/") else f"/{path}"
|
||||||
|
|
||||||
|
def _get_effective_path(
|
||||||
|
self, path: Optional[str], include_all_sessions: bool
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get effective path for list/count operations based on session context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Optional path prefix to filter
|
||||||
|
include_all_sessions: If True, don't apply session scoping
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Effective path prefix for database query
|
||||||
|
"""
|
||||||
|
if include_all_sessions:
|
||||||
|
# Normalize path to ensure leading slash (stored paths are normalized)
|
||||||
|
if path is not None and not path.startswith("/"):
|
||||||
|
return f"/{path}"
|
||||||
|
return path
|
||||||
|
elif path is not None:
|
||||||
|
# Resolve the provided path with session scoping
|
||||||
|
return self._resolve_path(path)
|
||||||
|
elif self.session_path:
|
||||||
|
# Default to session folder with trailing slash to prevent prefix collisions
|
||||||
|
# e.g., "/sessions/abc" should not match "/sessions/abc123"
|
||||||
|
return self.session_path.rstrip("/") + "/"
|
||||||
|
else:
|
||||||
|
# No session context, use path as-is
|
||||||
|
return path
|
||||||
|
|
||||||
|
async def read_file(self, path: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Read file from workspace by virtual path.
|
||||||
|
|
||||||
|
When session_id is set, paths are resolved relative to the session folder
|
||||||
|
unless they explicitly reference /sessions/...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Virtual path (e.g., "/documents/report.pdf")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content as bytes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If file doesn't exist
|
||||||
|
"""
|
||||||
|
resolved_path = self._resolve_path(path)
|
||||||
|
file = await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||||
|
if file is None:
|
||||||
|
raise FileNotFoundError(f"File not found at path: {resolved_path}")
|
||||||
|
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
return await storage.retrieve(file.storagePath)
|
||||||
|
|
||||||
|
async def read_file_by_id(self, file_id: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Read file from workspace by file ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content as bytes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If file doesn't exist
|
||||||
|
"""
|
||||||
|
file = await get_workspace_file(file_id, self.workspace_id)
|
||||||
|
if file is None:
|
||||||
|
raise FileNotFoundError(f"File not found: {file_id}")
|
||||||
|
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
return await storage.retrieve(file.storagePath)
|
||||||
|
|
||||||
|
async def write_file(
|
||||||
|
self,
|
||||||
|
content: bytes,
|
||||||
|
filename: str,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
mime_type: Optional[str] = None,
|
||||||
|
source: WorkspaceFileSource = WorkspaceFileSource.UPLOAD,
|
||||||
|
source_exec_id: Optional[str] = None,
|
||||||
|
source_session_id: Optional[str] = None,
|
||||||
|
overwrite: bool = False,
|
||||||
|
) -> UserWorkspaceFile:
|
||||||
|
"""
|
||||||
|
Write file to workspace.
|
||||||
|
|
||||||
|
When session_id is set, files are written to /sessions/{session_id}/...
|
||||||
|
by default. Use explicit /sessions/... paths for cross-session access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: File content as bytes
|
||||||
|
filename: Filename for the file
|
||||||
|
path: Virtual path (defaults to "/{filename}", session-scoped if session_id set)
|
||||||
|
mime_type: MIME type (auto-detected if not provided)
|
||||||
|
source: How the file was created
|
||||||
|
source_exec_id: Graph execution ID if from execution
|
||||||
|
source_session_id: Chat session ID if from CoPilot
|
||||||
|
overwrite: Whether to overwrite existing file at path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created UserWorkspaceFile instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If file exceeds size limit or path already exists
|
||||||
|
"""
|
||||||
|
# Enforce file size limit
|
||||||
|
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||||
|
if len(content) > max_file_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large: {len(content)} bytes exceeds "
|
||||||
|
f"{Config().max_file_size_mb}MB limit"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine path with session scoping
|
||||||
|
if path is None:
|
||||||
|
path = f"/{filename}"
|
||||||
|
elif not path.startswith("/"):
|
||||||
|
path = f"/{path}"
|
||||||
|
|
||||||
|
# Resolve path with session prefix
|
||||||
|
path = self._resolve_path(path)
|
||||||
|
|
||||||
|
# Check if file exists at path
|
||||||
|
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
||||||
|
if existing is not None:
|
||||||
|
if overwrite:
|
||||||
|
# Delete existing file first
|
||||||
|
await self.delete_file(existing.id)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"File already exists at path: {path}")
|
||||||
|
|
||||||
|
# Auto-detect MIME type if not provided
|
||||||
|
if mime_type is None:
|
||||||
|
mime_type, _ = mimetypes.guess_type(filename)
|
||||||
|
mime_type = mime_type or "application/octet-stream"
|
||||||
|
|
||||||
|
# Compute checksum
|
||||||
|
checksum = compute_file_checksum(content)
|
||||||
|
|
||||||
|
# Generate unique file ID for storage
|
||||||
|
file_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Store file in storage backend
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
storage_path = await storage.store(
|
||||||
|
workspace_id=self.workspace_id,
|
||||||
|
file_id=file_id,
|
||||||
|
filename=filename,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create database record - handle race condition where another request
|
||||||
|
# created a file at the same path between our check and create
|
||||||
|
try:
|
||||||
|
file = await create_workspace_file(
|
||||||
|
workspace_id=self.workspace_id,
|
||||||
|
file_id=file_id,
|
||||||
|
name=filename,
|
||||||
|
path=path,
|
||||||
|
storage_path=storage_path,
|
||||||
|
mime_type=mime_type,
|
||||||
|
size_bytes=len(content),
|
||||||
|
checksum=checksum,
|
||||||
|
source=source,
|
||||||
|
source_exec_id=source_exec_id,
|
||||||
|
source_session_id=source_session_id,
|
||||||
|
)
|
||||||
|
except UniqueViolationError:
|
||||||
|
# Race condition: another request created a file at this path
|
||||||
|
if overwrite:
|
||||||
|
# Re-fetch and delete the conflicting file, then retry
|
||||||
|
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
||||||
|
if existing:
|
||||||
|
await self.delete_file(existing.id)
|
||||||
|
# Retry the create - if this also fails, clean up storage file
|
||||||
|
try:
|
||||||
|
file = await create_workspace_file(
|
||||||
|
workspace_id=self.workspace_id,
|
||||||
|
file_id=file_id,
|
||||||
|
name=filename,
|
||||||
|
path=path,
|
||||||
|
storage_path=storage_path,
|
||||||
|
mime_type=mime_type,
|
||||||
|
size_bytes=len(content),
|
||||||
|
checksum=checksum,
|
||||||
|
source=source,
|
||||||
|
source_exec_id=source_exec_id,
|
||||||
|
source_session_id=source_session_id,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Clean up orphaned storage file on retry failure
|
||||||
|
try:
|
||||||
|
await storage.delete(storage_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# Clean up the orphaned storage file before raising
|
||||||
|
try:
|
||||||
|
await storage.delete(storage_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||||
|
raise ValueError(f"File already exists at path: {path}")
|
||||||
|
except Exception:
|
||||||
|
# Any other database error (connection, validation, etc.) - clean up storage
|
||||||
|
try:
|
||||||
|
await storage.delete(storage_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Wrote file {file.id} ({filename}) to workspace {self.workspace_id} "
|
||||||
|
f"at path {path}, size={len(content)} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
return file
|
||||||
|
|
||||||
|
async def list_files(
|
||||||
|
self,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
include_all_sessions: bool = False,
|
||||||
|
) -> list[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
List files in workspace.
|
||||||
|
|
||||||
|
When session_id is set and include_all_sessions is False (default),
|
||||||
|
only files in the current session's folder are listed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Optional path prefix to filter (e.g., "/documents/")
|
||||||
|
limit: Maximum number of files to return
|
||||||
|
offset: Number of files to skip
|
||||||
|
include_all_sessions: If True, list files from all sessions.
|
||||||
|
If False (default), only list current session's files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of UserWorkspaceFile instances
|
||||||
|
"""
|
||||||
|
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||||
|
|
||||||
|
return await list_workspace_files(
|
||||||
|
workspace_id=self.workspace_id,
|
||||||
|
path_prefix=effective_path,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_file(self, file_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a file (soft-delete).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False if not found
|
||||||
|
"""
|
||||||
|
file = await get_workspace_file(file_id, self.workspace_id)
|
||||||
|
if file is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Delete from storage
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
try:
|
||||||
|
await storage.delete(file.storagePath)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete file from storage: {e}")
|
||||||
|
# Continue with database soft-delete even if storage delete fails
|
||||||
|
|
||||||
|
# Soft-delete database record
|
||||||
|
result = await soft_delete_workspace_file(file_id, self.workspace_id)
|
||||||
|
return result is not None
|
||||||
|
|
||||||
|
async def get_download_url(self, file_id: str, expires_in: int = 3600) -> str:
|
||||||
|
"""
|
||||||
|
Get download URL for a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file's ID
|
||||||
|
expires_in: URL expiration in seconds (default 1 hour)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Download URL (signed URL for GCS, API endpoint for local)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If file doesn't exist
|
||||||
|
"""
|
||||||
|
file = await get_workspace_file(file_id, self.workspace_id)
|
||||||
|
if file is None:
|
||||||
|
raise FileNotFoundError(f"File not found: {file_id}")
|
||||||
|
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
return await storage.get_download_url(file.storagePath, expires_in)
|
||||||
|
|
||||||
|
async def get_file_info(self, file_id: str) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Get file metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspaceFile instance or None
|
||||||
|
"""
|
||||||
|
return await get_workspace_file(file_id, self.workspace_id)
|
||||||
|
|
||||||
|
async def get_file_info_by_path(self, path: str) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Get file metadata by path.
|
||||||
|
|
||||||
|
When session_id is set, paths are resolved relative to the session folder
|
||||||
|
unless they explicitly reference /sessions/...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Virtual path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspaceFile instance or None
|
||||||
|
"""
|
||||||
|
resolved_path = self._resolve_path(path)
|
||||||
|
return await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||||
|
|
||||||
|
async def get_file_count(
|
||||||
|
self,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
include_all_sessions: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Get number of files in workspace.
|
||||||
|
|
||||||
|
When session_id is set and include_all_sessions is False (default),
|
||||||
|
only counts files in the current session's folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Optional path prefix to filter (e.g., "/documents/")
|
||||||
|
include_all_sessions: If True, count all files in workspace.
|
||||||
|
If False (default), only count current session's files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of files
|
||||||
|
"""
|
||||||
|
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||||
|
|
||||||
|
return await count_workspace_files(
|
||||||
|
self.workspace_id, path_prefix=effective_path
|
||||||
|
)
|
||||||
398
autogpt_platform/backend/backend/util/workspace_storage.py
Normal file
398
autogpt_platform/backend/backend/util/workspace_storage.py
Normal file
@@ -0,0 +1,398 @@
|
|||||||
|
"""
|
||||||
|
Workspace storage backend abstraction for supporting both cloud and local deployments.
|
||||||
|
|
||||||
|
This module provides a unified interface for storing workspace files, with implementations
|
||||||
|
for Google Cloud Storage (cloud deployments) and local filesystem (self-hosted deployments).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import aiohttp
|
||||||
|
from gcloud.aio import storage as async_gcs_storage
|
||||||
|
from google.cloud import storage as gcs_storage
|
||||||
|
|
||||||
|
from backend.util.data import get_data_path
|
||||||
|
from backend.util.gcs_utils import (
|
||||||
|
download_with_fresh_session,
|
||||||
|
generate_signed_url,
|
||||||
|
parse_gcs_path,
|
||||||
|
)
|
||||||
|
from backend.util.settings import Config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceStorageBackend(ABC):
|
||||||
|
"""Abstract interface for workspace file storage."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def store(
|
||||||
|
self,
|
||||||
|
workspace_id: str,
|
||||||
|
file_id: str,
|
||||||
|
filename: str,
|
||||||
|
content: bytes,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Store file content, return storage path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
file_id: Unique file ID for storage
|
||||||
|
filename: Original filename
|
||||||
|
content: File content as bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path string (cloud path or local path)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def retrieve(self, storage_path: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Retrieve file content from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path: The storage path returned from store()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content as bytes
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete(self, storage_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete file from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path: The storage path to delete
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
||||||
|
"""
|
||||||
|
Get URL for downloading the file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path: The storage path
|
||||||
|
expires_in: URL expiration time in seconds (default 1 hour)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Download URL (signed URL for GCS, direct API path for local)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GCSWorkspaceStorage(WorkspaceStorageBackend):
|
||||||
|
"""Google Cloud Storage implementation for workspace storage."""
|
||||||
|
|
||||||
|
def __init__(self, bucket_name: str):
|
||||||
|
self.bucket_name = bucket_name
|
||||||
|
self._async_client: Optional[async_gcs_storage.Storage] = None
|
||||||
|
self._sync_client: Optional[gcs_storage.Client] = None
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
|
||||||
|
async def _get_async_client(self) -> async_gcs_storage.Storage:
|
||||||
|
"""Get or create async GCS client."""
|
||||||
|
if self._async_client is None:
|
||||||
|
self._session = aiohttp.ClientSession(
|
||||||
|
connector=aiohttp.TCPConnector(limit=100, force_close=False)
|
||||||
|
)
|
||||||
|
self._async_client = async_gcs_storage.Storage(session=self._session)
|
||||||
|
return self._async_client
|
||||||
|
|
||||||
|
def _get_sync_client(self) -> gcs_storage.Client:
|
||||||
|
"""Get or create sync GCS client (for signed URLs)."""
|
||||||
|
if self._sync_client is None:
|
||||||
|
self._sync_client = gcs_storage.Client()
|
||||||
|
return self._sync_client
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close all client connections."""
|
||||||
|
if self._async_client is not None:
|
||||||
|
try:
|
||||||
|
await self._async_client.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing GCS client: {e}")
|
||||||
|
self._async_client = None
|
||||||
|
|
||||||
|
if self._session is not None:
|
||||||
|
try:
|
||||||
|
await self._session.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing session: {e}")
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
def _build_blob_name(self, workspace_id: str, file_id: str, filename: str) -> str:
|
||||||
|
"""Build the blob path for workspace files."""
|
||||||
|
return f"workspaces/{workspace_id}/{file_id}/{filename}"
|
||||||
|
|
||||||
|
async def store(
|
||||||
|
self,
|
||||||
|
workspace_id: str,
|
||||||
|
file_id: str,
|
||||||
|
filename: str,
|
||||||
|
content: bytes,
|
||||||
|
) -> str:
|
||||||
|
"""Store file in GCS."""
|
||||||
|
client = await self._get_async_client()
|
||||||
|
blob_name = self._build_blob_name(workspace_id, file_id, filename)
|
||||||
|
|
||||||
|
# Upload with metadata
|
||||||
|
upload_time = datetime.now(timezone.utc)
|
||||||
|
await client.upload(
|
||||||
|
self.bucket_name,
|
||||||
|
blob_name,
|
||||||
|
content,
|
||||||
|
metadata={
|
||||||
|
"uploaded_at": upload_time.isoformat(),
|
||||||
|
"workspace_id": workspace_id,
|
||||||
|
"file_id": file_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"gcs://{self.bucket_name}/{blob_name}"
|
||||||
|
|
||||||
|
async def retrieve(self, storage_path: str) -> bytes:
|
||||||
|
"""Retrieve file from GCS."""
|
||||||
|
bucket_name, blob_name = parse_gcs_path(storage_path)
|
||||||
|
return await download_with_fresh_session(bucket_name, blob_name)
|
||||||
|
|
||||||
|
async def delete(self, storage_path: str) -> None:
|
||||||
|
"""Delete file from GCS."""
|
||||||
|
bucket_name, blob_name = parse_gcs_path(storage_path)
|
||||||
|
client = await self._get_async_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.delete(bucket_name, blob_name)
|
||||||
|
except Exception as e:
|
||||||
|
if "404" not in str(e) and "Not Found" not in str(e):
|
||||||
|
raise
|
||||||
|
# File already deleted, that's fine
|
||||||
|
|
||||||
|
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
||||||
|
"""
|
||||||
|
Generate download URL for GCS file.
|
||||||
|
|
||||||
|
Attempts to generate a signed URL if running with service account credentials.
|
||||||
|
Falls back to an API proxy endpoint if signed URL generation fails
|
||||||
|
(e.g., when running locally with user OAuth credentials).
|
||||||
|
"""
|
||||||
|
bucket_name, blob_name = parse_gcs_path(storage_path)
|
||||||
|
|
||||||
|
# Extract file_id from blob_name for fallback: workspaces/{workspace_id}/{file_id}/{filename}
|
||||||
|
blob_parts = blob_name.split("/")
|
||||||
|
file_id = blob_parts[2] if len(blob_parts) >= 3 else None
|
||||||
|
|
||||||
|
# Try to generate signed URL (requires service account credentials)
|
||||||
|
try:
|
||||||
|
sync_client = self._get_sync_client()
|
||||||
|
return await generate_signed_url(
|
||||||
|
sync_client, bucket_name, blob_name, expires_in
|
||||||
|
)
|
||||||
|
except AttributeError as e:
|
||||||
|
# Signed URL generation requires service account with private key.
|
||||||
|
# When running with user OAuth credentials, fall back to API proxy.
|
||||||
|
if "private key" in str(e) and file_id:
|
||||||
|
logger.debug(
|
||||||
|
"Cannot generate signed URL (no service account credentials), "
|
||||||
|
"falling back to API proxy endpoint"
|
||||||
|
)
|
||||||
|
return f"/api/workspace/files/{file_id}/download"
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class LocalWorkspaceStorage(WorkspaceStorageBackend):
|
||||||
|
"""Local filesystem implementation for workspace storage (self-hosted deployments)."""
|
||||||
|
|
||||||
|
def __init__(self, base_dir: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize local storage backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_dir: Base directory for workspace storage.
|
||||||
|
If None, defaults to {app_data}/workspaces
|
||||||
|
"""
|
||||||
|
if base_dir:
|
||||||
|
self.base_dir = Path(base_dir)
|
||||||
|
else:
|
||||||
|
self.base_dir = Path(get_data_path()) / "workspaces"
|
||||||
|
|
||||||
|
# Ensure base directory exists
|
||||||
|
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def _build_file_path(self, workspace_id: str, file_id: str, filename: str) -> Path:
|
||||||
|
"""Build the local file path with path traversal protection."""
|
||||||
|
# Import here to avoid circular import
|
||||||
|
# (file.py imports workspace.py which imports workspace_storage.py)
|
||||||
|
from backend.util.file import sanitize_filename
|
||||||
|
|
||||||
|
# Sanitize filename to prevent path traversal (removes / and \ among others)
|
||||||
|
safe_filename = sanitize_filename(filename)
|
||||||
|
file_path = (self.base_dir / workspace_id / file_id / safe_filename).resolve()
|
||||||
|
|
||||||
|
# Verify the resolved path is still under base_dir
|
||||||
|
if not file_path.is_relative_to(self.base_dir.resolve()):
|
||||||
|
raise ValueError("Invalid filename: path traversal detected")
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
def _parse_storage_path(self, storage_path: str) -> Path:
|
||||||
|
"""Parse local storage path to filesystem path."""
|
||||||
|
if storage_path.startswith("local://"):
|
||||||
|
relative_path = storage_path[8:] # Remove "local://"
|
||||||
|
else:
|
||||||
|
relative_path = storage_path
|
||||||
|
|
||||||
|
full_path = (self.base_dir / relative_path).resolve()
|
||||||
|
|
||||||
|
# Security check: ensure path is under base_dir
|
||||||
|
# Use is_relative_to() for robust path containment check
|
||||||
|
# (handles case-insensitive filesystems and edge cases)
|
||||||
|
if not full_path.is_relative_to(self.base_dir.resolve()):
|
||||||
|
raise ValueError("Invalid storage path: path traversal detected")
|
||||||
|
|
||||||
|
return full_path
|
||||||
|
|
||||||
|
async def store(
|
||||||
|
self,
|
||||||
|
workspace_id: str,
|
||||||
|
file_id: str,
|
||||||
|
filename: str,
|
||||||
|
content: bytes,
|
||||||
|
) -> str:
|
||||||
|
"""Store file locally."""
|
||||||
|
file_path = self._build_file_path(workspace_id, file_id, filename)
|
||||||
|
|
||||||
|
# Create parent directories
|
||||||
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Write file asynchronously
|
||||||
|
async with aiofiles.open(file_path, "wb") as f:
|
||||||
|
await f.write(content)
|
||||||
|
|
||||||
|
# Return relative path as storage path
|
||||||
|
relative_path = file_path.relative_to(self.base_dir)
|
||||||
|
return f"local://{relative_path}"
|
||||||
|
|
||||||
|
async def retrieve(self, storage_path: str) -> bytes:
|
||||||
|
"""Retrieve file from local storage."""
|
||||||
|
file_path = self._parse_storage_path(storage_path)
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {storage_path}")
|
||||||
|
|
||||||
|
async with aiofiles.open(file_path, "rb") as f:
|
||||||
|
return await f.read()
|
||||||
|
|
||||||
|
async def delete(self, storage_path: str) -> None:
|
||||||
|
"""Delete file from local storage."""
|
||||||
|
file_path = self._parse_storage_path(storage_path)
|
||||||
|
|
||||||
|
if file_path.exists():
|
||||||
|
# Remove file
|
||||||
|
file_path.unlink()
|
||||||
|
|
||||||
|
# Clean up empty parent directories
|
||||||
|
parent = file_path.parent
|
||||||
|
while parent != self.base_dir:
|
||||||
|
try:
|
||||||
|
if parent.exists() and not any(parent.iterdir()):
|
||||||
|
parent.rmdir()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
except OSError:
|
||||||
|
break
|
||||||
|
parent = parent.parent
|
||||||
|
|
||||||
|
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
||||||
|
"""
|
||||||
|
Get download URL for local file.
|
||||||
|
|
||||||
|
For local storage, this returns an API endpoint path.
|
||||||
|
The actual serving is handled by the API layer.
|
||||||
|
"""
|
||||||
|
# Parse the storage path to get the components
|
||||||
|
if storage_path.startswith("local://"):
|
||||||
|
relative_path = storage_path[8:]
|
||||||
|
else:
|
||||||
|
relative_path = storage_path
|
||||||
|
|
||||||
|
# Return the API endpoint for downloading
|
||||||
|
# The file_id is extracted from the path: {workspace_id}/{file_id}/{filename}
|
||||||
|
parts = relative_path.split("/")
|
||||||
|
if len(parts) >= 2:
|
||||||
|
file_id = parts[1] # Second component is file_id
|
||||||
|
return f"/api/workspace/files/{file_id}/download"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid storage path format: {storage_path}")
|
||||||
|
|
||||||
|
|
||||||
|
# Global storage backend instance
|
||||||
|
_workspace_storage: Optional[WorkspaceStorageBackend] = None
|
||||||
|
_storage_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_storage() -> WorkspaceStorageBackend:
|
||||||
|
"""
|
||||||
|
Get the workspace storage backend instance.
|
||||||
|
|
||||||
|
Uses GCS if media_gcs_bucket_name is configured, otherwise uses local storage.
|
||||||
|
"""
|
||||||
|
global _workspace_storage
|
||||||
|
|
||||||
|
if _workspace_storage is None:
|
||||||
|
async with _storage_lock:
|
||||||
|
if _workspace_storage is None:
|
||||||
|
config = Config()
|
||||||
|
|
||||||
|
if config.media_gcs_bucket_name:
|
||||||
|
logger.info(
|
||||||
|
f"Using GCS workspace storage: {config.media_gcs_bucket_name}"
|
||||||
|
)
|
||||||
|
_workspace_storage = GCSWorkspaceStorage(
|
||||||
|
config.media_gcs_bucket_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
storage_dir = (
|
||||||
|
config.workspace_storage_dir
|
||||||
|
if config.workspace_storage_dir
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Using local workspace storage: {storage_dir or 'default'}"
|
||||||
|
)
|
||||||
|
_workspace_storage = LocalWorkspaceStorage(storage_dir)
|
||||||
|
|
||||||
|
return _workspace_storage
|
||||||
|
|
||||||
|
|
||||||
|
async def shutdown_workspace_storage() -> None:
|
||||||
|
"""
|
||||||
|
Properly shutdown the global workspace storage backend.
|
||||||
|
|
||||||
|
Closes aiohttp sessions and other resources for GCS backend.
|
||||||
|
Should be called during application shutdown.
|
||||||
|
"""
|
||||||
|
global _workspace_storage
|
||||||
|
|
||||||
|
if _workspace_storage is not None:
|
||||||
|
async with _storage_lock:
|
||||||
|
if _workspace_storage is not None:
|
||||||
|
if isinstance(_workspace_storage, GCSWorkspaceStorage):
|
||||||
|
await _workspace_storage.close()
|
||||||
|
_workspace_storage = None
|
||||||
|
|
||||||
|
|
||||||
|
def compute_file_checksum(content: bytes) -> str:
|
||||||
|
"""Compute SHA256 checksum of file content."""
|
||||||
|
return hashlib.sha256(content).hexdigest()
|
||||||
@@ -1,12 +1,37 @@
|
|||||||
-- CreateExtension
|
-- CreateExtension
|
||||||
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
||||||
-- Creates extension in current schema (determined by search_path from DATABASE_URL ?schema= param)
|
-- Ensures vector extension is in the current schema (from DATABASE_URL ?schema= param)
|
||||||
|
-- If it exists in a different schema (e.g., public), we drop and recreate it in the current schema
|
||||||
-- This ensures vector type is in the same schema as tables, making ::vector work without explicit qualification
|
-- This ensures vector type is in the same schema as tables, making ::vector work without explicit qualification
|
||||||
DO $$
|
DO $$
|
||||||
|
DECLARE
|
||||||
|
current_schema_name text;
|
||||||
|
vector_schema text;
|
||||||
BEGIN
|
BEGIN
|
||||||
CREATE EXTENSION IF NOT EXISTS "vector";
|
-- Get the current schema from search_path
|
||||||
EXCEPTION WHEN OTHERS THEN
|
SELECT current_schema() INTO current_schema_name;
|
||||||
RAISE NOTICE 'vector extension not available or already exists, skipping';
|
|
||||||
|
-- Check if vector extension exists and which schema it's in
|
||||||
|
SELECT n.nspname INTO vector_schema
|
||||||
|
FROM pg_extension e
|
||||||
|
JOIN pg_namespace n ON e.extnamespace = n.oid
|
||||||
|
WHERE e.extname = 'vector';
|
||||||
|
|
||||||
|
-- Handle removal if in wrong schema
|
||||||
|
IF vector_schema IS NOT NULL AND vector_schema != current_schema_name THEN
|
||||||
|
BEGIN
|
||||||
|
-- Vector exists in a different schema, drop it first
|
||||||
|
RAISE WARNING 'pgvector found in schema "%" but need it in "%". Dropping and reinstalling...',
|
||||||
|
vector_schema, current_schema_name;
|
||||||
|
EXECUTE 'DROP EXTENSION IF EXISTS vector CASCADE';
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE EXCEPTION 'Failed to drop pgvector from schema "%": %. You may need to drop it manually.',
|
||||||
|
vector_schema, SQLERRM;
|
||||||
|
END;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Create extension in current schema (let it fail naturally if not available)
|
||||||
|
EXECUTE format('CREATE EXTENSION IF NOT EXISTS vector SCHEMA %I', current_schema_name);
|
||||||
END $$;
|
END $$;
|
||||||
|
|
||||||
-- CreateEnum
|
-- CreateEnum
|
||||||
|
|||||||
@@ -1,71 +0,0 @@
|
|||||||
-- Acknowledge Supabase-managed extensions to prevent drift warnings
|
|
||||||
-- These extensions are pre-installed by Supabase in specific schemas
|
|
||||||
-- This migration ensures they exist where available (Supabase) or skips gracefully (CI)
|
|
||||||
|
|
||||||
-- Create schemas (safe in both CI and Supabase)
|
|
||||||
CREATE SCHEMA IF NOT EXISTS "extensions";
|
|
||||||
|
|
||||||
-- Extensions that exist in both CI and Supabase
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pgcrypto extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'uuid-ossp extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
-- Supabase-specific extensions (skip gracefully in CI)
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pg_stat_statements extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pg_net" WITH SCHEMA "extensions";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pg_net extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pgjwt extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE SCHEMA IF NOT EXISTS "graphql";
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pg_graphql extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE SCHEMA IF NOT EXISTS "pgsodium";
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'pgsodium extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
CREATE SCHEMA IF NOT EXISTS "vault";
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault";
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE NOTICE 'supabase_vault extension not available, skipping';
|
|
||||||
END $$;
|
|
||||||
|
|
||||||
|
|
||||||
-- Return to platform
|
|
||||||
CREATE SCHEMA IF NOT EXISTS "platform";
|
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
-- Remove NodeExecution foreign key from PendingHumanReview
|
||||||
|
-- The nodeExecId column remains as the primary key, but we remove the FK constraint
|
||||||
|
-- to AgentNodeExecution since PendingHumanReview records can persist after node
|
||||||
|
-- execution records are deleted.
|
||||||
|
|
||||||
|
-- Drop foreign key constraint that linked PendingHumanReview.nodeExecId to AgentNodeExecution.id
|
||||||
|
ALTER TABLE "PendingHumanReview" DROP CONSTRAINT IF EXISTS "PendingHumanReview_nodeExecId_fkey";
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
-- AlterEnum
|
||||||
|
ALTER TYPE "OnboardingStep" ADD VALUE 'VISIT_COPILOT';
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
-- CreateEnum
|
||||||
|
CREATE TYPE "WorkspaceFileSource" AS ENUM ('UPLOAD', 'EXECUTION', 'COPILOT', 'IMPORT');
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "UserWorkspace" (
|
||||||
|
"id" TEXT NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
"userId" TEXT NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT "UserWorkspace_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "UserWorkspaceFile" (
|
||||||
|
"id" TEXT NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
"workspaceId" TEXT NOT NULL,
|
||||||
|
"name" TEXT NOT NULL,
|
||||||
|
"path" TEXT NOT NULL,
|
||||||
|
"storagePath" TEXT NOT NULL,
|
||||||
|
"mimeType" TEXT NOT NULL,
|
||||||
|
"sizeBytes" BIGINT NOT NULL,
|
||||||
|
"checksum" TEXT,
|
||||||
|
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
"deletedAt" TIMESTAMP(3),
|
||||||
|
"source" "WorkspaceFileSource" NOT NULL DEFAULT 'UPLOAD',
|
||||||
|
"sourceExecId" TEXT,
|
||||||
|
"sourceSessionId" TEXT,
|
||||||
|
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||||
|
|
||||||
|
CONSTRAINT "UserWorkspaceFile_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "UserWorkspace_userId_key" ON "UserWorkspace"("userId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "UserWorkspace_userId_idx" ON "UserWorkspace"("userId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "UserWorkspaceFile_workspaceId_isDeleted_idx" ON "UserWorkspaceFile"("workspaceId", "isDeleted");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "UserWorkspaceFile_workspaceId_path_key" ON "UserWorkspaceFile"("workspaceId", "path");
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "UserWorkspace" ADD CONSTRAINT "UserWorkspace_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "UserWorkspaceFile" ADD CONSTRAINT "UserWorkspaceFile_workspaceId_fkey" FOREIGN KEY ("workspaceId") REFERENCES "UserWorkspace"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
12
autogpt_platform/backend/poetry.lock
generated
12
autogpt_platform/backend/poetry.lock
generated
@@ -4204,14 +4204,14 @@ strenum = {version = ">=0.4.9,<0.5.0", markers = "python_version < \"3.11\""}
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "posthog"
|
name = "posthog"
|
||||||
version = "6.1.1"
|
version = "7.6.0"
|
||||||
description = "Integrate PostHog into any python application."
|
description = "Integrate PostHog into any python application."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.10"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "posthog-6.1.1-py3-none-any.whl", hash = "sha256:329fd3d06b4d54cec925f47235bd8e327c91403c2f9ec38f1deb849535934dba"},
|
{file = "posthog-7.6.0-py3-none-any.whl", hash = "sha256:c4dd78cf77c4fecceb965f86066e5ac37886ef867d68ffe75a1db5d681d7d9ad"},
|
||||||
{file = "posthog-6.1.1.tar.gz", hash = "sha256:b453f54c4a2589da859fd575dd3bf86fcb40580727ec399535f268b1b9f318b8"},
|
{file = "posthog-7.6.0.tar.gz", hash = "sha256:941dfd278ee427c9b14640f09b35b5bb52a71bdf028d7dbb7307e1838fd3002e"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -4225,7 +4225,7 @@ typing-extensions = ">=4.2.0"
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
dev = ["django-stubs", "lxml", "mypy", "mypy-baseline", "packaging", "pre-commit", "pydantic", "ruff", "setuptools", "tomli", "tomli_w", "twine", "types-mock", "types-python-dateutil", "types-requests", "types-setuptools", "types-six", "wheel"]
|
dev = ["django-stubs", "lxml", "mypy", "mypy-baseline", "packaging", "pre-commit", "pydantic", "ruff", "setuptools", "tomli", "tomli_w", "twine", "types-mock", "types-python-dateutil", "types-requests", "types-setuptools", "types-six", "wheel"]
|
||||||
langchain = ["langchain (>=0.2.0)"]
|
langchain = ["langchain (>=0.2.0)"]
|
||||||
test = ["anthropic", "coverage", "django", "freezegun (==1.5.1)", "google-genai", "langchain-anthropic (>=0.3.15)", "langchain-community (>=0.3.25)", "langchain-core (>=0.3.65)", "langchain-openai (>=0.3.22)", "langgraph (>=0.4.8)", "mock (>=2.0.0)", "openai", "parameterized (>=0.8.1)", "pydantic", "pytest", "pytest-asyncio", "pytest-timeout"]
|
test = ["anthropic (>=0.72)", "coverage", "django", "freezegun (==1.5.1)", "google-genai", "langchain-anthropic (>=1.0)", "langchain-community (>=0.4)", "langchain-core (>=1.0)", "langchain-openai (>=1.0)", "langgraph (>=1.0)", "mock (>=2.0.0)", "openai (>=2.0)", "parameterized (>=0.8.1)", "pydantic", "pytest", "pytest-asyncio", "pytest-timeout"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "postmarker"
|
name = "postmarker"
|
||||||
@@ -7512,4 +7512,4 @@ cffi = ["cffi (>=1.11)"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.10,<3.14"
|
python-versions = ">=3.10,<3.14"
|
||||||
content-hash = "18b92e09596298c82432e4d0a85cb6d80a40b4229bee0a0c15f0529fd6cb21a4"
|
content-hash = "ee5742dc1a9df50dfc06d4b26a1682cbb2b25cab6b79ce5625ec272f93e4f4bf"
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ exa-py = "^1.14.20"
|
|||||||
croniter = "^6.0.0"
|
croniter = "^6.0.0"
|
||||||
stagehand = "^0.5.1"
|
stagehand = "^0.5.1"
|
||||||
gravitas-md2gdocs = "^0.1.0"
|
gravitas-md2gdocs = "^0.1.0"
|
||||||
|
posthog = "^7.6.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
aiohappyeyeballs = "^2.6.1"
|
aiohappyeyeballs = "^2.6.1"
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ model User {
|
|||||||
IntegrationWebhooks IntegrationWebhook[]
|
IntegrationWebhooks IntegrationWebhook[]
|
||||||
NotificationBatches UserNotificationBatch[]
|
NotificationBatches UserNotificationBatch[]
|
||||||
PendingHumanReviews PendingHumanReview[]
|
PendingHumanReviews PendingHumanReview[]
|
||||||
|
Workspace UserWorkspace?
|
||||||
|
|
||||||
// OAuth Provider relations
|
// OAuth Provider relations
|
||||||
OAuthApplications OAuthApplication[]
|
OAuthApplications OAuthApplication[]
|
||||||
@@ -81,6 +82,7 @@ enum OnboardingStep {
|
|||||||
AGENT_INPUT
|
AGENT_INPUT
|
||||||
CONGRATS
|
CONGRATS
|
||||||
// First Wins
|
// First Wins
|
||||||
|
VISIT_COPILOT
|
||||||
GET_RESULTS
|
GET_RESULTS
|
||||||
MARKETPLACE_VISIT
|
MARKETPLACE_VISIT
|
||||||
MARKETPLACE_ADD_AGENT
|
MARKETPLACE_ADD_AGENT
|
||||||
@@ -136,6 +138,66 @@ model CoPilotUnderstanding {
|
|||||||
@@index([userId])
|
@@index([userId])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
//////////////// USER WORKSPACE TABLES /////////////////
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// User's persistent file storage workspace
|
||||||
|
model UserWorkspace {
|
||||||
|
id String @id @default(uuid())
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
|
userId String @unique
|
||||||
|
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
Files UserWorkspaceFile[]
|
||||||
|
|
||||||
|
@@index([userId])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Source of workspace file creation
|
||||||
|
enum WorkspaceFileSource {
|
||||||
|
UPLOAD // Direct user upload
|
||||||
|
EXECUTION // Created by graph execution
|
||||||
|
COPILOT // Created by CoPilot session
|
||||||
|
IMPORT // Imported from external source
|
||||||
|
}
|
||||||
|
|
||||||
|
// Individual files in a user's workspace
|
||||||
|
model UserWorkspaceFile {
|
||||||
|
id String @id @default(uuid())
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
|
workspaceId String
|
||||||
|
Workspace UserWorkspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
// File metadata
|
||||||
|
name String // User-visible filename
|
||||||
|
path String // Virtual path (e.g., "/documents/report.pdf")
|
||||||
|
storagePath String // Actual GCS or local storage path
|
||||||
|
mimeType String
|
||||||
|
sizeBytes BigInt
|
||||||
|
checksum String? // SHA256 for integrity
|
||||||
|
|
||||||
|
// File state
|
||||||
|
isDeleted Boolean @default(false)
|
||||||
|
deletedAt DateTime?
|
||||||
|
|
||||||
|
// Source tracking
|
||||||
|
source WorkspaceFileSource @default(UPLOAD)
|
||||||
|
sourceExecId String? // graph_exec_id if from execution
|
||||||
|
sourceSessionId String? // chat_session_id if from CoPilot
|
||||||
|
|
||||||
|
metadata Json @default("{}")
|
||||||
|
|
||||||
|
@@unique([workspaceId, path])
|
||||||
|
@@index([workspaceId, isDeleted])
|
||||||
|
}
|
||||||
|
|
||||||
model BuilderSearchHistory {
|
model BuilderSearchHistory {
|
||||||
id String @id @default(uuid())
|
id String @id @default(uuid())
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
@@ -517,8 +579,6 @@ model AgentNodeExecution {
|
|||||||
|
|
||||||
stats Json?
|
stats Json?
|
||||||
|
|
||||||
PendingHumanReview PendingHumanReview?
|
|
||||||
|
|
||||||
@@index([agentGraphExecutionId, agentNodeId, executionStatus])
|
@@index([agentGraphExecutionId, agentNodeId, executionStatus])
|
||||||
@@index([agentNodeId, executionStatus])
|
@@index([agentNodeId, executionStatus])
|
||||||
@@index([addedTime, queuedTime])
|
@@index([addedTime, queuedTime])
|
||||||
@@ -567,6 +627,7 @@ enum ReviewStatus {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Pending human reviews for Human-in-the-loop blocks
|
// Pending human reviews for Human-in-the-loop blocks
|
||||||
|
// Also stores auto-approval records with special nodeExecId patterns (e.g., "auto_approve_{graph_exec_id}_{node_id}")
|
||||||
model PendingHumanReview {
|
model PendingHumanReview {
|
||||||
nodeExecId String @id
|
nodeExecId String @id
|
||||||
userId String
|
userId String
|
||||||
@@ -585,7 +646,6 @@ model PendingHumanReview {
|
|||||||
reviewedAt DateTime?
|
reviewedAt DateTime?
|
||||||
|
|
||||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
NodeExecution AgentNodeExecution @relation(fields: [nodeExecId], references: [id], onDelete: Cascade)
|
|
||||||
GraphExecution AgentGraphExecution @relation(fields: [graphExecId], references: [id], onDelete: Cascade)
|
GraphExecution AgentGraphExecution @relation(fields: [graphExecId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
@@unique([nodeExecId]) // One pending review per node execution
|
@@unique([nodeExecId]) // One pending review per node execution
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user