mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-28 08:28:00 -05:00
Compare commits
19 Commits
user-works
...
claude/fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
77453f5f15 | ||
|
|
9ffff490b5 | ||
|
|
097949b3e7 | ||
|
|
7f1a1f636f | ||
|
|
4dc4ca4256 | ||
|
|
00730496e3 | ||
|
|
21c753b971 | ||
|
|
732dfcbb63 | ||
|
|
eebaf7df14 | ||
|
|
653aab44b6 | ||
|
|
f0bc3f2a49 | ||
|
|
e702d77cdf | ||
|
|
38741d2465 | ||
|
|
25d9dbac83 | ||
|
|
fcbecf3502 | ||
|
|
da9c4a4adf | ||
|
|
0ca73004e5 | ||
|
|
9a786ed8d9 | ||
|
|
0a435e2ffb |
@@ -194,50 +194,6 @@ ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
**Handling files in blocks with `store_media_file()`:**
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
**Modifying the API:**
|
||||
|
||||
1. Update route in `/backend/backend/server/routers/`
|
||||
|
||||
@@ -33,15 +33,9 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||
max_agent_runs: int = Field(default=3, description="Maximum number of agent runs")
|
||||
max_agent_schedules: int = Field(
|
||||
default=30, description="Maximum number of agent schedules"
|
||||
)
|
||||
|
||||
# Long-running operation configuration
|
||||
long_running_operation_ttl: int = Field(
|
||||
default=600,
|
||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||
default=3, description="Maximum number of agent schedules"
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
|
||||
@@ -247,45 +247,3 @@ async def get_chat_session_message_count(session_id: str) -> int:
|
||||
"""Get the number of messages in a chat session."""
|
||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||
return count
|
||||
|
||||
|
||||
async def update_tool_message_content(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
new_content: str,
|
||||
) -> bool:
|
||||
"""Update the content of a tool message in chat history.
|
||||
|
||||
Used by background tasks to update pending operation messages with final results.
|
||||
|
||||
Args:
|
||||
session_id: The chat session ID.
|
||||
tool_call_id: The tool call ID to find the message.
|
||||
new_content: The new content to set.
|
||||
|
||||
Returns:
|
||||
True if a message was updated, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = await PrismaChatMessage.prisma().update_many(
|
||||
where={
|
||||
"sessionId": session_id,
|
||||
"toolCallId": tool_call_id,
|
||||
},
|
||||
data={
|
||||
"content": new_content,
|
||||
},
|
||||
)
|
||||
if result == 0:
|
||||
logger.warning(
|
||||
f"No message found to update for session {session_id}, "
|
||||
f"tool_call_id {tool_call_id}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update tool message for session {session_id}, "
|
||||
f"tool_call_id {tool_call_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -295,21 +295,6 @@ async def cache_chat_session(session: ChatSession) -> None:
|
||||
await _cache_session(session)
|
||||
|
||||
|
||||
async def invalidate_session_cache(session_id: str) -> None:
|
||||
"""Invalidate a chat session from Redis cache.
|
||||
|
||||
Used by background tasks to ensure fresh data is loaded on next access.
|
||||
This is best-effort - Redis failures are logged but don't fail the operation.
|
||||
"""
|
||||
try:
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
# Best-effort: log but don't fail - cache will expire naturally
|
||||
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
|
||||
@@ -17,7 +17,6 @@ from openai import (
|
||||
)
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import (
|
||||
format_understanding_for_prompt,
|
||||
get_business_understanding,
|
||||
@@ -25,7 +24,6 @@ from backend.data.understanding import (
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
@@ -33,7 +31,6 @@ from .model import (
|
||||
Usage,
|
||||
cache_chat_session,
|
||||
get_chat_session,
|
||||
invalidate_session_cache,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
@@ -51,13 +48,8 @@ from .response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from .tools import execute_tool, get_tool, tools
|
||||
from .tools.models import (
|
||||
ErrorResponse,
|
||||
OperationInProgressResponse,
|
||||
OperationPendingResponse,
|
||||
OperationStartedResponse,
|
||||
)
|
||||
from .tools import execute_tool, tools
|
||||
from .tools.models import ErrorResponse
|
||||
from .tracking import track_user_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -69,126 +61,11 @@ client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
langfuse = get_client()
|
||||
|
||||
# Redis key prefix for tracking running long-running operations
|
||||
# Used for idempotency across Kubernetes pods - prevents duplicate executions on browser refresh
|
||||
RUNNING_OPERATION_PREFIX = "chat:running_operation:"
|
||||
|
||||
# Default system prompt used when Langfuse is not configured
|
||||
# This is a snapshot of the "CoPilot Prompt" from Langfuse (version 11)
|
||||
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
|
||||
class LangfuseNotConfiguredError(Exception):
|
||||
"""Raised when Langfuse is required but not configured."""
|
||||
|
||||
Here is everything you know about the current user from previous interactions:
|
||||
|
||||
<users_information>
|
||||
{users_information}
|
||||
</users_information>
|
||||
|
||||
## YOUR CORE MANDATE
|
||||
|
||||
You are action-oriented. Your success is measured by:
|
||||
- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"?
|
||||
- **Demonstrable Proof**: Show working automations, not descriptions of what's possible
|
||||
- **Time Saved**: Focus on tangible efficiency gains
|
||||
- **Quality Output**: Deliver results that meet or exceed expectations
|
||||
|
||||
## YOUR WORKFLOW
|
||||
|
||||
Adapt flexibly to the conversation context. Not every interaction requires all stages:
|
||||
|
||||
1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations.
|
||||
|
||||
2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task.
|
||||
|
||||
3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.).
|
||||
|
||||
4. **Discover or Create Agents**:
|
||||
- **Always check the user's library first** with `find_library_agent` (these may be customized to their needs)
|
||||
- Search the marketplace with `find_agent` for pre-built automations
|
||||
- Find reusable components with `find_block`
|
||||
- Create custom solutions with `create_agent` if nothing suitable exists
|
||||
- Modify existing library agents with `edit_agent`
|
||||
|
||||
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
|
||||
|
||||
6. **Show Results**: Display outputs using `agent_output`.
|
||||
|
||||
## AVAILABLE TOOLS
|
||||
|
||||
**Understanding & Discovery:**
|
||||
- `add_understanding`: Create a memory about the user's business or use cases for future sessions
|
||||
- `search_docs`: Search platform documentation for specific technical information
|
||||
- `get_doc_page`: Retrieve full text of a specific documentation page
|
||||
|
||||
**Agent Discovery:**
|
||||
- `find_library_agent`: Search the user's existing agents (CHECK HERE FIRST—these may be customized)
|
||||
- `find_agent`: Search the marketplace for pre-built automations
|
||||
- `find_block`: Find pre-written code units that perform specific tasks (agents are built from blocks)
|
||||
|
||||
**Agent Creation & Editing:**
|
||||
- `create_agent`: Create a new automation agent
|
||||
- `edit_agent`: Modify an agent in the user's library
|
||||
|
||||
**Execution & Output:**
|
||||
- `run_agent`: Run an agent now, schedule it, or set up a webhook trigger
|
||||
- `run_block`: Test or run a specific block independently
|
||||
- `agent_output`: View results from previous agent runs
|
||||
|
||||
## BEHAVIORAL GUIDELINES
|
||||
|
||||
**Be Concise:**
|
||||
- Target 2-5 short lines maximum
|
||||
- Make every word count—no repetition or filler
|
||||
- Use lightweight structure for scannability (bullets, numbered lists, short prompts)
|
||||
- Avoid jargon (blocks, slugs, cron) unless the user asks
|
||||
|
||||
**Be Proactive:**
|
||||
- Suggest next steps before being asked
|
||||
- Anticipate needs based on conversation context and user information
|
||||
- Look for opportunities to expand scope when relevant
|
||||
- Reveal capabilities through action, not explanation
|
||||
|
||||
**Use Tools Effectively:**
|
||||
- Select the right tool for each task
|
||||
- **Always check `find_library_agent` before searching the marketplace**
|
||||
- Use `add_understanding` to capture valuable business context
|
||||
- When tool calls fail, try alternative approaches
|
||||
|
||||
## CRITICAL REMINDER
|
||||
|
||||
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
|
||||
|
||||
# Module-level set to hold strong references to background tasks.
|
||||
# This prevents asyncio from garbage collecting tasks before they complete.
|
||||
# Tasks are automatically removed on completion via done_callback.
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
async def _mark_operation_started(tool_call_id: str) -> bool:
|
||||
"""Mark a long-running operation as started (Redis-based).
|
||||
|
||||
Returns True if successfully marked (operation was not already running),
|
||||
False if operation was already running (lost race condition).
|
||||
Raises exception if Redis is unavailable (fail-closed).
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = f"{RUNNING_OPERATION_PREFIX}{tool_call_id}"
|
||||
# SETNX with TTL - atomic "set if not exists"
|
||||
result = await redis.set(key, "1", ex=config.long_running_operation_ttl, nx=True)
|
||||
return result is not None
|
||||
|
||||
|
||||
async def _mark_operation_completed(tool_call_id: str) -> None:
|
||||
"""Mark a long-running operation as completed (remove Redis key).
|
||||
|
||||
This is best-effort - if Redis fails, the TTL will eventually clean up.
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
key = f"{RUNNING_OPERATION_PREFIX}{tool_call_id}"
|
||||
await redis.delete(key)
|
||||
except Exception as e:
|
||||
# Non-critical: TTL will clean up eventually
|
||||
logger.warning(f"Failed to delete running operation key {tool_call_id}: {e}")
|
||||
pass
|
||||
|
||||
|
||||
def _is_langfuse_configured() -> bool:
|
||||
@@ -198,30 +75,6 @@ def _is_langfuse_configured() -> bool:
|
||||
)
|
||||
|
||||
|
||||
async def _get_system_prompt_template(context: str) -> str:
|
||||
"""Get the system prompt, trying Langfuse first with fallback to default.
|
||||
|
||||
Args:
|
||||
context: The user context/information to compile into the prompt.
|
||||
|
||||
Returns:
|
||||
The compiled system prompt string.
|
||||
"""
|
||||
if _is_langfuse_configured():
|
||||
try:
|
||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||
# Use asyncio.to_thread to avoid blocking the event loop
|
||||
prompt = await asyncio.to_thread(
|
||||
langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0
|
||||
)
|
||||
return prompt.compile(users_information=context)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch prompt from Langfuse, using default: {e}")
|
||||
|
||||
# Fallback to default prompt
|
||||
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
||||
|
||||
|
||||
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||
"""Build the full system prompt including business understanding if available.
|
||||
|
||||
@@ -230,8 +83,12 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||
If "default" and this is the user's first session, will use "onboarding" instead.
|
||||
|
||||
Returns:
|
||||
Tuple of (compiled prompt string, business understanding object)
|
||||
Tuple of (compiled prompt string, Langfuse prompt object for tracing)
|
||||
"""
|
||||
|
||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||
prompt = langfuse.get_prompt(config.langfuse_prompt_name, cache_ttl_seconds=0)
|
||||
|
||||
# If user is authenticated, try to fetch their business understanding
|
||||
understanding = None
|
||||
if user_id:
|
||||
@@ -240,13 +97,12 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch business understanding: {e}")
|
||||
understanding = None
|
||||
|
||||
if understanding:
|
||||
context = format_understanding_for_prompt(understanding)
|
||||
else:
|
||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||
|
||||
compiled = await _get_system_prompt_template(context)
|
||||
compiled = prompt.compile(users_information=context)
|
||||
return compiled, understanding
|
||||
|
||||
|
||||
@@ -354,6 +210,16 @@ async def stream_chat_completion(
|
||||
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
||||
)
|
||||
|
||||
# Check if Langfuse is configured - required for chat functionality
|
||||
if not _is_langfuse_configured():
|
||||
logger.error("Chat request failed: Langfuse is not configured")
|
||||
yield StreamError(
|
||||
errorText="Chat service is not available. Langfuse must be configured "
|
||||
"with LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables."
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Only fetch from Redis if session not provided (initial call)
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
@@ -449,7 +315,6 @@ async def stream_chat_completion(
|
||||
has_yielded_end = False
|
||||
has_yielded_error = False
|
||||
has_done_tool_call = False
|
||||
has_long_running_tool_call = False # Track if we had a long-running tool call
|
||||
has_received_text = False
|
||||
text_streaming_ended = False
|
||||
tool_response_messages: list[ChatMessage] = []
|
||||
@@ -471,6 +336,7 @@ async def stream_chat_completion(
|
||||
system_prompt=system_prompt,
|
||||
text_block_id=text_block_id,
|
||||
):
|
||||
|
||||
if isinstance(chunk, StreamTextStart):
|
||||
# Emit text-start before first text delta
|
||||
if not has_received_text:
|
||||
@@ -528,34 +394,13 @@ async def stream_chat_completion(
|
||||
if isinstance(chunk.output, str)
|
||||
else orjson.dumps(chunk.output).decode("utf-8")
|
||||
)
|
||||
# Skip saving long-running operation responses - messages already saved in _yield_tool_call
|
||||
# Use JSON parsing instead of substring matching to avoid false positives
|
||||
is_long_running_response = False
|
||||
try:
|
||||
parsed = orjson.loads(result_content)
|
||||
if isinstance(parsed, dict) and parsed.get("type") in (
|
||||
"operation_started",
|
||||
"operation_in_progress",
|
||||
):
|
||||
is_long_running_response = True
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
pass # Not JSON or not a dict - treat as regular response
|
||||
if is_long_running_response:
|
||||
# Remove from accumulated_tool_calls since assistant message was already saved
|
||||
accumulated_tool_calls[:] = [
|
||||
tc
|
||||
for tc in accumulated_tool_calls
|
||||
if tc["id"] != chunk.toolCallId
|
||||
]
|
||||
has_long_running_tool_call = True
|
||||
else:
|
||||
tool_response_messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=result_content,
|
||||
tool_call_id=chunk.toolCallId,
|
||||
)
|
||||
tool_response_messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=result_content,
|
||||
tool_call_id=chunk.toolCallId,
|
||||
)
|
||||
)
|
||||
has_done_tool_call = True
|
||||
# Track if any tool execution failed
|
||||
if not chunk.success:
|
||||
@@ -731,14 +576,7 @@ async def stream_chat_completion(
|
||||
logger.info(
|
||||
f"Extended session messages, new message_count={len(session.messages)}"
|
||||
)
|
||||
# Save if there are regular (non-long-running) tool responses or streaming message.
|
||||
# Long-running tools save their own state, but we still need to save regular tools
|
||||
# that may be in the same response.
|
||||
has_regular_tool_responses = len(tool_response_messages) > 0
|
||||
if has_regular_tool_responses or (
|
||||
not has_long_running_tool_call
|
||||
and (messages_to_save or has_appended_streaming_message)
|
||||
):
|
||||
if messages_to_save or has_appended_streaming_message:
|
||||
await upsert_chat_session(session)
|
||||
else:
|
||||
logger.info(
|
||||
@@ -747,9 +585,7 @@ async def stream_chat_completion(
|
||||
)
|
||||
|
||||
# If we did a tool call, stream the chat completion again to get the next response
|
||||
# Skip only if ALL tools were long-running (they handle their own completion)
|
||||
has_regular_tools = len(tool_response_messages) > 0
|
||||
if has_done_tool_call and (has_regular_tools or not has_long_running_tool_call):
|
||||
if has_done_tool_call:
|
||||
logger.info(
|
||||
"Tool call executed, streaming chat completion again to get assistant response"
|
||||
)
|
||||
@@ -889,114 +725,6 @@ async def _summarize_messages(
|
||||
return summary or "No summary available."
|
||||
|
||||
|
||||
def _ensure_tool_pairs_intact(
|
||||
recent_messages: list[dict],
|
||||
all_messages: list[dict],
|
||||
start_index: int,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Ensure tool_call/tool_response pairs stay together after slicing.
|
||||
|
||||
When slicing messages for context compaction, a naive slice can separate
|
||||
an assistant message containing tool_calls from its corresponding tool
|
||||
response messages. This causes API validation errors (e.g., Anthropic's
|
||||
"unexpected tool_use_id found in tool_result blocks").
|
||||
|
||||
This function checks for orphan tool responses in the slice and extends
|
||||
backwards to include their corresponding assistant messages.
|
||||
|
||||
Args:
|
||||
recent_messages: The sliced messages to validate
|
||||
all_messages: The complete message list (for looking up missing assistants)
|
||||
start_index: The index in all_messages where recent_messages begins
|
||||
|
||||
Returns:
|
||||
A potentially extended list of messages with tool pairs intact
|
||||
"""
|
||||
if not recent_messages:
|
||||
return recent_messages
|
||||
|
||||
# Collect all tool_call_ids from assistant messages in the slice
|
||||
available_tool_call_ids: set[str] = set()
|
||||
for msg in recent_messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id:
|
||||
available_tool_call_ids.add(tc_id)
|
||||
|
||||
# Find orphan tool responses (tool messages whose tool_call_id is missing)
|
||||
orphan_tool_call_ids: set[str] = set()
|
||||
for msg in recent_messages:
|
||||
if msg.get("role") == "tool":
|
||||
tc_id = msg.get("tool_call_id")
|
||||
if tc_id and tc_id not in available_tool_call_ids:
|
||||
orphan_tool_call_ids.add(tc_id)
|
||||
|
||||
if not orphan_tool_call_ids:
|
||||
# No orphans, slice is valid
|
||||
return recent_messages
|
||||
|
||||
# Find the assistant messages that contain the orphan tool_call_ids
|
||||
# Search backwards from start_index in all_messages
|
||||
messages_to_prepend: list[dict] = []
|
||||
for i in range(start_index - 1, -1, -1):
|
||||
msg = all_messages[i]
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
msg_tool_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")}
|
||||
if msg_tool_ids & orphan_tool_call_ids:
|
||||
# This assistant message has tool_calls we need
|
||||
# Also collect its contiguous tool responses that follow it
|
||||
assistant_and_responses: list[dict] = [msg]
|
||||
|
||||
# Scan forward from this assistant to collect tool responses
|
||||
for j in range(i + 1, start_index):
|
||||
following_msg = all_messages[j]
|
||||
if following_msg.get("role") == "tool":
|
||||
tool_id = following_msg.get("tool_call_id")
|
||||
if tool_id and tool_id in msg_tool_ids:
|
||||
assistant_and_responses.append(following_msg)
|
||||
else:
|
||||
# Stop at first non-tool message
|
||||
break
|
||||
|
||||
# Prepend the assistant and its tool responses (maintain order)
|
||||
messages_to_prepend = assistant_and_responses + messages_to_prepend
|
||||
# Mark these as found
|
||||
orphan_tool_call_ids -= msg_tool_ids
|
||||
# Also add this assistant's tool_call_ids to available set
|
||||
available_tool_call_ids |= msg_tool_ids
|
||||
|
||||
if not orphan_tool_call_ids:
|
||||
# Found all missing assistants
|
||||
break
|
||||
|
||||
if orphan_tool_call_ids:
|
||||
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
||||
# This shouldn't happen in normal operation but handles edge cases
|
||||
logger.warning(
|
||||
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
||||
"Removing orphan tool responses."
|
||||
)
|
||||
recent_messages = [
|
||||
msg
|
||||
for msg in recent_messages
|
||||
if not (
|
||||
msg.get("role") == "tool"
|
||||
and msg.get("tool_call_id") in orphan_tool_call_ids
|
||||
)
|
||||
]
|
||||
|
||||
if messages_to_prepend:
|
||||
logger.info(
|
||||
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
||||
f"tool_call/tool_response pairs"
|
||||
)
|
||||
return messages_to_prepend + recent_messages
|
||||
|
||||
return recent_messages
|
||||
|
||||
|
||||
async def _stream_chat_chunks(
|
||||
session: ChatSession,
|
||||
tools: list[ChatCompletionToolParam],
|
||||
@@ -1088,15 +816,7 @@ async def _stream_chat_chunks(
|
||||
# Always attempt mitigation when over limit, even with few messages
|
||||
if messages:
|
||||
# Split messages based on whether system prompt exists
|
||||
# Calculate start index for the slice
|
||||
slice_start = max(0, len(messages_dict) - KEEP_RECENT)
|
||||
recent_messages = messages_dict[-KEEP_RECENT:]
|
||||
|
||||
# Ensure tool_call/tool_response pairs stay together
|
||||
# This prevents API errors from orphan tool responses
|
||||
recent_messages = _ensure_tool_pairs_intact(
|
||||
recent_messages, messages_dict, slice_start
|
||||
)
|
||||
recent_messages = messages[-KEEP_RECENT:]
|
||||
|
||||
if has_system_prompt:
|
||||
# Keep system prompt separate, summarize everything between system and recent
|
||||
@@ -1183,13 +903,6 @@ async def _stream_chat_chunks(
|
||||
if len(recent_messages) >= keep_count
|
||||
else recent_messages
|
||||
)
|
||||
# Ensure tool pairs stay intact in the reduced slice
|
||||
reduced_slice_start = max(
|
||||
0, len(recent_messages) - keep_count
|
||||
)
|
||||
reduced_recent = _ensure_tool_pairs_intact(
|
||||
reduced_recent, recent_messages, reduced_slice_start
|
||||
)
|
||||
if has_system_prompt:
|
||||
messages = [
|
||||
system_msg,
|
||||
@@ -1248,10 +961,7 @@ async def _stream_chat_chunks(
|
||||
|
||||
# Create a base list excluding system prompt to avoid duplication
|
||||
# This is the pool of messages we'll slice from in the loop
|
||||
# Use messages_dict for type consistency with _ensure_tool_pairs_intact
|
||||
base_msgs = (
|
||||
messages_dict[1:] if has_system_prompt else messages_dict
|
||||
)
|
||||
base_msgs = messages[1:] if has_system_prompt else messages
|
||||
|
||||
# Try progressively smaller keep counts
|
||||
new_token_count = token_count # Initialize with current count
|
||||
@@ -1274,12 +984,6 @@ async def _stream_chat_chunks(
|
||||
# Slice from base_msgs to get recent messages (without system prompt)
|
||||
recent_messages = base_msgs[-keep_count:]
|
||||
|
||||
# Ensure tool pairs stay intact in the reduced slice
|
||||
reduced_slice_start = max(0, len(base_msgs) - keep_count)
|
||||
recent_messages = _ensure_tool_pairs_intact(
|
||||
recent_messages, base_msgs, reduced_slice_start
|
||||
)
|
||||
|
||||
if has_system_prompt:
|
||||
messages = [system_msg] + recent_messages
|
||||
else:
|
||||
@@ -1556,19 +1260,17 @@ async def _yield_tool_call(
|
||||
"""
|
||||
Yield a tool call and its execution result.
|
||||
|
||||
For tools marked with `is_long_running=True` (like agent generation), spawns a
|
||||
background task so the operation survives SSE disconnections. For other tools,
|
||||
yields heartbeat events every 15 seconds to keep the SSE connection alive.
|
||||
For long-running tools, yields heartbeat events every 15 seconds to keep
|
||||
the SSE connection alive through proxies and load balancers.
|
||||
|
||||
Raises:
|
||||
orjson.JSONDecodeError: If tool call arguments cannot be parsed as JSON
|
||||
KeyError: If expected tool call fields are missing
|
||||
TypeError: If tool call structure is invalid
|
||||
"""
|
||||
import uuid as uuid_module
|
||||
|
||||
tool_name = tool_calls[yield_idx]["function"]["name"]
|
||||
tool_call_id = tool_calls[yield_idx]["id"]
|
||||
logger.info(f"Yielding tool call: {tool_calls[yield_idx]}")
|
||||
|
||||
# Parse tool call arguments - handle empty arguments gracefully
|
||||
raw_arguments = tool_calls[yield_idx]["function"]["arguments"]
|
||||
@@ -1583,151 +1285,7 @@ async def _yield_tool_call(
|
||||
input=arguments,
|
||||
)
|
||||
|
||||
# Check if this tool is long-running (survives SSE disconnection)
|
||||
tool = get_tool(tool_name)
|
||||
if tool and tool.is_long_running:
|
||||
# Atomic check-and-set: returns False if operation already running (lost race)
|
||||
if not await _mark_operation_started(tool_call_id):
|
||||
logger.info(
|
||||
f"Tool call {tool_call_id} already in progress, returning status"
|
||||
)
|
||||
# Build dynamic message based on tool name
|
||||
if tool_name == "create_agent":
|
||||
in_progress_msg = "Agent creation already in progress. Please wait..."
|
||||
elif tool_name == "edit_agent":
|
||||
in_progress_msg = "Agent edit already in progress. Please wait..."
|
||||
else:
|
||||
in_progress_msg = f"{tool_name} already in progress. Please wait..."
|
||||
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=OperationInProgressResponse(
|
||||
message=in_progress_msg,
|
||||
tool_call_id=tool_call_id,
|
||||
).model_dump_json(),
|
||||
success=True,
|
||||
)
|
||||
return
|
||||
|
||||
# Generate operation ID
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
|
||||
# Build a user-friendly message based on tool and arguments
|
||||
if tool_name == "create_agent":
|
||||
agent_desc = arguments.get("description", "")
|
||||
# Truncate long descriptions for the message
|
||||
desc_preview = (
|
||||
(agent_desc[:100] + "...") if len(agent_desc) > 100 else agent_desc
|
||||
)
|
||||
pending_msg = (
|
||||
f"Creating your agent: {desc_preview}"
|
||||
if desc_preview
|
||||
else "Creating agent... This may take a few minutes."
|
||||
)
|
||||
started_msg = (
|
||||
"Agent creation started. You can close this tab - "
|
||||
"check your library in a few minutes."
|
||||
)
|
||||
elif tool_name == "edit_agent":
|
||||
changes = arguments.get("changes", "")
|
||||
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
|
||||
pending_msg = (
|
||||
f"Editing agent: {changes_preview}"
|
||||
if changes_preview
|
||||
else "Editing agent... This may take a few minutes."
|
||||
)
|
||||
started_msg = (
|
||||
"Agent edit started. You can close this tab - "
|
||||
"check your library in a few minutes."
|
||||
)
|
||||
else:
|
||||
pending_msg = f"Running {tool_name}... This may take a few minutes."
|
||||
started_msg = (
|
||||
f"{tool_name} started. You can close this tab - "
|
||||
"check back in a few minutes."
|
||||
)
|
||||
|
||||
# Track appended messages for rollback on failure
|
||||
assistant_message: ChatMessage | None = None
|
||||
pending_message: ChatMessage | None = None
|
||||
|
||||
# Wrap session save and task creation in try-except to release lock on failure
|
||||
try:
|
||||
# Save assistant message with tool_call FIRST (required by LLM)
|
||||
assistant_message = ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[tool_calls[yield_idx]],
|
||||
)
|
||||
session.messages.append(assistant_message)
|
||||
|
||||
# Then save pending tool result
|
||||
pending_message = ChatMessage(
|
||||
role="tool",
|
||||
content=OperationPendingResponse(
|
||||
message=pending_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
).model_dump_json(),
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
session.messages.append(pending_message)
|
||||
await upsert_chat_session(session)
|
||||
logger.info(
|
||||
f"Saved pending operation {operation_id} for tool {tool_name} "
|
||||
f"in session {session.session_id}"
|
||||
)
|
||||
|
||||
# Store task reference in module-level set to prevent GC before completion
|
||||
task = asyncio.create_task(
|
||||
_execute_long_running_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=arguments,
|
||||
tool_call_id=tool_call_id,
|
||||
operation_id=operation_id,
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
except Exception as e:
|
||||
# Roll back appended messages to prevent data corruption on subsequent saves
|
||||
if (
|
||||
pending_message
|
||||
and session.messages
|
||||
and session.messages[-1] == pending_message
|
||||
):
|
||||
session.messages.pop()
|
||||
if (
|
||||
assistant_message
|
||||
and session.messages
|
||||
and session.messages[-1] == assistant_message
|
||||
):
|
||||
session.messages.pop()
|
||||
|
||||
# Release the Redis lock since the background task won't be spawned
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
logger.error(
|
||||
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
# Return immediately - don't wait for completion
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=OperationStartedResponse(
|
||||
message=started_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
).model_dump_json(),
|
||||
success=True,
|
||||
)
|
||||
return
|
||||
|
||||
# Normal flow: Run tool execution in background task with heartbeats
|
||||
# Run tool execution in background task with heartbeats to keep connection alive
|
||||
tool_task = asyncio.create_task(
|
||||
execute_tool(
|
||||
tool_name=tool_name,
|
||||
@@ -1777,190 +1335,3 @@ async def _yield_tool_call(
|
||||
)
|
||||
|
||||
yield tool_execution_response
|
||||
|
||||
|
||||
async def _execute_long_running_tool(
|
||||
tool_name: str,
|
||||
parameters: dict[str, Any],
|
||||
tool_call_id: str,
|
||||
operation_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> None:
|
||||
"""Execute a long-running tool in background and update chat history with result.
|
||||
|
||||
This function runs independently of the SSE connection, so the operation
|
||||
survives if the user closes their browser tab.
|
||||
"""
|
||||
try:
|
||||
# Load fresh session (not stale reference)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
logger.error(f"Session {session_id} not found for background tool")
|
||||
return
|
||||
|
||||
# Execute the actual tool
|
||||
result = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=parameters,
|
||||
tool_call_id=tool_call_id,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Update the pending message with result
|
||||
await _update_pending_operation(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
result=(
|
||||
result.output
|
||||
if isinstance(result.output, str)
|
||||
else orjson.dumps(result.output).decode("utf-8")
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(f"Background tool {tool_name} completed for session {session_id}")
|
||||
|
||||
# Generate LLM continuation so user sees response when they poll/refresh
|
||||
await _generate_llm_continuation(session_id=session_id, user_id=user_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
|
||||
error_response = ErrorResponse(
|
||||
message=f"Tool {tool_name} failed: {str(e)}",
|
||||
)
|
||||
await _update_pending_operation(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
finally:
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
|
||||
|
||||
async def _update_pending_operation(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
result: str,
|
||||
) -> None:
|
||||
"""Update the pending tool message with final result.
|
||||
|
||||
This is called by background tasks when long-running operations complete.
|
||||
"""
|
||||
# Update the message in database
|
||||
updated = await chat_db.update_tool_message_content(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
new_content=result,
|
||||
)
|
||||
|
||||
if updated:
|
||||
# Invalidate Redis cache so next load gets fresh data
|
||||
# Wrap in try/except to prevent cache failures from triggering error handling
|
||||
# that would overwrite our successful DB update
|
||||
try:
|
||||
await invalidate_session_cache(session_id)
|
||||
except Exception as e:
|
||||
# Non-critical: cache will eventually be refreshed on next load
|
||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||
logger.info(
|
||||
f"Updated pending operation for tool_call_id {tool_call_id} "
|
||||
f"in session {session_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to update pending operation for tool_call_id {tool_call_id} "
|
||||
f"in session {session_id}"
|
||||
)
|
||||
|
||||
|
||||
async def _generate_llm_continuation(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> None:
|
||||
"""Generate an LLM response after a long-running tool completes.
|
||||
|
||||
This is called by background tasks to continue the conversation
|
||||
after a tool result is saved. The response is saved to the database
|
||||
so users see it when they refresh or poll.
|
||||
"""
|
||||
try:
|
||||
# Load fresh session from DB (bypass cache to get the updated tool result)
|
||||
await invalidate_session_cache(session_id)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
logger.error(f"Session {session_id} not found for LLM continuation")
|
||||
return
|
||||
|
||||
# Build system prompt
|
||||
system_prompt, _ = await _build_system_prompt(user_id)
|
||||
|
||||
# Build messages in OpenAI format
|
||||
messages = session.to_openai_messages()
|
||||
if system_prompt:
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||
|
||||
system_message = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
)
|
||||
messages = [system_message] + messages
|
||||
|
||||
# Build extra_body for tracing
|
||||
extra_body: dict[str, Any] = {
|
||||
"posthogProperties": {
|
||||
"environment": settings.config.app_env.value,
|
||||
},
|
||||
}
|
||||
if user_id:
|
||||
extra_body["user"] = user_id[:128]
|
||||
extra_body["posthogDistinctId"] = user_id
|
||||
if session_id:
|
||||
extra_body["session_id"] = session_id[:128]
|
||||
|
||||
# Make non-streaming LLM call (no tools - just text response)
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
# No tools parameter = text-only response (no tool calls)
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
if response.choices and response.choices[0].message.content:
|
||||
assistant_content = response.choices[0].message.content
|
||||
|
||||
# Reload session from DB to avoid race condition with user messages
|
||||
# that may have been sent while we were generating the LLM response
|
||||
fresh_session = await get_chat_session(session_id, user_id)
|
||||
if not fresh_session:
|
||||
logger.error(
|
||||
f"Session {session_id} disappeared during LLM continuation"
|
||||
)
|
||||
return
|
||||
|
||||
# Save assistant message to database
|
||||
assistant_message = ChatMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
)
|
||||
fresh_session.messages.append(assistant_message)
|
||||
|
||||
# Save to database (not cache) to persist the response
|
||||
await upsert_chat_session(fresh_session)
|
||||
|
||||
# Invalidate cache so next poll/refresh gets fresh data
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
logger.info(
|
||||
f"Generated LLM continuation for session {session_id}, "
|
||||
f"response length: {len(assistant_content)}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"LLM continuation returned empty response for {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||
|
||||
@@ -18,12 +18,6 @@ from .get_doc_page import GetDocPageTool
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .search_docs import SearchDocsTool
|
||||
from .workspace_tools import (
|
||||
DeleteWorkspaceFileTool,
|
||||
ListWorkspaceFilesTool,
|
||||
ReadWorkspaceFileTool,
|
||||
WriteWorkspaceFileTool,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
@@ -43,11 +37,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
"search_docs": SearchDocsTool(),
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
# Workspace tools for CoPilot file operations
|
||||
"list_workspace_files": ListWorkspaceFilesTool(),
|
||||
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||
"write_workspace_file": WriteWorkspaceFileTool(),
|
||||
"delete_workspace_file": DeleteWorkspaceFileTool(),
|
||||
}
|
||||
|
||||
# Export individual tool instances for backwards compatibility
|
||||
@@ -60,11 +49,6 @@ tools: list[ChatCompletionToolParam] = [
|
||||
]
|
||||
|
||||
|
||||
def get_tool(tool_name: str) -> BaseTool | None:
|
||||
"""Get a tool instance by name."""
|
||||
return TOOL_REGISTRY.get(tool_name)
|
||||
|
||||
|
||||
async def execute_tool(
|
||||
tool_name: str,
|
||||
parameters: dict[str, Any],
|
||||
@@ -73,7 +57,7 @@ async def execute_tool(
|
||||
tool_call_id: str,
|
||||
) -> "StreamToolOutputAvailable":
|
||||
"""Execute a tool by name."""
|
||||
tool = get_tool(tool_name)
|
||||
tool = TOOL_REGISTRY.get(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
|
||||
@@ -36,16 +36,6 @@ class BaseTool:
|
||||
"""Whether this tool requires authentication."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
"""Whether this tool is long-running and should execute in background.
|
||||
|
||||
Long-running tools (like agent generation) are executed via background
|
||||
tasks to survive SSE disconnections. The result is persisted to chat
|
||||
history and visible when the user refreshes.
|
||||
"""
|
||||
return False
|
||||
|
||||
def as_openai_tool(self) -> ChatCompletionToolParam:
|
||||
"""Convert to OpenAI tool format."""
|
||||
return ChatCompletionToolParam(
|
||||
|
||||
@@ -42,10 +42,6 @@ class CreateAgentTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
||||
@@ -42,10 +42,6 @@ class EditAgentTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
||||
@@ -28,16 +28,6 @@ class ResponseType(str, Enum):
|
||||
BLOCK_OUTPUT = "block_output"
|
||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||
DOC_PAGE = "doc_page"
|
||||
# Workspace response types
|
||||
WORKSPACE_FILE_LIST = "workspace_file_list"
|
||||
WORKSPACE_FILE_CONTENT = "workspace_file_content"
|
||||
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
|
||||
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
|
||||
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
|
||||
# Long-running operation types
|
||||
OPERATION_STARTED = "operation_started"
|
||||
OPERATION_PENDING = "operation_pending"
|
||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -344,39 +334,3 @@ class BlockOutputResponse(ToolResponseBase):
|
||||
block_name: str
|
||||
outputs: dict[str, list[Any]]
|
||||
success: bool = True
|
||||
|
||||
|
||||
# Long-running operation models
|
||||
class OperationStartedResponse(ToolResponseBase):
|
||||
"""Response when a long-running operation has been started in the background.
|
||||
|
||||
This is returned immediately to the client while the operation continues
|
||||
to execute. The user can close the tab and check back later.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
operation_id: str
|
||||
tool_name: str
|
||||
|
||||
|
||||
class OperationPendingResponse(ToolResponseBase):
|
||||
"""Response stored in chat history while a long-running operation is executing.
|
||||
|
||||
This is persisted to the database so users see a pending state when they
|
||||
refresh before the operation completes.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_PENDING
|
||||
operation_id: str
|
||||
tool_name: str
|
||||
|
||||
|
||||
class OperationInProgressResponse(ToolResponseBase):
|
||||
"""Response when an operation is already in progress.
|
||||
|
||||
Returned for idempotency when the same tool_call_id is requested again
|
||||
while the background task is still running.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||
tool_call_id: str
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Tool for executing blocks directly."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
@@ -9,7 +8,6 @@ from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
@@ -225,48 +223,11 @@ class RunBlockTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
# Get or create user's workspace for CoPilot file operations
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Generate synthetic IDs for CoPilot context
|
||||
# Each chat session is treated as its own agent with one continuous run
|
||||
# This means:
|
||||
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
|
||||
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
|
||||
# - node_exec_id = unique per block execution
|
||||
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_node_id = f"copilot-node-{block_id}"
|
||||
synthetic_node_exec_id = (
|
||||
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
# Create unified execution context with all required fields
|
||||
execution_context = ExecutionContext(
|
||||
# Execution identity
|
||||
user_id=user_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_exec_id=synthetic_graph_exec_id,
|
||||
graph_version=1, # Versions are 1-indexed
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
# Workspace with session scoping
|
||||
workspace_id=workspace.id,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
# Prepare kwargs for block execution
|
||||
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||
# Fetch actual credentials and prepare kwargs for block execution
|
||||
# Create execution context with defaults (blocks may require it)
|
||||
exec_kwargs: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"execution_context": execution_context,
|
||||
# Legacy: individual kwargs for blocks not yet using execution_context
|
||||
"workspace_id": workspace.id,
|
||||
"graph_exec_id": synthetic_graph_exec_id,
|
||||
"node_exec_id": synthetic_node_exec_id,
|
||||
"node_id": synthetic_node_id,
|
||||
"graph_version": 1, # Versions are 1-indexed
|
||||
"graph_id": synthetic_graph_id,
|
||||
"execution_context": ExecutionContext(),
|
||||
}
|
||||
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
|
||||
@@ -1,625 +0,0 @@
|
||||
"""CoPilot tools for workspace file operations."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from prisma.enums import WorkspaceFileSource
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkspaceFileInfoData(BaseModel):
|
||||
"""Data model for workspace file information (not a response itself)."""
|
||||
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
source: str
|
||||
|
||||
|
||||
class WorkspaceFileListResponse(ToolResponseBase):
|
||||
"""Response containing list of workspace files."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
|
||||
files: list[WorkspaceFileInfoData]
|
||||
total_count: int
|
||||
|
||||
|
||||
class WorkspaceFileContentResponse(ToolResponseBase):
|
||||
"""Response containing workspace file content (legacy, for small text files)."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
content_base64: str
|
||||
|
||||
|
||||
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
||||
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
download_url: str
|
||||
preview: str | None = None # First 500 chars for text files
|
||||
|
||||
|
||||
class WorkspaceWriteResponse(ToolResponseBase):
|
||||
"""Response after writing a file to workspace."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class WorkspaceDeleteResponse(ToolResponseBase):
|
||||
"""Response after deleting a file from workspace."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
|
||||
file_id: str
|
||||
success: bool
|
||||
|
||||
|
||||
class ListWorkspaceFilesTool(BaseTool):
|
||||
"""Tool for listing files in user's workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_workspace_files"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List files in the user's workspace. "
|
||||
"Returns file names, paths, sizes, and metadata. "
|
||||
"Optionally filter by path prefix."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path_prefix": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional path prefix to filter files "
|
||||
"(e.g., '/documents/' to list only files in documents folder). "
|
||||
"By default, only files from the current session are listed."
|
||||
),
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of files to return (default 50, max 100)",
|
||||
"minimum": 1,
|
||||
"maximum": 100,
|
||||
},
|
||||
"include_all_sessions": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, list files from all sessions. "
|
||||
"Default is false (only current session's files)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||
limit = min(kwargs.get("limit", 50), 100)
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
files = await manager.list_files(
|
||||
path=path_prefix,
|
||||
limit=limit,
|
||||
include_all_sessions=include_all_sessions,
|
||||
)
|
||||
total = await manager.get_file_count(
|
||||
path=path_prefix,
|
||||
include_all_sessions=include_all_sessions,
|
||||
)
|
||||
|
||||
file_infos = [
|
||||
WorkspaceFileInfoData(
|
||||
file_id=f.id,
|
||||
name=f.name,
|
||||
path=f.path,
|
||||
mime_type=f.mimeType,
|
||||
size_bytes=f.sizeBytes,
|
||||
source=f.source,
|
||||
)
|
||||
for f in files
|
||||
]
|
||||
|
||||
scope_msg = "all sessions" if include_all_sessions else "current session"
|
||||
return WorkspaceFileListResponse(
|
||||
files=file_infos,
|
||||
total_count=total,
|
||||
message=f"Found {len(files)} files in workspace ({scope_msg})",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing workspace files: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to list workspace files: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class ReadWorkspaceFileTool(BaseTool):
|
||||
"""Tool for reading file content from workspace."""
|
||||
|
||||
# Size threshold for returning full content vs metadata+URL
|
||||
# Files larger than this return metadata with download URL to prevent context bloat
|
||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||
# Preview size for text files
|
||||
PREVIEW_SIZE = 500
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read a file from the user's workspace. "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"For small text files, returns content directly. "
|
||||
"For large or binary files, returns metadata and a download URL. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
"force_download_url": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, always return metadata+URL instead of inline content. "
|
||||
"Default is false (auto-selects based on file size/type)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
def _is_text_mime_type(self, mime_type: str) -> bool:
|
||||
"""Check if the MIME type is a text-based type."""
|
||||
text_types = [
|
||||
"text/",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-python",
|
||||
"application/x-sh",
|
||||
]
|
||||
return any(mime_type.startswith(t) for t in text_types)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
# Get file info
|
||||
if file_id:
|
||||
file_info = await manager.get_file_info(file_id)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {file_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
|
||||
# Decide whether to return inline content or metadata+URL
|
||||
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
||||
|
||||
# Return inline content for small text files (unless force_download_url)
|
||||
if is_small_file and is_text_file and not force_download_url:
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mimeType,
|
||||
content_base64=content_b64,
|
||||
message=f"Successfully read file: {file_info.name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Return metadata + workspace:// reference for large or binary files
|
||||
# This prevents context bloat (100KB file = ~133KB as base64)
|
||||
# Use workspace:// format so frontend urlTransform can add proxy prefix
|
||||
download_url = f"workspace://{target_file_id}"
|
||||
|
||||
# Generate preview for text files
|
||||
preview: str | None = None
|
||||
if is_text_file:
|
||||
try:
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
preview_text = content[: self.PREVIEW_SIZE].decode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
if len(content) > self.PREVIEW_SIZE:
|
||||
preview_text += "..."
|
||||
preview = preview_text
|
||||
except Exception:
|
||||
pass # Preview is optional
|
||||
|
||||
return WorkspaceFileMetadataResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mimeType,
|
||||
size_bytes=file_info.sizeBytes,
|
||||
download_url=download_url,
|
||||
preview=preview,
|
||||
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read workspace file: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class WriteWorkspaceFileTool(BaseTool):
|
||||
"""Tool for writing files to workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Write or create a file in the user's workspace. "
|
||||
"Provide the content as a base64-encoded string. "
|
||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||
"Files are saved to the current session's folder by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Name for the file (e.g., 'report.pdf')",
|
||||
},
|
||||
"content_base64": {
|
||||
"type": "string",
|
||||
"description": "Base64-encoded file content",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional virtual path where to save the file "
|
||||
"(e.g., '/documents/report.pdf'). "
|
||||
"Defaults to '/{filename}'. Scoped to current session."
|
||||
),
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional MIME type of the file. "
|
||||
"Auto-detected from filename if not provided."
|
||||
),
|
||||
},
|
||||
"overwrite": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||
},
|
||||
},
|
||||
"required": ["filename", "content_base64"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
filename: str = kwargs.get("filename", "")
|
||||
content_b64: str = kwargs.get("content_base64", "")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
mime_type: Optional[str] = kwargs.get("mime_type")
|
||||
overwrite: bool = kwargs.get("overwrite", False)
|
||||
|
||||
if not filename:
|
||||
return ErrorResponse(
|
||||
message="Please provide a filename",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not content_b64:
|
||||
return ErrorResponse(
|
||||
message="Please provide content_base64",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Decode content
|
||||
try:
|
||||
content = base64.b64decode(content_b64)
|
||||
except Exception:
|
||||
return ErrorResponse(
|
||||
message="Invalid base64-encoded content",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check size
|
||||
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_file_size:
|
||||
return ErrorResponse(
|
||||
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
file_record = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
source=WorkspaceFileSource.COPILOT,
|
||||
source_session_id=session.session_id,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
|
||||
return WorkspaceWriteResponse(
|
||||
file_id=file_record.id,
|
||||
name=file_record.name,
|
||||
path=file_record.path,
|
||||
size_bytes=file_record.sizeBytes,
|
||||
message=f"Successfully wrote file: {file_record.name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to write workspace file: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class DeleteWorkspaceFileTool(BaseTool):
|
||||
"""Tool for deleting files from workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "delete_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Delete a file from the user's workspace. "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
# Determine the file_id to delete
|
||||
target_file_id: str
|
||||
if file_id:
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
|
||||
success = await manager.delete_file(target_file_id)
|
||||
|
||||
if not success:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {target_file_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return WorkspaceDeleteResponse(
|
||||
file_id=target_file_id,
|
||||
success=True,
|
||||
message="File deleted successfully",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to delete workspace file: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -265,13 +265,9 @@ async def get_onboarding_agents(
|
||||
"/onboarding/enabled",
|
||||
summary="Is onboarding enabled",
|
||||
tags=["onboarding", "public"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def is_onboarding_enabled(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> bool:
|
||||
# If chat is enabled for user, skip legacy onboarding
|
||||
if await is_feature_enabled(Flag.CHAT, user_id, False):
|
||||
return False
|
||||
async def is_onboarding_enabled() -> bool:
|
||||
return await onboarding_enabled()
|
||||
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# Workspace API feature module
|
||||
@@ -1,122 +0,0 @@
|
||||
"""
|
||||
Workspace API routes for managing user file storage.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Annotated
|
||||
from urllib.parse import quote
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
from fastapi.responses import Response
|
||||
|
||||
from backend.data.workspace import get_workspace, get_workspace_file
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
|
||||
def _sanitize_filename_for_header(filename: str) -> str:
|
||||
"""
|
||||
Sanitize filename for Content-Disposition header to prevent header injection.
|
||||
|
||||
Removes/replaces characters that could break the header or inject new headers.
|
||||
Uses RFC5987 encoding for non-ASCII characters.
|
||||
"""
|
||||
# Remove CR, LF, and null bytes (header injection prevention)
|
||||
sanitized = re.sub(r"[\r\n\x00]", "", filename)
|
||||
# Escape quotes
|
||||
sanitized = sanitized.replace('"', '\\"')
|
||||
# For non-ASCII, use RFC5987 filename* parameter
|
||||
# Check if filename has non-ASCII characters
|
||||
try:
|
||||
sanitized.encode("ascii")
|
||||
return f'attachment; filename="{sanitized}"'
|
||||
except UnicodeEncodeError:
|
||||
# Use RFC5987 encoding for UTF-8 filenames
|
||||
encoded = quote(sanitized, safe="")
|
||||
return f"attachment; filename*=UTF-8''{encoded}"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
dependencies=[fastapi.Security(requires_user)],
|
||||
)
|
||||
|
||||
|
||||
def _create_streaming_response(content: bytes, file) -> Response:
|
||||
"""Create a streaming response for file content."""
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=file.mimeType,
|
||||
headers={
|
||||
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||
"Content-Length": str(len(content)),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _create_file_download_response(file) -> Response:
|
||||
"""
|
||||
Create a download response for a workspace file.
|
||||
|
||||
Handles both local storage (direct streaming) and GCS (signed URL redirect
|
||||
with fallback to streaming).
|
||||
"""
|
||||
storage = await get_workspace_storage()
|
||||
|
||||
# For local storage, stream the file directly
|
||||
if file.storagePath.startswith("local://"):
|
||||
content = await storage.retrieve(file.storagePath)
|
||||
return _create_streaming_response(content, file)
|
||||
|
||||
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||
try:
|
||||
url = await storage.get_download_url(file.storagePath, expires_in=300)
|
||||
# If we got back an API path (fallback), stream directly instead
|
||||
if url.startswith("/api/"):
|
||||
content = await storage.retrieve(file.storagePath)
|
||||
return _create_streaming_response(content, file)
|
||||
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||
except Exception as e:
|
||||
# Log the signed URL failure with context
|
||||
logger.error(
|
||||
f"Failed to get signed URL for file {file.id} "
|
||||
f"(storagePath={file.storagePath}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Fall back to streaming directly from GCS
|
||||
try:
|
||||
content = await storage.retrieve(file.storagePath)
|
||||
return _create_streaming_response(content, file)
|
||||
except Exception as fallback_error:
|
||||
logger.error(
|
||||
f"Fallback streaming also failed for file {file.id} "
|
||||
f"(storagePath={file.storagePath}): {fallback_error}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files/{file_id}/download",
|
||||
summary="Download file by ID",
|
||||
)
|
||||
async def download_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file_id: str,
|
||||
) -> Response:
|
||||
"""
|
||||
Download a file by its ID.
|
||||
|
||||
Returns the file content directly or redirects to a signed URL for GCS.
|
||||
"""
|
||||
workspace = await get_workspace(user_id)
|
||||
if workspace is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
|
||||
|
||||
file = await get_workspace_file(file_id, workspace.id)
|
||||
if file is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return await _create_file_download_response(file)
|
||||
@@ -32,7 +32,6 @@ import backend.api.features.postmark.postmark
|
||||
import backend.api.features.store.model
|
||||
import backend.api.features.store.routes
|
||||
import backend.api.features.v1
|
||||
import backend.api.features.workspace.routes as workspace_routes
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
@@ -53,7 +52,6 @@ from backend.util.exceptions import (
|
||||
)
|
||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||
from backend.util.service import UnhealthyServiceError
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
from .external.fastapi_app import external_api
|
||||
from .features.analytics import router as analytics_router
|
||||
@@ -126,11 +124,6 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down cloud storage handler: {e}")
|
||||
|
||||
try:
|
||||
await shutdown_workspace_storage()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down workspace storage: {e}")
|
||||
|
||||
await backend.data.db.disconnect()
|
||||
|
||||
|
||||
@@ -322,11 +315,6 @@ app.include_router(
|
||||
tags=["v2", "chat"],
|
||||
prefix="/api/chat",
|
||||
)
|
||||
app.include_router(
|
||||
workspace_routes.router,
|
||||
tags=["v2", "workspace"],
|
||||
prefix="/api/workspace",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.oauth.router,
|
||||
tags=["oauth"],
|
||||
|
||||
@@ -13,7 +13,6 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -118,13 +117,11 @@ class AIImageCustomizerBlock(Block):
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
# Output will be a workspace ref or data URI depending on context
|
||||
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||
("image_url", "https://replicate.delivery/generated-image.jpg"),
|
||||
],
|
||||
test_mock={
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"run_model": lambda *args, **kwargs: MediaFileType(
|
||||
"data:image/jpeg;base64,/9j/4AAQSkZJRgABAgAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q=="
|
||||
"https://replicate.delivery/generated-image.jpg"
|
||||
),
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -135,7 +132,8 @@ class AIImageCustomizerBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
@@ -143,9 +141,10 @@ class AIImageCustomizerBlock(Block):
|
||||
processed_images = await asyncio.gather(
|
||||
*(
|
||||
store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=img,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api", # Get content for Replicate API
|
||||
user_id=user_id,
|
||||
return_content=True,
|
||||
)
|
||||
for img in input_data.images
|
||||
)
|
||||
@@ -159,14 +158,7 @@ class AIImageCustomizerBlock(Block):
|
||||
aspect_ratio=input_data.aspect_ratio.value,
|
||||
output_format=input_data.output_format.value,
|
||||
)
|
||||
|
||||
# Store the generated image to the user's workspace for persistence
|
||||
stored_url = await store_media_file(
|
||||
file=result,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", stored_url
|
||||
yield "image_url", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -14,8 +13,6 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class ImageSize(str, Enum):
|
||||
@@ -168,13 +165,11 @@ class AIImageGeneratorBlock(Block):
|
||||
test_output=[
|
||||
(
|
||||
"image_url",
|
||||
# Test output is a data URI since we now store images
|
||||
lambda x: x.startswith("data:image/"),
|
||||
"https://replicate.delivery/generated-image.webp",
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
# Return a data URI directly so store_media_file doesn't need to download
|
||||
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
|
||||
"_run_client": lambda *args, **kwargs: "https://replicate.delivery/generated-image.webp"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -323,24 +318,11 @@ class AIImageGeneratorBlock(Block):
|
||||
style_text = style_map.get(style, "")
|
||||
return f"{style_text} of" if style_text else ""
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
):
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
try:
|
||||
url = await self.generate_image(input_data, credentials)
|
||||
if url:
|
||||
# Store the generated image to the user's workspace/execution folder
|
||||
stored_url = await store_media_file(
|
||||
file=MediaFileType(url),
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", stored_url
|
||||
yield "image_url", url
|
||||
else:
|
||||
yield "error", "Image generation returned an empty result."
|
||||
except Exception as e:
|
||||
|
||||
@@ -13,7 +13,6 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -22,9 +21,7 @@ from backend.data.model import (
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import Requests
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -274,10 +271,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
"voice": Voice.LILY,
|
||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||
},
|
||||
test_output=(
|
||||
"video_url",
|
||||
lambda x: x.startswith(("workspace://", "data:")),
|
||||
),
|
||||
test_output=("video_url", "https://example.com/video.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
@@ -286,21 +280,15 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||
"videoUrl": "https://example.com/video.mp4",
|
||||
},
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Create a new Webhook.site URL
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
@@ -352,13 +340,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
)
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
logger.debug(f"Video ready: {video_url}")
|
||||
# Store the generated video to the user's workspace for persistence
|
||||
stored_url = await store_media_file(
|
||||
file=MediaFileType(video_url),
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "video_url", stored_url
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIAdMakerVideoCreatorBlock(Block):
|
||||
@@ -465,10 +447,7 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
||||
],
|
||||
},
|
||||
test_output=(
|
||||
"video_url",
|
||||
lambda x: x.startswith(("workspace://", "data:")),
|
||||
),
|
||||
test_output=("video_url", "https://example.com/ad.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
@@ -477,21 +456,14 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||
"videoUrl": "https://example.com/ad.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
):
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
@@ -559,13 +531,7 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
# Store the generated video to the user's workspace for persistence
|
||||
stored_url = await store_media_file(
|
||||
file=MediaFileType(video_url),
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "video_url", stored_url
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIScreenshotToVideoAdBlock(Block):
|
||||
@@ -660,10 +626,7 @@ class AIScreenshotToVideoAdBlock(Block):
|
||||
"script": "Amazing numbers!",
|
||||
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
||||
},
|
||||
test_output=(
|
||||
"video_url",
|
||||
lambda x: x.startswith(("workspace://", "data:")),
|
||||
),
|
||||
test_output=("video_url", "https://example.com/screenshot.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
@@ -672,21 +635,14 @@ class AIScreenshotToVideoAdBlock(Block):
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||
"videoUrl": "https://example.com/screenshot.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
):
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
@@ -754,10 +710,4 @@ class AIScreenshotToVideoAdBlock(Block):
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
# Store the generated video to the user's workspace for persistence
|
||||
stored_url = await store_media_file(
|
||||
file=MediaFileType(video_url),
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "video_url", stored_url
|
||||
yield "video_url", video_url
|
||||
|
||||
@@ -6,7 +6,6 @@ if TYPE_CHECKING:
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -18,8 +17,6 @@ from backend.sdk import (
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._config import bannerbear
|
||||
|
||||
@@ -138,17 +135,15 @@ class BannerbearTextOverlayBlock(Block):
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
# Output will be a workspace ref or data URI depending on context
|
||||
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
|
||||
("uid", "test-uid-123"),
|
||||
("status", "completed"),
|
||||
],
|
||||
test_mock={
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"_make_api_request": lambda *args, **kwargs: {
|
||||
"uid": "test-uid-123",
|
||||
"status": "completed",
|
||||
"image_url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAABAAEBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+v//Z",
|
||||
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
|
||||
}
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -182,12 +177,7 @@ class BannerbearTextOverlayBlock(Block):
|
||||
raise Exception(error_msg)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Build the modifications array
|
||||
modifications = []
|
||||
@@ -244,18 +234,6 @@ class BannerbearTextOverlayBlock(Block):
|
||||
|
||||
# Synchronous request - image should be ready
|
||||
yield "success", True
|
||||
|
||||
# Store the generated image to workspace for persistence
|
||||
image_url = data.get("image_url", "")
|
||||
if image_url:
|
||||
stored_url = await store_media_file(
|
||||
file=MediaFileType(image_url),
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", stored_url
|
||||
else:
|
||||
yield "image_url", ""
|
||||
|
||||
yield "image_url", data.get("image_url", "")
|
||||
yield "uid", data.get("uid", "")
|
||||
yield "status", data.get("status", "completed")
|
||||
|
||||
@@ -9,7 +9,6 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType, convert
|
||||
@@ -18,10 +17,10 @@ from backend.util.type import MediaFileType, convert
|
||||
class FileStoreBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
file_in: MediaFileType = SchemaField(
|
||||
description="The file to download and store. Can be a URL (https://...), data URI, or local path."
|
||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
||||
)
|
||||
base_64: bool = SchemaField(
|
||||
description="Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks).",
|
||||
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
|
||||
default=False,
|
||||
advanced=True,
|
||||
title="Produce Base64 Output",
|
||||
@@ -29,18 +28,13 @@ class FileStoreBlock(Block):
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
file_out: MediaFileType = SchemaField(
|
||||
description="Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks."
|
||||
description="The relative path to the stored file in the temporary directory."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
|
||||
description=(
|
||||
"Downloads and stores a file from a URL, data URI, or local path. "
|
||||
"Use this to fetch images, documents, or other files for processing. "
|
||||
"In CoPilot: saves to workspace (use list_workspace_files to see it). "
|
||||
"In graphs: outputs a data URI to pass to other blocks."
|
||||
),
|
||||
description="Stores the input file in the temporary directory.",
|
||||
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
|
||||
input_schema=FileStoreBlock.Input,
|
||||
output_schema=FileStoreBlock.Output,
|
||||
@@ -51,18 +45,15 @@ class FileStoreBlock(Block):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Determine return format based on user preference
|
||||
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
|
||||
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
|
||||
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
|
||||
|
||||
yield "file_out", await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_in,
|
||||
execution_context=execution_context,
|
||||
return_format=return_format,
|
||||
user_id=user_id,
|
||||
return_content=input_data.base_64,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import Requests
|
||||
@@ -667,7 +666,8 @@ class SendDiscordFileBlock(Block):
|
||||
file: MediaFileType,
|
||||
filename: str,
|
||||
message_content: str,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
) -> dict:
|
||||
intents = discord.Intents.default()
|
||||
intents.guilds = True
|
||||
@@ -731,9 +731,10 @@ class SendDiscordFileBlock(Block):
|
||||
# Local file path - read from stored media file
|
||||
# This would be a path from a previous block's output
|
||||
stored_file = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api", # Get content to send to Discord
|
||||
user_id=user_id,
|
||||
return_content=True, # Get as data URI
|
||||
)
|
||||
# Now process as data URI
|
||||
header, encoded = stored_file.split(",", 1)
|
||||
@@ -780,7 +781,8 @@ class SendDiscordFileBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
@@ -791,7 +793,8 @@ class SendDiscordFileBlock(Block):
|
||||
file=input_data.file,
|
||||
filename=input_data.filename,
|
||||
message_content=input_data.message_content,
|
||||
execution_context=execution_context,
|
||||
graph_exec_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
yield "status", result.get("status", "Unknown error")
|
||||
|
||||
@@ -17,11 +17,8 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import ClientResponseError, Requests
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -67,13 +64,9 @@ class AIVideoGeneratorBlock(Block):
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
# Output will be a workspace ref or data URI depending on context
|
||||
("video_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||
],
|
||||
test_output=[("video_url", "https://fal.media/files/example/video.mp4")],
|
||||
test_mock={
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"generate_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA"
|
||||
"generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -215,22 +208,11 @@ class AIVideoGeneratorBlock(Block):
|
||||
raise RuntimeError(f"API request failed: {str(e)}")
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: FalCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
self, input_data: Input, *, credentials: FalCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
video_url = await self.generate_video(input_data, credentials)
|
||||
# Store the generated video to the user's workspace for persistence
|
||||
stored_url = await store_media_file(
|
||||
file=MediaFileType(video_url),
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "video_url", stored_url
|
||||
yield "video_url", video_url
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
yield "error", error_message
|
||||
|
||||
@@ -12,7 +12,6 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -122,12 +121,10 @@ class AIImageEditorBlock(Block):
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
# Output will be a workspace ref or data URI depending on context
|
||||
("output_image", lambda x: x.startswith(("workspace://", "data:"))),
|
||||
("output_image", "https://replicate.com/output/edited-image.png"),
|
||||
],
|
||||
test_mock={
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
|
||||
"run_model": lambda *args, **kwargs: "https://replicate.com/output/edited-image.png",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
@@ -137,7 +134,8 @@ class AIImageEditorBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self.run_model(
|
||||
@@ -146,25 +144,20 @@ class AIImageEditorBlock(Block):
|
||||
prompt=input_data.prompt,
|
||||
input_image_b64=(
|
||||
await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.input_image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api", # Get content for Replicate API
|
||||
user_id=user_id,
|
||||
return_content=True,
|
||||
)
|
||||
if input_data.input_image
|
||||
else None
|
||||
),
|
||||
aspect_ratio=input_data.aspect_ratio.value,
|
||||
seed=input_data.seed,
|
||||
user_id=execution_context.user_id or "",
|
||||
graph_exec_id=execution_context.graph_exec_id or "",
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
# Store the generated image to the user's workspace for persistence
|
||||
stored_url = await store_media_file(
|
||||
file=result,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "output_image", stored_url
|
||||
yield "output_image", result
|
||||
|
||||
async def run_model(
|
||||
self,
|
||||
|
||||
@@ -21,7 +21,6 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
from backend.util.settings import Settings
|
||||
@@ -96,7 +95,8 @@ def _make_mime_text(
|
||||
|
||||
async def create_mime_message(
|
||||
input_data,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
) -> str:
|
||||
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
||||
|
||||
@@ -117,12 +117,12 @@ async def create_mime_message(
|
||||
if input_data.attachments:
|
||||
for attach in input_data.attachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
return_content=False,
|
||||
)
|
||||
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
@@ -582,25 +582,27 @@ class GmailSendBlock(GmailBase):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
result = await self._send_email(
|
||||
service,
|
||||
input_data,
|
||||
execution_context,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
)
|
||||
yield "result", result
|
||||
|
||||
async def _send_email(
|
||||
self, service, input_data: Input, execution_context: ExecutionContext
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
if not input_data.to or not input_data.subject or not input_data.body:
|
||||
raise ValueError(
|
||||
"At least one recipient, subject, and body are required for sending an email"
|
||||
)
|
||||
raw_message = await create_mime_message(input_data, execution_context)
|
||||
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
|
||||
sent_message = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
@@ -690,28 +692,30 @@ class GmailCreateDraftBlock(GmailBase):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
result = await self._create_draft(
|
||||
service,
|
||||
input_data,
|
||||
execution_context,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
)
|
||||
yield "result", GmailDraftResult(
|
||||
id=result["id"], message_id=result["message"]["id"], status="draft_created"
|
||||
)
|
||||
|
||||
async def _create_draft(
|
||||
self, service, input_data: Input, execution_context: ExecutionContext
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
if not input_data.to or not input_data.subject:
|
||||
raise ValueError(
|
||||
"At least one recipient and subject are required for creating a draft"
|
||||
)
|
||||
|
||||
raw_message = await create_mime_message(input_data, execution_context)
|
||||
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
|
||||
draft = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.drafts()
|
||||
@@ -1096,7 +1100,7 @@ class GmailGetThreadBlock(GmailBase):
|
||||
|
||||
|
||||
async def _build_reply_message(
|
||||
service, input_data, execution_context: ExecutionContext
|
||||
service, input_data, graph_exec_id: str, user_id: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Builds a reply MIME message for Gmail threads.
|
||||
@@ -1186,12 +1190,12 @@ async def _build_reply_message(
|
||||
# Handle attachments
|
||||
for attach in input_data.attachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
return_content=False,
|
||||
)
|
||||
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
@@ -1307,14 +1311,16 @@ class GmailReplyBlock(GmailBase):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
message = await self._reply(
|
||||
service,
|
||||
input_data,
|
||||
execution_context,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
)
|
||||
yield "messageId", message["id"]
|
||||
yield "threadId", message.get("threadId", input_data.threadId)
|
||||
@@ -1337,11 +1343,11 @@ class GmailReplyBlock(GmailBase):
|
||||
yield "email", email
|
||||
|
||||
async def _reply(
|
||||
self, service, input_data: Input, execution_context: ExecutionContext
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
# Build the reply message using the shared helper
|
||||
raw, thread_id = await _build_reply_message(
|
||||
service, input_data, execution_context
|
||||
service, input_data, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
# Send the message
|
||||
@@ -1435,14 +1441,16 @@ class GmailDraftReplyBlock(GmailBase):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
draft = await self._create_draft_reply(
|
||||
service,
|
||||
input_data,
|
||||
execution_context,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
)
|
||||
yield "draftId", draft["id"]
|
||||
yield "messageId", draft["message"]["id"]
|
||||
@@ -1450,11 +1458,11 @@ class GmailDraftReplyBlock(GmailBase):
|
||||
yield "status", "draft_created"
|
||||
|
||||
async def _create_draft_reply(
|
||||
self, service, input_data: Input, execution_context: ExecutionContext
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
# Build the reply message using the shared helper
|
||||
raw, thread_id = await _build_reply_message(
|
||||
service, input_data, execution_context
|
||||
service, input_data, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
# Create draft with proper thread association
|
||||
@@ -1621,21 +1629,23 @@ class GmailForwardBlock(GmailBase):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
result = await self._forward_message(
|
||||
service,
|
||||
input_data,
|
||||
execution_context,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
)
|
||||
yield "messageId", result["id"]
|
||||
yield "threadId", result.get("threadId", "")
|
||||
yield "status", "forwarded"
|
||||
|
||||
async def _forward_message(
|
||||
self, service, input_data: Input, execution_context: ExecutionContext
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
if not input_data.to:
|
||||
raise ValueError("At least one recipient is required for forwarding")
|
||||
@@ -1717,12 +1727,12 @@ To: {original_to}
|
||||
# Add any additional attachments
|
||||
for attach in input_data.additionalAttachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
return_content=False,
|
||||
)
|
||||
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
|
||||
@@ -15,7 +15,6 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
@@ -117,9 +116,10 @@ class SendWebRequestBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
async def _prepare_files(
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
files_name: str,
|
||||
files: list[MediaFileType],
|
||||
user_id: str,
|
||||
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
|
||||
"""
|
||||
Prepare files for the request by storing them and reading their content.
|
||||
@@ -127,16 +127,11 @@ class SendWebRequestBlock(Block):
|
||||
(files_name, (filename, BytesIO, mime_type))
|
||||
"""
|
||||
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
if graph_exec_id is None:
|
||||
raise ValueError("graph_exec_id is required for file operations")
|
||||
|
||||
for media in files:
|
||||
# Normalise to a list so we can repeat the same key
|
||||
rel_path = await store_media_file(
|
||||
file=media,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
graph_exec_id, media, user_id, return_content=False
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, rel_path)
|
||||
async with aiofiles.open(abs_path, "rb") as f:
|
||||
@@ -148,7 +143,7 @@ class SendWebRequestBlock(Block):
|
||||
return files_payload
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, execution_context: ExecutionContext, **kwargs
|
||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
# ─── Parse/normalise body ────────────────────────────────────
|
||||
body = input_data.body
|
||||
@@ -179,7 +174,7 @@ class SendWebRequestBlock(Block):
|
||||
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
||||
if use_files:
|
||||
files_payload = await self._prepare_files(
|
||||
execution_context, input_data.files_name, input_data.files
|
||||
graph_exec_id, input_data.files_name, input_data.files, user_id
|
||||
)
|
||||
|
||||
# Enforce body format rules
|
||||
@@ -243,8 +238,9 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
credentials: HostScopedCredentials,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
||||
@@ -275,6 +271,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
||||
|
||||
# Use parent class run method
|
||||
async for output_name, output_data in super().run(
|
||||
base_input, execution_context=execution_context, **kwargs
|
||||
base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs
|
||||
):
|
||||
yield output_name, output_data
|
||||
|
||||
@@ -12,7 +12,6 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.mock import MockObject
|
||||
@@ -463,23 +462,18 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
if not input_data.value:
|
||||
return
|
||||
|
||||
# Determine return format based on user preference
|
||||
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
|
||||
# for_local_processing: returns local file path
|
||||
return_format = (
|
||||
"for_external_api" if input_data.base_64 else "for_local_processing"
|
||||
)
|
||||
|
||||
yield "result", await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.value,
|
||||
execution_context=execution_context,
|
||||
return_format=return_format,
|
||||
user_id=user_id,
|
||||
return_content=input_data.base_64,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||
from moviepy.video.fx.Loop import Loop
|
||||
@@ -13,7 +13,6 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
@@ -47,19 +46,18 @@ class MediaDurationBlock(Block):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# 1) Store the input media locally
|
||||
local_media_path = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.media_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
assert execution_context.graph_exec_id is not None
|
||||
media_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_media_path
|
||||
user_id=user_id,
|
||||
return_content=False,
|
||||
)
|
||||
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
|
||||
|
||||
# 2) Load the clip
|
||||
if input_data.is_video:
|
||||
@@ -90,6 +88,10 @@ class LoopVideoBlock(Block):
|
||||
default=None,
|
||||
ge=1,
|
||||
)
|
||||
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
||||
description="How to return the output video. Either a relative path or base64 data URI.",
|
||||
default="file_path",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: str = SchemaField(
|
||||
@@ -109,19 +111,17 @@ class LoopVideoBlock(Block):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the input video locally
|
||||
local_video_path = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.video_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
user_id=user_id,
|
||||
return_content=False,
|
||||
)
|
||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||
|
||||
@@ -149,11 +149,12 @@ class LoopVideoBlock(Block):
|
||||
looped_clip = looped_clip.with_audio(clip.audio)
|
||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||
# Return as data URI
|
||||
video_out = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=output_filename,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
user_id=user_id,
|
||||
return_content=input_data.output_return_type == "data_uri",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
@@ -176,6 +177,10 @@ class AddAudioToVideoBlock(Block):
|
||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||
default=1.0,
|
||||
)
|
||||
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
||||
description="Return the final output as a relative path or base64 data URI.",
|
||||
default="file_path",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
@@ -195,24 +200,23 @@ class AddAudioToVideoBlock(Block):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the inputs locally
|
||||
local_video_path = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.video_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
user_id=user_id,
|
||||
return_content=False,
|
||||
)
|
||||
local_audio_path = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.audio_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
user_id=user_id,
|
||||
return_content=False,
|
||||
)
|
||||
|
||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
||||
@@ -236,11 +240,12 @@ class AddAudioToVideoBlock(Block):
|
||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||
# 5) Return either path or data URI
|
||||
video_out = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=output_filename,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
user_id=user_id,
|
||||
return_content=input_data.output_return_type == "data_uri",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
|
||||
@@ -11,7 +11,6 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -113,7 +112,8 @@ class ScreenshotWebPageBlock(Block):
|
||||
@staticmethod
|
||||
async def take_screenshot(
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
url: str,
|
||||
viewport_width: int,
|
||||
viewport_height: int,
|
||||
@@ -155,11 +155,12 @@ class ScreenshotWebPageBlock(Block):
|
||||
|
||||
return {
|
||||
"image": await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=MediaFileType(
|
||||
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
|
||||
),
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
user_id=user_id,
|
||||
return_content=True,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -168,13 +169,15 @@ class ScreenshotWebPageBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
screenshot_data = await self.take_screenshot(
|
||||
credentials=credentials,
|
||||
execution_context=execution_context,
|
||||
graph_exec_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
url=input_data.url,
|
||||
viewport_width=input_data.viewport_width,
|
||||
viewport_height=input_data.viewport_height,
|
||||
|
||||
@@ -7,7 +7,6 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import ContributorDetails, SchemaField
|
||||
from backend.util.file import get_exec_file_path, store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
@@ -99,7 +98,7 @@ class ReadSpreadsheetBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
|
||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
||||
) -> BlockOutput:
|
||||
import csv
|
||||
from io import StringIO
|
||||
@@ -107,16 +106,14 @@ class ReadSpreadsheetBlock(Block):
|
||||
# Determine data source - prefer file_input if provided, otherwise use contents
|
||||
if input_data.file_input:
|
||||
stored_file_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_input,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
return_content=False,
|
||||
)
|
||||
|
||||
# Get full file path
|
||||
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||
file_path = get_exec_file_path(
|
||||
execution_context.graph_exec_id, stored_file_path
|
||||
)
|
||||
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
||||
if not Path(file_path).exists():
|
||||
raise ValueError(f"File does not exist: {file_path}")
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -18,9 +17,7 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import Requests
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -105,7 +102,7 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
test_output=[
|
||||
(
|
||||
"video_url",
|
||||
lambda x: x.startswith(("workspace://", "data:")),
|
||||
"https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
@@ -113,10 +110,9 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
|
||||
"status": "created",
|
||||
},
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"get_clip_status": lambda *args, **kwargs: {
|
||||
"status": "done",
|
||||
"result_url": "data:video/mp4;base64,AAAA",
|
||||
"result_url": "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
|
||||
},
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -142,12 +138,7 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
return response.json()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Create the clip
|
||||
payload = {
|
||||
@@ -174,14 +165,7 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
for _ in range(input_data.max_polling_attempts):
|
||||
status_response = await self.get_clip_status(credentials.api_key, clip_id)
|
||||
if status_response["status"] == "done":
|
||||
# Store the generated video to the user's workspace for persistence
|
||||
video_url = status_response["result_url"]
|
||||
stored_url = await store_media_file(
|
||||
file=MediaFileType(video_url),
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "video_url", stored_url
|
||||
yield "video_url", status_response["result_url"]
|
||||
return
|
||||
elif status_response["status"] == "error":
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -12,7 +12,6 @@ from backend.blocks.iteration import StepThroughItemsBlock
|
||||
from backend.blocks.llm import AITextSummarizerBlock
|
||||
from backend.blocks.text import ExtractTextInformationBlock
|
||||
from backend.blocks.xml_parser import XMLParserBlock
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
@@ -234,12 +233,9 @@ class TestStoreMediaFileSecurity:
|
||||
|
||||
with pytest.raises(ValueError, match="File too large"):
|
||||
await store_media_file(
|
||||
graph_exec_id="test",
|
||||
file=MediaFileType(large_data_uri),
|
||||
execution_context=ExecutionContext(
|
||||
user_id="test_user",
|
||||
graph_exec_id="test",
|
||||
),
|
||||
return_format="for_local_processing",
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
@patch("backend.util.file.Path")
|
||||
@@ -274,12 +270,9 @@ class TestStoreMediaFileSecurity:
|
||||
# Should raise an error when directory size exceeds limit
|
||||
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
|
||||
await store_media_file(
|
||||
graph_exec_id="test",
|
||||
file=MediaFileType(
|
||||
"data:text/plain;base64,dGVzdA=="
|
||||
), # Small test file
|
||||
execution_context=ExecutionContext(
|
||||
user_id="test_user",
|
||||
graph_exec_id="test",
|
||||
),
|
||||
return_format="for_local_processing",
|
||||
user_id="test_user",
|
||||
)
|
||||
|
||||
@@ -11,22 +11,10 @@ from backend.blocks.http import (
|
||||
HttpMethod,
|
||||
SendAuthenticatedWebRequestBlock,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import HostScopedCredentials
|
||||
from backend.util.request import Response
|
||||
|
||||
|
||||
def make_test_context(
|
||||
graph_exec_id: str = "test-exec-id",
|
||||
user_id: str = "test-user-id",
|
||||
) -> ExecutionContext:
|
||||
"""Helper to create test ExecutionContext."""
|
||||
return ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
|
||||
class TestHttpBlockWithHostScopedCredentials:
|
||||
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
||||
|
||||
@@ -117,7 +105,8 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
execution_context=make_test_context(),
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -172,7 +161,8 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=wildcard_credentials,
|
||||
execution_context=make_test_context(),
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -218,7 +208,8 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=non_matching_credentials,
|
||||
execution_context=make_test_context(),
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -267,7 +258,8 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
execution_context=make_test_context(),
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -326,7 +318,8 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=auto_discovered_creds, # Execution manager found these
|
||||
execution_context=make_test_context(),
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -389,7 +382,8 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=multi_header_creds,
|
||||
execution_context=make_test_context(),
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -477,7 +471,8 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=test_creds,
|
||||
execution_context=make_test_context(),
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json, text
|
||||
from backend.util.file import get_exec_file_path, store_media_file
|
||||
@@ -445,21 +444,18 @@ class FileReadBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
|
||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
||||
) -> BlockOutput:
|
||||
# Store the media file properly (handles URLs, data URIs, etc.)
|
||||
stored_file_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_input,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
return_content=False,
|
||||
)
|
||||
|
||||
# Get full file path (graph_exec_id validated by store_media_file above)
|
||||
if not execution_context.graph_exec_id:
|
||||
raise ValueError("execution_context.graph_exec_id is required")
|
||||
file_path = get_exec_file_path(
|
||||
execution_context.graph_exec_id, stored_file_path
|
||||
)
|
||||
# Get full file path
|
||||
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
||||
|
||||
if not Path(file_path).exists():
|
||||
raise ValueError(f"File does not exist: {file_path}")
|
||||
|
||||
@@ -26,6 +26,31 @@ def add_param(url: str, key: str, value: str) -> str:
|
||||
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||
|
||||
# Extract the application schema from DATABASE_URL for use in queries
|
||||
_parsed = urlparse(DATABASE_URL)
|
||||
_query_params = dict(parse_qsl(_parsed.query))
|
||||
_app_schema = _query_params.get("schema", "public")
|
||||
|
||||
# Build search_path that includes app schema and extension schemas where pgvector may live.
|
||||
# This is used both in connection options (may be ignored by PgBouncer) and in SET LOCAL
|
||||
# statements before raw queries (guaranteed to work).
|
||||
SEARCH_PATH = (
|
||||
f"{_app_schema},extensions,public"
|
||||
if _app_schema != "public"
|
||||
else "public,extensions"
|
||||
)
|
||||
|
||||
# Try to set search_path via PostgreSQL options parameter at connection time.
|
||||
# NOTE: This may be ignored by PgBouncer in transaction pooling mode.
|
||||
# As a fallback, we also SET LOCAL search_path before raw queries.
|
||||
if "options" in _query_params:
|
||||
_query_params["options"] = (
|
||||
_query_params["options"] + f" -c search_path={SEARCH_PATH}"
|
||||
)
|
||||
else:
|
||||
_query_params["options"] = f"-c search_path={SEARCH_PATH}"
|
||||
DATABASE_URL = urlunparse(_parsed._replace(query=urlencode(_query_params)))
|
||||
|
||||
CONN_LIMIT = os.getenv("DB_CONNECTION_LIMIT")
|
||||
if CONN_LIMIT:
|
||||
DATABASE_URL = add_param(DATABASE_URL, "connection_limit", CONN_LIMIT)
|
||||
@@ -108,6 +133,70 @@ def get_database_schema() -> str:
|
||||
return query_params.get("schema", "public")
|
||||
|
||||
|
||||
def get_pod_info() -> dict:
|
||||
"""Get information about the current pod/host.
|
||||
|
||||
Returns dict with: hostname, pod_name (from HOSTNAME env var in k8s),
|
||||
pod_namespace, pod_ip if available.
|
||||
"""
|
||||
import socket
|
||||
|
||||
return {
|
||||
"hostname": socket.gethostname(),
|
||||
"pod_name": os.getenv("HOSTNAME", "unknown"),
|
||||
"pod_namespace": os.getenv("POD_NAMESPACE", "unknown"),
|
||||
"pod_ip": os.getenv("POD_IP", "unknown"),
|
||||
}
|
||||
|
||||
|
||||
async def get_connection_debug_info(tx=None) -> dict:
|
||||
"""Get diagnostic info about the current database connection and pod.
|
||||
|
||||
Useful for debugging "table does not exist" or "type does not exist" errors
|
||||
that may indicate connections going to different database instances or pods.
|
||||
|
||||
Args:
|
||||
tx: Optional transaction client to use for the query (ensures same connection)
|
||||
|
||||
Returns dict with: search_path, current_schema, server_version, pg_backend_pid,
|
||||
pgvector_installed, pgvector_schema, plus pod info
|
||||
"""
|
||||
import prisma as prisma_module
|
||||
|
||||
pod_info = get_pod_info()
|
||||
db_client = tx if tx else prisma_module.get_client()
|
||||
|
||||
try:
|
||||
# Get connection info and check for pgvector in a single query
|
||||
result = await db_client.query_raw(
|
||||
"""
|
||||
SELECT
|
||||
current_setting('search_path') as search_path,
|
||||
current_schema() as current_schema,
|
||||
current_database() as current_database,
|
||||
inet_server_addr() as server_addr,
|
||||
inet_server_port() as server_port,
|
||||
pg_backend_pid() as backend_pid,
|
||||
version() as server_version,
|
||||
(SELECT EXISTS(
|
||||
SELECT 1 FROM pg_extension WHERE extname = 'vector'
|
||||
)) as pgvector_installed,
|
||||
(SELECT nspname FROM pg_extension e
|
||||
JOIN pg_namespace n ON e.extnamespace = n.oid
|
||||
WHERE e.extname = 'vector'
|
||||
LIMIT 1) as pgvector_schema,
|
||||
(SELECT string_agg(extname || ' in ' || nspname, ', ')
|
||||
FROM pg_extension e
|
||||
JOIN pg_namespace n ON e.extnamespace = n.oid
|
||||
) as all_extensions
|
||||
"""
|
||||
)
|
||||
db_info = result[0] if result else {}
|
||||
return {**pod_info, **db_info}
|
||||
except Exception as e:
|
||||
return {**pod_info, "db_error": str(e)}
|
||||
|
||||
|
||||
async def _raw_with_schema(
|
||||
query_template: str,
|
||||
*args,
|
||||
@@ -124,8 +213,9 @@ async def _raw_with_schema(
|
||||
|
||||
Note on pgvector types:
|
||||
Use unqualified ::vector and <=> operator in queries. PostgreSQL resolves
|
||||
these via search_path, which includes the schema where pgvector is installed
|
||||
on all environments (local, CI, dev).
|
||||
these via search_path. The connection's search_path is configured at module
|
||||
load to include common extension schemas (public, extensions) where pgvector
|
||||
may be installed across different environments (local, CI, Supabase).
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix} and/or {schema} placeholders
|
||||
@@ -155,12 +245,60 @@ async def _raw_with_schema(
|
||||
|
||||
db_client = client if client else prisma_module.get_client()
|
||||
|
||||
if execute:
|
||||
result = await db_client.execute_raw(formatted_query, *args) # type: ignore
|
||||
else:
|
||||
result = await db_client.query_raw(formatted_query, *args) # type: ignore
|
||||
# For queries that might use pgvector types (::vector or <=> operator),
|
||||
# we need to ensure search_path includes the schema where pgvector is installed.
|
||||
# PgBouncer in transaction mode may ignore connection-level options, so we
|
||||
# use SET LOCAL within a transaction to guarantee correct search_path.
|
||||
needs_vector_search_path = "::vector" in formatted_query or "<=>" in formatted_query
|
||||
|
||||
return result
|
||||
try:
|
||||
if needs_vector_search_path and client is None:
|
||||
# Use transaction to set search_path for vector queries
|
||||
async with db_client.tx() as tx:
|
||||
# Log debug info BEFORE the query to capture which backend we're hitting
|
||||
debug_info = await get_connection_debug_info(tx)
|
||||
logger.info(
|
||||
f"Vector query starting. backend_pid={debug_info.get('backend_pid')}, "
|
||||
f"server_addr={debug_info.get('server_addr')}, "
|
||||
f"pgvector_installed={debug_info.get('pgvector_installed')}, "
|
||||
f"pgvector_schema={debug_info.get('pgvector_schema')}, "
|
||||
f"search_path={debug_info.get('search_path')}, "
|
||||
f"pod={debug_info.get('pod_name')}"
|
||||
)
|
||||
|
||||
await tx.execute_raw(f"SET LOCAL search_path TO {SEARCH_PATH}")
|
||||
if execute:
|
||||
result = await tx.execute_raw(formatted_query, *args) # type: ignore
|
||||
else:
|
||||
result = await tx.query_raw(formatted_query, *args) # type: ignore
|
||||
|
||||
logger.info(
|
||||
f"Vector query SUCCESS. backend_pid={debug_info.get('backend_pid')}"
|
||||
)
|
||||
else:
|
||||
# Regular query without vector types, or already in a transaction
|
||||
if execute:
|
||||
result = await db_client.execute_raw(formatted_query, *args) # type: ignore
|
||||
else:
|
||||
result = await db_client.query_raw(formatted_query, *args) # type: ignore
|
||||
return result
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# Log connection debug info for "does not exist" errors to help diagnose
|
||||
# whether connections are going to different database instances
|
||||
if "does not exist" in error_msg:
|
||||
try:
|
||||
debug_info = await get_connection_debug_info()
|
||||
logger.error(
|
||||
f"Vector query FAILED. Connection debug info: {debug_info}. "
|
||||
f"Query template: {query_template[:200]}... Error: {error_msg}"
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Vector query FAILED (debug info unavailable). "
|
||||
f"Query template: {query_template[:200]}... Error: {error_msg}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||
|
||||
@@ -83,29 +83,12 @@ class ExecutionContext(BaseModel):
|
||||
|
||||
model_config = {"extra": "ignore"}
|
||||
|
||||
# Execution identity
|
||||
user_id: Optional[str] = None
|
||||
graph_id: Optional[str] = None
|
||||
graph_exec_id: Optional[str] = None
|
||||
graph_version: Optional[int] = None
|
||||
node_id: Optional[str] = None
|
||||
node_exec_id: Optional[str] = None
|
||||
|
||||
# Safety settings
|
||||
human_in_the_loop_safe_mode: bool = True
|
||||
sensitive_action_safe_mode: bool = False
|
||||
|
||||
# User settings
|
||||
user_timezone: str = "UTC"
|
||||
|
||||
# Execution hierarchy
|
||||
root_execution_id: Optional[str] = None
|
||||
parent_execution_id: Optional[str] = None
|
||||
|
||||
# Workspace
|
||||
workspace_id: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
# -------------------------- Models -------------------------- #
|
||||
|
||||
|
||||
@@ -41,7 +41,6 @@ FrontendOnboardingStep = Literal[
|
||||
OnboardingStep.AGENT_NEW_RUN,
|
||||
OnboardingStep.AGENT_INPUT,
|
||||
OnboardingStep.CONGRATS,
|
||||
OnboardingStep.VISIT_COPILOT,
|
||||
OnboardingStep.MARKETPLACE_VISIT,
|
||||
OnboardingStep.BUILDER_OPEN,
|
||||
]
|
||||
@@ -123,9 +122,6 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
||||
async def _reward_user(user_id: str, onboarding: UserOnboarding, step: OnboardingStep):
|
||||
reward = 0
|
||||
match step:
|
||||
# Welcome bonus for visiting copilot ($5 = 500 credits)
|
||||
case OnboardingStep.VISIT_COPILOT:
|
||||
reward = 500
|
||||
# Reward user when they clicked New Run during onboarding
|
||||
# This is because they need credits before scheduling a run (next step)
|
||||
# This is seen as a reward for the GET_RESULTS step in the wallet
|
||||
|
||||
@@ -216,7 +216,27 @@ async def get_business_understanding(
|
||||
|
||||
# Cache miss - load from database
|
||||
logger.debug(f"Business understanding cache miss for user {user_id}")
|
||||
record = await CoPilotUnderstanding.prisma().find_unique(where={"userId": user_id})
|
||||
try:
|
||||
record = await CoPilotUnderstanding.prisma().find_unique(where={"userId": user_id})
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if "does not exist" in error_msg:
|
||||
# Log connection debug info to diagnose if connections go to different DBs
|
||||
from backend.data.db import get_connection_debug_info
|
||||
|
||||
try:
|
||||
debug_info = await get_connection_debug_info()
|
||||
logger.error(
|
||||
f"CoPilotUnderstanding table not found. Connection debug: {debug_info}. "
|
||||
f"Error: {error_msg}"
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"CoPilotUnderstanding table not found (debug unavailable). "
|
||||
f"Error: {error_msg}"
|
||||
)
|
||||
raise
|
||||
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,285 +0,0 @@
|
||||
"""
|
||||
Database CRUD operations for User Workspace.
|
||||
|
||||
This module provides functions for managing user workspaces and workspace files.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from prisma.enums import WorkspaceFileSource
|
||||
from prisma.models import UserWorkspace, UserWorkspaceFile
|
||||
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
|
||||
"""
|
||||
Get user's workspace, creating one if it doesn't exist.
|
||||
|
||||
Uses upsert to handle race conditions when multiple concurrent requests
|
||||
attempt to create a workspace for the same user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
UserWorkspace instance
|
||||
"""
|
||||
workspace = await UserWorkspace.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id},
|
||||
"update": {}, # No updates needed if exists
|
||||
},
|
||||
)
|
||||
|
||||
return workspace
|
||||
|
||||
|
||||
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
|
||||
"""
|
||||
Get user's workspace if it exists.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
UserWorkspace instance or None
|
||||
"""
|
||||
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
||||
|
||||
|
||||
async def create_workspace_file(
|
||||
workspace_id: str,
|
||||
file_id: str,
|
||||
name: str,
|
||||
path: str,
|
||||
storage_path: str,
|
||||
mime_type: str,
|
||||
size_bytes: int,
|
||||
checksum: Optional[str] = None,
|
||||
source: WorkspaceFileSource = WorkspaceFileSource.UPLOAD,
|
||||
source_exec_id: Optional[str] = None,
|
||||
source_session_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> UserWorkspaceFile:
|
||||
"""
|
||||
Create a new workspace file record.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID
|
||||
file_id: The file ID (same as used in storage path for consistency)
|
||||
name: User-visible filename
|
||||
path: Virtual path (e.g., "/documents/report.pdf")
|
||||
storage_path: Actual storage path (GCS or local)
|
||||
mime_type: MIME type of the file
|
||||
size_bytes: File size in bytes
|
||||
checksum: Optional SHA256 checksum
|
||||
source: How the file was created
|
||||
source_exec_id: Graph execution ID if from execution
|
||||
source_session_id: Chat session ID if from CoPilot
|
||||
metadata: Optional additional metadata
|
||||
|
||||
Returns:
|
||||
Created UserWorkspaceFile instance
|
||||
"""
|
||||
# Normalize path to start with /
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
|
||||
file = await UserWorkspaceFile.prisma().create(
|
||||
data={
|
||||
"id": file_id,
|
||||
"workspaceId": workspace_id,
|
||||
"name": name,
|
||||
"path": path,
|
||||
"storagePath": storage_path,
|
||||
"mimeType": mime_type,
|
||||
"sizeBytes": size_bytes,
|
||||
"checksum": checksum,
|
||||
"source": source,
|
||||
"sourceExecId": source_exec_id,
|
||||
"sourceSessionId": source_session_id,
|
||||
"metadata": SafeJson(metadata or {}),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created workspace file {file.id} at path {path} "
|
||||
f"in workspace {workspace_id}"
|
||||
)
|
||||
return file
|
||||
|
||||
|
||||
async def get_workspace_file(
|
||||
file_id: str,
|
||||
workspace_id: Optional[str] = None,
|
||||
) -> Optional[UserWorkspaceFile]:
|
||||
"""
|
||||
Get a workspace file by ID.
|
||||
|
||||
Args:
|
||||
file_id: The file ID
|
||||
workspace_id: Optional workspace ID for validation
|
||||
|
||||
Returns:
|
||||
UserWorkspaceFile instance or None
|
||||
"""
|
||||
where_clause: dict = {"id": file_id, "isDeleted": False}
|
||||
if workspace_id:
|
||||
where_clause["workspaceId"] = workspace_id
|
||||
|
||||
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
|
||||
|
||||
|
||||
async def get_workspace_file_by_path(
|
||||
workspace_id: str,
|
||||
path: str,
|
||||
) -> Optional[UserWorkspaceFile]:
|
||||
"""
|
||||
Get a workspace file by its virtual path.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID
|
||||
path: Virtual path
|
||||
|
||||
Returns:
|
||||
UserWorkspaceFile instance or None
|
||||
"""
|
||||
# Normalize path
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
|
||||
return await UserWorkspaceFile.prisma().find_first(
|
||||
where={
|
||||
"workspaceId": workspace_id,
|
||||
"path": path,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def list_workspace_files(
|
||||
workspace_id: str,
|
||||
path_prefix: Optional[str] = None,
|
||||
include_deleted: bool = False,
|
||||
limit: Optional[int] = None,
|
||||
offset: int = 0,
|
||||
) -> list[UserWorkspaceFile]:
|
||||
"""
|
||||
List files in a workspace.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID
|
||||
path_prefix: Optional path prefix to filter (e.g., "/documents/")
|
||||
include_deleted: Whether to include soft-deleted files
|
||||
limit: Maximum number of files to return
|
||||
offset: Number of files to skip
|
||||
|
||||
Returns:
|
||||
List of UserWorkspaceFile instances
|
||||
"""
|
||||
where_clause: dict = {"workspaceId": workspace_id}
|
||||
|
||||
if not include_deleted:
|
||||
where_clause["isDeleted"] = False
|
||||
|
||||
if path_prefix:
|
||||
# Normalize prefix
|
||||
if not path_prefix.startswith("/"):
|
||||
path_prefix = f"/{path_prefix}"
|
||||
where_clause["path"] = {"startswith": path_prefix}
|
||||
|
||||
return await UserWorkspaceFile.prisma().find_many(
|
||||
where=where_clause,
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
)
|
||||
|
||||
|
||||
async def count_workspace_files(
|
||||
workspace_id: str,
|
||||
path_prefix: Optional[str] = None,
|
||||
include_deleted: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Count files in a workspace.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID
|
||||
path_prefix: Optional path prefix to filter (e.g., "/sessions/abc123/")
|
||||
include_deleted: Whether to include soft-deleted files
|
||||
|
||||
Returns:
|
||||
Number of files
|
||||
"""
|
||||
where_clause: dict = {"workspaceId": workspace_id}
|
||||
if not include_deleted:
|
||||
where_clause["isDeleted"] = False
|
||||
|
||||
if path_prefix:
|
||||
# Normalize prefix
|
||||
if not path_prefix.startswith("/"):
|
||||
path_prefix = f"/{path_prefix}"
|
||||
where_clause["path"] = {"startswith": path_prefix}
|
||||
|
||||
return await UserWorkspaceFile.prisma().count(where=where_clause)
|
||||
|
||||
|
||||
async def soft_delete_workspace_file(
|
||||
file_id: str,
|
||||
workspace_id: Optional[str] = None,
|
||||
) -> Optional[UserWorkspaceFile]:
|
||||
"""
|
||||
Soft-delete a workspace file.
|
||||
|
||||
The path is modified to include a deletion timestamp to free up the original
|
||||
path for new files while preserving the record for potential recovery.
|
||||
|
||||
Args:
|
||||
file_id: The file ID
|
||||
workspace_id: Optional workspace ID for validation
|
||||
|
||||
Returns:
|
||||
Updated UserWorkspaceFile instance or None if not found
|
||||
"""
|
||||
# First verify the file exists and belongs to workspace
|
||||
file = await get_workspace_file(file_id, workspace_id)
|
||||
if file is None:
|
||||
return None
|
||||
|
||||
deleted_at = datetime.now(timezone.utc)
|
||||
# Modify path to free up the unique constraint for new files at original path
|
||||
# Format: {original_path}__deleted__{timestamp}
|
||||
deleted_path = f"{file.path}__deleted__{int(deleted_at.timestamp())}"
|
||||
|
||||
updated = await UserWorkspaceFile.prisma().update(
|
||||
where={"id": file_id},
|
||||
data={
|
||||
"isDeleted": True,
|
||||
"deletedAt": deleted_at,
|
||||
"path": deleted_path,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Soft-deleted workspace file {file_id}")
|
||||
return updated
|
||||
|
||||
|
||||
async def get_workspace_total_size(workspace_id: str) -> int:
|
||||
"""
|
||||
Get the total size of all files in a workspace.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID
|
||||
|
||||
Returns:
|
||||
Total size in bytes
|
||||
"""
|
||||
files = await list_workspace_files(workspace_id)
|
||||
return sum(file.sizeBytes for file in files)
|
||||
@@ -236,14 +236,7 @@ async def execute_node(
|
||||
input_size = len(input_data_str)
|
||||
log_metadata.debug("Executed node with input", input=input_data_str)
|
||||
|
||||
# Create node-specific execution context to avoid race conditions
|
||||
# (multiple nodes can execute concurrently and would otherwise mutate shared state)
|
||||
execution_context = execution_context.model_copy(
|
||||
update={"node_id": node_id, "node_exec_id": node_exec_id}
|
||||
)
|
||||
|
||||
# Inject extra execution arguments for the blocks via kwargs
|
||||
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||
extra_exec_kwargs: dict = {
|
||||
"graph_id": graph_id,
|
||||
"graph_version": graph_version,
|
||||
|
||||
@@ -892,19 +892,11 @@ async def add_graph_execution(
|
||||
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
# Execution identity
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec.id,
|
||||
graph_version=graph_exec.graph_version,
|
||||
# Safety settings
|
||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||
# User settings
|
||||
user_timezone=(
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
),
|
||||
# Execution hierarchy
|
||||
root_execution_id=graph_exec.id,
|
||||
)
|
||||
|
||||
|
||||
@@ -348,7 +348,6 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||
mock_graph_exec.graph_version = graph_version
|
||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Mock the queue and event bus
|
||||
@@ -435,9 +434,6 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
# Create a second mock execution for the sanity check
|
||||
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec_2.id = "execution-id-456"
|
||||
mock_graph_exec_2.node_executions = []
|
||||
mock_graph_exec_2.status = ExecutionStatus.QUEUED
|
||||
mock_graph_exec_2.graph_version = graph_version
|
||||
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Reset mocks and set up for second call
|
||||
@@ -618,7 +614,6 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = []
|
||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||
mock_graph_exec.graph_version = graph_version
|
||||
|
||||
# Track what's passed to to_graph_execution_entry
|
||||
captured_kwargs = {}
|
||||
|
||||
@@ -13,7 +13,6 @@ import aiohttp
|
||||
from gcloud.aio import storage as async_gcs_storage
|
||||
from google.cloud import storage as gcs_storage
|
||||
|
||||
from backend.util.gcs_utils import download_with_fresh_session, generate_signed_url
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -252,7 +251,7 @@ class CloudStorageHandler:
|
||||
f"in_task: {current_task is not None}"
|
||||
)
|
||||
|
||||
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
|
||||
# Parse bucket and blob name from path
|
||||
parts = path.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Invalid GCS path: {path}")
|
||||
@@ -262,19 +261,50 @@ class CloudStorageHandler:
|
||||
# Authorization check
|
||||
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
||||
|
||||
logger.info(
|
||||
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
|
||||
# Use a fresh client for each download to avoid session issues
|
||||
# This is less efficient but more reliable with the executor's event loop
|
||||
logger.info("[CloudStorage] Creating fresh GCS client for download")
|
||||
|
||||
# Create a new session specifically for this download
|
||||
session = aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(limit=10, force_close=True)
|
||||
)
|
||||
|
||||
async_client = None
|
||||
try:
|
||||
content = await download_with_fresh_session(bucket_name, blob_name)
|
||||
# Create a new GCS client with the fresh session
|
||||
async_client = async_gcs_storage.Storage(session=session)
|
||||
|
||||
logger.info(
|
||||
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
|
||||
)
|
||||
|
||||
# Download content using the fresh client
|
||||
content = await async_client.download(bucket_name, blob_name)
|
||||
logger.info(
|
||||
f"[CloudStorage] GCS download successful - size: {len(content)} bytes"
|
||||
)
|
||||
|
||||
# Clean up
|
||||
await async_client.close()
|
||||
await session.close()
|
||||
|
||||
return content
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
# Always try to clean up
|
||||
if async_client is not None:
|
||||
try:
|
||||
await async_client.close()
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(
|
||||
f"[CloudStorage] Error closing GCS client: {cleanup_error}"
|
||||
)
|
||||
try:
|
||||
await session.close()
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"[CloudStorage] Error closing session: {cleanup_error}")
|
||||
|
||||
# Log the specific error for debugging
|
||||
logger.error(
|
||||
f"[CloudStorage] GCS download failed - error: {str(e)}, "
|
||||
@@ -289,6 +319,10 @@ class CloudStorageHandler:
|
||||
f"current_task: {current_task}, "
|
||||
f"bucket: {bucket_name}, blob: redacted for privacy"
|
||||
)
|
||||
|
||||
# Convert gcloud-aio exceptions to standard ones
|
||||
if "404" in str(e) or "Not Found" in str(e):
|
||||
raise FileNotFoundError(f"File not found: gcs://{path}")
|
||||
raise
|
||||
|
||||
def _validate_file_access(
|
||||
@@ -411,7 +445,8 @@ class CloudStorageHandler:
|
||||
graph_exec_id: str | None = None,
|
||||
) -> str:
|
||||
"""Generate signed URL for GCS with authorization."""
|
||||
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
|
||||
|
||||
# Parse bucket and blob name from path
|
||||
parts = path.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Invalid GCS path: {path}")
|
||||
@@ -421,11 +456,21 @@ class CloudStorageHandler:
|
||||
# Authorization check
|
||||
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
||||
|
||||
# Use sync client for signed URLs since gcloud-aio doesn't support them
|
||||
sync_client = self._get_sync_gcs_client()
|
||||
return await generate_signed_url(
|
||||
sync_client, bucket_name, blob_name, expiration_hours * 3600
|
||||
bucket = sync_client.bucket(bucket_name)
|
||||
blob = bucket.blob(blob_name)
|
||||
|
||||
# Generate signed URL asynchronously using sync client
|
||||
url = await asyncio.to_thread(
|
||||
blob.generate_signed_url,
|
||||
version="v4",
|
||||
expiration=datetime.now(timezone.utc) + timedelta(hours=expiration_hours),
|
||||
method="GET",
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
async def delete_expired_files(self, provider: str = "gcs") -> int:
|
||||
"""
|
||||
Delete files that have passed their expiration time.
|
||||
|
||||
@@ -5,28 +5,13 @@ import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from prisma.enums import WorkspaceFileSource
|
||||
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.request import Requests
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import MediaFileType
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
# Return format options for store_media_file
|
||||
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
||||
# - "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
|
||||
MediaReturnFormat = Literal[
|
||||
"for_local_processing", "for_external_api", "for_block_output"
|
||||
]
|
||||
|
||||
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
|
||||
|
||||
# Maximum filename length (conservative limit for most filesystems)
|
||||
@@ -82,56 +67,42 @@ def clean_exec_files(graph_exec_id: str, file: str = "") -> None:
|
||||
|
||||
|
||||
async def store_media_file(
|
||||
graph_exec_id: str,
|
||||
file: MediaFileType,
|
||||
execution_context: "ExecutionContext",
|
||||
*,
|
||||
return_format: MediaReturnFormat,
|
||||
user_id: str,
|
||||
return_content: bool = False,
|
||||
) -> MediaFileType:
|
||||
"""
|
||||
Safely handle 'file' (a data URI, a URL, a workspace:// reference, or a local path
|
||||
relative to {temp}/exec_file/{exec_id}), placing or verifying it under:
|
||||
Safely handle 'file' (a data URI, a URL, or a local path relative to {temp}/exec_file/{exec_id}),
|
||||
placing or verifying it under:
|
||||
{tempdir}/exec_file/{exec_id}/...
|
||||
|
||||
For each MediaFileType input:
|
||||
- Data URI: decode and store locally
|
||||
- URL: download and store locally
|
||||
- workspace:// reference: read from workspace, store locally
|
||||
- Local path: verify it exists in exec_file directory
|
||||
If 'return_content=True', return a data URI (data:<mime>;base64,<content>).
|
||||
Otherwise, returns the file media path relative to the exec_id folder.
|
||||
|
||||
Return format options:
|
||||
- "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||
- "for_external_api": Returns data URI (base64) - use when sending to external APIs
|
||||
- "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
|
||||
For each MediaFileType type:
|
||||
- Data URI:
|
||||
-> decode and store in a new random file in that folder
|
||||
- URL:
|
||||
-> download and store in that folder
|
||||
- Local path:
|
||||
-> interpret as relative to that folder; verify it exists
|
||||
(no copying, as it's presumably already there).
|
||||
We realpath-check so no symlink or '..' can escape the folder.
|
||||
|
||||
:param file: Data URI, URL, workspace://, or local (relative) path.
|
||||
:param execution_context: ExecutionContext with user_id, graph_exec_id, workspace_id.
|
||||
:param return_format: What to return: "for_local_processing", "for_external_api", or "for_block_output".
|
||||
:return: The requested result based on return_format.
|
||||
|
||||
:param graph_exec_id: The unique ID of the graph execution.
|
||||
:param file: Data URI, URL, or local (relative) path.
|
||||
:param return_content: If True, return a data URI of the file content.
|
||||
If False, return the *relative* path inside the exec_id folder.
|
||||
:return: The requested result: data URI or relative path of the media.
|
||||
"""
|
||||
# Extract values from execution_context
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
user_id = execution_context.user_id
|
||||
|
||||
if not graph_exec_id:
|
||||
raise ValueError("execution_context.graph_exec_id is required")
|
||||
if not user_id:
|
||||
raise ValueError("execution_context.user_id is required")
|
||||
|
||||
# Create workspace_manager if we have workspace_id (with session scoping)
|
||||
# Import here to avoid circular import (file.py → workspace.py → data → blocks → file.py)
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
workspace_manager: WorkspaceManager | None = None
|
||||
if execution_context.workspace_id:
|
||||
workspace_manager = WorkspaceManager(
|
||||
user_id, execution_context.workspace_id, execution_context.session_id
|
||||
)
|
||||
# Build base path
|
||||
base_path = Path(get_exec_file_path(graph_exec_id, ""))
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Security fix: Add disk space limits to prevent DoS
|
||||
MAX_FILE_SIZE_BYTES = Config().max_file_size_mb * 1024 * 1024
|
||||
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB per file
|
||||
MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory
|
||||
|
||||
# Check total disk usage in base_path
|
||||
@@ -171,57 +142,9 @@ async def store_media_file(
|
||||
"""
|
||||
return str(absolute_path.relative_to(base))
|
||||
|
||||
# Get cloud storage handler for checking cloud paths
|
||||
cloud_storage = await get_cloud_storage_handler()
|
||||
|
||||
# Track if the input came from workspace (don't re-save it)
|
||||
is_from_workspace = file.startswith("workspace://")
|
||||
|
||||
# Check if this is a workspace file reference
|
||||
if is_from_workspace:
|
||||
if workspace_manager is None:
|
||||
raise ValueError(
|
||||
"Workspace file reference requires workspace context. "
|
||||
"This file type is only available in CoPilot sessions."
|
||||
)
|
||||
|
||||
# Parse workspace reference
|
||||
# workspace://abc123 - by file ID
|
||||
# workspace:///path/to/file.txt - by virtual path
|
||||
file_ref = file[12:] # Remove "workspace://"
|
||||
|
||||
if file_ref.startswith("/"):
|
||||
# Path reference
|
||||
workspace_content = await workspace_manager.read_file(file_ref)
|
||||
file_info = await workspace_manager.get_file_info_by_path(file_ref)
|
||||
filename = sanitize_filename(
|
||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||
)
|
||||
else:
|
||||
# ID reference
|
||||
workspace_content = await workspace_manager.read_file_by_id(file_ref)
|
||||
file_info = await workspace_manager.get_file_info(file_ref)
|
||||
filename = sanitize_filename(
|
||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||
)
|
||||
|
||||
try:
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
except OSError as e:
|
||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||
|
||||
# Check file size limit
|
||||
if len(workspace_content) > MAX_FILE_SIZE_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large: {len(workspace_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||
)
|
||||
|
||||
# Virus scan the workspace content before writing locally
|
||||
await scan_content_safe(workspace_content, filename=filename)
|
||||
target_path.write_bytes(workspace_content)
|
||||
|
||||
# Check if this is a cloud storage path
|
||||
elif cloud_storage.is_cloud_path(file):
|
||||
cloud_storage = await get_cloud_storage_handler()
|
||||
if cloud_storage.is_cloud_path(file):
|
||||
# Download from cloud storage and store locally
|
||||
cloud_content = await cloud_storage.retrieve_file(
|
||||
file, user_id=user_id, graph_exec_id=graph_exec_id
|
||||
@@ -236,9 +159,9 @@ async def store_media_file(
|
||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||
|
||||
# Check file size limit
|
||||
if len(cloud_content) > MAX_FILE_SIZE_BYTES:
|
||||
if len(cloud_content) > MAX_FILE_SIZE:
|
||||
raise ValueError(
|
||||
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE} bytes"
|
||||
)
|
||||
|
||||
# Virus scan the cloud content before writing locally
|
||||
@@ -266,9 +189,9 @@ async def store_media_file(
|
||||
content = base64.b64decode(b64_content)
|
||||
|
||||
# Check file size limit
|
||||
if len(content) > MAX_FILE_SIZE_BYTES:
|
||||
if len(content) > MAX_FILE_SIZE:
|
||||
raise ValueError(
|
||||
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE} bytes"
|
||||
)
|
||||
|
||||
# Virus scan the base64 content before writing
|
||||
@@ -276,31 +199,23 @@ async def store_media_file(
|
||||
target_path.write_bytes(content)
|
||||
|
||||
elif file.startswith(("http://", "https://")):
|
||||
# URL - download first to get Content-Type header
|
||||
resp = await Requests().get(file)
|
||||
|
||||
# Check file size limit
|
||||
if len(resp.content) > MAX_FILE_SIZE_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||
)
|
||||
|
||||
# Extract filename from URL path
|
||||
# URL
|
||||
parsed_url = urlparse(file)
|
||||
filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}")
|
||||
|
||||
# If filename lacks extension, add one from Content-Type header
|
||||
if "." not in filename:
|
||||
content_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
if content_type:
|
||||
ext = _extension_from_mime(content_type)
|
||||
filename = f"{filename}{ext}"
|
||||
|
||||
try:
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
except OSError as e:
|
||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||
|
||||
# Download and save
|
||||
resp = await Requests().get(file)
|
||||
|
||||
# Check file size limit
|
||||
if len(resp.content) > MAX_FILE_SIZE:
|
||||
raise ValueError(
|
||||
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE} bytes"
|
||||
)
|
||||
|
||||
# Virus scan the downloaded content before writing
|
||||
await scan_content_safe(resp.content, filename=filename)
|
||||
target_path.write_bytes(resp.content)
|
||||
@@ -315,45 +230,11 @@ async def store_media_file(
|
||||
if not target_path.is_file():
|
||||
raise ValueError(f"Local file does not exist: {target_path}")
|
||||
|
||||
# Return based on requested format
|
||||
if return_format == "for_local_processing":
|
||||
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
||||
# Returns: relative path in exec_file directory (e.g., "image.png")
|
||||
return MediaFileType(_strip_base_prefix(target_path, base_path))
|
||||
|
||||
elif return_format == "for_external_api":
|
||||
# Use when sending content to external APIs that need base64
|
||||
# Returns: data URI (e.g., "data:image/png;base64,iVBORw0...")
|
||||
# Return result
|
||||
if return_content:
|
||||
return MediaFileType(_file_to_data_uri(target_path))
|
||||
|
||||
elif return_format == "for_block_output":
|
||||
# Use when returning output from a block to user/next block
|
||||
# Returns: workspace:// ref (CoPilot) or data URI (graph execution)
|
||||
if workspace_manager is None:
|
||||
# No workspace available (graph execution without CoPilot)
|
||||
# Fallback to data URI so the content can still be used/displayed
|
||||
return MediaFileType(_file_to_data_uri(target_path))
|
||||
|
||||
# Don't re-save if input was already from workspace
|
||||
if is_from_workspace:
|
||||
# Return original workspace reference
|
||||
return MediaFileType(file)
|
||||
|
||||
# Save new content to workspace
|
||||
content = target_path.read_bytes()
|
||||
filename = target_path.name
|
||||
|
||||
file_record = await workspace_manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
source=WorkspaceFileSource.COPILOT,
|
||||
source_session_id=execution_context.session_id,
|
||||
overwrite=True,
|
||||
)
|
||||
return MediaFileType(f"workspace://{file_record.id}")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid return_format: {return_format}")
|
||||
return MediaFileType(_strip_base_prefix(target_path, base_path))
|
||||
|
||||
|
||||
def get_dir_size(path: Path) -> int:
|
||||
|
||||
@@ -7,22 +7,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
def make_test_context(
|
||||
graph_exec_id: str = "test-exec-123",
|
||||
user_id: str = "test-user-123",
|
||||
) -> ExecutionContext:
|
||||
"""Helper to create test ExecutionContext."""
|
||||
return ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
|
||||
class TestFileCloudIntegration:
|
||||
"""Test cases for cloud storage integration in file utilities."""
|
||||
|
||||
@@ -82,9 +70,10 @@ class TestFileCloudIntegration:
|
||||
mock_path_class.side_effect = path_constructor
|
||||
|
||||
result = await store_media_file(
|
||||
file=MediaFileType(cloud_path),
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
graph_exec_id,
|
||||
MediaFileType(cloud_path),
|
||||
"test-user-123",
|
||||
return_content=False,
|
||||
)
|
||||
|
||||
# Verify cloud storage operations
|
||||
@@ -155,9 +144,10 @@ class TestFileCloudIntegration:
|
||||
mock_path_obj.name = "image.png"
|
||||
with patch("backend.util.file.Path", return_value=mock_path_obj):
|
||||
result = await store_media_file(
|
||||
file=MediaFileType(cloud_path),
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_external_api",
|
||||
graph_exec_id,
|
||||
MediaFileType(cloud_path),
|
||||
"test-user-123",
|
||||
return_content=True,
|
||||
)
|
||||
|
||||
# Verify result is a data URI
|
||||
@@ -208,9 +198,10 @@ class TestFileCloudIntegration:
|
||||
mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt")
|
||||
|
||||
await store_media_file(
|
||||
file=MediaFileType(data_uri),
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
graph_exec_id,
|
||||
MediaFileType(data_uri),
|
||||
"test-user-123",
|
||||
return_content=False,
|
||||
)
|
||||
|
||||
# Verify cloud handler was checked but not used for retrieval
|
||||
@@ -243,7 +234,5 @@ class TestFileCloudIntegration:
|
||||
FileNotFoundError, match="File not found in cloud storage"
|
||||
):
|
||||
await store_media_file(
|
||||
file=MediaFileType(cloud_path),
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
graph_exec_id, MediaFileType(cloud_path), "test-user-123"
|
||||
)
|
||||
|
||||
@@ -1,160 +0,0 @@
|
||||
"""
|
||||
Shared GCS utilities for workspace and cloud storage backends.
|
||||
|
||||
This module provides common functionality for working with Google Cloud Storage,
|
||||
including path parsing, client management, and signed URL generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
from gcloud.aio import storage as async_gcs_storage
|
||||
from google.cloud import storage as gcs_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_gcs_path(path: str) -> tuple[str, str]:
|
||||
"""
|
||||
Parse a GCS path in the format 'gcs://bucket/blob' to (bucket, blob).
|
||||
|
||||
Args:
|
||||
path: GCS path string (e.g., "gcs://my-bucket/path/to/file")
|
||||
|
||||
Returns:
|
||||
Tuple of (bucket_name, blob_name)
|
||||
|
||||
Raises:
|
||||
ValueError: If the path format is invalid
|
||||
"""
|
||||
if not path.startswith("gcs://"):
|
||||
raise ValueError(f"Invalid GCS path: {path}")
|
||||
|
||||
path_without_prefix = path[6:] # Remove "gcs://"
|
||||
parts = path_without_prefix.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Invalid GCS path format: {path}")
|
||||
|
||||
return parts[0], parts[1]
|
||||
|
||||
|
||||
class GCSClientManager:
|
||||
"""
|
||||
Manages async and sync GCS clients with lazy initialization.
|
||||
|
||||
This class provides a unified way to manage GCS client lifecycle,
|
||||
supporting both async operations (uploads, downloads) and sync
|
||||
operations that require service account credentials (signed URLs).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._async_client: Optional[async_gcs_storage.Storage] = None
|
||||
self._sync_client: Optional[gcs_storage.Client] = None
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def get_async_client(self) -> async_gcs_storage.Storage:
|
||||
"""
|
||||
Get or create async GCS client.
|
||||
|
||||
Returns:
|
||||
Async GCS storage client
|
||||
"""
|
||||
if self._async_client is None:
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(limit=100, force_close=False)
|
||||
)
|
||||
self._async_client = async_gcs_storage.Storage(session=self._session)
|
||||
return self._async_client
|
||||
|
||||
def get_sync_client(self) -> gcs_storage.Client:
|
||||
"""
|
||||
Get or create sync GCS client (used for signed URLs).
|
||||
|
||||
Returns:
|
||||
Sync GCS storage client
|
||||
"""
|
||||
if self._sync_client is None:
|
||||
self._sync_client = gcs_storage.Client()
|
||||
return self._sync_client
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close all client connections."""
|
||||
if self._async_client is not None:
|
||||
try:
|
||||
await self._async_client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing GCS client: {e}")
|
||||
self._async_client = None
|
||||
|
||||
if self._session is not None:
|
||||
try:
|
||||
await self._session.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing session: {e}")
|
||||
self._session = None
|
||||
|
||||
|
||||
async def download_with_fresh_session(bucket: str, blob: str) -> bytes:
|
||||
"""
|
||||
Download file content using a fresh session.
|
||||
|
||||
This approach avoids event loop issues that can occur when reusing
|
||||
sessions across different async contexts (e.g., in executors).
|
||||
|
||||
Args:
|
||||
bucket: GCS bucket name
|
||||
blob: Blob path within the bucket
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file doesn't exist
|
||||
"""
|
||||
session = aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(limit=10, force_close=True)
|
||||
)
|
||||
try:
|
||||
client = async_gcs_storage.Storage(session=session)
|
||||
content = await client.download(bucket, blob)
|
||||
await client.close()
|
||||
return content
|
||||
except Exception as e:
|
||||
if "404" in str(e) or "Not Found" in str(e):
|
||||
raise FileNotFoundError(f"File not found: gcs://{bucket}/{blob}")
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def generate_signed_url(
|
||||
sync_client: gcs_storage.Client,
|
||||
bucket_name: str,
|
||||
blob_name: str,
|
||||
expires_in: int,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a signed URL for temporary access to a GCS file.
|
||||
|
||||
Uses asyncio.to_thread() to run the sync operation without blocking.
|
||||
|
||||
Args:
|
||||
sync_client: Sync GCS client with service account credentials
|
||||
bucket_name: GCS bucket name
|
||||
blob_name: Blob path within the bucket
|
||||
expires_in: URL expiration time in seconds
|
||||
|
||||
Returns:
|
||||
Signed URL string
|
||||
"""
|
||||
bucket = sync_client.bucket(bucket_name)
|
||||
blob = bucket.blob(blob_name)
|
||||
return await asyncio.to_thread(
|
||||
blob.generate_signed_url,
|
||||
version="v4",
|
||||
expiration=datetime.now(timezone.utc) + timedelta(seconds=expires_in),
|
||||
method="GET",
|
||||
)
|
||||
@@ -263,12 +263,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The name of the Google Cloud Storage bucket for media files",
|
||||
)
|
||||
|
||||
workspace_storage_dir: str = Field(
|
||||
default="",
|
||||
description="Local directory for workspace file storage when GCS is not configured. "
|
||||
"If empty, defaults to {app_data}/workspaces. Used for self-hosted deployments.",
|
||||
)
|
||||
|
||||
reddit_user_agent: str = Field(
|
||||
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
|
||||
description="The user agent for the Reddit API",
|
||||
@@ -365,8 +359,8 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The port for the Agent Generator service",
|
||||
)
|
||||
agentgenerator_timeout: int = Field(
|
||||
default=600,
|
||||
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
|
||||
default=120,
|
||||
description="The timeout in seconds for Agent Generator service requests",
|
||||
)
|
||||
|
||||
enable_example_blocks: bool = Field(
|
||||
@@ -395,13 +389,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Maximum file size in MB for file uploads (1-1024 MB)",
|
||||
)
|
||||
|
||||
max_file_size_mb: int = Field(
|
||||
default=100,
|
||||
ge=1,
|
||||
le=1024,
|
||||
description="Maximum file size in MB for workspace files (1-1024 MB)",
|
||||
)
|
||||
|
||||
# AutoMod configuration
|
||||
automod_enabled: bool = Field(
|
||||
default=False,
|
||||
|
||||
@@ -140,29 +140,14 @@ async def execute_block_test(block: Block):
|
||||
setattr(block, mock_name, mock_obj)
|
||||
|
||||
# Populate credentials argument(s)
|
||||
# Generate IDs for execution context
|
||||
graph_id = str(uuid.uuid4())
|
||||
node_id = str(uuid.uuid4())
|
||||
graph_exec_id = str(uuid.uuid4())
|
||||
node_exec_id = str(uuid.uuid4())
|
||||
user_id = str(uuid.uuid4())
|
||||
graph_version = 1 # Default version for tests
|
||||
|
||||
extra_exec_kwargs: dict = {
|
||||
"graph_id": graph_id,
|
||||
"node_id": node_id,
|
||||
"graph_exec_id": graph_exec_id,
|
||||
"node_exec_id": node_exec_id,
|
||||
"user_id": user_id,
|
||||
"graph_version": graph_version,
|
||||
"execution_context": ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_version=graph_version,
|
||||
node_id=node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
),
|
||||
"graph_id": str(uuid.uuid4()),
|
||||
"node_id": str(uuid.uuid4()),
|
||||
"graph_exec_id": str(uuid.uuid4()),
|
||||
"node_exec_id": str(uuid.uuid4()),
|
||||
"user_id": str(uuid.uuid4()),
|
||||
"graph_version": 1, # Default version for tests
|
||||
"execution_context": ExecutionContext(),
|
||||
}
|
||||
input_model = cast(type[BlockSchema], block.input_schema)
|
||||
|
||||
|
||||
@@ -1,432 +0,0 @@
|
||||
"""
|
||||
WorkspaceManager for managing user workspace file operations.
|
||||
|
||||
This module provides a high-level interface for workspace file operations,
|
||||
combining the storage backend and database layer.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import mimetypes
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from prisma.enums import WorkspaceFileSource
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import UserWorkspaceFile
|
||||
|
||||
from backend.data.workspace import (
|
||||
count_workspace_files,
|
||||
create_workspace_file,
|
||||
get_workspace_file,
|
||||
get_workspace_file_by_path,
|
||||
list_workspace_files,
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkspaceManager:
|
||||
"""
|
||||
Manages workspace file operations.
|
||||
|
||||
Combines storage backend operations with database record management.
|
||||
Supports session-scoped file segmentation where files are stored in
|
||||
session-specific virtual paths: /sessions/{session_id}/{filename}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, user_id: str, workspace_id: str, session_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize WorkspaceManager.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
workspace_id: The workspace ID
|
||||
session_id: Optional session ID for session-scoped file access
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.workspace_id = workspace_id
|
||||
self.session_id = session_id
|
||||
# Session path prefix for file isolation
|
||||
self.session_path = f"/sessions/{session_id}" if session_id else ""
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve a path, defaulting to session folder if session_id is set.
|
||||
|
||||
Cross-session access is allowed by explicitly using /sessions/other-session-id/...
|
||||
|
||||
Args:
|
||||
path: Virtual path (e.g., "/file.txt" or "/sessions/abc123/file.txt")
|
||||
|
||||
Returns:
|
||||
Resolved path with session prefix if applicable
|
||||
"""
|
||||
# If path explicitly references a session folder, use it as-is
|
||||
if path.startswith("/sessions/"):
|
||||
return path
|
||||
|
||||
# If we have a session context, prepend session path
|
||||
if self.session_path:
|
||||
# Normalize the path
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
return f"{self.session_path}{path}"
|
||||
|
||||
# No session context, use path as-is
|
||||
return path if path.startswith("/") else f"/{path}"
|
||||
|
||||
def _get_effective_path(
|
||||
self, path: Optional[str], include_all_sessions: bool
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get effective path for list/count operations based on session context.
|
||||
|
||||
Args:
|
||||
path: Optional path prefix to filter
|
||||
include_all_sessions: If True, don't apply session scoping
|
||||
|
||||
Returns:
|
||||
Effective path prefix for database query
|
||||
"""
|
||||
if include_all_sessions:
|
||||
# Normalize path to ensure leading slash (stored paths are normalized)
|
||||
if path is not None and not path.startswith("/"):
|
||||
return f"/{path}"
|
||||
return path
|
||||
elif path is not None:
|
||||
# Resolve the provided path with session scoping
|
||||
return self._resolve_path(path)
|
||||
elif self.session_path:
|
||||
# Default to session folder with trailing slash to prevent prefix collisions
|
||||
# e.g., "/sessions/abc" should not match "/sessions/abc123"
|
||||
return self.session_path.rstrip("/") + "/"
|
||||
else:
|
||||
# No session context, use path as-is
|
||||
return path
|
||||
|
||||
async def read_file(self, path: str) -> bytes:
|
||||
"""
|
||||
Read file from workspace by virtual path.
|
||||
|
||||
When session_id is set, paths are resolved relative to the session folder
|
||||
unless they explicitly reference /sessions/...
|
||||
|
||||
Args:
|
||||
path: Virtual path (e.g., "/documents/report.pdf")
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
resolved_path = self._resolve_path(path)
|
||||
file = await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||
if file is None:
|
||||
raise FileNotFoundError(f"File not found at path: {resolved_path}")
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
return await storage.retrieve(file.storagePath)
|
||||
|
||||
async def read_file_by_id(self, file_id: str) -> bytes:
|
||||
"""
|
||||
Read file from workspace by file ID.
|
||||
|
||||
Args:
|
||||
file_id: The file's ID
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
file = await get_workspace_file(file_id, self.workspace_id)
|
||||
if file is None:
|
||||
raise FileNotFoundError(f"File not found: {file_id}")
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
return await storage.retrieve(file.storagePath)
|
||||
|
||||
async def write_file(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
path: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
source: WorkspaceFileSource = WorkspaceFileSource.UPLOAD,
|
||||
source_exec_id: Optional[str] = None,
|
||||
source_session_id: Optional[str] = None,
|
||||
overwrite: bool = False,
|
||||
) -> UserWorkspaceFile:
|
||||
"""
|
||||
Write file to workspace.
|
||||
|
||||
When session_id is set, files are written to /sessions/{session_id}/...
|
||||
by default. Use explicit /sessions/... paths for cross-session access.
|
||||
|
||||
Args:
|
||||
content: File content as bytes
|
||||
filename: Filename for the file
|
||||
path: Virtual path (defaults to "/{filename}", session-scoped if session_id set)
|
||||
mime_type: MIME type (auto-detected if not provided)
|
||||
source: How the file was created
|
||||
source_exec_id: Graph execution ID if from execution
|
||||
source_session_id: Chat session ID if from CoPilot
|
||||
overwrite: Whether to overwrite existing file at path
|
||||
|
||||
Returns:
|
||||
Created UserWorkspaceFile instance
|
||||
|
||||
Raises:
|
||||
ValueError: If file exceeds size limit or path already exists
|
||||
"""
|
||||
# Enforce file size limit
|
||||
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_file_size:
|
||||
raise ValueError(
|
||||
f"File too large: {len(content)} bytes exceeds "
|
||||
f"{Config().max_file_size_mb}MB limit"
|
||||
)
|
||||
|
||||
# Determine path with session scoping
|
||||
if path is None:
|
||||
path = f"/{filename}"
|
||||
elif not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
|
||||
# Resolve path with session prefix
|
||||
path = self._resolve_path(path)
|
||||
|
||||
# Check if file exists at path
|
||||
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
||||
if existing is not None:
|
||||
if overwrite:
|
||||
# Delete existing file first
|
||||
await self.delete_file(existing.id)
|
||||
else:
|
||||
raise ValueError(f"File already exists at path: {path}")
|
||||
|
||||
# Auto-detect MIME type if not provided
|
||||
if mime_type is None:
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
mime_type = mime_type or "application/octet-stream"
|
||||
|
||||
# Compute checksum
|
||||
checksum = compute_file_checksum(content)
|
||||
|
||||
# Generate unique file ID for storage
|
||||
file_id = str(uuid.uuid4())
|
||||
|
||||
# Store file in storage backend
|
||||
storage = await get_workspace_storage()
|
||||
storage_path = await storage.store(
|
||||
workspace_id=self.workspace_id,
|
||||
file_id=file_id,
|
||||
filename=filename,
|
||||
content=content,
|
||||
)
|
||||
|
||||
# Create database record - handle race condition where another request
|
||||
# created a file at the same path between our check and create
|
||||
try:
|
||||
file = await create_workspace_file(
|
||||
workspace_id=self.workspace_id,
|
||||
file_id=file_id,
|
||||
name=filename,
|
||||
path=path,
|
||||
storage_path=storage_path,
|
||||
mime_type=mime_type,
|
||||
size_bytes=len(content),
|
||||
checksum=checksum,
|
||||
source=source,
|
||||
source_exec_id=source_exec_id,
|
||||
source_session_id=source_session_id,
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# Race condition: another request created a file at this path
|
||||
if overwrite:
|
||||
# Re-fetch and delete the conflicting file, then retry
|
||||
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
||||
if existing:
|
||||
await self.delete_file(existing.id)
|
||||
# Retry the create - if this also fails, clean up storage file
|
||||
try:
|
||||
file = await create_workspace_file(
|
||||
workspace_id=self.workspace_id,
|
||||
file_id=file_id,
|
||||
name=filename,
|
||||
path=path,
|
||||
storage_path=storage_path,
|
||||
mime_type=mime_type,
|
||||
size_bytes=len(content),
|
||||
checksum=checksum,
|
||||
source=source,
|
||||
source_exec_id=source_exec_id,
|
||||
source_session_id=source_session_id,
|
||||
)
|
||||
except Exception:
|
||||
# Clean up orphaned storage file on retry failure
|
||||
try:
|
||||
await storage.delete(storage_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||
raise
|
||||
else:
|
||||
# Clean up the orphaned storage file before raising
|
||||
try:
|
||||
await storage.delete(storage_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||
raise ValueError(f"File already exists at path: {path}")
|
||||
except Exception:
|
||||
# Any other database error (connection, validation, etc.) - clean up storage
|
||||
try:
|
||||
await storage.delete(storage_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Wrote file {file.id} ({filename}) to workspace {self.workspace_id} "
|
||||
f"at path {path}, size={len(content)} bytes"
|
||||
)
|
||||
|
||||
return file
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
path: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: int = 0,
|
||||
include_all_sessions: bool = False,
|
||||
) -> list[UserWorkspaceFile]:
|
||||
"""
|
||||
List files in workspace.
|
||||
|
||||
When session_id is set and include_all_sessions is False (default),
|
||||
only files in the current session's folder are listed.
|
||||
|
||||
Args:
|
||||
path: Optional path prefix to filter (e.g., "/documents/")
|
||||
limit: Maximum number of files to return
|
||||
offset: Number of files to skip
|
||||
include_all_sessions: If True, list files from all sessions.
|
||||
If False (default), only list current session's files.
|
||||
|
||||
Returns:
|
||||
List of UserWorkspaceFile instances
|
||||
"""
|
||||
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||
|
||||
return await list_workspace_files(
|
||||
workspace_id=self.workspace_id,
|
||||
path_prefix=effective_path,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
async def delete_file(self, file_id: str) -> bool:
|
||||
"""
|
||||
Delete a file (soft-delete).
|
||||
|
||||
Args:
|
||||
file_id: The file's ID
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
file = await get_workspace_file(file_id, self.workspace_id)
|
||||
if file is None:
|
||||
return False
|
||||
|
||||
# Delete from storage
|
||||
storage = await get_workspace_storage()
|
||||
try:
|
||||
await storage.delete(file.storagePath)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete file from storage: {e}")
|
||||
# Continue with database soft-delete even if storage delete fails
|
||||
|
||||
# Soft-delete database record
|
||||
result = await soft_delete_workspace_file(file_id, self.workspace_id)
|
||||
return result is not None
|
||||
|
||||
async def get_download_url(self, file_id: str, expires_in: int = 3600) -> str:
|
||||
"""
|
||||
Get download URL for a file.
|
||||
|
||||
Args:
|
||||
file_id: The file's ID
|
||||
expires_in: URL expiration in seconds (default 1 hour)
|
||||
|
||||
Returns:
|
||||
Download URL (signed URL for GCS, API endpoint for local)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
file = await get_workspace_file(file_id, self.workspace_id)
|
||||
if file is None:
|
||||
raise FileNotFoundError(f"File not found: {file_id}")
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
return await storage.get_download_url(file.storagePath, expires_in)
|
||||
|
||||
async def get_file_info(self, file_id: str) -> Optional[UserWorkspaceFile]:
|
||||
"""
|
||||
Get file metadata.
|
||||
|
||||
Args:
|
||||
file_id: The file's ID
|
||||
|
||||
Returns:
|
||||
UserWorkspaceFile instance or None
|
||||
"""
|
||||
return await get_workspace_file(file_id, self.workspace_id)
|
||||
|
||||
async def get_file_info_by_path(self, path: str) -> Optional[UserWorkspaceFile]:
|
||||
"""
|
||||
Get file metadata by path.
|
||||
|
||||
When session_id is set, paths are resolved relative to the session folder
|
||||
unless they explicitly reference /sessions/...
|
||||
|
||||
Args:
|
||||
path: Virtual path
|
||||
|
||||
Returns:
|
||||
UserWorkspaceFile instance or None
|
||||
"""
|
||||
resolved_path = self._resolve_path(path)
|
||||
return await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||
|
||||
async def get_file_count(
|
||||
self,
|
||||
path: Optional[str] = None,
|
||||
include_all_sessions: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of files in workspace.
|
||||
|
||||
When session_id is set and include_all_sessions is False (default),
|
||||
only counts files in the current session's folder.
|
||||
|
||||
Args:
|
||||
path: Optional path prefix to filter (e.g., "/documents/")
|
||||
include_all_sessions: If True, count all files in workspace.
|
||||
If False (default), only count current session's files.
|
||||
|
||||
Returns:
|
||||
Number of files
|
||||
"""
|
||||
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||
|
||||
return await count_workspace_files(
|
||||
self.workspace_id, path_prefix=effective_path
|
||||
)
|
||||
@@ -1,398 +0,0 @@
|
||||
"""
|
||||
Workspace storage backend abstraction for supporting both cloud and local deployments.
|
||||
|
||||
This module provides a unified interface for storing workspace files, with implementations
|
||||
for Google Cloud Storage (cloud deployments) and local filesystem (self-hosted deployments).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
from gcloud.aio import storage as async_gcs_storage
|
||||
from google.cloud import storage as gcs_storage
|
||||
|
||||
from backend.util.data import get_data_path
|
||||
from backend.util.gcs_utils import (
|
||||
download_with_fresh_session,
|
||||
generate_signed_url,
|
||||
parse_gcs_path,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkspaceStorageBackend(ABC):
|
||||
"""Abstract interface for workspace file storage."""
|
||||
|
||||
@abstractmethod
|
||||
async def store(
|
||||
self,
|
||||
workspace_id: str,
|
||||
file_id: str,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""
|
||||
Store file content, return storage path.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID
|
||||
file_id: Unique file ID for storage
|
||||
filename: Original filename
|
||||
content: File content as bytes
|
||||
|
||||
Returns:
|
||||
Storage path string (cloud path or local path)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def retrieve(self, storage_path: str) -> bytes:
|
||||
"""
|
||||
Retrieve file content from storage.
|
||||
|
||||
Args:
|
||||
storage_path: The storage path returned from store()
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, storage_path: str) -> None:
|
||||
"""
|
||||
Delete file from storage.
|
||||
|
||||
Args:
|
||||
storage_path: The storage path to delete
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
||||
"""
|
||||
Get URL for downloading the file.
|
||||
|
||||
Args:
|
||||
storage_path: The storage path
|
||||
expires_in: URL expiration time in seconds (default 1 hour)
|
||||
|
||||
Returns:
|
||||
Download URL (signed URL for GCS, direct API path for local)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class GCSWorkspaceStorage(WorkspaceStorageBackend):
|
||||
"""Google Cloud Storage implementation for workspace storage."""
|
||||
|
||||
def __init__(self, bucket_name: str):
|
||||
self.bucket_name = bucket_name
|
||||
self._async_client: Optional[async_gcs_storage.Storage] = None
|
||||
self._sync_client: Optional[gcs_storage.Client] = None
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def _get_async_client(self) -> async_gcs_storage.Storage:
|
||||
"""Get or create async GCS client."""
|
||||
if self._async_client is None:
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(limit=100, force_close=False)
|
||||
)
|
||||
self._async_client = async_gcs_storage.Storage(session=self._session)
|
||||
return self._async_client
|
||||
|
||||
def _get_sync_client(self) -> gcs_storage.Client:
|
||||
"""Get or create sync GCS client (for signed URLs)."""
|
||||
if self._sync_client is None:
|
||||
self._sync_client = gcs_storage.Client()
|
||||
return self._sync_client
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close all client connections."""
|
||||
if self._async_client is not None:
|
||||
try:
|
||||
await self._async_client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing GCS client: {e}")
|
||||
self._async_client = None
|
||||
|
||||
if self._session is not None:
|
||||
try:
|
||||
await self._session.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing session: {e}")
|
||||
self._session = None
|
||||
|
||||
def _build_blob_name(self, workspace_id: str, file_id: str, filename: str) -> str:
|
||||
"""Build the blob path for workspace files."""
|
||||
return f"workspaces/{workspace_id}/{file_id}/{filename}"
|
||||
|
||||
async def store(
|
||||
self,
|
||||
workspace_id: str,
|
||||
file_id: str,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""Store file in GCS."""
|
||||
client = await self._get_async_client()
|
||||
blob_name = self._build_blob_name(workspace_id, file_id, filename)
|
||||
|
||||
# Upload with metadata
|
||||
upload_time = datetime.now(timezone.utc)
|
||||
await client.upload(
|
||||
self.bucket_name,
|
||||
blob_name,
|
||||
content,
|
||||
metadata={
|
||||
"uploaded_at": upload_time.isoformat(),
|
||||
"workspace_id": workspace_id,
|
||||
"file_id": file_id,
|
||||
},
|
||||
)
|
||||
|
||||
return f"gcs://{self.bucket_name}/{blob_name}"
|
||||
|
||||
async def retrieve(self, storage_path: str) -> bytes:
|
||||
"""Retrieve file from GCS."""
|
||||
bucket_name, blob_name = parse_gcs_path(storage_path)
|
||||
return await download_with_fresh_session(bucket_name, blob_name)
|
||||
|
||||
async def delete(self, storage_path: str) -> None:
|
||||
"""Delete file from GCS."""
|
||||
bucket_name, blob_name = parse_gcs_path(storage_path)
|
||||
client = await self._get_async_client()
|
||||
|
||||
try:
|
||||
await client.delete(bucket_name, blob_name)
|
||||
except Exception as e:
|
||||
if "404" not in str(e) and "Not Found" not in str(e):
|
||||
raise
|
||||
# File already deleted, that's fine
|
||||
|
||||
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
||||
"""
|
||||
Generate download URL for GCS file.
|
||||
|
||||
Attempts to generate a signed URL if running with service account credentials.
|
||||
Falls back to an API proxy endpoint if signed URL generation fails
|
||||
(e.g., when running locally with user OAuth credentials).
|
||||
"""
|
||||
bucket_name, blob_name = parse_gcs_path(storage_path)
|
||||
|
||||
# Extract file_id from blob_name for fallback: workspaces/{workspace_id}/{file_id}/{filename}
|
||||
blob_parts = blob_name.split("/")
|
||||
file_id = blob_parts[2] if len(blob_parts) >= 3 else None
|
||||
|
||||
# Try to generate signed URL (requires service account credentials)
|
||||
try:
|
||||
sync_client = self._get_sync_client()
|
||||
return await generate_signed_url(
|
||||
sync_client, bucket_name, blob_name, expires_in
|
||||
)
|
||||
except AttributeError as e:
|
||||
# Signed URL generation requires service account with private key.
|
||||
# When running with user OAuth credentials, fall back to API proxy.
|
||||
if "private key" in str(e) and file_id:
|
||||
logger.debug(
|
||||
"Cannot generate signed URL (no service account credentials), "
|
||||
"falling back to API proxy endpoint"
|
||||
)
|
||||
return f"/api/workspace/files/{file_id}/download"
|
||||
raise
|
||||
|
||||
|
||||
class LocalWorkspaceStorage(WorkspaceStorageBackend):
|
||||
"""Local filesystem implementation for workspace storage (self-hosted deployments)."""
|
||||
|
||||
def __init__(self, base_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize local storage backend.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for workspace storage.
|
||||
If None, defaults to {app_data}/workspaces
|
||||
"""
|
||||
if base_dir:
|
||||
self.base_dir = Path(base_dir)
|
||||
else:
|
||||
self.base_dir = Path(get_data_path()) / "workspaces"
|
||||
|
||||
# Ensure base directory exists
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _build_file_path(self, workspace_id: str, file_id: str, filename: str) -> Path:
|
||||
"""Build the local file path with path traversal protection."""
|
||||
# Import here to avoid circular import
|
||||
# (file.py imports workspace.py which imports workspace_storage.py)
|
||||
from backend.util.file import sanitize_filename
|
||||
|
||||
# Sanitize filename to prevent path traversal (removes / and \ among others)
|
||||
safe_filename = sanitize_filename(filename)
|
||||
file_path = (self.base_dir / workspace_id / file_id / safe_filename).resolve()
|
||||
|
||||
# Verify the resolved path is still under base_dir
|
||||
if not file_path.is_relative_to(self.base_dir.resolve()):
|
||||
raise ValueError("Invalid filename: path traversal detected")
|
||||
|
||||
return file_path
|
||||
|
||||
def _parse_storage_path(self, storage_path: str) -> Path:
|
||||
"""Parse local storage path to filesystem path."""
|
||||
if storage_path.startswith("local://"):
|
||||
relative_path = storage_path[8:] # Remove "local://"
|
||||
else:
|
||||
relative_path = storage_path
|
||||
|
||||
full_path = (self.base_dir / relative_path).resolve()
|
||||
|
||||
# Security check: ensure path is under base_dir
|
||||
# Use is_relative_to() for robust path containment check
|
||||
# (handles case-insensitive filesystems and edge cases)
|
||||
if not full_path.is_relative_to(self.base_dir.resolve()):
|
||||
raise ValueError("Invalid storage path: path traversal detected")
|
||||
|
||||
return full_path
|
||||
|
||||
async def store(
|
||||
self,
|
||||
workspace_id: str,
|
||||
file_id: str,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""Store file locally."""
|
||||
file_path = self._build_file_path(workspace_id, file_id, filename)
|
||||
|
||||
# Create parent directories
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write file asynchronously
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(content)
|
||||
|
||||
# Return relative path as storage path
|
||||
relative_path = file_path.relative_to(self.base_dir)
|
||||
return f"local://{relative_path}"
|
||||
|
||||
async def retrieve(self, storage_path: str) -> bytes:
|
||||
"""Retrieve file from local storage."""
|
||||
file_path = self._parse_storage_path(storage_path)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {storage_path}")
|
||||
|
||||
async with aiofiles.open(file_path, "rb") as f:
|
||||
return await f.read()
|
||||
|
||||
async def delete(self, storage_path: str) -> None:
|
||||
"""Delete file from local storage."""
|
||||
file_path = self._parse_storage_path(storage_path)
|
||||
|
||||
if file_path.exists():
|
||||
# Remove file
|
||||
file_path.unlink()
|
||||
|
||||
# Clean up empty parent directories
|
||||
parent = file_path.parent
|
||||
while parent != self.base_dir:
|
||||
try:
|
||||
if parent.exists() and not any(parent.iterdir()):
|
||||
parent.rmdir()
|
||||
else:
|
||||
break
|
||||
except OSError:
|
||||
break
|
||||
parent = parent.parent
|
||||
|
||||
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
||||
"""
|
||||
Get download URL for local file.
|
||||
|
||||
For local storage, this returns an API endpoint path.
|
||||
The actual serving is handled by the API layer.
|
||||
"""
|
||||
# Parse the storage path to get the components
|
||||
if storage_path.startswith("local://"):
|
||||
relative_path = storage_path[8:]
|
||||
else:
|
||||
relative_path = storage_path
|
||||
|
||||
# Return the API endpoint for downloading
|
||||
# The file_id is extracted from the path: {workspace_id}/{file_id}/{filename}
|
||||
parts = relative_path.split("/")
|
||||
if len(parts) >= 2:
|
||||
file_id = parts[1] # Second component is file_id
|
||||
return f"/api/workspace/files/{file_id}/download"
|
||||
else:
|
||||
raise ValueError(f"Invalid storage path format: {storage_path}")
|
||||
|
||||
|
||||
# Global storage backend instance
|
||||
_workspace_storage: Optional[WorkspaceStorageBackend] = None
|
||||
_storage_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_workspace_storage() -> WorkspaceStorageBackend:
|
||||
"""
|
||||
Get the workspace storage backend instance.
|
||||
|
||||
Uses GCS if media_gcs_bucket_name is configured, otherwise uses local storage.
|
||||
"""
|
||||
global _workspace_storage
|
||||
|
||||
if _workspace_storage is None:
|
||||
async with _storage_lock:
|
||||
if _workspace_storage is None:
|
||||
config = Config()
|
||||
|
||||
if config.media_gcs_bucket_name:
|
||||
logger.info(
|
||||
f"Using GCS workspace storage: {config.media_gcs_bucket_name}"
|
||||
)
|
||||
_workspace_storage = GCSWorkspaceStorage(
|
||||
config.media_gcs_bucket_name
|
||||
)
|
||||
else:
|
||||
storage_dir = (
|
||||
config.workspace_storage_dir
|
||||
if config.workspace_storage_dir
|
||||
else None
|
||||
)
|
||||
logger.info(
|
||||
f"Using local workspace storage: {storage_dir or 'default'}"
|
||||
)
|
||||
_workspace_storage = LocalWorkspaceStorage(storage_dir)
|
||||
|
||||
return _workspace_storage
|
||||
|
||||
|
||||
async def shutdown_workspace_storage() -> None:
|
||||
"""
|
||||
Properly shutdown the global workspace storage backend.
|
||||
|
||||
Closes aiohttp sessions and other resources for GCS backend.
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
global _workspace_storage
|
||||
|
||||
if _workspace_storage is not None:
|
||||
async with _storage_lock:
|
||||
if _workspace_storage is not None:
|
||||
if isinstance(_workspace_storage, GCSWorkspaceStorage):
|
||||
await _workspace_storage.close()
|
||||
_workspace_storage = None
|
||||
|
||||
|
||||
def compute_file_checksum(content: bytes) -> str:
|
||||
"""Compute SHA256 checksum of file content."""
|
||||
return hashlib.sha256(content).hexdigest()
|
||||
@@ -1,2 +0,0 @@
|
||||
-- AlterEnum
|
||||
ALTER TYPE "OnboardingStep" ADD VALUE 'VISIT_COPILOT';
|
||||
@@ -1,52 +0,0 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "WorkspaceFileSource" AS ENUM ('UPLOAD', 'EXECUTION', 'COPILOT', 'IMPORT');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "UserWorkspace" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "UserWorkspace_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "UserWorkspaceFile" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"workspaceId" TEXT NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"path" TEXT NOT NULL,
|
||||
"storagePath" TEXT NOT NULL,
|
||||
"mimeType" TEXT NOT NULL,
|
||||
"sizeBytes" BIGINT NOT NULL,
|
||||
"checksum" TEXT,
|
||||
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
|
||||
"deletedAt" TIMESTAMP(3),
|
||||
"source" "WorkspaceFileSource" NOT NULL DEFAULT 'UPLOAD',
|
||||
"sourceExecId" TEXT,
|
||||
"sourceSessionId" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "UserWorkspaceFile_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "UserWorkspace_userId_key" ON "UserWorkspace"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UserWorkspace_userId_idx" ON "UserWorkspace"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UserWorkspaceFile_workspaceId_isDeleted_idx" ON "UserWorkspaceFile"("workspaceId", "isDeleted");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "UserWorkspaceFile_workspaceId_path_key" ON "UserWorkspaceFile"("workspaceId", "path");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "UserWorkspace" ADD CONSTRAINT "UserWorkspace_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "UserWorkspaceFile" ADD CONSTRAINT "UserWorkspaceFile_workspaceId_fkey" FOREIGN KEY ("workspaceId") REFERENCES "UserWorkspace"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -63,7 +63,6 @@ model User {
|
||||
IntegrationWebhooks IntegrationWebhook[]
|
||||
NotificationBatches UserNotificationBatch[]
|
||||
PendingHumanReviews PendingHumanReview[]
|
||||
Workspace UserWorkspace?
|
||||
|
||||
// OAuth Provider relations
|
||||
OAuthApplications OAuthApplication[]
|
||||
@@ -82,7 +81,6 @@ enum OnboardingStep {
|
||||
AGENT_INPUT
|
||||
CONGRATS
|
||||
// First Wins
|
||||
VISIT_COPILOT
|
||||
GET_RESULTS
|
||||
MARKETPLACE_VISIT
|
||||
MARKETPLACE_ADD_AGENT
|
||||
@@ -138,66 +136,6 @@ model CoPilotUnderstanding {
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
//////////////// USER WORKSPACE TABLES /////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// User's persistent file storage workspace
|
||||
model UserWorkspace {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
userId String @unique
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
Files UserWorkspaceFile[]
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
// Source of workspace file creation
|
||||
enum WorkspaceFileSource {
|
||||
UPLOAD // Direct user upload
|
||||
EXECUTION // Created by graph execution
|
||||
COPILOT // Created by CoPilot session
|
||||
IMPORT // Imported from external source
|
||||
}
|
||||
|
||||
// Individual files in a user's workspace
|
||||
model UserWorkspaceFile {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
workspaceId String
|
||||
Workspace UserWorkspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
|
||||
|
||||
// File metadata
|
||||
name String // User-visible filename
|
||||
path String // Virtual path (e.g., "/documents/report.pdf")
|
||||
storagePath String // Actual GCS or local storage path
|
||||
mimeType String
|
||||
sizeBytes BigInt
|
||||
checksum String? // SHA256 for integrity
|
||||
|
||||
// File state
|
||||
isDeleted Boolean @default(false)
|
||||
deletedAt DateTime?
|
||||
|
||||
// Source tracking
|
||||
source WorkspaceFileSource @default(UPLOAD)
|
||||
sourceExecId String? // graph_exec_id if from execution
|
||||
sourceSessionId String? // chat_session_id if from CoPilot
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
@@unique([workspaceId, path])
|
||||
@@index([workspaceId, isDeleted])
|
||||
}
|
||||
|
||||
model BuilderSearchHistory {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"use client";
|
||||
|
||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
||||
import type { ReactNode } from "react";
|
||||
import { useEffect } from "react";
|
||||
import { useCopilotStore } from "../../copilot-page-store";
|
||||
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
|
||||
import { LoadingState } from "./components/LoadingState/LoadingState";
|
||||
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
||||
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
||||
import { useCopilotShell } from "./useCopilotShell";
|
||||
@@ -18,21 +20,38 @@ export function CopilotShell({ children }: Props) {
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
isLoading,
|
||||
isCreatingSession,
|
||||
isLoggedIn,
|
||||
hasActiveSession,
|
||||
sessions,
|
||||
currentSessionId,
|
||||
handleSelectSession,
|
||||
handleOpenDrawer,
|
||||
handleCloseDrawer,
|
||||
handleDrawerOpenChange,
|
||||
handleNewChatClick,
|
||||
handleSessionClick,
|
||||
handleNewChat,
|
||||
hasNextPage,
|
||||
isFetchingNextPage,
|
||||
fetchNextPage,
|
||||
isReadyToShowContent,
|
||||
} = useCopilotShell();
|
||||
|
||||
const setNewChatHandler = useCopilotStore((s) => s.setNewChatHandler);
|
||||
const requestNewChat = useCopilotStore((s) => s.requestNewChat);
|
||||
|
||||
useEffect(
|
||||
function registerNewChatHandler() {
|
||||
setNewChatHandler(handleNewChat);
|
||||
return function cleanup() {
|
||||
setNewChatHandler(null);
|
||||
};
|
||||
},
|
||||
[handleNewChat],
|
||||
);
|
||||
|
||||
function handleNewChatClick() {
|
||||
requestNewChat();
|
||||
}
|
||||
|
||||
if (!isLoggedIn) {
|
||||
return (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
@@ -53,7 +72,7 @@ export function CopilotShell({ children }: Props) {
|
||||
isLoading={isLoading}
|
||||
hasNextPage={hasNextPage}
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
onSelectSession={handleSessionClick}
|
||||
onSelectSession={handleSelectSession}
|
||||
onFetchNextPage={fetchNextPage}
|
||||
onNewChat={handleNewChatClick}
|
||||
hasActiveSession={Boolean(hasActiveSession)}
|
||||
@@ -63,18 +82,7 @@ export function CopilotShell({ children }: Props) {
|
||||
<div className="relative flex min-h-0 flex-1 flex-col">
|
||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||
<div className="flex min-h-0 flex-1 flex-col">
|
||||
{isCreatingSession ? (
|
||||
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
||||
<div className="flex flex-col items-center gap-4">
|
||||
<ChatLoader />
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
Creating your chat...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
children
|
||||
)}
|
||||
{isReadyToShowContent ? children : <LoadingState />}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -86,7 +94,7 @@ export function CopilotShell({ children }: Props) {
|
||||
isLoading={isLoading}
|
||||
hasNextPage={hasNextPage}
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
onSelectSession={handleSessionClick}
|
||||
onSelectSession={handleSelectSession}
|
||||
onFetchNextPage={fetchNextPage}
|
||||
onNewChat={handleNewChatClick}
|
||||
onClose={handleCloseDrawer}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||
|
||||
export function LoadingState() {
|
||||
return (
|
||||
<div className="flex flex-1 items-center justify-center">
|
||||
<div className="flex flex-col items-center gap-4">
|
||||
<ChatLoader />
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
Loading your chats...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -3,17 +3,17 @@ import { useState } from "react";
|
||||
export function useMobileDrawer() {
|
||||
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
||||
|
||||
const handleOpenDrawer = () => {
|
||||
function handleOpenDrawer() {
|
||||
setIsDrawerOpen(true);
|
||||
};
|
||||
}
|
||||
|
||||
const handleCloseDrawer = () => {
|
||||
function handleCloseDrawer() {
|
||||
setIsDrawerOpen(false);
|
||||
};
|
||||
}
|
||||
|
||||
const handleDrawerOpenChange = (open: boolean) => {
|
||||
function handleDrawerOpenChange(open: boolean) {
|
||||
setIsDrawerOpen(open);
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
isDrawerOpen,
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import {
|
||||
getGetV2ListSessionsQueryKey,
|
||||
useGetV2ListSessions,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
const PAGE_SIZE = 50;
|
||||
@@ -11,12 +16,12 @@ export interface UseSessionsPaginationArgs {
|
||||
|
||||
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
||||
const [offset, setOffset] = useState(0);
|
||||
|
||||
const [accumulatedSessions, setAccumulatedSessions] = useState<
|
||||
SessionSummaryResponse[]
|
||||
>([]);
|
||||
|
||||
const [totalCount, setTotalCount] = useState<number | null>(null);
|
||||
const queryClient = useQueryClient();
|
||||
const onStreamComplete = useChatStore((state) => state.onStreamComplete);
|
||||
|
||||
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
|
||||
{ limit: PAGE_SIZE, offset },
|
||||
@@ -27,23 +32,38 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
||||
},
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const responseData = okData(data);
|
||||
if (responseData) {
|
||||
const newSessions = responseData.sessions;
|
||||
const total = responseData.total;
|
||||
setTotalCount(total);
|
||||
|
||||
if (offset === 0) {
|
||||
setAccumulatedSessions(newSessions);
|
||||
} else {
|
||||
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
|
||||
}
|
||||
} else if (!enabled) {
|
||||
useEffect(function refreshOnStreamComplete() {
|
||||
const unsubscribe = onStreamComplete(function handleStreamComplete() {
|
||||
setOffset(0);
|
||||
setAccumulatedSessions([]);
|
||||
setTotalCount(null);
|
||||
}
|
||||
}, [data, offset, enabled]);
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
});
|
||||
return unsubscribe;
|
||||
}, []);
|
||||
|
||||
useEffect(
|
||||
function updateSessionsFromResponse() {
|
||||
const responseData = okData(data);
|
||||
if (responseData) {
|
||||
const newSessions = responseData.sessions;
|
||||
const total = responseData.total;
|
||||
setTotalCount(total);
|
||||
|
||||
if (offset === 0) {
|
||||
setAccumulatedSessions(newSessions);
|
||||
} else {
|
||||
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
|
||||
}
|
||||
} else if (!enabled) {
|
||||
setAccumulatedSessions([]);
|
||||
setTotalCount(null);
|
||||
}
|
||||
},
|
||||
[data, offset, enabled],
|
||||
);
|
||||
|
||||
const hasNextPage =
|
||||
totalCount !== null && accumulatedSessions.length < totalCount;
|
||||
@@ -66,17 +86,17 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
||||
}
|
||||
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
|
||||
|
||||
const fetchNextPage = () => {
|
||||
function fetchNextPage() {
|
||||
if (hasNextPage && !isFetching) {
|
||||
setOffset((prev) => prev + PAGE_SIZE);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
const reset = () => {
|
||||
function reset() {
|
||||
setOffset(0);
|
||||
setAccumulatedSessions([]);
|
||||
setTotalCount(null);
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
sessions: accumulatedSessions,
|
||||
|
||||
@@ -104,3 +104,76 @@ export function mergeCurrentSessionIntoList(
|
||||
export function getCurrentSessionId(searchParams: URLSearchParams) {
|
||||
return searchParams.get("sessionId");
|
||||
}
|
||||
|
||||
export function shouldAutoSelectSession(
|
||||
areAllSessionsLoaded: boolean,
|
||||
hasAutoSelectedSession: boolean,
|
||||
paramSessionId: string | null,
|
||||
visibleSessions: SessionSummaryResponse[],
|
||||
accumulatedSessions: SessionSummaryResponse[],
|
||||
isLoading: boolean,
|
||||
totalCount: number | null,
|
||||
) {
|
||||
if (!areAllSessionsLoaded || hasAutoSelectedSession) {
|
||||
return {
|
||||
shouldSelect: false,
|
||||
sessionIdToSelect: null,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (paramSessionId) {
|
||||
return {
|
||||
shouldSelect: false,
|
||||
sessionIdToSelect: null,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (visibleSessions.length > 0) {
|
||||
return {
|
||||
shouldSelect: true,
|
||||
sessionIdToSelect: visibleSessions[0].id,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (accumulatedSessions.length === 0 && !isLoading && totalCount === 0) {
|
||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: true };
|
||||
}
|
||||
|
||||
if (totalCount === 0) {
|
||||
return {
|
||||
shouldSelect: false,
|
||||
sessionIdToSelect: null,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: false };
|
||||
}
|
||||
|
||||
export function checkReadyToShowContent(
|
||||
areAllSessionsLoaded: boolean,
|
||||
paramSessionId: string | null,
|
||||
accumulatedSessions: SessionSummaryResponse[],
|
||||
isCurrentSessionLoading: boolean,
|
||||
currentSessionData: SessionDetailResponse | null | undefined,
|
||||
hasAutoSelectedSession: boolean,
|
||||
) {
|
||||
if (!areAllSessionsLoaded) return false;
|
||||
|
||||
if (paramSessionId) {
|
||||
const sessionFound = accumulatedSessions.some(
|
||||
(s) => s.id === paramSessionId,
|
||||
);
|
||||
return (
|
||||
sessionFound ||
|
||||
(!isCurrentSessionLoading &&
|
||||
currentSessionData !== undefined &&
|
||||
currentSessionData !== null)
|
||||
);
|
||||
}
|
||||
|
||||
return hasAutoSelectedSession;
|
||||
}
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
getGetV2GetSessionQueryKey,
|
||||
getGetV2ListSessionsQueryKey,
|
||||
useGetV2GetSession,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { usePathname, useSearchParams } from "next/navigation";
|
||||
import { useRef } from "react";
|
||||
import { useCopilotStore } from "../../copilot-page-store";
|
||||
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
||||
import { getCurrentSessionId } from "./helpers";
|
||||
import { useShellSessionList } from "./useShellSessionList";
|
||||
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
|
||||
import {
|
||||
checkReadyToShowContent,
|
||||
convertSessionDetailToSummary,
|
||||
filterVisibleSessions,
|
||||
getCurrentSessionId,
|
||||
mergeCurrentSessionIntoList,
|
||||
} from "./helpers";
|
||||
|
||||
export function useCopilotShell() {
|
||||
const pathname = usePathname();
|
||||
@@ -27,7 +31,7 @@ export function useCopilotShell() {
|
||||
const isMobile =
|
||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||
|
||||
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
|
||||
const [, setUrlSessionId] = useQueryState("sessionId", parseAsString);
|
||||
|
||||
const isOnHomepage = pathname === "/copilot";
|
||||
const paramSessionId = searchParams.get("sessionId");
|
||||
@@ -41,80 +45,123 @@ export function useCopilotShell() {
|
||||
|
||||
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
|
||||
|
||||
const {
|
||||
sessions: accumulatedSessions,
|
||||
isLoading: isSessionsLoading,
|
||||
isFetching: isSessionsFetching,
|
||||
hasNextPage,
|
||||
areAllSessionsLoaded,
|
||||
fetchNextPage,
|
||||
reset: resetPagination,
|
||||
} = useSessionsPagination({
|
||||
enabled: paginationEnabled,
|
||||
});
|
||||
|
||||
const currentSessionId = getCurrentSessionId(searchParams);
|
||||
|
||||
const { data: currentSessionData } = useGetV2GetSession(
|
||||
currentSessionId || "",
|
||||
{
|
||||
const { data: currentSessionData, isLoading: isCurrentSessionLoading } =
|
||||
useGetV2GetSession(currentSessionId || "", {
|
||||
query: {
|
||||
enabled: !!currentSessionId,
|
||||
select: okData,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const {
|
||||
sessions,
|
||||
isLoading,
|
||||
isSessionsFetching,
|
||||
hasNextPage,
|
||||
fetchNextPage,
|
||||
resetPagination,
|
||||
recentlyCreatedSessionsRef,
|
||||
} = useShellSessionList({
|
||||
paginationEnabled,
|
||||
currentSessionId,
|
||||
currentSessionData,
|
||||
isOnHomepage,
|
||||
paramSessionId,
|
||||
});
|
||||
|
||||
const stopStream = useChatStore((s) => s.stopStream);
|
||||
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||
const isStreaming = useCopilotStore((s) => s.isStreaming);
|
||||
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
||||
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
|
||||
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
|
||||
|
||||
const pendingActionRef = useRef<(() => void) | null>(null);
|
||||
|
||||
async function stopCurrentStream() {
|
||||
if (!currentSessionId) return;
|
||||
|
||||
setIsSwitchingSession(true);
|
||||
await new Promise<void>((resolve) => {
|
||||
const unsubscribe = onStreamComplete((completedId) => {
|
||||
if (completedId === currentSessionId) {
|
||||
clearTimeout(timeout);
|
||||
unsubscribe();
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
const timeout = setTimeout(() => {
|
||||
unsubscribe();
|
||||
resolve();
|
||||
}, 3000);
|
||||
stopStream(currentSessionId);
|
||||
});
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
|
||||
});
|
||||
setIsSwitchingSession(false);
|
||||
}
|
||||
const [hasAutoSelectedSession, setHasAutoSelectedSession] = useState(false);
|
||||
const hasAutoSelectedRef = useRef(false);
|
||||
const recentlyCreatedSessionsRef = useRef<
|
||||
Map<string, SessionSummaryResponse>
|
||||
>(new Map());
|
||||
|
||||
function selectSession(sessionId: string) {
|
||||
if (sessionId === currentSessionId) return;
|
||||
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
||||
// Mark as auto-selected when sessionId is in URL
|
||||
useEffect(() => {
|
||||
if (paramSessionId && !hasAutoSelectedRef.current) {
|
||||
hasAutoSelectedRef.current = true;
|
||||
setHasAutoSelectedSession(true);
|
||||
}
|
||||
}, [paramSessionId]);
|
||||
|
||||
// On homepage without sessionId, mark as ready immediately
|
||||
useEffect(() => {
|
||||
if (isOnHomepage && !paramSessionId && !hasAutoSelectedRef.current) {
|
||||
hasAutoSelectedRef.current = true;
|
||||
setHasAutoSelectedSession(true);
|
||||
}
|
||||
}, [isOnHomepage, paramSessionId]);
|
||||
|
||||
// Invalidate sessions list when navigating to homepage (to show newly created sessions)
|
||||
useEffect(() => {
|
||||
if (isOnHomepage && !paramSessionId) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
}
|
||||
}, [isOnHomepage, paramSessionId, queryClient]);
|
||||
|
||||
// Track newly created sessions to ensure they stay visible even when switching away
|
||||
useEffect(() => {
|
||||
if (currentSessionId && currentSessionData) {
|
||||
const isNewSession =
|
||||
currentSessionData.updated_at === currentSessionData.created_at;
|
||||
const isNotInAccumulated = !accumulatedSessions.some(
|
||||
(s) => s.id === currentSessionId,
|
||||
);
|
||||
if (isNewSession || isNotInAccumulated) {
|
||||
const summary = convertSessionDetailToSummary(currentSessionData);
|
||||
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
|
||||
}
|
||||
}
|
||||
}, [currentSessionId, currentSessionData, accumulatedSessions]);
|
||||
|
||||
// Clean up recently created sessions that are now in the accumulated list
|
||||
useEffect(() => {
|
||||
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
|
||||
if (accumulatedSessions.some((s) => s.id === sessionId)) {
|
||||
recentlyCreatedSessionsRef.current.delete(sessionId);
|
||||
}
|
||||
}
|
||||
}, [accumulatedSessions]);
|
||||
|
||||
// Reset pagination when query becomes disabled
|
||||
const prevPaginationEnabledRef = useRef(paginationEnabled);
|
||||
useEffect(() => {
|
||||
if (prevPaginationEnabledRef.current && !paginationEnabled) {
|
||||
resetPagination();
|
||||
resetAutoSelect();
|
||||
}
|
||||
prevPaginationEnabledRef.current = paginationEnabled;
|
||||
}, [paginationEnabled, resetPagination]);
|
||||
|
||||
const sessions = mergeCurrentSessionIntoList(
|
||||
accumulatedSessions,
|
||||
currentSessionId,
|
||||
currentSessionData,
|
||||
recentlyCreatedSessionsRef.current,
|
||||
);
|
||||
|
||||
const visibleSessions = filterVisibleSessions(sessions);
|
||||
|
||||
const sidebarSelectedSessionId =
|
||||
isOnHomepage && !paramSessionId ? null : currentSessionId;
|
||||
|
||||
const isReadyToShowContent = isOnHomepage
|
||||
? true
|
||||
: checkReadyToShowContent(
|
||||
areAllSessionsLoaded,
|
||||
paramSessionId,
|
||||
accumulatedSessions,
|
||||
isCurrentSessionLoading,
|
||||
currentSessionData,
|
||||
hasAutoSelectedSession,
|
||||
);
|
||||
|
||||
function handleSelectSession(sessionId: string) {
|
||||
setUrlSessionId(sessionId, { shallow: false });
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function startNewChat() {
|
||||
function handleNewChat() {
|
||||
resetAutoSelect();
|
||||
resetPagination();
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
@@ -123,31 +170,12 @@ export function useCopilotShell() {
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function handleSessionClick(sessionId: string) {
|
||||
if (sessionId === currentSessionId) return;
|
||||
|
||||
if (isStreaming) {
|
||||
pendingActionRef.current = async () => {
|
||||
await stopCurrentStream();
|
||||
selectSession(sessionId);
|
||||
};
|
||||
openInterruptModal(pendingActionRef.current);
|
||||
} else {
|
||||
selectSession(sessionId);
|
||||
}
|
||||
function resetAutoSelect() {
|
||||
hasAutoSelectedRef.current = false;
|
||||
setHasAutoSelectedSession(false);
|
||||
}
|
||||
|
||||
function handleNewChatClick() {
|
||||
if (isStreaming) {
|
||||
pendingActionRef.current = async () => {
|
||||
await stopCurrentStream();
|
||||
startNewChat();
|
||||
};
|
||||
openInterruptModal(pendingActionRef.current);
|
||||
} else {
|
||||
startNewChat();
|
||||
}
|
||||
}
|
||||
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
|
||||
|
||||
return {
|
||||
isMobile,
|
||||
@@ -155,17 +183,17 @@ export function useCopilotShell() {
|
||||
isLoggedIn,
|
||||
hasActiveSession:
|
||||
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
|
||||
isLoading: isLoading || isCreatingSession,
|
||||
isCreatingSession,
|
||||
sessions,
|
||||
currentSessionId: urlSessionId,
|
||||
isLoading,
|
||||
sessions: visibleSessions,
|
||||
currentSessionId: sidebarSelectedSessionId,
|
||||
handleSelectSession,
|
||||
handleOpenDrawer,
|
||||
handleCloseDrawer,
|
||||
handleDrawerOpenChange,
|
||||
handleNewChatClick,
|
||||
handleSessionClick,
|
||||
handleNewChat,
|
||||
hasNextPage,
|
||||
isFetchingNextPage: isSessionsFetching,
|
||||
fetchNextPage,
|
||||
isReadyToShowContent,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
import { getGetV2ListSessionsQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useMemo, useRef } from "react";
|
||||
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
|
||||
import {
|
||||
convertSessionDetailToSummary,
|
||||
filterVisibleSessions,
|
||||
mergeCurrentSessionIntoList,
|
||||
} from "./helpers";
|
||||
|
||||
interface UseShellSessionListArgs {
|
||||
paginationEnabled: boolean;
|
||||
currentSessionId: string | null;
|
||||
currentSessionData: SessionDetailResponse | null | undefined;
|
||||
isOnHomepage: boolean;
|
||||
paramSessionId: string | null;
|
||||
}
|
||||
|
||||
export function useShellSessionList({
|
||||
paginationEnabled,
|
||||
currentSessionId,
|
||||
currentSessionData,
|
||||
isOnHomepage,
|
||||
paramSessionId,
|
||||
}: UseShellSessionListArgs) {
|
||||
const queryClient = useQueryClient();
|
||||
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||
|
||||
const {
|
||||
sessions: accumulatedSessions,
|
||||
isLoading: isSessionsLoading,
|
||||
isFetching: isSessionsFetching,
|
||||
hasNextPage,
|
||||
fetchNextPage,
|
||||
reset: resetPagination,
|
||||
} = useSessionsPagination({
|
||||
enabled: paginationEnabled,
|
||||
});
|
||||
|
||||
const recentlyCreatedSessionsRef = useRef<
|
||||
Map<string, SessionSummaryResponse>
|
||||
>(new Map());
|
||||
|
||||
useEffect(() => {
|
||||
if (isOnHomepage && !paramSessionId) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
}
|
||||
}, [isOnHomepage, paramSessionId, queryClient]);
|
||||
|
||||
useEffect(() => {
|
||||
if (currentSessionId && currentSessionData) {
|
||||
const isNewSession =
|
||||
currentSessionData.updated_at === currentSessionData.created_at;
|
||||
const isNotInAccumulated = !accumulatedSessions.some(
|
||||
(s) => s.id === currentSessionId,
|
||||
);
|
||||
if (isNewSession || isNotInAccumulated) {
|
||||
const summary = convertSessionDetailToSummary(currentSessionData);
|
||||
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
|
||||
}
|
||||
}
|
||||
}, [currentSessionId, currentSessionData, accumulatedSessions]);
|
||||
|
||||
useEffect(() => {
|
||||
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
|
||||
if (accumulatedSessions.some((s) => s.id === sessionId)) {
|
||||
recentlyCreatedSessionsRef.current.delete(sessionId);
|
||||
}
|
||||
}
|
||||
}, [accumulatedSessions]);
|
||||
|
||||
useEffect(() => {
|
||||
const unsubscribe = onStreamComplete(() => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
});
|
||||
return unsubscribe;
|
||||
}, [onStreamComplete, queryClient]);
|
||||
|
||||
const sessions = useMemo(
|
||||
() =>
|
||||
mergeCurrentSessionIntoList(
|
||||
accumulatedSessions,
|
||||
currentSessionId,
|
||||
currentSessionData,
|
||||
recentlyCreatedSessionsRef.current,
|
||||
),
|
||||
[accumulatedSessions, currentSessionId, currentSessionData],
|
||||
);
|
||||
|
||||
const visibleSessions = useMemo(
|
||||
() => filterVisibleSessions(sessions),
|
||||
[sessions],
|
||||
);
|
||||
|
||||
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
|
||||
|
||||
return {
|
||||
sessions: visibleSessions,
|
||||
isLoading,
|
||||
isSessionsFetching,
|
||||
hasNextPage,
|
||||
fetchNextPage,
|
||||
resetPagination,
|
||||
recentlyCreatedSessionsRef,
|
||||
};
|
||||
}
|
||||
@@ -4,53 +4,51 @@ import { create } from "zustand";
|
||||
|
||||
interface CopilotStoreState {
|
||||
isStreaming: boolean;
|
||||
isSwitchingSession: boolean;
|
||||
isCreatingSession: boolean;
|
||||
isInterruptModalOpen: boolean;
|
||||
pendingAction: (() => void) | null;
|
||||
isNewChatModalOpen: boolean;
|
||||
newChatHandler: (() => void) | null;
|
||||
}
|
||||
|
||||
interface CopilotStoreActions {
|
||||
setIsStreaming: (isStreaming: boolean) => void;
|
||||
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
|
||||
setIsCreatingSession: (isCreating: boolean) => void;
|
||||
openInterruptModal: (onConfirm: () => void) => void;
|
||||
confirmInterrupt: () => void;
|
||||
cancelInterrupt: () => void;
|
||||
setNewChatHandler: (handler: (() => void) | null) => void;
|
||||
requestNewChat: () => void;
|
||||
confirmNewChat: () => void;
|
||||
cancelNewChat: () => void;
|
||||
}
|
||||
|
||||
type CopilotStore = CopilotStoreState & CopilotStoreActions;
|
||||
|
||||
export const useCopilotStore = create<CopilotStore>((set, get) => ({
|
||||
isStreaming: false,
|
||||
isSwitchingSession: false,
|
||||
isCreatingSession: false,
|
||||
isInterruptModalOpen: false,
|
||||
pendingAction: null,
|
||||
isNewChatModalOpen: false,
|
||||
newChatHandler: null,
|
||||
|
||||
setIsStreaming(isStreaming) {
|
||||
set({ isStreaming });
|
||||
},
|
||||
|
||||
setIsSwitchingSession(isSwitchingSession) {
|
||||
set({ isSwitchingSession });
|
||||
setNewChatHandler(handler) {
|
||||
set({ newChatHandler: handler });
|
||||
},
|
||||
|
||||
setIsCreatingSession(isCreatingSession) {
|
||||
set({ isCreatingSession });
|
||||
requestNewChat() {
|
||||
const { isStreaming, newChatHandler } = get();
|
||||
if (isStreaming) {
|
||||
set({ isNewChatModalOpen: true });
|
||||
} else if (newChatHandler) {
|
||||
newChatHandler();
|
||||
}
|
||||
},
|
||||
|
||||
openInterruptModal(onConfirm) {
|
||||
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
|
||||
confirmNewChat() {
|
||||
const { newChatHandler } = get();
|
||||
set({ isNewChatModalOpen: false });
|
||||
if (newChatHandler) {
|
||||
newChatHandler();
|
||||
}
|
||||
},
|
||||
|
||||
confirmInterrupt() {
|
||||
const { pendingAction } = get();
|
||||
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||
if (pendingAction) pendingAction();
|
||||
},
|
||||
|
||||
cancelInterrupt() {
|
||||
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||
cancelNewChat() {
|
||||
set({ isNewChatModalOpen: false });
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -1,5 +1,28 @@
|
||||
import type { User } from "@supabase/supabase-js";
|
||||
|
||||
export type PageState =
|
||||
| { type: "welcome" }
|
||||
| { type: "newChat" }
|
||||
| { type: "creating"; prompt: string }
|
||||
| { type: "chat"; sessionId: string; initialPrompt?: string };
|
||||
|
||||
export function getInitialPromptFromState(
|
||||
pageState: PageState,
|
||||
storedInitialPrompt: string | undefined,
|
||||
) {
|
||||
if (storedInitialPrompt) return storedInitialPrompt;
|
||||
if (pageState.type === "creating") return pageState.prompt;
|
||||
if (pageState.type === "chat") return pageState.initialPrompt;
|
||||
}
|
||||
|
||||
export function shouldResetToWelcome(pageState: PageState) {
|
||||
return (
|
||||
pageState.type !== "newChat" &&
|
||||
pageState.type !== "creating" &&
|
||||
pageState.type !== "welcome"
|
||||
);
|
||||
}
|
||||
|
||||
export function getGreetingName(user?: User | null): string {
|
||||
if (!user) return "there";
|
||||
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
|
||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useCopilotStore } from "./copilot-page-store";
|
||||
import { useCopilotPage } from "./useCopilotPage";
|
||||
|
||||
export default function CopilotPage() {
|
||||
const { state, handlers } = useCopilotPage();
|
||||
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||
const confirmNewChat = useCopilotStore((s) => s.confirmNewChat);
|
||||
const {
|
||||
greetingName,
|
||||
quickActions,
|
||||
isLoading,
|
||||
hasSession,
|
||||
initialPrompt,
|
||||
pageState,
|
||||
isNewChatModalOpen,
|
||||
isReady,
|
||||
} = state;
|
||||
const {
|
||||
@@ -27,16 +27,20 @@ export default function CopilotPage() {
|
||||
startChatWithPrompt,
|
||||
handleSessionNotFound,
|
||||
handleStreamingChange,
|
||||
handleCancelNewChat,
|
||||
handleNewChatModalOpen,
|
||||
} = handlers;
|
||||
|
||||
if (!isReady) return null;
|
||||
|
||||
if (hasSession) {
|
||||
if (pageState.type === "chat") {
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
<Chat
|
||||
key={pageState.sessionId ?? "welcome"}
|
||||
className="flex-1"
|
||||
initialPrompt={initialPrompt}
|
||||
urlSessionId={pageState.sessionId}
|
||||
initialPrompt={pageState.initialPrompt}
|
||||
onSessionNotFound={handleSessionNotFound}
|
||||
onStreamingChange={handleStreamingChange}
|
||||
/>
|
||||
@@ -44,33 +48,31 @@ export default function CopilotPage() {
|
||||
title="Interrupt current chat?"
|
||||
styling={{ maxWidth: 300, width: "100%" }}
|
||||
controlled={{
|
||||
isOpen: isInterruptModalOpen,
|
||||
set: (open) => {
|
||||
if (!open) cancelInterrupt();
|
||||
},
|
||||
isOpen: isNewChatModalOpen,
|
||||
set: handleNewChatModalOpen,
|
||||
}}
|
||||
onClose={cancelInterrupt}
|
||||
onClose={handleCancelNewChat}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="flex flex-col gap-4">
|
||||
<Text variant="body">
|
||||
The current chat response will be interrupted. Are you sure you
|
||||
want to continue?
|
||||
want to start a new chat?
|
||||
</Text>
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={cancelInterrupt}
|
||||
onClick={handleCancelNewChat}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="primary"
|
||||
onClick={confirmInterrupt}
|
||||
onClick={confirmNewChat}
|
||||
>
|
||||
Continue
|
||||
Start new chat
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</div>
|
||||
@@ -80,6 +82,19 @@ export default function CopilotPage() {
|
||||
);
|
||||
}
|
||||
|
||||
if (pageState.type === "newChat" || pageState.type === "creating") {
|
||||
return (
|
||||
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
||||
<div className="flex flex-col items-center gap-4">
|
||||
<ChatLoader />
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
Loading your chats...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
||||
<div className="w-full text-center">
|
||||
|
||||
@@ -5,40 +5,79 @@ import {
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||
import {
|
||||
Flag,
|
||||
type FlagValues,
|
||||
useGetFlag,
|
||||
} from "@/services/feature-flags/use-get-flag";
|
||||
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
import { useEffect, useReducer } from "react";
|
||||
import { useCopilotStore } from "./copilot-page-store";
|
||||
import { getGreetingName, getQuickActions } from "./helpers";
|
||||
import { useCopilotSessionId } from "./useCopilotSessionId";
|
||||
import { getGreetingName, getQuickActions, type PageState } from "./helpers";
|
||||
import { useCopilotURLState } from "./useCopilotURLState";
|
||||
|
||||
type CopilotState = {
|
||||
pageState: PageState;
|
||||
initialPrompts: Record<string, string>;
|
||||
previousSessionId: string | null;
|
||||
};
|
||||
|
||||
type CopilotAction =
|
||||
| { type: "setPageState"; pageState: PageState }
|
||||
| { type: "setInitialPrompt"; sessionId: string; prompt: string }
|
||||
| { type: "setPreviousSessionId"; sessionId: string | null };
|
||||
|
||||
function isSamePageState(next: PageState, current: PageState) {
|
||||
if (next.type !== current.type) return false;
|
||||
if (next.type === "creating" && current.type === "creating") {
|
||||
return next.prompt === current.prompt;
|
||||
}
|
||||
if (next.type === "chat" && current.type === "chat") {
|
||||
return (
|
||||
next.sessionId === current.sessionId &&
|
||||
next.initialPrompt === current.initialPrompt
|
||||
);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
function copilotReducer(
|
||||
state: CopilotState,
|
||||
action: CopilotAction,
|
||||
): CopilotState {
|
||||
if (action.type === "setPageState") {
|
||||
if (isSamePageState(action.pageState, state.pageState)) return state;
|
||||
return { ...state, pageState: action.pageState };
|
||||
}
|
||||
if (action.type === "setInitialPrompt") {
|
||||
if (state.initialPrompts[action.sessionId] === action.prompt) return state;
|
||||
return {
|
||||
...state,
|
||||
initialPrompts: {
|
||||
...state.initialPrompts,
|
||||
[action.sessionId]: action.prompt,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (action.type === "setPreviousSessionId") {
|
||||
if (state.previousSessionId === action.sessionId) return state;
|
||||
return { ...state, previousSessionId: action.sessionId };
|
||||
}
|
||||
return state;
|
||||
}
|
||||
|
||||
export function useCopilotPage() {
|
||||
const router = useRouter();
|
||||
const queryClient = useQueryClient();
|
||||
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
||||
const { toast } = useToast();
|
||||
const { completeStep } = useOnboarding();
|
||||
|
||||
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
|
||||
const isNewChatModalOpen = useCopilotStore((s) => s.isNewChatModalOpen);
|
||||
const setIsStreaming = useCopilotStore((s) => s.setIsStreaming);
|
||||
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
||||
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
||||
|
||||
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
|
||||
useEffect(() => {
|
||||
if (isLoggedIn) {
|
||||
completeStep("VISIT_COPILOT");
|
||||
}
|
||||
}, [completeStep, isLoggedIn]);
|
||||
const cancelNewChat = useCopilotStore((s) => s.cancelNewChat);
|
||||
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const flags = useFlags<FlagValues>();
|
||||
@@ -49,27 +88,72 @@ export function useCopilotPage() {
|
||||
const isFlagReady =
|
||||
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
||||
|
||||
const [state, dispatch] = useReducer(copilotReducer, {
|
||||
pageState: { type: "welcome" },
|
||||
initialPrompts: {},
|
||||
previousSessionId: null,
|
||||
});
|
||||
|
||||
const greetingName = getGreetingName(user);
|
||||
const quickActions = getQuickActions();
|
||||
|
||||
const hasSession = Boolean(urlSessionId);
|
||||
const initialPrompt = urlSessionId
|
||||
? getInitialPrompt(urlSessionId)
|
||||
: undefined;
|
||||
function setPageState(pageState: PageState) {
|
||||
dispatch({ type: "setPageState", pageState });
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (!isFlagReady) return;
|
||||
if (isChatEnabled === false) {
|
||||
router.replace(homepageRoute);
|
||||
}
|
||||
}, [homepageRoute, isChatEnabled, isFlagReady, router]);
|
||||
function setInitialPrompt(sessionId: string, prompt: string) {
|
||||
dispatch({ type: "setInitialPrompt", sessionId, prompt });
|
||||
}
|
||||
|
||||
function setPreviousSessionId(sessionId: string | null) {
|
||||
dispatch({ type: "setPreviousSessionId", sessionId });
|
||||
}
|
||||
|
||||
const { setUrlSessionId } = useCopilotURLState({
|
||||
pageState: state.pageState,
|
||||
initialPrompts: state.initialPrompts,
|
||||
previousSessionId: state.previousSessionId,
|
||||
setPageState,
|
||||
setInitialPrompt,
|
||||
setPreviousSessionId,
|
||||
});
|
||||
|
||||
useEffect(
|
||||
function transitionNewChatToWelcome() {
|
||||
if (state.pageState.type === "newChat") {
|
||||
function setWelcomeState() {
|
||||
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
|
||||
}
|
||||
|
||||
const timer = setTimeout(setWelcomeState, 300);
|
||||
|
||||
return function cleanup() {
|
||||
clearTimeout(timer);
|
||||
};
|
||||
}
|
||||
},
|
||||
[state.pageState.type],
|
||||
);
|
||||
|
||||
useEffect(
|
||||
function ensureAccess() {
|
||||
if (!isFlagReady) return;
|
||||
if (isChatEnabled === false) {
|
||||
router.replace(homepageRoute);
|
||||
}
|
||||
},
|
||||
[homepageRoute, isChatEnabled, isFlagReady, router],
|
||||
);
|
||||
|
||||
async function startChatWithPrompt(prompt: string) {
|
||||
if (!prompt?.trim()) return;
|
||||
if (isCreating) return;
|
||||
if (state.pageState.type === "creating") return;
|
||||
|
||||
const trimmedPrompt = prompt.trim();
|
||||
setIsCreating(true);
|
||||
dispatch({
|
||||
type: "setPageState",
|
||||
pageState: { type: "creating", prompt: trimmedPrompt },
|
||||
});
|
||||
|
||||
try {
|
||||
const sessionResponse = await postV2CreateSession({
|
||||
@@ -81,19 +165,27 @@ export function useCopilotPage() {
|
||||
}
|
||||
|
||||
const sessionId = sessionResponse.data.id;
|
||||
setInitialPrompt(sessionId, trimmedPrompt);
|
||||
|
||||
dispatch({
|
||||
type: "setInitialPrompt",
|
||||
sessionId,
|
||||
prompt: trimmedPrompt,
|
||||
});
|
||||
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
|
||||
await setUrlSessionId(sessionId, { shallow: true });
|
||||
await setUrlSessionId(sessionId, { shallow: false });
|
||||
dispatch({
|
||||
type: "setPageState",
|
||||
pageState: { type: "chat", sessionId, initialPrompt: trimmedPrompt },
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("[CopilotPage] Failed to start chat:", error);
|
||||
toast({ title: "Failed to start chat", variant: "destructive" });
|
||||
Sentry.captureException(error);
|
||||
} finally {
|
||||
setIsCreating(false);
|
||||
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,13 +201,21 @@ export function useCopilotPage() {
|
||||
setIsStreaming(isStreamingValue);
|
||||
}
|
||||
|
||||
function handleCancelNewChat() {
|
||||
cancelNewChat();
|
||||
}
|
||||
|
||||
function handleNewChatModalOpen(isOpen: boolean) {
|
||||
if (!isOpen) cancelNewChat();
|
||||
}
|
||||
|
||||
return {
|
||||
state: {
|
||||
greetingName,
|
||||
quickActions,
|
||||
isLoading: isUserLoading,
|
||||
hasSession,
|
||||
initialPrompt,
|
||||
pageState: state.pageState,
|
||||
isNewChatModalOpen,
|
||||
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
|
||||
},
|
||||
handlers: {
|
||||
@@ -123,32 +223,8 @@ export function useCopilotPage() {
|
||||
startChatWithPrompt,
|
||||
handleSessionNotFound,
|
||||
handleStreamingChange,
|
||||
handleCancelNewChat,
|
||||
handleNewChatModalOpen,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function getInitialPrompt(sessionId: string): string | undefined {
|
||||
try {
|
||||
const prompts = JSON.parse(
|
||||
sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}",
|
||||
);
|
||||
return prompts[sessionId];
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
function setInitialPrompt(sessionId: string, prompt: string): void {
|
||||
try {
|
||||
const prompts = JSON.parse(
|
||||
sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}",
|
||||
);
|
||||
prompts[sessionId] = prompt;
|
||||
sessionStorage.set(
|
||||
SessionKey.CHAT_INITIAL_PROMPTS,
|
||||
JSON.stringify(prompts),
|
||||
);
|
||||
} catch {
|
||||
// Ignore storage errors
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
|
||||
export function useCopilotSessionId() {
|
||||
const [urlSessionId, setUrlSessionId] = useQueryState(
|
||||
"sessionId",
|
||||
parseAsString,
|
||||
);
|
||||
|
||||
return { urlSessionId, setUrlSessionId };
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { useLayoutEffect } from "react";
|
||||
import {
|
||||
getInitialPromptFromState,
|
||||
type PageState,
|
||||
shouldResetToWelcome,
|
||||
} from "./helpers";
|
||||
|
||||
interface UseCopilotUrlStateArgs {
|
||||
pageState: PageState;
|
||||
initialPrompts: Record<string, string>;
|
||||
previousSessionId: string | null;
|
||||
setPageState: (pageState: PageState) => void;
|
||||
setInitialPrompt: (sessionId: string, prompt: string) => void;
|
||||
setPreviousSessionId: (sessionId: string | null) => void;
|
||||
}
|
||||
|
||||
export function useCopilotURLState({
|
||||
pageState,
|
||||
initialPrompts,
|
||||
previousSessionId,
|
||||
setPageState,
|
||||
setInitialPrompt,
|
||||
setPreviousSessionId,
|
||||
}: UseCopilotUrlStateArgs) {
|
||||
const [urlSessionId, setUrlSessionId] = useQueryState(
|
||||
"sessionId",
|
||||
parseAsString,
|
||||
);
|
||||
|
||||
function syncSessionFromUrl() {
|
||||
if (urlSessionId) {
|
||||
if (pageState.type === "chat" && pageState.sessionId === urlSessionId) {
|
||||
setPreviousSessionId(urlSessionId);
|
||||
return;
|
||||
}
|
||||
|
||||
const storedInitialPrompt = initialPrompts[urlSessionId];
|
||||
const currentInitialPrompt = getInitialPromptFromState(
|
||||
pageState,
|
||||
storedInitialPrompt,
|
||||
);
|
||||
|
||||
if (currentInitialPrompt) {
|
||||
setInitialPrompt(urlSessionId, currentInitialPrompt);
|
||||
}
|
||||
|
||||
setPageState({
|
||||
type: "chat",
|
||||
sessionId: urlSessionId,
|
||||
initialPrompt: currentInitialPrompt,
|
||||
});
|
||||
setPreviousSessionId(urlSessionId);
|
||||
return;
|
||||
}
|
||||
|
||||
const wasInChat = previousSessionId !== null && pageState.type === "chat";
|
||||
setPreviousSessionId(null);
|
||||
if (wasInChat) {
|
||||
setPageState({ type: "newChat" });
|
||||
return;
|
||||
}
|
||||
|
||||
if (shouldResetToWelcome(pageState)) {
|
||||
setPageState({ type: "welcome" });
|
||||
}
|
||||
}
|
||||
|
||||
useLayoutEffect(syncSessionFromUrl, [
|
||||
urlSessionId,
|
||||
pageState.type,
|
||||
previousSessionId,
|
||||
initialPrompts,
|
||||
]);
|
||||
|
||||
return {
|
||||
urlSessionId,
|
||||
setUrlSessionId,
|
||||
};
|
||||
}
|
||||
@@ -4594,7 +4594,6 @@
|
||||
"AGENT_NEW_RUN",
|
||||
"AGENT_INPUT",
|
||||
"CONGRATS",
|
||||
"VISIT_COPILOT",
|
||||
"MARKETPLACE_VISIT",
|
||||
"BUILDER_OPEN"
|
||||
],
|
||||
@@ -5928,40 +5927,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/workspace/files/{file_id}/download": {
|
||||
"get": {
|
||||
"tags": ["v2", "workspace"],
|
||||
"summary": "Download file by ID",
|
||||
"description": "Download a file by its ID.\n\nReturns the file content directly or redirects to a signed URL for GCS.",
|
||||
"operationId": "getV2Download file by id",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "file_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "File Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": { "application/json": { "schema": {} } }
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/health": {
|
||||
"get": {
|
||||
"tags": ["health"],
|
||||
@@ -8789,7 +8754,6 @@
|
||||
"AGENT_NEW_RUN",
|
||||
"AGENT_INPUT",
|
||||
"CONGRATS",
|
||||
"VISIT_COPILOT",
|
||||
"GET_RESULTS",
|
||||
"MARKETPLACE_VISIT",
|
||||
"MARKETPLACE_ADD_AGENT",
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import {
|
||||
ApiError,
|
||||
getServerAuthToken,
|
||||
makeAuthenticatedFileUpload,
|
||||
makeAuthenticatedRequest,
|
||||
} from "@/lib/autogpt-server-api/helpers";
|
||||
@@ -16,69 +15,6 @@ function buildBackendUrl(path: string[], queryString: string): string {
|
||||
return `${environment.getAGPTServerBaseUrl()}/${backendPath}${queryString}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if this is a workspace file download request that needs binary response handling.
|
||||
*/
|
||||
function isWorkspaceDownloadRequest(path: string[]): boolean {
|
||||
// Match pattern: api/workspace/files/{id}/download (5 segments)
|
||||
return (
|
||||
path.length >= 5 &&
|
||||
path[0] === "api" &&
|
||||
path[1] === "workspace" &&
|
||||
path[2] === "files" &&
|
||||
path[path.length - 1] === "download"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle workspace file download requests with proper binary response streaming.
|
||||
*/
|
||||
async function handleWorkspaceDownload(
|
||||
req: NextRequest,
|
||||
backendUrl: string,
|
||||
): Promise<NextResponse> {
|
||||
const token = await getServerAuthToken();
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (token && token !== "no-token-found") {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
const response = await fetch(backendUrl, {
|
||||
method: "GET",
|
||||
headers,
|
||||
redirect: "follow", // Follow redirects to signed URLs
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to download file: ${response.statusText}` },
|
||||
{ status: response.status },
|
||||
);
|
||||
}
|
||||
|
||||
// Get the content type from the backend response
|
||||
const contentType =
|
||||
response.headers.get("Content-Type") || "application/octet-stream";
|
||||
const contentDisposition = response.headers.get("Content-Disposition");
|
||||
|
||||
// Stream the response body
|
||||
const responseHeaders: Record<string, string> = {
|
||||
"Content-Type": contentType,
|
||||
};
|
||||
|
||||
if (contentDisposition) {
|
||||
responseHeaders["Content-Disposition"] = contentDisposition;
|
||||
}
|
||||
|
||||
// Return the binary content
|
||||
const arrayBuffer = await response.arrayBuffer();
|
||||
return new NextResponse(arrayBuffer, {
|
||||
status: 200,
|
||||
headers: responseHeaders,
|
||||
});
|
||||
}
|
||||
|
||||
async function handleJsonRequest(
|
||||
req: NextRequest,
|
||||
method: string,
|
||||
@@ -244,11 +180,6 @@ async function handler(
|
||||
};
|
||||
|
||||
try {
|
||||
// Handle workspace file downloads separately (binary response)
|
||||
if (method === "GET" && isWorkspaceDownloadRequest(path)) {
|
||||
return await handleWorkspaceDownload(req, backendUrl);
|
||||
}
|
||||
|
||||
if (method === "GET" || method === "DELETE") {
|
||||
responseBody = await handleGetDeleteRequest(method, backendUrl, req);
|
||||
} else if (contentType?.includes("application/json")) {
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
"use client";
|
||||
|
||||
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
|
||||
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||
import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState";
|
||||
import { ChatLoader } from "./components/ChatLoader/ChatLoader";
|
||||
import { useChat } from "./useChat";
|
||||
|
||||
export interface ChatProps {
|
||||
className?: string;
|
||||
urlSessionId?: string | null;
|
||||
initialPrompt?: string;
|
||||
onSessionNotFound?: () => void;
|
||||
onStreamingChange?: (isStreaming: boolean) => void;
|
||||
@@ -19,13 +18,12 @@ export interface ChatProps {
|
||||
|
||||
export function Chat({
|
||||
className,
|
||||
urlSessionId,
|
||||
initialPrompt,
|
||||
onSessionNotFound,
|
||||
onStreamingChange,
|
||||
}: ChatProps) {
|
||||
const { urlSessionId } = useCopilotSessionId();
|
||||
const hasHandledNotFoundRef = useRef(false);
|
||||
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
|
||||
const {
|
||||
messages,
|
||||
isLoading,
|
||||
@@ -35,59 +33,49 @@ export function Chat({
|
||||
sessionId,
|
||||
createSession,
|
||||
showLoader,
|
||||
startPollingForOperation,
|
||||
} = useChat({ urlSessionId });
|
||||
|
||||
useEffect(() => {
|
||||
if (!onSessionNotFound) return;
|
||||
if (!urlSessionId) return;
|
||||
if (!isSessionNotFound || isLoading || isCreating) return;
|
||||
if (hasHandledNotFoundRef.current) return;
|
||||
hasHandledNotFoundRef.current = true;
|
||||
onSessionNotFound();
|
||||
}, [
|
||||
onSessionNotFound,
|
||||
urlSessionId,
|
||||
isSessionNotFound,
|
||||
isLoading,
|
||||
isCreating,
|
||||
]);
|
||||
|
||||
const shouldShowLoader =
|
||||
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
|
||||
useEffect(
|
||||
function handleMissingSession() {
|
||||
if (!onSessionNotFound) return;
|
||||
if (!urlSessionId) return;
|
||||
if (!isSessionNotFound || isLoading || isCreating) return;
|
||||
if (hasHandledNotFoundRef.current) return;
|
||||
hasHandledNotFoundRef.current = true;
|
||||
onSessionNotFound();
|
||||
},
|
||||
[onSessionNotFound, urlSessionId, isSessionNotFound, isLoading, isCreating],
|
||||
);
|
||||
|
||||
return (
|
||||
<div className={cn("flex h-full flex-col", className)}>
|
||||
{/* Main Content */}
|
||||
<main className="flex min-h-0 w-full flex-1 flex-col overflow-hidden bg-[#f8f8f9]">
|
||||
{/* Loading State */}
|
||||
{shouldShowLoader && (
|
||||
{showLoader && (isLoading || isCreating) && (
|
||||
<div className="flex flex-1 items-center justify-center">
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
<LoadingSpinner size="large" className="text-neutral-400" />
|
||||
<div className="flex flex-col items-center gap-4">
|
||||
<ChatLoader />
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
{isSwitchingSession
|
||||
? "Switching chat..."
|
||||
: "Loading your chat..."}
|
||||
Loading your chats...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Error State */}
|
||||
{error && !isLoading && !isSwitchingSession && (
|
||||
{error && !isLoading && (
|
||||
<ChatErrorState error={error} onRetry={createSession} />
|
||||
)}
|
||||
|
||||
{/* Session Content */}
|
||||
{sessionId && !isLoading && !error && !isSwitchingSession && (
|
||||
{sessionId && !isLoading && !error && (
|
||||
<ChatContainer
|
||||
sessionId={sessionId}
|
||||
initialMessages={messages}
|
||||
initialPrompt={initialPrompt}
|
||||
className="flex-1"
|
||||
onStreamingChange={onStreamingChange}
|
||||
onOperationStarted={startPollingForOperation}
|
||||
/>
|
||||
)}
|
||||
</main>
|
||||
|
||||
@@ -58,17 +58,39 @@ function notifyStreamComplete(
|
||||
}
|
||||
}
|
||||
|
||||
function cleanupExpiredStreams(
|
||||
completedStreams: Map<string, StreamResult>,
|
||||
): Map<string, StreamResult> {
|
||||
function cleanupCompletedStreams(completedStreams: Map<string, StreamResult>) {
|
||||
const now = Date.now();
|
||||
const cleaned = new Map(completedStreams);
|
||||
for (const [sessionId, result] of cleaned) {
|
||||
for (const [sessionId, result] of completedStreams) {
|
||||
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
|
||||
cleaned.delete(sessionId);
|
||||
completedStreams.delete(sessionId);
|
||||
}
|
||||
}
|
||||
return cleaned;
|
||||
}
|
||||
|
||||
function moveToCompleted(
|
||||
activeStreams: Map<string, ActiveStream>,
|
||||
completedStreams: Map<string, StreamResult>,
|
||||
streamCompleteCallbacks: Set<StreamCompleteCallback>,
|
||||
sessionId: string,
|
||||
) {
|
||||
const stream = activeStreams.get(sessionId);
|
||||
if (!stream) return;
|
||||
|
||||
const result: StreamResult = {
|
||||
sessionId,
|
||||
status: stream.status,
|
||||
chunks: stream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: stream.error,
|
||||
};
|
||||
|
||||
completedStreams.set(sessionId, result);
|
||||
activeStreams.delete(sessionId);
|
||||
cleanupCompletedStreams(completedStreams);
|
||||
|
||||
if (stream.status === "completed" || stream.status === "error") {
|
||||
notifyStreamComplete(streamCompleteCallbacks, sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
@@ -84,31 +106,17 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
context,
|
||||
onChunk,
|
||||
) {
|
||||
const state = get();
|
||||
const newActiveStreams = new Map(state.activeStreams);
|
||||
let newCompletedStreams = new Map(state.completedStreams);
|
||||
const callbacks = state.streamCompleteCallbacks;
|
||||
const { activeStreams, completedStreams, streamCompleteCallbacks } = get();
|
||||
|
||||
const existingStream = newActiveStreams.get(sessionId);
|
||||
const existingStream = activeStreams.get(sessionId);
|
||||
if (existingStream) {
|
||||
existingStream.abortController.abort();
|
||||
const normalizedStatus =
|
||||
existingStream.status === "streaming"
|
||||
? "completed"
|
||||
: existingStream.status;
|
||||
const result: StreamResult = {
|
||||
moveToCompleted(
|
||||
activeStreams,
|
||||
completedStreams,
|
||||
streamCompleteCallbacks,
|
||||
sessionId,
|
||||
status: normalizedStatus,
|
||||
chunks: existingStream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: existingStream.error,
|
||||
};
|
||||
newCompletedStreams.set(sessionId, result);
|
||||
newActiveStreams.delete(sessionId);
|
||||
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||
if (normalizedStatus === "completed" || normalizedStatus === "error") {
|
||||
notifyStreamComplete(callbacks, sessionId);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
@@ -124,76 +132,36 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
onChunkCallbacks: initialCallbacks,
|
||||
};
|
||||
|
||||
newActiveStreams.set(sessionId, stream);
|
||||
set({
|
||||
activeStreams: newActiveStreams,
|
||||
completedStreams: newCompletedStreams,
|
||||
});
|
||||
activeStreams.set(sessionId, stream);
|
||||
|
||||
try {
|
||||
await executeStream(stream, message, isUserMessage, context);
|
||||
} finally {
|
||||
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
||||
if (stream.status !== "streaming") {
|
||||
const currentState = get();
|
||||
const finalActiveStreams = new Map(currentState.activeStreams);
|
||||
let finalCompletedStreams = new Map(currentState.completedStreams);
|
||||
|
||||
const storedStream = finalActiveStreams.get(sessionId);
|
||||
if (storedStream === stream) {
|
||||
const result: StreamResult = {
|
||||
sessionId,
|
||||
status: stream.status,
|
||||
chunks: stream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: stream.error,
|
||||
};
|
||||
finalCompletedStreams.set(sessionId, result);
|
||||
finalActiveStreams.delete(sessionId);
|
||||
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
|
||||
set({
|
||||
activeStreams: finalActiveStreams,
|
||||
completedStreams: finalCompletedStreams,
|
||||
});
|
||||
if (stream.status === "completed" || stream.status === "error") {
|
||||
notifyStreamComplete(
|
||||
currentState.streamCompleteCallbacks,
|
||||
sessionId,
|
||||
);
|
||||
}
|
||||
}
|
||||
moveToCompleted(
|
||||
activeStreams,
|
||||
completedStreams,
|
||||
streamCompleteCallbacks,
|
||||
sessionId,
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
stopStream: function stopStream(sessionId) {
|
||||
const state = get();
|
||||
const stream = state.activeStreams.get(sessionId);
|
||||
if (!stream) return;
|
||||
|
||||
stream.abortController.abort();
|
||||
stream.status = "completed";
|
||||
|
||||
const newActiveStreams = new Map(state.activeStreams);
|
||||
let newCompletedStreams = new Map(state.completedStreams);
|
||||
|
||||
const result: StreamResult = {
|
||||
sessionId,
|
||||
status: stream.status,
|
||||
chunks: stream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: stream.error,
|
||||
};
|
||||
newCompletedStreams.set(sessionId, result);
|
||||
newActiveStreams.delete(sessionId);
|
||||
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||
|
||||
set({
|
||||
activeStreams: newActiveStreams,
|
||||
completedStreams: newCompletedStreams,
|
||||
});
|
||||
|
||||
notifyStreamComplete(state.streamCompleteCallbacks, sessionId);
|
||||
const { activeStreams, completedStreams, streamCompleteCallbacks } = get();
|
||||
const stream = activeStreams.get(sessionId);
|
||||
if (stream) {
|
||||
stream.abortController.abort();
|
||||
stream.status = "completed";
|
||||
moveToCompleted(
|
||||
activeStreams,
|
||||
completedStreams,
|
||||
streamCompleteCallbacks,
|
||||
sessionId,
|
||||
);
|
||||
}
|
||||
},
|
||||
|
||||
subscribeToStream: function subscribeToStream(
|
||||
@@ -201,18 +169,16 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
onChunk,
|
||||
skipReplay = false,
|
||||
) {
|
||||
const state = get();
|
||||
const stream = state.activeStreams.get(sessionId);
|
||||
const { activeStreams } = get();
|
||||
|
||||
const stream = activeStreams.get(sessionId);
|
||||
if (stream) {
|
||||
if (!skipReplay) {
|
||||
for (const chunk of stream.chunks) {
|
||||
onChunk(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
stream.onChunkCallbacks.add(onChunk);
|
||||
|
||||
return function unsubscribe() {
|
||||
stream.onChunkCallbacks.delete(onChunk);
|
||||
};
|
||||
@@ -238,12 +204,7 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
},
|
||||
|
||||
clearCompletedStream: function clearCompletedStream(sessionId) {
|
||||
const state = get();
|
||||
if (!state.completedStreams.has(sessionId)) return;
|
||||
|
||||
const newCompletedStreams = new Map(state.completedStreams);
|
||||
newCompletedStreams.delete(sessionId);
|
||||
set({ completedStreams: newCompletedStreams });
|
||||
get().completedStreams.delete(sessionId);
|
||||
},
|
||||
|
||||
isStreaming: function isStreaming(sessionId) {
|
||||
@@ -252,21 +213,11 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
},
|
||||
|
||||
registerActiveSession: function registerActiveSession(sessionId) {
|
||||
const state = get();
|
||||
if (state.activeSessions.has(sessionId)) return;
|
||||
|
||||
const newActiveSessions = new Set(state.activeSessions);
|
||||
newActiveSessions.add(sessionId);
|
||||
set({ activeSessions: newActiveSessions });
|
||||
get().activeSessions.add(sessionId);
|
||||
},
|
||||
|
||||
unregisterActiveSession: function unregisterActiveSession(sessionId) {
|
||||
const state = get();
|
||||
if (!state.activeSessions.has(sessionId)) return;
|
||||
|
||||
const newActiveSessions = new Set(state.activeSessions);
|
||||
newActiveSessions.delete(sessionId);
|
||||
set({ activeSessions: newActiveSessions });
|
||||
get().activeSessions.delete(sessionId);
|
||||
},
|
||||
|
||||
isSessionActive: function isSessionActive(sessionId) {
|
||||
@@ -274,16 +225,10 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
},
|
||||
|
||||
onStreamComplete: function onStreamComplete(callback) {
|
||||
const state = get();
|
||||
const newCallbacks = new Set(state.streamCompleteCallbacks);
|
||||
newCallbacks.add(callback);
|
||||
set({ streamCompleteCallbacks: newCallbacks });
|
||||
|
||||
const { streamCompleteCallbacks } = get();
|
||||
streamCompleteCallbacks.add(callback);
|
||||
return function unsubscribe() {
|
||||
const currentState = get();
|
||||
const cleanedCallbacks = new Set(currentState.streamCompleteCallbacks);
|
||||
cleanedCallbacks.delete(callback);
|
||||
set({ streamCompleteCallbacks: cleanedCallbacks });
|
||||
streamCompleteCallbacks.delete(callback);
|
||||
};
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -16,7 +16,6 @@ export interface ChatContainerProps {
|
||||
initialPrompt?: string;
|
||||
className?: string;
|
||||
onStreamingChange?: (isStreaming: boolean) => void;
|
||||
onOperationStarted?: () => void;
|
||||
}
|
||||
|
||||
export function ChatContainer({
|
||||
@@ -25,7 +24,6 @@ export function ChatContainer({
|
||||
initialPrompt,
|
||||
className,
|
||||
onStreamingChange,
|
||||
onOperationStarted,
|
||||
}: ChatContainerProps) {
|
||||
const {
|
||||
messages,
|
||||
@@ -40,7 +38,6 @@ export function ChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
initialPrompt,
|
||||
onOperationStarted,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
@@ -22,7 +22,6 @@ export interface HandlerDependencies {
|
||||
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
||||
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
||||
sessionId: string;
|
||||
onOperationStarted?: () => void;
|
||||
}
|
||||
|
||||
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
||||
@@ -49,15 +48,6 @@ export function handleTextEnded(
|
||||
const completedText = deps.streamingChunksRef.current.join("");
|
||||
if (completedText.trim()) {
|
||||
deps.setMessages((prev) => {
|
||||
// Check if this exact message already exists to prevent duplicates
|
||||
const exists = prev.some(
|
||||
(msg) =>
|
||||
msg.type === "message" &&
|
||||
msg.role === "assistant" &&
|
||||
msg.content === completedText,
|
||||
);
|
||||
if (exists) return prev;
|
||||
|
||||
const assistantMessage: ChatMessageData = {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
@@ -164,11 +154,6 @@ export function handleToolResponse(
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Trigger polling when operation_started is received
|
||||
if (responseMessage.type === "operation_started") {
|
||||
deps.onOperationStarted?.();
|
||||
}
|
||||
|
||||
deps.setMessages((prev) => {
|
||||
const toolCallIndex = prev.findIndex(
|
||||
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
||||
@@ -218,24 +203,13 @@ export function handleStreamEnd(
|
||||
]);
|
||||
}
|
||||
if (completedContent.trim()) {
|
||||
deps.setMessages((prev) => {
|
||||
// Check if this exact message already exists to prevent duplicates
|
||||
const exists = prev.some(
|
||||
(msg) =>
|
||||
msg.type === "message" &&
|
||||
msg.role === "assistant" &&
|
||||
msg.content === completedContent,
|
||||
);
|
||||
if (exists) return prev;
|
||||
|
||||
const assistantMessage: ChatMessageData = {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: completedContent,
|
||||
timestamp: new Date(),
|
||||
};
|
||||
return [...prev, assistantMessage];
|
||||
});
|
||||
const assistantMessage: ChatMessageData = {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: completedContent,
|
||||
timestamp: new Date(),
|
||||
};
|
||||
deps.setMessages((prev) => [...prev, assistantMessage]);
|
||||
}
|
||||
deps.setStreamingChunks([]);
|
||||
deps.streamingChunksRef.current = [];
|
||||
|
||||
@@ -304,7 +304,6 @@ export function parseToolResponse(
|
||||
if (isAgentArray(agentsData)) {
|
||||
return {
|
||||
type: "agent_carousel",
|
||||
toolId,
|
||||
toolName: "agent_carousel",
|
||||
agents: agentsData,
|
||||
totalCount: parsedResult.total_count as number | undefined,
|
||||
@@ -317,7 +316,6 @@ export function parseToolResponse(
|
||||
if (responseType === "execution_started") {
|
||||
return {
|
||||
type: "execution_started",
|
||||
toolId,
|
||||
toolName: "execution_started",
|
||||
executionId: (parsedResult.execution_id as string) || "",
|
||||
agentName: (parsedResult.graph_name as string) || undefined,
|
||||
@@ -343,41 +341,6 @@ export function parseToolResponse(
|
||||
timestamp: timestamp || new Date(),
|
||||
};
|
||||
}
|
||||
if (responseType === "operation_started") {
|
||||
return {
|
||||
type: "operation_started",
|
||||
toolName: (parsedResult.tool_name as string) || toolName,
|
||||
toolId,
|
||||
operationId: (parsedResult.operation_id as string) || "",
|
||||
message:
|
||||
(parsedResult.message as string) ||
|
||||
"Operation started. You can close this tab.",
|
||||
timestamp: timestamp || new Date(),
|
||||
};
|
||||
}
|
||||
if (responseType === "operation_pending") {
|
||||
return {
|
||||
type: "operation_pending",
|
||||
toolName: (parsedResult.tool_name as string) || toolName,
|
||||
toolId,
|
||||
operationId: (parsedResult.operation_id as string) || "",
|
||||
message:
|
||||
(parsedResult.message as string) ||
|
||||
"Operation in progress. Please wait...",
|
||||
timestamp: timestamp || new Date(),
|
||||
};
|
||||
}
|
||||
if (responseType === "operation_in_progress") {
|
||||
return {
|
||||
type: "operation_in_progress",
|
||||
toolName: (parsedResult.tool_name as string) || toolName,
|
||||
toolCallId: (parsedResult.tool_call_id as string) || toolId,
|
||||
message:
|
||||
(parsedResult.message as string) ||
|
||||
"Operation already in progress. Please wait...",
|
||||
timestamp: timestamp || new Date(),
|
||||
};
|
||||
}
|
||||
if (responseType === "need_login") {
|
||||
return {
|
||||
type: "login_needed",
|
||||
|
||||
@@ -14,40 +14,16 @@ import {
|
||||
processInitialMessages,
|
||||
} from "./helpers";
|
||||
|
||||
// Helper to generate deduplication key for a message
|
||||
function getMessageKey(msg: ChatMessageData): string {
|
||||
if (msg.type === "message") {
|
||||
// Don't include timestamp - dedupe by role + content only
|
||||
// This handles the case where local and server timestamps differ
|
||||
// Server messages are authoritative, so duplicates from local state are filtered
|
||||
return `msg:${msg.role}:${msg.content}`;
|
||||
} else if (msg.type === "tool_call") {
|
||||
return `toolcall:${msg.toolId}`;
|
||||
} else if (msg.type === "tool_response") {
|
||||
return `toolresponse:${(msg as any).toolId}`;
|
||||
} else if (
|
||||
msg.type === "operation_started" ||
|
||||
msg.type === "operation_pending" ||
|
||||
msg.type === "operation_in_progress"
|
||||
) {
|
||||
return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`;
|
||||
} else {
|
||||
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
|
||||
}
|
||||
}
|
||||
|
||||
interface Args {
|
||||
sessionId: string | null;
|
||||
initialMessages: SessionDetailResponse["messages"];
|
||||
initialPrompt?: string;
|
||||
onOperationStarted?: () => void;
|
||||
}
|
||||
|
||||
export function useChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
initialPrompt,
|
||||
onOperationStarted,
|
||||
}: Args) {
|
||||
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
||||
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
||||
@@ -97,102 +73,20 @@ export function useChatContainer({
|
||||
setIsRegionBlockedModalOpen,
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
onOperationStarted,
|
||||
});
|
||||
|
||||
setIsStreamingInitiated(true);
|
||||
const skipReplay = initialMessages.length > 0;
|
||||
return subscribeToStream(sessionId, dispatcher, skipReplay);
|
||||
},
|
||||
[
|
||||
sessionId,
|
||||
stopStreaming,
|
||||
activeStreams,
|
||||
subscribeToStream,
|
||||
onOperationStarted,
|
||||
],
|
||||
[sessionId, stopStreaming, activeStreams, subscribeToStream],
|
||||
);
|
||||
|
||||
// Collect toolIds from completed tool results in initialMessages
|
||||
// Used to filter out operation messages when their results arrive
|
||||
const completedToolIds = useMemo(() => {
|
||||
const processedInitial = processInitialMessages(initialMessages);
|
||||
const ids = new Set<string>();
|
||||
for (const msg of processedInitial) {
|
||||
if (
|
||||
msg.type === "tool_response" ||
|
||||
msg.type === "agent_carousel" ||
|
||||
msg.type === "execution_started"
|
||||
) {
|
||||
const toolId = (msg as any).toolId;
|
||||
if (toolId) {
|
||||
ids.add(toolId);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ids;
|
||||
}, [initialMessages]);
|
||||
|
||||
// Clean up local operation messages when their completed results arrive from polling
|
||||
// This effect runs when completedToolIds changes (i.e., when polling brings new results)
|
||||
useEffect(
|
||||
function cleanupCompletedOperations() {
|
||||
if (completedToolIds.size === 0) return;
|
||||
|
||||
setMessages((prev) => {
|
||||
const filtered = prev.filter((msg) => {
|
||||
if (
|
||||
msg.type === "operation_started" ||
|
||||
msg.type === "operation_pending" ||
|
||||
msg.type === "operation_in_progress"
|
||||
) {
|
||||
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
||||
if (toolId && completedToolIds.has(toolId)) {
|
||||
return false; // Remove - operation completed
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
// Only update state if something was actually filtered
|
||||
return filtered.length === prev.length ? prev : filtered;
|
||||
});
|
||||
},
|
||||
[completedToolIds],
|
||||
const allMessages = useMemo(
|
||||
() => [...processInitialMessages(initialMessages), ...messages],
|
||||
[initialMessages, messages],
|
||||
);
|
||||
|
||||
// Combine initial messages from backend with local streaming messages,
|
||||
// Server messages maintain correct order; only append truly new local messages
|
||||
const allMessages = useMemo(() => {
|
||||
const processedInitial = processInitialMessages(initialMessages);
|
||||
|
||||
// Build a set of keys from server messages for deduplication
|
||||
const serverKeys = new Set<string>();
|
||||
for (const msg of processedInitial) {
|
||||
serverKeys.add(getMessageKey(msg));
|
||||
}
|
||||
|
||||
// Filter local messages: remove duplicates and completed operation messages
|
||||
const newLocalMessages = messages.filter((msg) => {
|
||||
// Remove operation messages for completed tools
|
||||
if (
|
||||
msg.type === "operation_started" ||
|
||||
msg.type === "operation_pending" ||
|
||||
msg.type === "operation_in_progress"
|
||||
) {
|
||||
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
||||
if (toolId && completedToolIds.has(toolId)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Remove messages that already exist in server data
|
||||
const key = getMessageKey(msg);
|
||||
return !serverKeys.has(key);
|
||||
});
|
||||
|
||||
// Server messages first (correct order), then new local messages
|
||||
return [...processedInitial, ...newLocalMessages];
|
||||
}, [initialMessages, messages, completedToolIds]);
|
||||
|
||||
async function sendMessage(
|
||||
content: string,
|
||||
isUserMessage: boolean = true,
|
||||
@@ -224,7 +118,6 @@ export function useChatContainer({
|
||||
setIsRegionBlockedModalOpen,
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
onOperationStarted,
|
||||
});
|
||||
|
||||
try {
|
||||
|
||||
@@ -16,7 +16,6 @@ import { AuthPromptWidget } from "../AuthPromptWidget/AuthPromptWidget";
|
||||
import { ChatCredentialsSetup } from "../ChatCredentialsSetup/ChatCredentialsSetup";
|
||||
import { ClarificationQuestionsWidget } from "../ClarificationQuestionsWidget/ClarificationQuestionsWidget";
|
||||
import { ExecutionStartedMessage } from "../ExecutionStartedMessage/ExecutionStartedMessage";
|
||||
import { PendingOperationWidget } from "../PendingOperationWidget/PendingOperationWidget";
|
||||
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
||||
import { NoResultsMessage } from "../NoResultsMessage/NoResultsMessage";
|
||||
import { ToolCallMessage } from "../ToolCallMessage/ToolCallMessage";
|
||||
@@ -72,9 +71,6 @@ export function ChatMessage({
|
||||
isLoginNeeded,
|
||||
isCredentialsNeeded,
|
||||
isClarificationNeeded,
|
||||
isOperationStarted,
|
||||
isOperationPending,
|
||||
isOperationInProgress,
|
||||
} = useChatMessage(message);
|
||||
const displayContent = getDisplayContent(message, isUser);
|
||||
|
||||
@@ -130,6 +126,10 @@ export function ChatMessage({
|
||||
[displayContent, message],
|
||||
);
|
||||
|
||||
function isLongResponse(content: string): boolean {
|
||||
return content.split("\n").length > 5;
|
||||
}
|
||||
|
||||
const handleTryAgain = useCallback(() => {
|
||||
if (message.type !== "message" || !onSendMessage) return;
|
||||
onSendMessage(message.content, message.role === "user");
|
||||
@@ -294,42 +294,6 @@ export function ChatMessage({
|
||||
);
|
||||
}
|
||||
|
||||
// Render operation_started messages (long-running background operations)
|
||||
if (isOperationStarted && message.type === "operation_started") {
|
||||
return (
|
||||
<PendingOperationWidget
|
||||
status="started"
|
||||
message={message.message}
|
||||
toolName={message.toolName}
|
||||
className={className}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Render operation_pending messages (operations in progress when refreshing)
|
||||
if (isOperationPending && message.type === "operation_pending") {
|
||||
return (
|
||||
<PendingOperationWidget
|
||||
status="pending"
|
||||
message={message.message}
|
||||
toolName={message.toolName}
|
||||
className={className}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Render operation_in_progress messages (duplicate request while operation running)
|
||||
if (isOperationInProgress && message.type === "operation_in_progress") {
|
||||
return (
|
||||
<PendingOperationWidget
|
||||
status="in_progress"
|
||||
message={message.message}
|
||||
toolName={message.toolName}
|
||||
className={className}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Render tool response messages (but skip agent_output if it's being rendered inside assistant message)
|
||||
if (isToolResponse && message.type === "tool_response") {
|
||||
return (
|
||||
@@ -394,7 +358,7 @@ export function ChatMessage({
|
||||
<ArrowsClockwiseIcon className="size-4 text-zinc-600" />
|
||||
</Button>
|
||||
)}
|
||||
{!isUser && isFinalMessage && !isStreaming && (
|
||||
{!isUser && isFinalMessage && isLongResponse(displayContent) && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
|
||||
@@ -61,7 +61,6 @@ export type ChatMessageData =
|
||||
}
|
||||
| {
|
||||
type: "agent_carousel";
|
||||
toolId: string;
|
||||
toolName: string;
|
||||
agents: Array<{
|
||||
id: string;
|
||||
@@ -75,7 +74,6 @@ export type ChatMessageData =
|
||||
}
|
||||
| {
|
||||
type: "execution_started";
|
||||
toolId: string;
|
||||
toolName: string;
|
||||
executionId: string;
|
||||
agentName?: string;
|
||||
@@ -105,29 +103,6 @@ export type ChatMessageData =
|
||||
message: string;
|
||||
sessionId: string;
|
||||
timestamp?: string | Date;
|
||||
}
|
||||
| {
|
||||
type: "operation_started";
|
||||
toolName: string;
|
||||
toolId: string;
|
||||
operationId: string;
|
||||
message: string;
|
||||
timestamp?: string | Date;
|
||||
}
|
||||
| {
|
||||
type: "operation_pending";
|
||||
toolName: string;
|
||||
toolId: string;
|
||||
operationId: string;
|
||||
message: string;
|
||||
timestamp?: string | Date;
|
||||
}
|
||||
| {
|
||||
type: "operation_in_progress";
|
||||
toolName: string;
|
||||
toolCallId: string;
|
||||
message: string;
|
||||
timestamp?: string | Date;
|
||||
};
|
||||
|
||||
export function useChatMessage(message: ChatMessageData) {
|
||||
@@ -149,8 +124,5 @@ export function useChatMessage(message: ChatMessageData) {
|
||||
isExecutionStarted: message.type === "execution_started",
|
||||
isInputsNeeded: message.type === "inputs_needed",
|
||||
isClarificationNeeded: message.type === "clarification_needed",
|
||||
isOperationStarted: message.type === "operation_started",
|
||||
isOperationPending: message.type === "operation_pending",
|
||||
isOperationInProgress: message.type === "operation_in_progress",
|
||||
};
|
||||
}
|
||||
|
||||
@@ -30,7 +30,6 @@ export function ClarificationQuestionsWidget({
|
||||
className,
|
||||
}: Props) {
|
||||
const [answers, setAnswers] = useState<Record<string, string>>({});
|
||||
const [isSubmitted, setIsSubmitted] = useState(false);
|
||||
|
||||
function handleAnswerChange(keyword: string, value: string) {
|
||||
setAnswers((prev) => ({ ...prev, [keyword]: value }));
|
||||
@@ -42,42 +41,11 @@ export function ClarificationQuestionsWidget({
|
||||
if (!allAnswered) {
|
||||
return;
|
||||
}
|
||||
setIsSubmitted(true);
|
||||
onSubmitAnswers(answers);
|
||||
}
|
||||
|
||||
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
|
||||
|
||||
// Show submitted state after answers are submitted
|
||||
if (isSubmitted) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"group relative flex w-full justify-start gap-3 px-4 py-3",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex w-full max-w-3xl gap-3">
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-green-500">
|
||||
<CheckCircleIcon className="h-4 w-4 text-white" weight="bold" />
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex min-w-0 flex-1 flex-col">
|
||||
<Card className="p-4">
|
||||
<Text variant="h4" className="mb-1 text-slate-900">
|
||||
Answers submitted
|
||||
</Text>
|
||||
<Text variant="small" className="text-slate-600">
|
||||
Processing your responses...
|
||||
</Text>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { getGetV2DownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { EyeSlash } from "@phosphor-icons/react";
|
||||
import React from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import remarkGfm from "remark-gfm";
|
||||
@@ -31,88 +29,12 @@ interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {
|
||||
type?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a workspace:// URL to a proxy URL that routes through Next.js to the backend.
|
||||
* workspace://abc123 -> /api/proxy/api/workspace/files/abc123/download
|
||||
*
|
||||
* Uses the generated API URL helper and routes through the Next.js proxy
|
||||
* which handles authentication and proper backend routing.
|
||||
*/
|
||||
/**
|
||||
* URL transformer for ReactMarkdown.
|
||||
* Converts workspace:// URLs to proxy URLs that route through Next.js to the backend.
|
||||
* workspace://abc123 -> /api/proxy/api/workspace/files/abc123/download
|
||||
*
|
||||
* This is needed because ReactMarkdown sanitizes URLs and only allows
|
||||
* http, https, mailto, and tel protocols by default.
|
||||
*/
|
||||
function resolveWorkspaceUrl(src: string): string {
|
||||
if (src.startsWith("workspace://")) {
|
||||
const fileId = src.replace("workspace://", "");
|
||||
// Use the generated API URL helper to get the correct path
|
||||
const apiPath = getGetV2DownloadFileByIdUrl(fileId);
|
||||
// Route through the Next.js proxy (same pattern as customMutator for client-side)
|
||||
return `/api/proxy${apiPath}`;
|
||||
}
|
||||
return src;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the image URL is a workspace file (AI cannot see these yet).
|
||||
* After URL transformation, workspace files have URLs like /api/proxy/api/workspace/files/...
|
||||
*/
|
||||
function isWorkspaceImage(src: string | undefined): boolean {
|
||||
return src?.includes("/workspace/files/") ?? false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom image component that shows an indicator when the AI cannot see the image.
|
||||
* Note: src is already transformed by urlTransform, so workspace:// is now /api/workspace/...
|
||||
*/
|
||||
function MarkdownImage(props: Record<string, unknown>) {
|
||||
const src = props.src as string | undefined;
|
||||
const alt = props.alt as string | undefined;
|
||||
|
||||
const aiCannotSee = isWorkspaceImage(src);
|
||||
|
||||
// If no src, show a placeholder
|
||||
if (!src) {
|
||||
return (
|
||||
<span className="my-2 inline-block rounded border border-amber-200 bg-amber-50 px-2 py-1 text-sm text-amber-700">
|
||||
[Image: {alt || "missing src"}]
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<span className="relative my-2 inline-block">
|
||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||
<img
|
||||
src={src}
|
||||
alt={alt || "Image"}
|
||||
className="h-auto max-w-full rounded-md border border-zinc-200"
|
||||
loading="lazy"
|
||||
/>
|
||||
{aiCannotSee && (
|
||||
<span
|
||||
className="absolute bottom-2 right-2 flex items-center gap-1 rounded bg-black/70 px-2 py-1 text-xs text-white"
|
||||
title="The AI cannot see this image"
|
||||
>
|
||||
<EyeSlash size={14} />
|
||||
<span>AI cannot see this image</span>
|
||||
</span>
|
||||
)}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
||||
return (
|
||||
<div className={cn("markdown-content", className)}>
|
||||
<ReactMarkdown
|
||||
skipHtml={true}
|
||||
remarkPlugins={[remarkGfm]}
|
||||
urlTransform={resolveWorkspaceUrl}
|
||||
components={{
|
||||
code: ({ children, className, ...props }: CodeProps) => {
|
||||
const isInline = !className?.includes("language-");
|
||||
@@ -284,9 +206,6 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
||||
{children}
|
||||
</td>
|
||||
),
|
||||
img: ({ src, alt, ...props }) => (
|
||||
<MarkdownImage src={src} alt={alt} {...props} />
|
||||
),
|
||||
}}
|
||||
>
|
||||
{content}
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Card } from "@/components/atoms/Card/Card";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CircleNotch, CheckCircle, XCircle } from "@phosphor-icons/react";
|
||||
|
||||
type OperationStatus =
|
||||
| "pending"
|
||||
| "started"
|
||||
| "in_progress"
|
||||
| "completed"
|
||||
| "error";
|
||||
|
||||
interface Props {
|
||||
status: OperationStatus;
|
||||
message: string;
|
||||
toolName?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
function getOperationTitle(toolName?: string): string {
|
||||
if (!toolName) return "Operation";
|
||||
// Convert tool name to human-readable format
|
||||
// e.g., "create_agent" -> "Creating Agent", "edit_agent" -> "Editing Agent"
|
||||
if (toolName === "create_agent") return "Creating Agent";
|
||||
if (toolName === "edit_agent") return "Editing Agent";
|
||||
// Default: capitalize and format tool name
|
||||
return toolName
|
||||
.split("_")
|
||||
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join(" ");
|
||||
}
|
||||
|
||||
export function PendingOperationWidget({
|
||||
status,
|
||||
message,
|
||||
toolName,
|
||||
className,
|
||||
}: Props) {
|
||||
const isPending =
|
||||
status === "pending" || status === "started" || status === "in_progress";
|
||||
const isCompleted = status === "completed";
|
||||
const isError = status === "error";
|
||||
|
||||
const operationTitle = getOperationTitle(toolName);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"group relative flex w-full justify-start gap-3 px-4 py-3",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex w-full max-w-3xl gap-3">
|
||||
<div className="flex-shrink-0">
|
||||
<div
|
||||
className={cn(
|
||||
"flex h-7 w-7 items-center justify-center rounded-lg",
|
||||
isPending && "bg-blue-500",
|
||||
isCompleted && "bg-green-500",
|
||||
isError && "bg-red-500",
|
||||
)}
|
||||
>
|
||||
{isPending && (
|
||||
<CircleNotch
|
||||
className="h-4 w-4 animate-spin text-white"
|
||||
weight="bold"
|
||||
/>
|
||||
)}
|
||||
{isCompleted && (
|
||||
<CheckCircle className="h-4 w-4 text-white" weight="bold" />
|
||||
)}
|
||||
{isError && (
|
||||
<XCircle className="h-4 w-4 text-white" weight="bold" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex min-w-0 flex-1 flex-col">
|
||||
<Card className="space-y-2 p-4">
|
||||
<div>
|
||||
<Text variant="h4" className="mb-1 text-slate-900">
|
||||
{isPending && operationTitle}
|
||||
{isCompleted && `${operationTitle} Complete`}
|
||||
{isError && `${operationTitle} Failed`}
|
||||
</Text>
|
||||
<Text variant="small" className="text-slate-600">
|
||||
{message}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{isPending && (
|
||||
<Text variant="small" className="italic text-slate-500">
|
||||
Check your library in a few minutes.
|
||||
</Text>
|
||||
)}
|
||||
|
||||
{toolName && (
|
||||
<Text variant="small" className="text-slate-400">
|
||||
Tool: {toolName}
|
||||
</Text>
|
||||
)}
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -37,87 +37,6 @@ export function getErrorMessage(result: unknown): string {
|
||||
return "An error occurred";
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a value is a workspace file reference.
|
||||
*/
|
||||
function isWorkspaceRef(value: unknown): value is string {
|
||||
return typeof value === "string" && value.startsWith("workspace://");
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a workspace reference appears to be an image based on common patterns.
|
||||
* Since workspace refs don't have extensions, we check the context or assume image
|
||||
* for certain block types.
|
||||
*
|
||||
* TODO: Replace keyword matching with MIME type encoded in workspace ref.
|
||||
* e.g., workspace://abc123#image/png or workspace://abc123#video/mp4
|
||||
* This would let frontend render correctly without fragile keyword matching.
|
||||
*/
|
||||
function isLikelyImageRef(value: string, outputKey?: string): boolean {
|
||||
if (!isWorkspaceRef(value)) return false;
|
||||
|
||||
// Check output key name for video-related hints (these are NOT images)
|
||||
const videoKeywords = ["video", "mp4", "mov", "avi", "webm", "movie", "clip"];
|
||||
if (outputKey) {
|
||||
const lowerKey = outputKey.toLowerCase();
|
||||
if (videoKeywords.some((kw) => lowerKey.includes(kw))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check output key name for image-related hints
|
||||
const imageKeywords = [
|
||||
"image",
|
||||
"img",
|
||||
"photo",
|
||||
"picture",
|
||||
"thumbnail",
|
||||
"avatar",
|
||||
"icon",
|
||||
"screenshot",
|
||||
];
|
||||
if (outputKey) {
|
||||
const lowerKey = outputKey.toLowerCase();
|
||||
if (imageKeywords.some((kw) => lowerKey.includes(kw))) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Default to treating workspace refs as potential images
|
||||
// since that's the most common case for generated content
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format a single output value, converting workspace refs to markdown images.
|
||||
*/
|
||||
function formatOutputValue(value: unknown, outputKey?: string): string {
|
||||
if (isWorkspaceRef(value) && isLikelyImageRef(value, outputKey)) {
|
||||
// Format as markdown image
|
||||
return ``;
|
||||
}
|
||||
|
||||
if (typeof value === "string") {
|
||||
// Check for data URIs (images)
|
||||
if (value.startsWith("data:image/")) {
|
||||
return ``;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
if (Array.isArray(value)) {
|
||||
return value
|
||||
.map((item, idx) => formatOutputValue(item, `${outputKey}_${idx}`))
|
||||
.join("\n\n");
|
||||
}
|
||||
|
||||
if (typeof value === "object" && value !== null) {
|
||||
return JSON.stringify(value, null, 2);
|
||||
}
|
||||
|
||||
return String(value);
|
||||
}
|
||||
|
||||
function getToolCompletionPhrase(toolName: string): string {
|
||||
const toolCompletionPhrases: Record<string, string> = {
|
||||
add_understanding: "Updated your business information",
|
||||
@@ -208,26 +127,10 @@ export function formatToolResponse(result: unknown, toolName: string): string {
|
||||
|
||||
case "block_output":
|
||||
const blockName = (response.block_name as string) || "Block";
|
||||
const outputs = response.outputs as Record<string, unknown[]> | undefined;
|
||||
const outputs = response.outputs as Record<string, unknown> | undefined;
|
||||
if (outputs && Object.keys(outputs).length > 0) {
|
||||
const formattedOutputs: string[] = [];
|
||||
|
||||
for (const [key, values] of Object.entries(outputs)) {
|
||||
if (!Array.isArray(values) || values.length === 0) continue;
|
||||
|
||||
// Format each value in the output array
|
||||
for (const value of values) {
|
||||
const formatted = formatOutputValue(value, key);
|
||||
if (formatted) {
|
||||
formattedOutputs.push(formatted);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (formattedOutputs.length > 0) {
|
||||
return `${blockName} executed successfully.\n\n${formattedOutputs.join("\n\n")}`;
|
||||
}
|
||||
return `${blockName} executed successfully.`;
|
||||
const outputKeys = Object.keys(outputs);
|
||||
return `${blockName} executed successfully. Outputs: ${outputKeys.join(", ")}`;
|
||||
}
|
||||
return `${blockName} executed successfully.`;
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ export function UserChatBubble({ children, className }: UserChatBubbleProps) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"group relative min-w-20 overflow-hidden rounded-xl bg-purple-100 px-3 text-left text-[1rem] leading-relaxed transition-all duration-500 ease-in-out",
|
||||
"group relative min-w-20 overflow-hidden rounded-xl bg-purple-100 px-3 text-right text-[1rem] leading-relaxed transition-all duration-500 ease-in-out",
|
||||
className,
|
||||
)}
|
||||
style={{
|
||||
|
||||
@@ -26,7 +26,6 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
||||
claimSession,
|
||||
clearSession: clearSessionBase,
|
||||
loadSession,
|
||||
startPollingForOperation,
|
||||
} = useChatSession({
|
||||
urlSessionId,
|
||||
autoCreate: false,
|
||||
@@ -95,6 +94,5 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
||||
loadSession,
|
||||
sessionId: sessionIdFromHook,
|
||||
showLoader,
|
||||
startPollingForOperation,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -59,7 +59,6 @@ export function useChatSession({
|
||||
query: {
|
||||
enabled: !!sessionId,
|
||||
select: okData,
|
||||
staleTime: 0,
|
||||
retry: shouldRetrySessionLoad,
|
||||
retryDelay: getSessionRetryDelay,
|
||||
},
|
||||
@@ -103,123 +102,15 @@ export function useChatSession({
|
||||
}
|
||||
}, [createError, loadError]);
|
||||
|
||||
// Track if we should be polling (set by external callers when they receive operation_started via SSE)
|
||||
const [forcePolling, setForcePolling] = useState(false);
|
||||
// Track if we've seen server acknowledge the pending operation (to avoid clearing forcePolling prematurely)
|
||||
const hasSeenServerPendingRef = useRef(false);
|
||||
|
||||
// Check if there are any pending operations in the messages
|
||||
// Must check all operation types: operation_pending, operation_started, operation_in_progress
|
||||
const hasPendingOperationsFromServer = useMemo(() => {
|
||||
if (!messages || messages.length === 0) return false;
|
||||
const pendingTypes = new Set([
|
||||
"operation_pending",
|
||||
"operation_in_progress",
|
||||
"operation_started",
|
||||
]);
|
||||
return messages.some((msg) => {
|
||||
if (msg.role !== "tool" || !msg.content) return false;
|
||||
try {
|
||||
const content =
|
||||
typeof msg.content === "string"
|
||||
? JSON.parse(msg.content)
|
||||
: msg.content;
|
||||
return pendingTypes.has(content?.type);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
});
|
||||
}, [messages]);
|
||||
|
||||
// Track when server has acknowledged the pending operation
|
||||
useEffect(() => {
|
||||
if (hasPendingOperationsFromServer) {
|
||||
hasSeenServerPendingRef.current = true;
|
||||
}
|
||||
}, [hasPendingOperationsFromServer]);
|
||||
|
||||
// Combined: poll if server has pending ops OR if we received operation_started via SSE
|
||||
const hasPendingOperations = hasPendingOperationsFromServer || forcePolling;
|
||||
|
||||
// Clear forcePolling only after server has acknowledged AND completed the operation
|
||||
useEffect(() => {
|
||||
if (
|
||||
forcePolling &&
|
||||
!hasPendingOperationsFromServer &&
|
||||
hasSeenServerPendingRef.current
|
||||
) {
|
||||
// Server acknowledged the operation and it's now complete
|
||||
setForcePolling(false);
|
||||
hasSeenServerPendingRef.current = false;
|
||||
}
|
||||
}, [forcePolling, hasPendingOperationsFromServer]);
|
||||
|
||||
// Function to trigger polling (called when operation_started is received via SSE)
|
||||
function startPollingForOperation() {
|
||||
setForcePolling(true);
|
||||
hasSeenServerPendingRef.current = false; // Reset for new operation
|
||||
}
|
||||
|
||||
// Refresh sessions list when a pending operation completes
|
||||
// (hasPendingOperations transitions from true to false)
|
||||
const prevHasPendingOperationsRef = useRef(hasPendingOperations);
|
||||
useEffect(
|
||||
function refreshSessionsListOnOperationComplete() {
|
||||
const wasHasPending = prevHasPendingOperationsRef.current;
|
||||
prevHasPendingOperationsRef.current = hasPendingOperations;
|
||||
|
||||
// Only invalidate when transitioning from pending to not pending
|
||||
if (wasHasPending && !hasPendingOperations && sessionId) {
|
||||
function refreshSessionsListOnLoad() {
|
||||
if (sessionId && sessionData && !isLoadingSession) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
}
|
||||
},
|
||||
[hasPendingOperations, sessionId, queryClient],
|
||||
);
|
||||
|
||||
// Poll for updates when there are pending operations
|
||||
// Backoff: 2s, 4s, 6s, 8s, 10s, ... up to 30s max
|
||||
const pollAttemptRef = useRef(0);
|
||||
const hasPendingOperationsRef = useRef(hasPendingOperations);
|
||||
hasPendingOperationsRef.current = hasPendingOperations;
|
||||
|
||||
useEffect(
|
||||
function pollForPendingOperations() {
|
||||
if (!sessionId || !hasPendingOperations) {
|
||||
pollAttemptRef.current = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
let cancelled = false;
|
||||
let timeoutId: ReturnType<typeof setTimeout> | null = null;
|
||||
|
||||
function schedule() {
|
||||
// 2s, 4s, 6s, 8s, 10s, ... 30s (max)
|
||||
const delay = Math.min((pollAttemptRef.current + 1) * 2000, 30000);
|
||||
timeoutId = setTimeout(async () => {
|
||||
if (cancelled) return;
|
||||
pollAttemptRef.current += 1;
|
||||
try {
|
||||
await refetch();
|
||||
} catch (err) {
|
||||
console.error("[useChatSession] Poll failed:", err);
|
||||
} finally {
|
||||
if (!cancelled && hasPendingOperationsRef.current) {
|
||||
schedule();
|
||||
}
|
||||
}
|
||||
}, delay);
|
||||
}
|
||||
|
||||
schedule();
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
if (timeoutId) clearTimeout(timeoutId);
|
||||
};
|
||||
},
|
||||
[sessionId, hasPendingOperations, refetch],
|
||||
[sessionId, sessionData, isLoadingSession, queryClient],
|
||||
);
|
||||
|
||||
async function createSession() {
|
||||
@@ -348,13 +239,11 @@ export function useChatSession({
|
||||
isCreating,
|
||||
error,
|
||||
isSessionNotFound: isNotFoundError(loadError),
|
||||
hasPendingOperations,
|
||||
createSession,
|
||||
loadSession,
|
||||
refreshSession,
|
||||
claimSession,
|
||||
clearSession,
|
||||
startPollingForOperation,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -255,18 +255,13 @@ export function Wallet() {
|
||||
(notification: WebSocketNotification) => {
|
||||
if (
|
||||
notification.type !== "onboarding" ||
|
||||
notification.event !== "step_completed"
|
||||
notification.event !== "step_completed" ||
|
||||
!walletRef.current
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Always refresh credits when any onboarding step completes
|
||||
fetchCredits();
|
||||
|
||||
// Only trigger confetti for tasks that are in displayed groups
|
||||
if (!walletRef.current) {
|
||||
return;
|
||||
}
|
||||
// Only trigger confetti for tasks that are in groups
|
||||
const taskIds = groups
|
||||
.flatMap((group) => group.tasks)
|
||||
.map((task) => task.id);
|
||||
@@ -279,6 +274,7 @@ export function Wallet() {
|
||||
return;
|
||||
}
|
||||
|
||||
fetchCredits();
|
||||
party.confetti(walletRef.current, {
|
||||
count: 30,
|
||||
spread: 120,
|
||||
@@ -288,7 +284,7 @@ export function Wallet() {
|
||||
modules: [fadeOut],
|
||||
});
|
||||
},
|
||||
[fetchCredits, fadeOut, groups],
|
||||
[fetchCredits, fadeOut],
|
||||
);
|
||||
|
||||
// WebSocket setup for onboarding notifications
|
||||
|
||||
@@ -1003,7 +1003,6 @@ export type OnboardingStep =
|
||||
| "AGENT_INPUT"
|
||||
| "CONGRATS"
|
||||
// First Wins
|
||||
| "VISIT_COPILOT"
|
||||
| "GET_RESULTS"
|
||||
| "MARKETPLACE_VISIT"
|
||||
| "MARKETPLACE_ADD_AGENT"
|
||||
|
||||
@@ -3,7 +3,6 @@ import { environment } from "../environment";
|
||||
|
||||
export enum SessionKey {
|
||||
CHAT_SENT_INITIAL_PROMPTS = "chat_sent_initial_prompts",
|
||||
CHAT_INITIAL_PROMPTS = "chat_initial_prompts",
|
||||
}
|
||||
|
||||
function get(key: SessionKey) {
|
||||
|
||||
@@ -37,13 +37,9 @@ export class LoginPage {
|
||||
this.page.on("load", (page) => console.log(`ℹ️ Now at URL: ${page.url()}`));
|
||||
|
||||
// Start waiting for navigation before clicking
|
||||
// Wait for redirect to marketplace, onboarding, library, or copilot (new landing pages)
|
||||
const leaveLoginPage = this.page
|
||||
.waitForURL(
|
||||
(url: URL) =>
|
||||
/^\/(marketplace|onboarding(\/.*)?|library|copilot)?$/.test(
|
||||
url.pathname,
|
||||
),
|
||||
(url) => /^\/(marketplace|onboarding(\/.*)?)?$/.test(url.pathname),
|
||||
{ timeout: 10_000 },
|
||||
)
|
||||
.catch((reason) => {
|
||||
|
||||
@@ -36,16 +36,14 @@ export async function signupTestUser(
|
||||
const signupButton = getButton("Sign up");
|
||||
await signupButton.click();
|
||||
|
||||
// Wait for successful signup - could redirect to various pages depending on onboarding state
|
||||
// Wait for successful signup - could redirect to onboarding or marketplace
|
||||
|
||||
try {
|
||||
// Wait for redirect to onboarding, marketplace, copilot, or library
|
||||
// Use a single waitForURL with a callback to avoid Promise.race race conditions
|
||||
await page.waitForURL(
|
||||
(url: URL) =>
|
||||
/\/(onboarding|marketplace|copilot|library)/.test(url.pathname),
|
||||
{ timeout: 15000 },
|
||||
);
|
||||
// Wait for either onboarding or marketplace redirect
|
||||
await Promise.race([
|
||||
page.waitForURL(/\/onboarding/, { timeout: 15000 }),
|
||||
page.waitForURL(/\/marketplace/, { timeout: 15000 }),
|
||||
]);
|
||||
} catch (error) {
|
||||
console.error(
|
||||
"❌ Timeout waiting for redirect, current URL:",
|
||||
@@ -56,19 +54,14 @@ export async function signupTestUser(
|
||||
|
||||
const currentUrl = page.url();
|
||||
|
||||
// Handle onboarding redirect if needed
|
||||
// Handle onboarding or marketplace redirect
|
||||
if (currentUrl.includes("/onboarding") && ignoreOnboarding) {
|
||||
await page.goto("http://localhost:3000/marketplace");
|
||||
await page.waitForLoadState("domcontentloaded", { timeout: 10000 });
|
||||
}
|
||||
|
||||
// Verify we're on an expected final page and user is authenticated
|
||||
if (currentUrl.includes("/copilot") || currentUrl.includes("/library")) {
|
||||
// For copilot/library landing pages, just verify user is authenticated
|
||||
await page
|
||||
.getByTestId("profile-popout-menu-trigger")
|
||||
.waitFor({ state: "visible", timeout: 10000 });
|
||||
} else if (ignoreOnboarding || currentUrl.includes("/marketplace")) {
|
||||
// Verify we're on the expected final page
|
||||
if (ignoreOnboarding || currentUrl.includes("/marketplace")) {
|
||||
// Verify we're on marketplace
|
||||
await page
|
||||
.getByText(
|
||||
|
||||
1
backend/blocks/video/__init__.py
Normal file
1
backend/blocks/video/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Video editing blocks
|
||||
@@ -53,7 +53,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
||||
| [Block Installation](block-integrations/basic.md#block-installation) | Given a code string, this block allows the verification and installation of a block code into the system |
|
||||
| [Concatenate Lists](block-integrations/basic.md#concatenate-lists) | Concatenates multiple lists into a single list |
|
||||
| [Dictionary Is Empty](block-integrations/basic.md#dictionary-is-empty) | Checks if a dictionary is empty |
|
||||
| [File Store](block-integrations/basic.md#file-store) | Downloads and stores a file from a URL, data URI, or local path |
|
||||
| [File Store](block-integrations/basic.md#file-store) | Stores the input file in the temporary directory |
|
||||
| [Find In Dictionary](block-integrations/basic.md#find-in-dictionary) | A block that looks up a value in a dictionary, list, or object by key or index and returns the corresponding value |
|
||||
| [Find In List](block-integrations/basic.md#find-in-list) | Finds the index of the value in the list |
|
||||
| [Get All Memories](block-integrations/basic.md#get-all-memories) | Retrieve all memories from Mem0 with optional conversation filtering |
|
||||
|
||||
@@ -709,7 +709,7 @@ This is useful for conditional logic where you need to verify if data was return
|
||||
## File Store
|
||||
|
||||
### What it is
|
||||
Downloads and stores a file from a URL, data URI, or local path. Use this to fetch images, documents, or other files for processing. In CoPilot: saves to workspace (use list_workspace_files to see it). In graphs: outputs a data URI to pass to other blocks.
|
||||
Stores the input file in the temporary directory.
|
||||
|
||||
### How it works
|
||||
<!-- MANUAL: how_it_works -->
|
||||
@@ -722,15 +722,15 @@ The block outputs a file path that other blocks can use to access the stored fil
|
||||
|
||||
| Input | Description | Type | Required |
|
||||
|-------|-------------|------|----------|
|
||||
| file_in | The file to download and store. Can be a URL (https://...), data URI, or local path. | str (file) | Yes |
|
||||
| base_64 | Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks). | bool | No |
|
||||
| file_in | The file to store in the temporary directory, it can be a URL, data URI, or local path. | str (file) | Yes |
|
||||
| base_64 | Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks). | bool | No |
|
||||
|
||||
### Outputs
|
||||
|
||||
| Output | Description | Type |
|
||||
|--------|-------------|------|
|
||||
| error | Error message if the operation failed | str |
|
||||
| file_out | Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks. | str (file) |
|
||||
| file_out | The relative path to the stored file in the temporary directory. | str (file) |
|
||||
|
||||
### Possible use case
|
||||
<!-- MANUAL: use_case -->
|
||||
|
||||
@@ -12,7 +12,7 @@ Block to attach an audio file to a video file using moviepy.
|
||||
<!-- MANUAL: how_it_works -->
|
||||
This block combines a video file with an audio file using the moviepy library. The audio track is attached to the video, optionally with volume adjustment via the volume parameter (1.0 = original volume).
|
||||
|
||||
Input files can be URLs, data URIs, or local paths. The output format is automatically determined: `workspace://` URLs in CoPilot, data URIs in graph executions.
|
||||
Input files can be URLs, data URIs, or local paths. The output can be returned as either a file path or base64 data URI.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
### Inputs
|
||||
@@ -22,6 +22,7 @@ Input files can be URLs, data URIs, or local paths. The output format is automat
|
||||
| video_in | Video input (URL, data URI, or local path). | str (file) | Yes |
|
||||
| audio_in | Audio input (URL, data URI, or local path). | str (file) | Yes |
|
||||
| volume | Volume scale for the newly attached audio track (1.0 = original). | float | No |
|
||||
| output_return_type | Return the final output as a relative path or base64 data URI. | "file_path" \| "data_uri" | No |
|
||||
|
||||
### Outputs
|
||||
|
||||
@@ -50,7 +51,7 @@ Block to loop a video to a given duration or number of repeats.
|
||||
<!-- MANUAL: how_it_works -->
|
||||
This block extends a video by repeating it to reach a target duration or number of loops. Set duration to specify the total length in seconds, or use n_loops to repeat the video a specific number of times.
|
||||
|
||||
The looped video is seamlessly concatenated. The output format is automatically determined: `workspace://` URLs in CoPilot, data URIs in graph executions.
|
||||
The looped video is seamlessly concatenated and can be output as a file path or base64 data URI.
|
||||
<!-- END MANUAL -->
|
||||
|
||||
### Inputs
|
||||
@@ -60,6 +61,7 @@ The looped video is seamlessly concatenated. The output format is automatically
|
||||
| video_in | The input video (can be a URL, data URI, or local path). | str (file) | Yes |
|
||||
| duration | Target duration (in seconds) to loop the video to. If omitted, defaults to no looping. | float | No |
|
||||
| n_loops | Number of times to repeat the video. If omitted, defaults to 1 (no repeat). | int | No |
|
||||
| output_return_type | How to return the output video. Either a relative path or base64 data URI. | "file_path" \| "data_uri" | No |
|
||||
|
||||
### Outputs
|
||||
|
||||
|
||||
@@ -277,50 +277,6 @@ async def run(
|
||||
token = credentials.api_key.get_secret_value()
|
||||
```
|
||||
|
||||
### Handling Files
|
||||
|
||||
When your block works with files (images, videos, documents), use `store_media_file()`:
|
||||
|
||||
```python
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
):
|
||||
# PROCESSING: Need local file path for tools like ffmpeg, MoviePy, PIL
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
# EXTERNAL API: Need base64 content for APIs like Replicate, OpenAI
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
|
||||
# OUTPUT: Return to user/next block (auto-adapts to context)
|
||||
result = await store_media_file(
|
||||
file=generated_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output", # workspace:// in CoPilot, data URI in graphs
|
||||
)
|
||||
yield "image_url", result
|
||||
```
|
||||
|
||||
**Return format options:**
|
||||
- `"for_local_processing"` - Local file path for processing tools
|
||||
- `"for_external_api"` - Data URI for external APIs needing base64
|
||||
- `"for_block_output"` - **Always use for outputs** - automatically picks best format
|
||||
|
||||
## Testing Your Block
|
||||
|
||||
```bash
|
||||
|
||||
@@ -111,71 +111,6 @@ Follow these steps to create and test a new block:
|
||||
- `graph_exec_id`: The ID of the execution of the agent. This changes every time the agent has a new "run"
|
||||
- `node_exec_id`: The ID of the execution of the node. This changes every time the node is executed
|
||||
- `node_id`: The ID of the node that is being executed. It changes every version of the graph, but not every time the node is executed.
|
||||
- `execution_context`: An `ExecutionContext` object containing user_id, graph_exec_id, workspace_id, and session_id. Required for file handling.
|
||||
|
||||
### Handling Files in Blocks
|
||||
|
||||
When your block needs to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. This function handles downloading, validation, virus scanning, and storage.
|
||||
|
||||
**Import:**
|
||||
```python
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
```
|
||||
|
||||
**The `return_format` parameter determines what you get back:**
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# PROCESSING: Need to work with file locally (ffmpeg, MoviePy, PIL)
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path, ffmpeg, subprocess, etc.
|
||||
full_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||
|
||||
# EXTERNAL API: Need to send content to an API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to external API
|
||||
|
||||
# OUTPUT: Returning result from block to user/next block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123" (persistent, context-efficient)
|
||||
# In graphs: result_url = "data:image/png;base64,..." (for next block/display)
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the **only** format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never manually check for `workspace_id` - let `for_block_output` handle the logic
|
||||
- The function handles URLs, data URIs, `workspace://` references, and local paths as input
|
||||
|
||||
### Field Types
|
||||
|
||||
|
||||
Reference in New Issue
Block a user