mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-28 16:38:17 -05:00
Compare commits
13 Commits
feat/text-
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e0dfae5732 | ||
|
|
7df867d645 | ||
|
|
d855f79874 | ||
|
|
dac99694fe | ||
|
|
0953983944 | ||
|
|
0058cd3ba6 | ||
|
|
ea035224bc | ||
|
|
62813a1ea6 | ||
|
|
67405f7eb9 | ||
|
|
171ff6e776 | ||
|
|
349b1f9c79 | ||
|
|
277b0537e9 | ||
|
|
071b3bb5cd |
@@ -33,9 +33,15 @@ class ChatConfig(BaseSettings):
|
|||||||
|
|
||||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||||
max_agent_runs: int = Field(default=3, description="Maximum number of agent runs")
|
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||||
max_agent_schedules: int = Field(
|
max_agent_schedules: int = Field(
|
||||||
default=3, description="Maximum number of agent schedules"
|
default=30, description="Maximum number of agent schedules"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Long-running operation configuration
|
||||||
|
long_running_operation_ttl: int = Field(
|
||||||
|
default=600,
|
||||||
|
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Langfuse Prompt Management Configuration
|
# Langfuse Prompt Management Configuration
|
||||||
|
|||||||
@@ -247,3 +247,45 @@ async def get_chat_session_message_count(session_id: str) -> int:
|
|||||||
"""Get the number of messages in a chat session."""
|
"""Get the number of messages in a chat session."""
|
||||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
async def update_tool_message_content(
|
||||||
|
session_id: str,
|
||||||
|
tool_call_id: str,
|
||||||
|
new_content: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Update the content of a tool message in chat history.
|
||||||
|
|
||||||
|
Used by background tasks to update pending operation messages with final results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The chat session ID.
|
||||||
|
tool_call_id: The tool call ID to find the message.
|
||||||
|
new_content: The new content to set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a message was updated, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await PrismaChatMessage.prisma().update_many(
|
||||||
|
where={
|
||||||
|
"sessionId": session_id,
|
||||||
|
"toolCallId": tool_call_id,
|
||||||
|
},
|
||||||
|
data={
|
||||||
|
"content": new_content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if result == 0:
|
||||||
|
logger.warning(
|
||||||
|
f"No message found to update for session {session_id}, "
|
||||||
|
f"tool_call_id {tool_call_id}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to update tool message for session {session_id}, "
|
||||||
|
f"tool_call_id {tool_call_id}: {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|||||||
@@ -295,6 +295,21 @@ async def cache_chat_session(session: ChatSession) -> None:
|
|||||||
await _cache_session(session)
|
await _cache_session(session)
|
||||||
|
|
||||||
|
|
||||||
|
async def invalidate_session_cache(session_id: str) -> None:
|
||||||
|
"""Invalidate a chat session from Redis cache.
|
||||||
|
|
||||||
|
Used by background tasks to ensure fresh data is loaded on next access.
|
||||||
|
This is best-effort - Redis failures are logged but don't fail the operation.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis_key = _get_session_cache_key(session_id)
|
||||||
|
async_redis = await get_redis_async()
|
||||||
|
await async_redis.delete(redis_key)
|
||||||
|
except Exception as e:
|
||||||
|
# Best-effort: log but don't fail - cache will expire naturally
|
||||||
|
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||||
"""Get a chat session from the database."""
|
"""Get a chat session from the database."""
|
||||||
prisma_session = await chat_db.get_chat_session(session_id)
|
prisma_session = await chat_db.get_chat_session(session_id)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from openai import (
|
|||||||
)
|
)
|
||||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||||
|
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.data.understanding import (
|
from backend.data.understanding import (
|
||||||
format_understanding_for_prompt,
|
format_understanding_for_prompt,
|
||||||
get_business_understanding,
|
get_business_understanding,
|
||||||
@@ -24,6 +25,7 @@ from backend.data.understanding import (
|
|||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
from . import db as chat_db
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import (
|
from .model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@@ -31,6 +33,7 @@ from .model import (
|
|||||||
Usage,
|
Usage,
|
||||||
cache_chat_session,
|
cache_chat_session,
|
||||||
get_chat_session,
|
get_chat_session,
|
||||||
|
invalidate_session_cache,
|
||||||
update_session_title,
|
update_session_title,
|
||||||
upsert_chat_session,
|
upsert_chat_session,
|
||||||
)
|
)
|
||||||
@@ -48,8 +51,13 @@ from .response_model import (
|
|||||||
StreamToolOutputAvailable,
|
StreamToolOutputAvailable,
|
||||||
StreamUsage,
|
StreamUsage,
|
||||||
)
|
)
|
||||||
from .tools import execute_tool, tools
|
from .tools import execute_tool, get_tool, tools
|
||||||
from .tools.models import ErrorResponse
|
from .tools.models import (
|
||||||
|
ErrorResponse,
|
||||||
|
OperationInProgressResponse,
|
||||||
|
OperationPendingResponse,
|
||||||
|
OperationStartedResponse,
|
||||||
|
)
|
||||||
from .tracking import track_user_message
|
from .tracking import track_user_message
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -61,11 +69,126 @@ client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|||||||
|
|
||||||
langfuse = get_client()
|
langfuse = get_client()
|
||||||
|
|
||||||
|
# Redis key prefix for tracking running long-running operations
|
||||||
|
# Used for idempotency across Kubernetes pods - prevents duplicate executions on browser refresh
|
||||||
|
RUNNING_OPERATION_PREFIX = "chat:running_operation:"
|
||||||
|
|
||||||
class LangfuseNotConfiguredError(Exception):
|
# Default system prompt used when Langfuse is not configured
|
||||||
"""Raised when Langfuse is required but not configured."""
|
# This is a snapshot of the "CoPilot Prompt" from Langfuse (version 11)
|
||||||
|
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
|
||||||
|
|
||||||
pass
|
Here is everything you know about the current user from previous interactions:
|
||||||
|
|
||||||
|
<users_information>
|
||||||
|
{users_information}
|
||||||
|
</users_information>
|
||||||
|
|
||||||
|
## YOUR CORE MANDATE
|
||||||
|
|
||||||
|
You are action-oriented. Your success is measured by:
|
||||||
|
- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"?
|
||||||
|
- **Demonstrable Proof**: Show working automations, not descriptions of what's possible
|
||||||
|
- **Time Saved**: Focus on tangible efficiency gains
|
||||||
|
- **Quality Output**: Deliver results that meet or exceed expectations
|
||||||
|
|
||||||
|
## YOUR WORKFLOW
|
||||||
|
|
||||||
|
Adapt flexibly to the conversation context. Not every interaction requires all stages:
|
||||||
|
|
||||||
|
1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations.
|
||||||
|
|
||||||
|
2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task.
|
||||||
|
|
||||||
|
3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.).
|
||||||
|
|
||||||
|
4. **Discover or Create Agents**:
|
||||||
|
- **Always check the user's library first** with `find_library_agent` (these may be customized to their needs)
|
||||||
|
- Search the marketplace with `find_agent` for pre-built automations
|
||||||
|
- Find reusable components with `find_block`
|
||||||
|
- Create custom solutions with `create_agent` if nothing suitable exists
|
||||||
|
- Modify existing library agents with `edit_agent`
|
||||||
|
|
||||||
|
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
|
||||||
|
|
||||||
|
6. **Show Results**: Display outputs using `agent_output`.
|
||||||
|
|
||||||
|
## AVAILABLE TOOLS
|
||||||
|
|
||||||
|
**Understanding & Discovery:**
|
||||||
|
- `add_understanding`: Create a memory about the user's business or use cases for future sessions
|
||||||
|
- `search_docs`: Search platform documentation for specific technical information
|
||||||
|
- `get_doc_page`: Retrieve full text of a specific documentation page
|
||||||
|
|
||||||
|
**Agent Discovery:**
|
||||||
|
- `find_library_agent`: Search the user's existing agents (CHECK HERE FIRST—these may be customized)
|
||||||
|
- `find_agent`: Search the marketplace for pre-built automations
|
||||||
|
- `find_block`: Find pre-written code units that perform specific tasks (agents are built from blocks)
|
||||||
|
|
||||||
|
**Agent Creation & Editing:**
|
||||||
|
- `create_agent`: Create a new automation agent
|
||||||
|
- `edit_agent`: Modify an agent in the user's library
|
||||||
|
|
||||||
|
**Execution & Output:**
|
||||||
|
- `run_agent`: Run an agent now, schedule it, or set up a webhook trigger
|
||||||
|
- `run_block`: Test or run a specific block independently
|
||||||
|
- `agent_output`: View results from previous agent runs
|
||||||
|
|
||||||
|
## BEHAVIORAL GUIDELINES
|
||||||
|
|
||||||
|
**Be Concise:**
|
||||||
|
- Target 2-5 short lines maximum
|
||||||
|
- Make every word count—no repetition or filler
|
||||||
|
- Use lightweight structure for scannability (bullets, numbered lists, short prompts)
|
||||||
|
- Avoid jargon (blocks, slugs, cron) unless the user asks
|
||||||
|
|
||||||
|
**Be Proactive:**
|
||||||
|
- Suggest next steps before being asked
|
||||||
|
- Anticipate needs based on conversation context and user information
|
||||||
|
- Look for opportunities to expand scope when relevant
|
||||||
|
- Reveal capabilities through action, not explanation
|
||||||
|
|
||||||
|
**Use Tools Effectively:**
|
||||||
|
- Select the right tool for each task
|
||||||
|
- **Always check `find_library_agent` before searching the marketplace**
|
||||||
|
- Use `add_understanding` to capture valuable business context
|
||||||
|
- When tool calls fail, try alternative approaches
|
||||||
|
|
||||||
|
## CRITICAL REMINDER
|
||||||
|
|
||||||
|
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
|
||||||
|
|
||||||
|
# Module-level set to hold strong references to background tasks.
|
||||||
|
# This prevents asyncio from garbage collecting tasks before they complete.
|
||||||
|
# Tasks are automatically removed on completion via done_callback.
|
||||||
|
_background_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
|
||||||
|
async def _mark_operation_started(tool_call_id: str) -> bool:
|
||||||
|
"""Mark a long-running operation as started (Redis-based).
|
||||||
|
|
||||||
|
Returns True if successfully marked (operation was not already running),
|
||||||
|
False if operation was already running (lost race condition).
|
||||||
|
Raises exception if Redis is unavailable (fail-closed).
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
key = f"{RUNNING_OPERATION_PREFIX}{tool_call_id}"
|
||||||
|
# SETNX with TTL - atomic "set if not exists"
|
||||||
|
result = await redis.set(key, "1", ex=config.long_running_operation_ttl, nx=True)
|
||||||
|
return result is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def _mark_operation_completed(tool_call_id: str) -> None:
|
||||||
|
"""Mark a long-running operation as completed (remove Redis key).
|
||||||
|
|
||||||
|
This is best-effort - if Redis fails, the TTL will eventually clean up.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
key = f"{RUNNING_OPERATION_PREFIX}{tool_call_id}"
|
||||||
|
await redis.delete(key)
|
||||||
|
except Exception as e:
|
||||||
|
# Non-critical: TTL will clean up eventually
|
||||||
|
logger.warning(f"Failed to delete running operation key {tool_call_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _is_langfuse_configured() -> bool:
|
def _is_langfuse_configured() -> bool:
|
||||||
@@ -75,6 +198,30 @@ def _is_langfuse_configured() -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_system_prompt_template(context: str) -> str:
|
||||||
|
"""Get the system prompt, trying Langfuse first with fallback to default.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: The user context/information to compile into the prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The compiled system prompt string.
|
||||||
|
"""
|
||||||
|
if _is_langfuse_configured():
|
||||||
|
try:
|
||||||
|
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||||
|
# Use asyncio.to_thread to avoid blocking the event loop
|
||||||
|
prompt = await asyncio.to_thread(
|
||||||
|
langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0
|
||||||
|
)
|
||||||
|
return prompt.compile(users_information=context)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch prompt from Langfuse, using default: {e}")
|
||||||
|
|
||||||
|
# Fallback to default prompt
|
||||||
|
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
||||||
|
|
||||||
|
|
||||||
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||||
"""Build the full system prompt including business understanding if available.
|
"""Build the full system prompt including business understanding if available.
|
||||||
|
|
||||||
@@ -83,12 +230,8 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
|||||||
If "default" and this is the user's first session, will use "onboarding" instead.
|
If "default" and this is the user's first session, will use "onboarding" instead.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (compiled prompt string, Langfuse prompt object for tracing)
|
Tuple of (compiled prompt string, business understanding object)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
|
||||||
prompt = langfuse.get_prompt(config.langfuse_prompt_name, cache_ttl_seconds=0)
|
|
||||||
|
|
||||||
# If user is authenticated, try to fetch their business understanding
|
# If user is authenticated, try to fetch their business understanding
|
||||||
understanding = None
|
understanding = None
|
||||||
if user_id:
|
if user_id:
|
||||||
@@ -97,12 +240,13 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to fetch business understanding: {e}")
|
logger.warning(f"Failed to fetch business understanding: {e}")
|
||||||
understanding = None
|
understanding = None
|
||||||
|
|
||||||
if understanding:
|
if understanding:
|
||||||
context = format_understanding_for_prompt(understanding)
|
context = format_understanding_for_prompt(understanding)
|
||||||
else:
|
else:
|
||||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||||
|
|
||||||
compiled = prompt.compile(users_information=context)
|
compiled = await _get_system_prompt_template(context)
|
||||||
return compiled, understanding
|
return compiled, understanding
|
||||||
|
|
||||||
|
|
||||||
@@ -210,16 +354,6 @@ async def stream_chat_completion(
|
|||||||
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if Langfuse is configured - required for chat functionality
|
|
||||||
if not _is_langfuse_configured():
|
|
||||||
logger.error("Chat request failed: Langfuse is not configured")
|
|
||||||
yield StreamError(
|
|
||||||
errorText="Chat service is not available. Langfuse must be configured "
|
|
||||||
"with LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables."
|
|
||||||
)
|
|
||||||
yield StreamFinish()
|
|
||||||
return
|
|
||||||
|
|
||||||
# Only fetch from Redis if session not provided (initial call)
|
# Only fetch from Redis if session not provided (initial call)
|
||||||
if session is None:
|
if session is None:
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
@@ -315,6 +449,7 @@ async def stream_chat_completion(
|
|||||||
has_yielded_end = False
|
has_yielded_end = False
|
||||||
has_yielded_error = False
|
has_yielded_error = False
|
||||||
has_done_tool_call = False
|
has_done_tool_call = False
|
||||||
|
has_long_running_tool_call = False # Track if we had a long-running tool call
|
||||||
has_received_text = False
|
has_received_text = False
|
||||||
text_streaming_ended = False
|
text_streaming_ended = False
|
||||||
tool_response_messages: list[ChatMessage] = []
|
tool_response_messages: list[ChatMessage] = []
|
||||||
@@ -336,7 +471,6 @@ async def stream_chat_completion(
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
text_block_id=text_block_id,
|
text_block_id=text_block_id,
|
||||||
):
|
):
|
||||||
|
|
||||||
if isinstance(chunk, StreamTextStart):
|
if isinstance(chunk, StreamTextStart):
|
||||||
# Emit text-start before first text delta
|
# Emit text-start before first text delta
|
||||||
if not has_received_text:
|
if not has_received_text:
|
||||||
@@ -394,6 +528,27 @@ async def stream_chat_completion(
|
|||||||
if isinstance(chunk.output, str)
|
if isinstance(chunk.output, str)
|
||||||
else orjson.dumps(chunk.output).decode("utf-8")
|
else orjson.dumps(chunk.output).decode("utf-8")
|
||||||
)
|
)
|
||||||
|
# Skip saving long-running operation responses - messages already saved in _yield_tool_call
|
||||||
|
# Use JSON parsing instead of substring matching to avoid false positives
|
||||||
|
is_long_running_response = False
|
||||||
|
try:
|
||||||
|
parsed = orjson.loads(result_content)
|
||||||
|
if isinstance(parsed, dict) and parsed.get("type") in (
|
||||||
|
"operation_started",
|
||||||
|
"operation_in_progress",
|
||||||
|
):
|
||||||
|
is_long_running_response = True
|
||||||
|
except (orjson.JSONDecodeError, TypeError):
|
||||||
|
pass # Not JSON or not a dict - treat as regular response
|
||||||
|
if is_long_running_response:
|
||||||
|
# Remove from accumulated_tool_calls since assistant message was already saved
|
||||||
|
accumulated_tool_calls[:] = [
|
||||||
|
tc
|
||||||
|
for tc in accumulated_tool_calls
|
||||||
|
if tc["id"] != chunk.toolCallId
|
||||||
|
]
|
||||||
|
has_long_running_tool_call = True
|
||||||
|
else:
|
||||||
tool_response_messages.append(
|
tool_response_messages.append(
|
||||||
ChatMessage(
|
ChatMessage(
|
||||||
role="tool",
|
role="tool",
|
||||||
@@ -576,7 +731,14 @@ async def stream_chat_completion(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Extended session messages, new message_count={len(session.messages)}"
|
f"Extended session messages, new message_count={len(session.messages)}"
|
||||||
)
|
)
|
||||||
if messages_to_save or has_appended_streaming_message:
|
# Save if there are regular (non-long-running) tool responses or streaming message.
|
||||||
|
# Long-running tools save their own state, but we still need to save regular tools
|
||||||
|
# that may be in the same response.
|
||||||
|
has_regular_tool_responses = len(tool_response_messages) > 0
|
||||||
|
if has_regular_tool_responses or (
|
||||||
|
not has_long_running_tool_call
|
||||||
|
and (messages_to_save or has_appended_streaming_message)
|
||||||
|
):
|
||||||
await upsert_chat_session(session)
|
await upsert_chat_session(session)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -585,7 +747,9 @@ async def stream_chat_completion(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# If we did a tool call, stream the chat completion again to get the next response
|
# If we did a tool call, stream the chat completion again to get the next response
|
||||||
if has_done_tool_call:
|
# Skip only if ALL tools were long-running (they handle their own completion)
|
||||||
|
has_regular_tools = len(tool_response_messages) > 0
|
||||||
|
if has_done_tool_call and (has_regular_tools or not has_long_running_tool_call):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Tool call executed, streaming chat completion again to get assistant response"
|
"Tool call executed, streaming chat completion again to get assistant response"
|
||||||
)
|
)
|
||||||
@@ -725,6 +889,114 @@ async def _summarize_messages(
|
|||||||
return summary or "No summary available."
|
return summary or "No summary available."
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_tool_pairs_intact(
|
||||||
|
recent_messages: list[dict],
|
||||||
|
all_messages: list[dict],
|
||||||
|
start_index: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Ensure tool_call/tool_response pairs stay together after slicing.
|
||||||
|
|
||||||
|
When slicing messages for context compaction, a naive slice can separate
|
||||||
|
an assistant message containing tool_calls from its corresponding tool
|
||||||
|
response messages. This causes API validation errors (e.g., Anthropic's
|
||||||
|
"unexpected tool_use_id found in tool_result blocks").
|
||||||
|
|
||||||
|
This function checks for orphan tool responses in the slice and extends
|
||||||
|
backwards to include their corresponding assistant messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recent_messages: The sliced messages to validate
|
||||||
|
all_messages: The complete message list (for looking up missing assistants)
|
||||||
|
start_index: The index in all_messages where recent_messages begins
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A potentially extended list of messages with tool pairs intact
|
||||||
|
"""
|
||||||
|
if not recent_messages:
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
# Collect all tool_call_ids from assistant messages in the slice
|
||||||
|
available_tool_call_ids: set[str] = set()
|
||||||
|
for msg in recent_messages:
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||||
|
for tc in msg["tool_calls"]:
|
||||||
|
tc_id = tc.get("id")
|
||||||
|
if tc_id:
|
||||||
|
available_tool_call_ids.add(tc_id)
|
||||||
|
|
||||||
|
# Find orphan tool responses (tool messages whose tool_call_id is missing)
|
||||||
|
orphan_tool_call_ids: set[str] = set()
|
||||||
|
for msg in recent_messages:
|
||||||
|
if msg.get("role") == "tool":
|
||||||
|
tc_id = msg.get("tool_call_id")
|
||||||
|
if tc_id and tc_id not in available_tool_call_ids:
|
||||||
|
orphan_tool_call_ids.add(tc_id)
|
||||||
|
|
||||||
|
if not orphan_tool_call_ids:
|
||||||
|
# No orphans, slice is valid
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
# Find the assistant messages that contain the orphan tool_call_ids
|
||||||
|
# Search backwards from start_index in all_messages
|
||||||
|
messages_to_prepend: list[dict] = []
|
||||||
|
for i in range(start_index - 1, -1, -1):
|
||||||
|
msg = all_messages[i]
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||||
|
msg_tool_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")}
|
||||||
|
if msg_tool_ids & orphan_tool_call_ids:
|
||||||
|
# This assistant message has tool_calls we need
|
||||||
|
# Also collect its contiguous tool responses that follow it
|
||||||
|
assistant_and_responses: list[dict] = [msg]
|
||||||
|
|
||||||
|
# Scan forward from this assistant to collect tool responses
|
||||||
|
for j in range(i + 1, start_index):
|
||||||
|
following_msg = all_messages[j]
|
||||||
|
if following_msg.get("role") == "tool":
|
||||||
|
tool_id = following_msg.get("tool_call_id")
|
||||||
|
if tool_id and tool_id in msg_tool_ids:
|
||||||
|
assistant_and_responses.append(following_msg)
|
||||||
|
else:
|
||||||
|
# Stop at first non-tool message
|
||||||
|
break
|
||||||
|
|
||||||
|
# Prepend the assistant and its tool responses (maintain order)
|
||||||
|
messages_to_prepend = assistant_and_responses + messages_to_prepend
|
||||||
|
# Mark these as found
|
||||||
|
orphan_tool_call_ids -= msg_tool_ids
|
||||||
|
# Also add this assistant's tool_call_ids to available set
|
||||||
|
available_tool_call_ids |= msg_tool_ids
|
||||||
|
|
||||||
|
if not orphan_tool_call_ids:
|
||||||
|
# Found all missing assistants
|
||||||
|
break
|
||||||
|
|
||||||
|
if orphan_tool_call_ids:
|
||||||
|
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
||||||
|
# This shouldn't happen in normal operation but handles edge cases
|
||||||
|
logger.warning(
|
||||||
|
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
||||||
|
"Removing orphan tool responses."
|
||||||
|
)
|
||||||
|
recent_messages = [
|
||||||
|
msg
|
||||||
|
for msg in recent_messages
|
||||||
|
if not (
|
||||||
|
msg.get("role") == "tool"
|
||||||
|
and msg.get("tool_call_id") in orphan_tool_call_ids
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if messages_to_prepend:
|
||||||
|
logger.info(
|
||||||
|
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
||||||
|
f"tool_call/tool_response pairs"
|
||||||
|
)
|
||||||
|
return messages_to_prepend + recent_messages
|
||||||
|
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
|
||||||
async def _stream_chat_chunks(
|
async def _stream_chat_chunks(
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
tools: list[ChatCompletionToolParam],
|
tools: list[ChatCompletionToolParam],
|
||||||
@@ -816,7 +1088,15 @@ async def _stream_chat_chunks(
|
|||||||
# Always attempt mitigation when over limit, even with few messages
|
# Always attempt mitigation when over limit, even with few messages
|
||||||
if messages:
|
if messages:
|
||||||
# Split messages based on whether system prompt exists
|
# Split messages based on whether system prompt exists
|
||||||
recent_messages = messages[-KEEP_RECENT:]
|
# Calculate start index for the slice
|
||||||
|
slice_start = max(0, len(messages_dict) - KEEP_RECENT)
|
||||||
|
recent_messages = messages_dict[-KEEP_RECENT:]
|
||||||
|
|
||||||
|
# Ensure tool_call/tool_response pairs stay together
|
||||||
|
# This prevents API errors from orphan tool responses
|
||||||
|
recent_messages = _ensure_tool_pairs_intact(
|
||||||
|
recent_messages, messages_dict, slice_start
|
||||||
|
)
|
||||||
|
|
||||||
if has_system_prompt:
|
if has_system_prompt:
|
||||||
# Keep system prompt separate, summarize everything between system and recent
|
# Keep system prompt separate, summarize everything between system and recent
|
||||||
@@ -903,6 +1183,13 @@ async def _stream_chat_chunks(
|
|||||||
if len(recent_messages) >= keep_count
|
if len(recent_messages) >= keep_count
|
||||||
else recent_messages
|
else recent_messages
|
||||||
)
|
)
|
||||||
|
# Ensure tool pairs stay intact in the reduced slice
|
||||||
|
reduced_slice_start = max(
|
||||||
|
0, len(recent_messages) - keep_count
|
||||||
|
)
|
||||||
|
reduced_recent = _ensure_tool_pairs_intact(
|
||||||
|
reduced_recent, recent_messages, reduced_slice_start
|
||||||
|
)
|
||||||
if has_system_prompt:
|
if has_system_prompt:
|
||||||
messages = [
|
messages = [
|
||||||
system_msg,
|
system_msg,
|
||||||
@@ -961,7 +1248,10 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
# Create a base list excluding system prompt to avoid duplication
|
# Create a base list excluding system prompt to avoid duplication
|
||||||
# This is the pool of messages we'll slice from in the loop
|
# This is the pool of messages we'll slice from in the loop
|
||||||
base_msgs = messages[1:] if has_system_prompt else messages
|
# Use messages_dict for type consistency with _ensure_tool_pairs_intact
|
||||||
|
base_msgs = (
|
||||||
|
messages_dict[1:] if has_system_prompt else messages_dict
|
||||||
|
)
|
||||||
|
|
||||||
# Try progressively smaller keep counts
|
# Try progressively smaller keep counts
|
||||||
new_token_count = token_count # Initialize with current count
|
new_token_count = token_count # Initialize with current count
|
||||||
@@ -984,6 +1274,12 @@ async def _stream_chat_chunks(
|
|||||||
# Slice from base_msgs to get recent messages (without system prompt)
|
# Slice from base_msgs to get recent messages (without system prompt)
|
||||||
recent_messages = base_msgs[-keep_count:]
|
recent_messages = base_msgs[-keep_count:]
|
||||||
|
|
||||||
|
# Ensure tool pairs stay intact in the reduced slice
|
||||||
|
reduced_slice_start = max(0, len(base_msgs) - keep_count)
|
||||||
|
recent_messages = _ensure_tool_pairs_intact(
|
||||||
|
recent_messages, base_msgs, reduced_slice_start
|
||||||
|
)
|
||||||
|
|
||||||
if has_system_prompt:
|
if has_system_prompt:
|
||||||
messages = [system_msg] + recent_messages
|
messages = [system_msg] + recent_messages
|
||||||
else:
|
else:
|
||||||
@@ -1260,17 +1556,19 @@ async def _yield_tool_call(
|
|||||||
"""
|
"""
|
||||||
Yield a tool call and its execution result.
|
Yield a tool call and its execution result.
|
||||||
|
|
||||||
For long-running tools, yields heartbeat events every 15 seconds to keep
|
For tools marked with `is_long_running=True` (like agent generation), spawns a
|
||||||
the SSE connection alive through proxies and load balancers.
|
background task so the operation survives SSE disconnections. For other tools,
|
||||||
|
yields heartbeat events every 15 seconds to keep the SSE connection alive.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
orjson.JSONDecodeError: If tool call arguments cannot be parsed as JSON
|
orjson.JSONDecodeError: If tool call arguments cannot be parsed as JSON
|
||||||
KeyError: If expected tool call fields are missing
|
KeyError: If expected tool call fields are missing
|
||||||
TypeError: If tool call structure is invalid
|
TypeError: If tool call structure is invalid
|
||||||
"""
|
"""
|
||||||
|
import uuid as uuid_module
|
||||||
|
|
||||||
tool_name = tool_calls[yield_idx]["function"]["name"]
|
tool_name = tool_calls[yield_idx]["function"]["name"]
|
||||||
tool_call_id = tool_calls[yield_idx]["id"]
|
tool_call_id = tool_calls[yield_idx]["id"]
|
||||||
logger.info(f"Yielding tool call: {tool_calls[yield_idx]}")
|
|
||||||
|
|
||||||
# Parse tool call arguments - handle empty arguments gracefully
|
# Parse tool call arguments - handle empty arguments gracefully
|
||||||
raw_arguments = tool_calls[yield_idx]["function"]["arguments"]
|
raw_arguments = tool_calls[yield_idx]["function"]["arguments"]
|
||||||
@@ -1285,7 +1583,151 @@ async def _yield_tool_call(
|
|||||||
input=arguments,
|
input=arguments,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run tool execution in background task with heartbeats to keep connection alive
|
# Check if this tool is long-running (survives SSE disconnection)
|
||||||
|
tool = get_tool(tool_name)
|
||||||
|
if tool and tool.is_long_running:
|
||||||
|
# Atomic check-and-set: returns False if operation already running (lost race)
|
||||||
|
if not await _mark_operation_started(tool_call_id):
|
||||||
|
logger.info(
|
||||||
|
f"Tool call {tool_call_id} already in progress, returning status"
|
||||||
|
)
|
||||||
|
# Build dynamic message based on tool name
|
||||||
|
if tool_name == "create_agent":
|
||||||
|
in_progress_msg = "Agent creation already in progress. Please wait..."
|
||||||
|
elif tool_name == "edit_agent":
|
||||||
|
in_progress_msg = "Agent edit already in progress. Please wait..."
|
||||||
|
else:
|
||||||
|
in_progress_msg = f"{tool_name} already in progress. Please wait..."
|
||||||
|
|
||||||
|
yield StreamToolOutputAvailable(
|
||||||
|
toolCallId=tool_call_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
output=OperationInProgressResponse(
|
||||||
|
message=in_progress_msg,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
).model_dump_json(),
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Generate operation ID
|
||||||
|
operation_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
|
# Build a user-friendly message based on tool and arguments
|
||||||
|
if tool_name == "create_agent":
|
||||||
|
agent_desc = arguments.get("description", "")
|
||||||
|
# Truncate long descriptions for the message
|
||||||
|
desc_preview = (
|
||||||
|
(agent_desc[:100] + "...") if len(agent_desc) > 100 else agent_desc
|
||||||
|
)
|
||||||
|
pending_msg = (
|
||||||
|
f"Creating your agent: {desc_preview}"
|
||||||
|
if desc_preview
|
||||||
|
else "Creating agent... This may take a few minutes."
|
||||||
|
)
|
||||||
|
started_msg = (
|
||||||
|
"Agent creation started. You can close this tab - "
|
||||||
|
"check your library in a few minutes."
|
||||||
|
)
|
||||||
|
elif tool_name == "edit_agent":
|
||||||
|
changes = arguments.get("changes", "")
|
||||||
|
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
|
||||||
|
pending_msg = (
|
||||||
|
f"Editing agent: {changes_preview}"
|
||||||
|
if changes_preview
|
||||||
|
else "Editing agent... This may take a few minutes."
|
||||||
|
)
|
||||||
|
started_msg = (
|
||||||
|
"Agent edit started. You can close this tab - "
|
||||||
|
"check your library in a few minutes."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pending_msg = f"Running {tool_name}... This may take a few minutes."
|
||||||
|
started_msg = (
|
||||||
|
f"{tool_name} started. You can close this tab - "
|
||||||
|
"check back in a few minutes."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track appended messages for rollback on failure
|
||||||
|
assistant_message: ChatMessage | None = None
|
||||||
|
pending_message: ChatMessage | None = None
|
||||||
|
|
||||||
|
# Wrap session save and task creation in try-except to release lock on failure
|
||||||
|
try:
|
||||||
|
# Save assistant message with tool_call FIRST (required by LLM)
|
||||||
|
assistant_message = ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[tool_calls[yield_idx]],
|
||||||
|
)
|
||||||
|
session.messages.append(assistant_message)
|
||||||
|
|
||||||
|
# Then save pending tool result
|
||||||
|
pending_message = ChatMessage(
|
||||||
|
role="tool",
|
||||||
|
content=OperationPendingResponse(
|
||||||
|
message=pending_msg,
|
||||||
|
operation_id=operation_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
).model_dump_json(),
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
session.messages.append(pending_message)
|
||||||
|
await upsert_chat_session(session)
|
||||||
|
logger.info(
|
||||||
|
f"Saved pending operation {operation_id} for tool {tool_name} "
|
||||||
|
f"in session {session.session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store task reference in module-level set to prevent GC before completion
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_execute_long_running_tool(
|
||||||
|
tool_name=tool_name,
|
||||||
|
parameters=arguments,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
operation_id=operation_id,
|
||||||
|
session_id=session.session_id,
|
||||||
|
user_id=session.user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
except Exception as e:
|
||||||
|
# Roll back appended messages to prevent data corruption on subsequent saves
|
||||||
|
if (
|
||||||
|
pending_message
|
||||||
|
and session.messages
|
||||||
|
and session.messages[-1] == pending_message
|
||||||
|
):
|
||||||
|
session.messages.pop()
|
||||||
|
if (
|
||||||
|
assistant_message
|
||||||
|
and session.messages
|
||||||
|
and session.messages[-1] == assistant_message
|
||||||
|
):
|
||||||
|
session.messages.pop()
|
||||||
|
|
||||||
|
# Release the Redis lock since the background task won't be spawned
|
||||||
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
logger.error(
|
||||||
|
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Return immediately - don't wait for completion
|
||||||
|
yield StreamToolOutputAvailable(
|
||||||
|
toolCallId=tool_call_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
output=OperationStartedResponse(
|
||||||
|
message=started_msg,
|
||||||
|
operation_id=operation_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
).model_dump_json(),
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Normal flow: Run tool execution in background task with heartbeats
|
||||||
tool_task = asyncio.create_task(
|
tool_task = asyncio.create_task(
|
||||||
execute_tool(
|
execute_tool(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
@@ -1335,3 +1777,190 @@ async def _yield_tool_call(
|
|||||||
)
|
)
|
||||||
|
|
||||||
yield tool_execution_response
|
yield tool_execution_response
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_long_running_tool(
|
||||||
|
tool_name: str,
|
||||||
|
parameters: dict[str, Any],
|
||||||
|
tool_call_id: str,
|
||||||
|
operation_id: str,
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a long-running tool in background and update chat history with result.
|
||||||
|
|
||||||
|
This function runs independently of the SSE connection, so the operation
|
||||||
|
survives if the user closes their browser tab.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load fresh session (not stale reference)
|
||||||
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
if not session:
|
||||||
|
logger.error(f"Session {session_id} not found for background tool")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Execute the actual tool
|
||||||
|
result = await execute_tool(
|
||||||
|
tool_name=tool_name,
|
||||||
|
parameters=parameters,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the pending message with result
|
||||||
|
await _update_pending_operation(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
result=(
|
||||||
|
result.output
|
||||||
|
if isinstance(result.output, str)
|
||||||
|
else orjson.dumps(result.output).decode("utf-8")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Background tool {tool_name} completed for session {session_id}")
|
||||||
|
|
||||||
|
# Generate LLM continuation so user sees response when they poll/refresh
|
||||||
|
await _generate_llm_continuation(session_id=session_id, user_id=user_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message=f"Tool {tool_name} failed: {str(e)}",
|
||||||
|
)
|
||||||
|
await _update_pending_operation(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
result=error_response.model_dump_json(),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_pending_operation(
|
||||||
|
session_id: str,
|
||||||
|
tool_call_id: str,
|
||||||
|
result: str,
|
||||||
|
) -> None:
|
||||||
|
"""Update the pending tool message with final result.
|
||||||
|
|
||||||
|
This is called by background tasks when long-running operations complete.
|
||||||
|
"""
|
||||||
|
# Update the message in database
|
||||||
|
updated = await chat_db.update_tool_message_content(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
new_content=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
if updated:
|
||||||
|
# Invalidate Redis cache so next load gets fresh data
|
||||||
|
# Wrap in try/except to prevent cache failures from triggering error handling
|
||||||
|
# that would overwrite our successful DB update
|
||||||
|
try:
|
||||||
|
await invalidate_session_cache(session_id)
|
||||||
|
except Exception as e:
|
||||||
|
# Non-critical: cache will eventually be refreshed on next load
|
||||||
|
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||||
|
logger.info(
|
||||||
|
f"Updated pending operation for tool_call_id {tool_call_id} "
|
||||||
|
f"in session {session_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to update pending operation for tool_call_id {tool_call_id} "
|
||||||
|
f"in session {session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_llm_continuation(
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Generate an LLM response after a long-running tool completes.
|
||||||
|
|
||||||
|
This is called by background tasks to continue the conversation
|
||||||
|
after a tool result is saved. The response is saved to the database
|
||||||
|
so users see it when they refresh or poll.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load fresh session from DB (bypass cache to get the updated tool result)
|
||||||
|
await invalidate_session_cache(session_id)
|
||||||
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
if not session:
|
||||||
|
logger.error(f"Session {session_id} not found for LLM continuation")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build system prompt
|
||||||
|
system_prompt, _ = await _build_system_prompt(user_id)
|
||||||
|
|
||||||
|
# Build messages in OpenAI format
|
||||||
|
messages = session.to_openai_messages()
|
||||||
|
if system_prompt:
|
||||||
|
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||||
|
|
||||||
|
system_message = ChatCompletionSystemMessageParam(
|
||||||
|
role="system",
|
||||||
|
content=system_prompt,
|
||||||
|
)
|
||||||
|
messages = [system_message] + messages
|
||||||
|
|
||||||
|
# Build extra_body for tracing
|
||||||
|
extra_body: dict[str, Any] = {
|
||||||
|
"posthogProperties": {
|
||||||
|
"environment": settings.config.app_env.value,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if user_id:
|
||||||
|
extra_body["user"] = user_id[:128]
|
||||||
|
extra_body["posthogDistinctId"] = user_id
|
||||||
|
if session_id:
|
||||||
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
|
# Make non-streaming LLM call (no tools - just text response)
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
# No tools parameter = text-only response (no tool calls)
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=config.model,
|
||||||
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
|
extra_body=extra_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.choices and response.choices[0].message.content:
|
||||||
|
assistant_content = response.choices[0].message.content
|
||||||
|
|
||||||
|
# Reload session from DB to avoid race condition with user messages
|
||||||
|
# that may have been sent while we were generating the LLM response
|
||||||
|
fresh_session = await get_chat_session(session_id, user_id)
|
||||||
|
if not fresh_session:
|
||||||
|
logger.error(
|
||||||
|
f"Session {session_id} disappeared during LLM continuation"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save assistant message to database
|
||||||
|
assistant_message = ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=assistant_content,
|
||||||
|
)
|
||||||
|
fresh_session.messages.append(assistant_message)
|
||||||
|
|
||||||
|
# Save to database (not cache) to persist the response
|
||||||
|
await upsert_chat_session(fresh_session)
|
||||||
|
|
||||||
|
# Invalidate cache so next poll/refresh gets fresh data
|
||||||
|
await invalidate_session_cache(session_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generated LLM continuation for session {session_id}, "
|
||||||
|
f"response length: {len(assistant_content)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"LLM continuation returned empty response for {session_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||||
|
|||||||
@@ -49,6 +49,11 @@ tools: list[ChatCompletionToolParam] = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool(tool_name: str) -> BaseTool | None:
|
||||||
|
"""Get a tool instance by name."""
|
||||||
|
return TOOL_REGISTRY.get(tool_name)
|
||||||
|
|
||||||
|
|
||||||
async def execute_tool(
|
async def execute_tool(
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
parameters: dict[str, Any],
|
parameters: dict[str, Any],
|
||||||
@@ -57,7 +62,7 @@ async def execute_tool(
|
|||||||
tool_call_id: str,
|
tool_call_id: str,
|
||||||
) -> "StreamToolOutputAvailable":
|
) -> "StreamToolOutputAvailable":
|
||||||
"""Execute a tool by name."""
|
"""Execute a tool by name."""
|
||||||
tool = TOOL_REGISTRY.get(tool_name)
|
tool = get_tool(tool_name)
|
||||||
if not tool:
|
if not tool:
|
||||||
raise ValueError(f"Tool {tool_name} not found")
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,16 @@ class BaseTool:
|
|||||||
"""Whether this tool requires authentication."""
|
"""Whether this tool requires authentication."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
"""Whether this tool is long-running and should execute in background.
|
||||||
|
|
||||||
|
Long-running tools (like agent generation) are executed via background
|
||||||
|
tasks to survive SSE disconnections. The result is persisted to chat
|
||||||
|
history and visible when the user refreshes.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
def as_openai_tool(self) -> ChatCompletionToolParam:
|
def as_openai_tool(self) -> ChatCompletionToolParam:
|
||||||
"""Convert to OpenAI tool format."""
|
"""Convert to OpenAI tool format."""
|
||||||
return ChatCompletionToolParam(
|
return ChatCompletionToolParam(
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ class CreateAgentTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ class EditAgentTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ class ResponseType(str, Enum):
|
|||||||
BLOCK_OUTPUT = "block_output"
|
BLOCK_OUTPUT = "block_output"
|
||||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||||
DOC_PAGE = "doc_page"
|
DOC_PAGE = "doc_page"
|
||||||
|
# Long-running operation types
|
||||||
|
OPERATION_STARTED = "operation_started"
|
||||||
|
OPERATION_PENDING = "operation_pending"
|
||||||
|
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||||
|
|
||||||
|
|
||||||
# Base response model
|
# Base response model
|
||||||
@@ -334,3 +338,39 @@ class BlockOutputResponse(ToolResponseBase):
|
|||||||
block_name: str
|
block_name: str
|
||||||
outputs: dict[str, list[Any]]
|
outputs: dict[str, list[Any]]
|
||||||
success: bool = True
|
success: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
# Long-running operation models
|
||||||
|
class OperationStartedResponse(ToolResponseBase):
|
||||||
|
"""Response when a long-running operation has been started in the background.
|
||||||
|
|
||||||
|
This is returned immediately to the client while the operation continues
|
||||||
|
to execute. The user can close the tab and check back later.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||||
|
operation_id: str
|
||||||
|
tool_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class OperationPendingResponse(ToolResponseBase):
|
||||||
|
"""Response stored in chat history while a long-running operation is executing.
|
||||||
|
|
||||||
|
This is persisted to the database so users see a pending state when they
|
||||||
|
refresh before the operation completes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_PENDING
|
||||||
|
operation_id: str
|
||||||
|
tool_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class OperationInProgressResponse(ToolResponseBase):
|
||||||
|
"""Response when an operation is already in progress.
|
||||||
|
|
||||||
|
Returned for idempotency when the same tool_call_id is requested again
|
||||||
|
while the background task is still running.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||||
|
tool_call_id: str
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from backend.data.model import CredentialsMetaInput
|
|||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
@@ -64,11 +64,11 @@ async def list_library_agents(
|
|||||||
|
|
||||||
if page < 1 or page_size < 1:
|
if page < 1 or page_size < 1:
|
||||||
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
||||||
raise DatabaseError("Invalid pagination input")
|
raise InvalidInputError("Invalid pagination input")
|
||||||
|
|
||||||
if search_term and len(search_term.strip()) > 100:
|
if search_term and len(search_term.strip()) > 100:
|
||||||
logger.warning(f"Search term too long: {repr(search_term)}")
|
logger.warning(f"Search term too long: {repr(search_term)}")
|
||||||
raise DatabaseError("Search term is too long")
|
raise InvalidInputError("Search term is too long")
|
||||||
|
|
||||||
where_clause: prisma.types.LibraryAgentWhereInput = {
|
where_clause: prisma.types.LibraryAgentWhereInput = {
|
||||||
"userId": user_id,
|
"userId": user_id,
|
||||||
@@ -175,7 +175,7 @@ async def list_favorite_library_agents(
|
|||||||
|
|
||||||
if page < 1 or page_size < 1:
|
if page < 1 or page_size < 1:
|
||||||
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
||||||
raise DatabaseError("Invalid pagination input")
|
raise InvalidInputError("Invalid pagination input")
|
||||||
|
|
||||||
where_clause: prisma.types.LibraryAgentWhereInput = {
|
where_clause: prisma.types.LibraryAgentWhereInput = {
|
||||||
"userId": user_id,
|
"userId": user_id,
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import autogpt_libs.auth as autogpt_auth_lib
|
import autogpt_libs.auth as autogpt_auth_lib
|
||||||
@@ -6,15 +5,11 @@ from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
|||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
from prisma.enums import OnboardingStep
|
from prisma.enums import OnboardingStep
|
||||||
|
|
||||||
import backend.api.features.store.exceptions as store_exceptions
|
|
||||||
from backend.data.onboarding import complete_onboarding_step
|
from backend.data.onboarding import complete_onboarding_step
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
|
||||||
|
|
||||||
from .. import db as library_db
|
from .. import db as library_db
|
||||||
from .. import model as library_model
|
from .. import model as library_model
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/agents",
|
prefix="/agents",
|
||||||
tags=["library", "private"],
|
tags=["library", "private"],
|
||||||
@@ -26,10 +21,6 @@ router = APIRouter(
|
|||||||
"",
|
"",
|
||||||
summary="List Library Agents",
|
summary="List Library Agents",
|
||||||
response_model=library_model.LibraryAgentResponse,
|
response_model=library_model.LibraryAgentResponse,
|
||||||
responses={
|
|
||||||
200: {"description": "List of library agents"},
|
|
||||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def list_library_agents(
|
async def list_library_agents(
|
||||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||||
@@ -53,22 +44,7 @@ async def list_library_agents(
|
|||||||
) -> library_model.LibraryAgentResponse:
|
) -> library_model.LibraryAgentResponse:
|
||||||
"""
|
"""
|
||||||
Get all agents in the user's library (both created and saved).
|
Get all agents in the user's library (both created and saved).
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
search_term: Optional search term to filter agents by name/description.
|
|
||||||
filter_by: List of filters to apply (favorites, created by user).
|
|
||||||
sort_by: List of sorting criteria (created date, updated date).
|
|
||||||
page: Page number to retrieve.
|
|
||||||
page_size: Number of agents per page.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A LibraryAgentResponse containing agents and pagination metadata.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If a server/database error occurs.
|
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
return await library_db.list_library_agents(
|
return await library_db.list_library_agents(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_term=search_term,
|
search_term=search_term,
|
||||||
@@ -76,20 +52,11 @@ async def list_library_agents(
|
|||||||
page=page,
|
page=page,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Could not list library agents for user #{user_id}: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=str(e),
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/favorites",
|
"/favorites",
|
||||||
summary="List Favorite Library Agents",
|
summary="List Favorite Library Agents",
|
||||||
responses={
|
|
||||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def list_favorite_library_agents(
|
async def list_favorite_library_agents(
|
||||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||||
@@ -106,30 +73,12 @@ async def list_favorite_library_agents(
|
|||||||
) -> library_model.LibraryAgentResponse:
|
) -> library_model.LibraryAgentResponse:
|
||||||
"""
|
"""
|
||||||
Get all favorite agents in the user's library.
|
Get all favorite agents in the user's library.
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
page: Page number to retrieve.
|
|
||||||
page_size: Number of agents per page.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A LibraryAgentResponse containing favorite agents and pagination metadata.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If a server/database error occurs.
|
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
return await library_db.list_favorite_library_agents(
|
return await library_db.list_favorite_library_agents(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
page=page,
|
page=page,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Could not list favorite library agents for user #{user_id}: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=str(e),
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{library_agent_id}", summary="Get Library Agent")
|
@router.get("/{library_agent_id}", summary="Get Library Agent")
|
||||||
@@ -162,10 +111,6 @@ async def get_library_agent_by_graph_id(
|
|||||||
summary="Get Agent By Store ID",
|
summary="Get Agent By Store ID",
|
||||||
tags=["store", "library"],
|
tags=["store", "library"],
|
||||||
response_model=library_model.LibraryAgent | None,
|
response_model=library_model.LibraryAgent | None,
|
||||||
responses={
|
|
||||||
200: {"description": "Library agent found"},
|
|
||||||
404: {"description": "Agent not found"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def get_library_agent_by_store_listing_version_id(
|
async def get_library_agent_by_store_listing_version_id(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
@@ -174,32 +119,15 @@ async def get_library_agent_by_store_listing_version_id(
|
|||||||
"""
|
"""
|
||||||
Get Library Agent from Store Listing Version ID.
|
Get Library Agent from Store Listing Version ID.
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
return await library_db.get_library_agent_by_store_version_id(
|
return await library_db.get_library_agent_by_store_version_id(
|
||||||
store_listing_version_id, user_id
|
store_listing_version_id, user_id
|
||||||
)
|
)
|
||||||
except NotFoundError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=str(e),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Could not fetch library agent from store version ID: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=str(e),
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"",
|
"",
|
||||||
summary="Add Marketplace Agent",
|
summary="Add Marketplace Agent",
|
||||||
status_code=status.HTTP_201_CREATED,
|
status_code=status.HTTP_201_CREATED,
|
||||||
responses={
|
|
||||||
201: {"description": "Agent added successfully"},
|
|
||||||
404: {"description": "Store listing version not found"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def add_marketplace_agent_to_library(
|
async def add_marketplace_agent_to_library(
|
||||||
store_listing_version_id: str = Body(embed=True),
|
store_listing_version_id: str = Body(embed=True),
|
||||||
@@ -210,59 +138,19 @@ async def add_marketplace_agent_to_library(
|
|||||||
) -> library_model.LibraryAgent:
|
) -> library_model.LibraryAgent:
|
||||||
"""
|
"""
|
||||||
Add an agent from the marketplace to the user's library.
|
Add an agent from the marketplace to the user's library.
|
||||||
|
|
||||||
Args:
|
|
||||||
store_listing_version_id: ID of the store listing version to add.
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
library_model.LibraryAgent: Agent added to the library
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException(404): If the listing version is not found.
|
|
||||||
HTTPException(500): If a server/database error occurs.
|
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
agent = await library_db.add_store_agent_to_library(
|
agent = await library_db.add_store_agent_to_library(
|
||||||
store_listing_version_id=store_listing_version_id,
|
store_listing_version_id=store_listing_version_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
if source != "onboarding":
|
if source != "onboarding":
|
||||||
await complete_onboarding_step(
|
await complete_onboarding_step(user_id, OnboardingStep.MARKETPLACE_ADD_AGENT)
|
||||||
user_id, OnboardingStep.MARKETPLACE_ADD_AGENT
|
|
||||||
)
|
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
except store_exceptions.AgentNotFoundError as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Could not find store listing version {store_listing_version_id} "
|
|
||||||
"to add to library"
|
|
||||||
)
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
|
||||||
except DatabaseError as e:
|
|
||||||
logger.error(f"Database error while adding agent to library: {e}", e)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail={"message": str(e), "hint": "Inspect DB logs for details."},
|
|
||||||
) from e
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error while adding agent to library: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail={
|
|
||||||
"message": str(e),
|
|
||||||
"hint": "Check server logs for more information.",
|
|
||||||
},
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
@router.patch(
|
||||||
"/{library_agent_id}",
|
"/{library_agent_id}",
|
||||||
summary="Update Library Agent",
|
summary="Update Library Agent",
|
||||||
responses={
|
|
||||||
200: {"description": "Agent updated successfully"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def update_library_agent(
|
async def update_library_agent(
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
@@ -271,16 +159,7 @@ async def update_library_agent(
|
|||||||
) -> library_model.LibraryAgent:
|
) -> library_model.LibraryAgent:
|
||||||
"""
|
"""
|
||||||
Update the library agent with the given fields.
|
Update the library agent with the given fields.
|
||||||
|
|
||||||
Args:
|
|
||||||
library_agent_id: ID of the library agent to update.
|
|
||||||
payload: Fields to update (auto_update_version, is_favorite, etc.).
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException(500): If a server/database error occurs.
|
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
return await library_db.update_library_agent(
|
return await library_db.update_library_agent(
|
||||||
library_agent_id=library_agent_id,
|
library_agent_id=library_agent_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -290,33 +169,11 @@ async def update_library_agent(
|
|||||||
is_archived=payload.is_archived,
|
is_archived=payload.is_archived,
|
||||||
settings=payload.settings,
|
settings=payload.settings,
|
||||||
)
|
)
|
||||||
except NotFoundError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=str(e),
|
|
||||||
) from e
|
|
||||||
except DatabaseError as e:
|
|
||||||
logger.error(f"Database error while updating library agent: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail={"message": str(e), "hint": "Verify DB connection."},
|
|
||||||
) from e
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error while updating library agent: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail={"message": str(e), "hint": "Check server logs."},
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"/{library_agent_id}",
|
"/{library_agent_id}",
|
||||||
summary="Delete Library Agent",
|
summary="Delete Library Agent",
|
||||||
responses={
|
|
||||||
204: {"description": "Agent deleted successfully"},
|
|
||||||
404: {"description": "Agent not found"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def delete_library_agent(
|
async def delete_library_agent(
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
@@ -324,28 +181,11 @@ async def delete_library_agent(
|
|||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Soft-delete the specified library agent.
|
Soft-delete the specified library agent.
|
||||||
|
|
||||||
Args:
|
|
||||||
library_agent_id: ID of the library agent to delete.
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
204 No Content if successful.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException(404): If the agent does not exist.
|
|
||||||
HTTPException(500): If a server/database error occurs.
|
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
await library_db.delete_library_agent(
|
await library_db.delete_library_agent(
|
||||||
library_agent_id=library_agent_id, user_id=user_id
|
library_agent_id=library_agent_id, user_id=user_id
|
||||||
)
|
)
|
||||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
except NotFoundError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=str(e),
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")
|
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")
|
||||||
|
|||||||
@@ -118,21 +118,6 @@ async def test_get_library_agents_success(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
|
||||||
mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents")
|
|
||||||
mock_db_call.side_effect = Exception("Test error")
|
|
||||||
|
|
||||||
response = client.get("/agents?search_term=test")
|
|
||||||
assert response.status_code == 500
|
|
||||||
mock_db_call.assert_called_once_with(
|
|
||||||
user_id=test_user_id,
|
|
||||||
search_term="test",
|
|
||||||
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
|
|
||||||
page=1,
|
|
||||||
page_size=15,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_favorite_library_agents_success(
|
async def test_get_favorite_library_agents_success(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
@@ -190,23 +175,6 @@ async def test_get_favorite_library_agents_success(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_favorite_library_agents_error(
|
|
||||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
|
||||||
):
|
|
||||||
mock_db_call = mocker.patch(
|
|
||||||
"backend.api.features.library.db.list_favorite_library_agents"
|
|
||||||
)
|
|
||||||
mock_db_call.side_effect = Exception("Test error")
|
|
||||||
|
|
||||||
response = client.get("/agents/favorites")
|
|
||||||
assert response.status_code == 500
|
|
||||||
mock_db_call.assert_called_once_with(
|
|
||||||
user_id=test_user_id,
|
|
||||||
page=1,
|
|
||||||
page_size=15,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_agent_to_library_success(
|
def test_add_agent_to_library_success(
|
||||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||||
):
|
):
|
||||||
@@ -258,19 +226,3 @@ def test_add_agent_to_library_success(
|
|||||||
store_listing_version_id="test-version-id", user_id=test_user_id
|
store_listing_version_id="test-version-id", user_id=test_user_id
|
||||||
)
|
)
|
||||||
mock_complete_onboarding.assert_awaited_once()
|
mock_complete_onboarding.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
|
||||||
mock_db_call = mocker.patch(
|
|
||||||
"backend.api.features.library.db.add_store_agent_to_library"
|
|
||||||
)
|
|
||||||
mock_db_call.side_effect = Exception("Test error")
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/agents", json={"store_listing_version_id": "test-version-id"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 500
|
|
||||||
assert "detail" in response.json() # Verify error response structure
|
|
||||||
mock_db_call.assert_called_once_with(
|
|
||||||
store_listing_version_id="test-version-id", user_id=test_user_id
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -454,6 +454,7 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
|||||||
total_processed = 0
|
total_processed = 0
|
||||||
total_success = 0
|
total_success = 0
|
||||||
total_failed = 0
|
total_failed = 0
|
||||||
|
all_errors: dict[str, int] = {} # Aggregate errors across all content types
|
||||||
|
|
||||||
# Process content types in explicit order
|
# Process content types in explicit order
|
||||||
processing_order = [
|
processing_order = [
|
||||||
@@ -499,23 +500,12 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
|||||||
success = sum(1 for result in results if result is True)
|
success = sum(1 for result in results if result is True)
|
||||||
failed = len(results) - success
|
failed = len(results) - success
|
||||||
|
|
||||||
# Aggregate unique errors to avoid Sentry spam
|
# Aggregate errors across all content types
|
||||||
if failed > 0:
|
if failed > 0:
|
||||||
# Group errors by type and message
|
|
||||||
error_summary: dict[str, int] = {}
|
|
||||||
for result in results:
|
for result in results:
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
error_key = f"{type(result).__name__}: {str(result)}"
|
error_key = f"{type(result).__name__}: {str(result)}"
|
||||||
error_summary[error_key] = error_summary.get(error_key, 0) + 1
|
all_errors[error_key] = all_errors.get(error_key, 0) + 1
|
||||||
|
|
||||||
# Log aggregated error summary
|
|
||||||
error_details = ", ".join(
|
|
||||||
f"{error} ({count}x)" for error, count in error_summary.items()
|
|
||||||
)
|
|
||||||
logger.error(
|
|
||||||
f"{content_type.value}: {failed}/{len(results)} embeddings failed. "
|
|
||||||
f"Errors: {error_details}"
|
|
||||||
)
|
|
||||||
|
|
||||||
results_by_type[content_type.value] = {
|
results_by_type[content_type.value] = {
|
||||||
"processed": len(missing_items),
|
"processed": len(missing_items),
|
||||||
@@ -542,6 +532,13 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
|||||||
"error": str(e),
|
"error": str(e),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Log aggregated errors once at the end
|
||||||
|
if all_errors:
|
||||||
|
error_details = ", ".join(
|
||||||
|
f"{error} ({count}x)" for error, count in all_errors.items()
|
||||||
|
)
|
||||||
|
logger.error(f"Embedding backfill errors: {error_details}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"by_type": results_by_type,
|
"by_type": results_by_type,
|
||||||
"totals": {
|
"totals": {
|
||||||
|
|||||||
@@ -261,14 +261,36 @@ async def get_onboarding_agents(
|
|||||||
return await get_recommended_agents(user_id)
|
return await get_recommended_agents(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingStatusResponse(pydantic.BaseModel):
|
||||||
|
"""Response for onboarding status check."""
|
||||||
|
|
||||||
|
is_onboarding_enabled: bool
|
||||||
|
is_chat_enabled: bool
|
||||||
|
|
||||||
|
|
||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
"/onboarding/enabled",
|
"/onboarding/enabled",
|
||||||
summary="Is onboarding enabled",
|
summary="Is onboarding enabled",
|
||||||
tags=["onboarding", "public"],
|
tags=["onboarding", "public"],
|
||||||
dependencies=[Security(requires_user)],
|
response_model=OnboardingStatusResponse,
|
||||||
)
|
)
|
||||||
async def is_onboarding_enabled() -> bool:
|
async def is_onboarding_enabled(
|
||||||
return await onboarding_enabled()
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
) -> OnboardingStatusResponse:
|
||||||
|
# Check if chat is enabled for user
|
||||||
|
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
|
||||||
|
|
||||||
|
# If chat is enabled, skip legacy onboarding
|
||||||
|
if is_chat_enabled:
|
||||||
|
return OnboardingStatusResponse(
|
||||||
|
is_onboarding_enabled=False,
|
||||||
|
is_chat_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OnboardingStatusResponse(
|
||||||
|
is_onboarding_enabled=await onboarding_enabled(),
|
||||||
|
is_chat_enabled=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@v1_router.post(
|
@v1_router.post(
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ FrontendOnboardingStep = Literal[
|
|||||||
OnboardingStep.AGENT_NEW_RUN,
|
OnboardingStep.AGENT_NEW_RUN,
|
||||||
OnboardingStep.AGENT_INPUT,
|
OnboardingStep.AGENT_INPUT,
|
||||||
OnboardingStep.CONGRATS,
|
OnboardingStep.CONGRATS,
|
||||||
|
OnboardingStep.VISIT_COPILOT,
|
||||||
OnboardingStep.MARKETPLACE_VISIT,
|
OnboardingStep.MARKETPLACE_VISIT,
|
||||||
OnboardingStep.BUILDER_OPEN,
|
OnboardingStep.BUILDER_OPEN,
|
||||||
]
|
]
|
||||||
@@ -122,6 +123,9 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
|||||||
async def _reward_user(user_id: str, onboarding: UserOnboarding, step: OnboardingStep):
|
async def _reward_user(user_id: str, onboarding: UserOnboarding, step: OnboardingStep):
|
||||||
reward = 0
|
reward = 0
|
||||||
match step:
|
match step:
|
||||||
|
# Welcome bonus for visiting copilot ($5 = 500 credits)
|
||||||
|
case OnboardingStep.VISIT_COPILOT:
|
||||||
|
reward = 500
|
||||||
# Reward user when they clicked New Run during onboarding
|
# Reward user when they clicked New Run during onboarding
|
||||||
# This is because they need credits before scheduling a run (next step)
|
# This is because they need credits before scheduling a run (next step)
|
||||||
# This is seen as a reward for the GET_RESULTS step in the wallet
|
# This is seen as a reward for the GET_RESULTS step in the wallet
|
||||||
|
|||||||
@@ -135,6 +135,12 @@ class GraphValidationError(ValueError):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidInputError(ValueError):
|
||||||
|
"""Raised when user input validation fails (e.g., search term too long)"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DatabaseError(Exception):
|
class DatabaseError(Exception):
|
||||||
"""Raised when there is an error interacting with the database"""
|
"""Raised when there is an error interacting with the database"""
|
||||||
|
|
||||||
|
|||||||
@@ -359,8 +359,8 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
description="The port for the Agent Generator service",
|
description="The port for the Agent Generator service",
|
||||||
)
|
)
|
||||||
agentgenerator_timeout: int = Field(
|
agentgenerator_timeout: int = Field(
|
||||||
default=120,
|
default=600,
|
||||||
description="The timeout in seconds for Agent Generator service requests",
|
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
|
||||||
)
|
)
|
||||||
|
|
||||||
enable_example_blocks: bool = Field(
|
enable_example_blocks: bool = Field(
|
||||||
|
|||||||
@@ -0,0 +1,2 @@
|
|||||||
|
-- AlterEnum
|
||||||
|
ALTER TYPE "OnboardingStep" ADD VALUE 'VISIT_COPILOT';
|
||||||
@@ -81,6 +81,7 @@ enum OnboardingStep {
|
|||||||
AGENT_INPUT
|
AGENT_INPUT
|
||||||
CONGRATS
|
CONGRATS
|
||||||
// First Wins
|
// First Wins
|
||||||
|
VISIT_COPILOT
|
||||||
GET_RESULTS
|
GET_RESULTS
|
||||||
MARKETPLACE_VISIT
|
MARKETPLACE_VISIT
|
||||||
MARKETPLACE_ADD_AGENT
|
MARKETPLACE_ADD_AGENT
|
||||||
|
|||||||
@@ -2,8 +2,9 @@
|
|||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
import { resolveResponse, shouldShowOnboarding } from "@/app/api/helpers";
|
import { resolveResponse, getOnboardingStatus } from "@/app/api/helpers";
|
||||||
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
|
|
||||||
export default function OnboardingPage() {
|
export default function OnboardingPage() {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
@@ -11,10 +12,13 @@ export default function OnboardingPage() {
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
async function redirectToStep() {
|
async function redirectToStep() {
|
||||||
try {
|
try {
|
||||||
// Check if onboarding is enabled
|
// Check if onboarding is enabled (also gets chat flag for redirect)
|
||||||
const isEnabled = await shouldShowOnboarding();
|
const { shouldShowOnboarding, isChatEnabled } =
|
||||||
if (!isEnabled) {
|
await getOnboardingStatus();
|
||||||
router.replace("/");
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
|
if (!shouldShowOnboarding) {
|
||||||
|
router.replace(homepageRoute);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,7 +26,7 @@ export default function OnboardingPage() {
|
|||||||
|
|
||||||
// Handle completed onboarding
|
// Handle completed onboarding
|
||||||
if (onboarding.completedSteps.includes("GET_RESULTS")) {
|
if (onboarding.completedSteps.includes("GET_RESULTS")) {
|
||||||
router.replace("/");
|
router.replace(homepageRoute);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import BackendAPI from "@/lib/autogpt-server-api";
|
import BackendAPI from "@/lib/autogpt-server-api";
|
||||||
import { NextResponse } from "next/server";
|
import { NextResponse } from "next/server";
|
||||||
import { revalidatePath } from "next/cache";
|
import { revalidatePath } from "next/cache";
|
||||||
import { shouldShowOnboarding } from "@/app/api/helpers";
|
import { getOnboardingStatus } from "@/app/api/helpers";
|
||||||
|
|
||||||
// Handle the callback to complete the user session login
|
// Handle the callback to complete the user session login
|
||||||
export async function GET(request: Request) {
|
export async function GET(request: Request) {
|
||||||
@@ -25,11 +26,15 @@ export async function GET(request: Request) {
|
|||||||
const api = new BackendAPI();
|
const api = new BackendAPI();
|
||||||
await api.createUser();
|
await api.createUser();
|
||||||
|
|
||||||
if (await shouldShowOnboarding()) {
|
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||||
|
const { shouldShowOnboarding, isChatEnabled } =
|
||||||
|
await getOnboardingStatus();
|
||||||
|
if (shouldShowOnboarding) {
|
||||||
next = "/onboarding";
|
next = "/onboarding";
|
||||||
revalidatePath("/onboarding", "layout");
|
revalidatePath("/onboarding", "layout");
|
||||||
} else {
|
} else {
|
||||||
revalidatePath("/", "layout");
|
next = getHomepageRoute(isChatEnabled);
|
||||||
|
revalidatePath(next, "layout");
|
||||||
}
|
}
|
||||||
} catch (createUserError) {
|
} catch (createUserError) {
|
||||||
console.error("Error creating user:", createUserError);
|
console.error("Error creating user:", createUserError);
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
||||||
import type { ReactNode } from "react";
|
import type { ReactNode } from "react";
|
||||||
import { useEffect } from "react";
|
|
||||||
import { useCopilotStore } from "../../copilot-page-store";
|
|
||||||
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
|
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
|
||||||
import { LoadingState } from "./components/LoadingState/LoadingState";
|
|
||||||
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
||||||
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
||||||
import { useCopilotShell } from "./useCopilotShell";
|
import { useCopilotShell } from "./useCopilotShell";
|
||||||
@@ -20,38 +18,21 @@ export function CopilotShell({ children }: Props) {
|
|||||||
isMobile,
|
isMobile,
|
||||||
isDrawerOpen,
|
isDrawerOpen,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
isCreatingSession,
|
||||||
isLoggedIn,
|
isLoggedIn,
|
||||||
hasActiveSession,
|
hasActiveSession,
|
||||||
sessions,
|
sessions,
|
||||||
currentSessionId,
|
currentSessionId,
|
||||||
handleSelectSession,
|
|
||||||
handleOpenDrawer,
|
handleOpenDrawer,
|
||||||
handleCloseDrawer,
|
handleCloseDrawer,
|
||||||
handleDrawerOpenChange,
|
handleDrawerOpenChange,
|
||||||
handleNewChat,
|
handleNewChatClick,
|
||||||
|
handleSessionClick,
|
||||||
hasNextPage,
|
hasNextPage,
|
||||||
isFetchingNextPage,
|
isFetchingNextPage,
|
||||||
fetchNextPage,
|
fetchNextPage,
|
||||||
isReadyToShowContent,
|
|
||||||
} = useCopilotShell();
|
} = useCopilotShell();
|
||||||
|
|
||||||
const setNewChatHandler = useCopilotStore((s) => s.setNewChatHandler);
|
|
||||||
const requestNewChat = useCopilotStore((s) => s.requestNewChat);
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function registerNewChatHandler() {
|
|
||||||
setNewChatHandler(handleNewChat);
|
|
||||||
return function cleanup() {
|
|
||||||
setNewChatHandler(null);
|
|
||||||
};
|
|
||||||
},
|
|
||||||
[handleNewChat],
|
|
||||||
);
|
|
||||||
|
|
||||||
function handleNewChatClick() {
|
|
||||||
requestNewChat();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isLoggedIn) {
|
if (!isLoggedIn) {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full items-center justify-center">
|
<div className="flex h-full items-center justify-center">
|
||||||
@@ -72,7 +53,7 @@ export function CopilotShell({ children }: Props) {
|
|||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
hasNextPage={hasNextPage}
|
hasNextPage={hasNextPage}
|
||||||
isFetchingNextPage={isFetchingNextPage}
|
isFetchingNextPage={isFetchingNextPage}
|
||||||
onSelectSession={handleSelectSession}
|
onSelectSession={handleSessionClick}
|
||||||
onFetchNextPage={fetchNextPage}
|
onFetchNextPage={fetchNextPage}
|
||||||
onNewChat={handleNewChatClick}
|
onNewChat={handleNewChatClick}
|
||||||
hasActiveSession={Boolean(hasActiveSession)}
|
hasActiveSession={Boolean(hasActiveSession)}
|
||||||
@@ -82,7 +63,18 @@ export function CopilotShell({ children }: Props) {
|
|||||||
<div className="relative flex min-h-0 flex-1 flex-col">
|
<div className="relative flex min-h-0 flex-1 flex-col">
|
||||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||||
<div className="flex min-h-0 flex-1 flex-col">
|
<div className="flex min-h-0 flex-1 flex-col">
|
||||||
{isReadyToShowContent ? children : <LoadingState />}
|
{isCreatingSession ? (
|
||||||
|
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
||||||
|
<div className="flex flex-col items-center gap-4">
|
||||||
|
<ChatLoader />
|
||||||
|
<Text variant="body" className="text-zinc-500">
|
||||||
|
Creating your chat...
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
children
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -94,7 +86,7 @@ export function CopilotShell({ children }: Props) {
|
|||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
hasNextPage={hasNextPage}
|
hasNextPage={hasNextPage}
|
||||||
isFetchingNextPage={isFetchingNextPage}
|
isFetchingNextPage={isFetchingNextPage}
|
||||||
onSelectSession={handleSelectSession}
|
onSelectSession={handleSessionClick}
|
||||||
onFetchNextPage={fetchNextPage}
|
onFetchNextPage={fetchNextPage}
|
||||||
onNewChat={handleNewChatClick}
|
onNewChat={handleNewChatClick}
|
||||||
onClose={handleCloseDrawer}
|
onClose={handleCloseDrawer}
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
|
||||||
|
|
||||||
export function LoadingState() {
|
|
||||||
return (
|
|
||||||
<div className="flex flex-1 items-center justify-center">
|
|
||||||
<div className="flex flex-col items-center gap-4">
|
|
||||||
<ChatLoader />
|
|
||||||
<Text variant="body" className="text-zinc-500">
|
|
||||||
Loading your chats...
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -3,17 +3,17 @@ import { useState } from "react";
|
|||||||
export function useMobileDrawer() {
|
export function useMobileDrawer() {
|
||||||
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
||||||
|
|
||||||
function handleOpenDrawer() {
|
const handleOpenDrawer = () => {
|
||||||
setIsDrawerOpen(true);
|
setIsDrawerOpen(true);
|
||||||
}
|
};
|
||||||
|
|
||||||
function handleCloseDrawer() {
|
const handleCloseDrawer = () => {
|
||||||
setIsDrawerOpen(false);
|
setIsDrawerOpen(false);
|
||||||
}
|
};
|
||||||
|
|
||||||
function handleDrawerOpenChange(open: boolean) {
|
const handleDrawerOpenChange = (open: boolean) => {
|
||||||
setIsDrawerOpen(open);
|
setIsDrawerOpen(open);
|
||||||
}
|
};
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isDrawerOpen,
|
isDrawerOpen,
|
||||||
|
|||||||
@@ -1,11 +1,6 @@
|
|||||||
import {
|
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
getGetV2ListSessionsQueryKey,
|
|
||||||
useGetV2ListSessions,
|
|
||||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
|
||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||||
import { okData } from "@/app/api/helpers";
|
import { okData } from "@/app/api/helpers";
|
||||||
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
|
|
||||||
const PAGE_SIZE = 50;
|
const PAGE_SIZE = 50;
|
||||||
@@ -16,12 +11,12 @@ export interface UseSessionsPaginationArgs {
|
|||||||
|
|
||||||
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
||||||
const [offset, setOffset] = useState(0);
|
const [offset, setOffset] = useState(0);
|
||||||
|
|
||||||
const [accumulatedSessions, setAccumulatedSessions] = useState<
|
const [accumulatedSessions, setAccumulatedSessions] = useState<
|
||||||
SessionSummaryResponse[]
|
SessionSummaryResponse[]
|
||||||
>([]);
|
>([]);
|
||||||
|
|
||||||
const [totalCount, setTotalCount] = useState<number | null>(null);
|
const [totalCount, setTotalCount] = useState<number | null>(null);
|
||||||
const queryClient = useQueryClient();
|
|
||||||
const onStreamComplete = useChatStore((state) => state.onStreamComplete);
|
|
||||||
|
|
||||||
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
|
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
|
||||||
{ limit: PAGE_SIZE, offset },
|
{ limit: PAGE_SIZE, offset },
|
||||||
@@ -32,20 +27,7 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
useEffect(function refreshOnStreamComplete() {
|
useEffect(() => {
|
||||||
const unsubscribe = onStreamComplete(function handleStreamComplete() {
|
|
||||||
setOffset(0);
|
|
||||||
setAccumulatedSessions([]);
|
|
||||||
setTotalCount(null);
|
|
||||||
queryClient.invalidateQueries({
|
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
|
||||||
});
|
|
||||||
});
|
|
||||||
return unsubscribe;
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function updateSessionsFromResponse() {
|
|
||||||
const responseData = okData(data);
|
const responseData = okData(data);
|
||||||
if (responseData) {
|
if (responseData) {
|
||||||
const newSessions = responseData.sessions;
|
const newSessions = responseData.sessions;
|
||||||
@@ -61,9 +43,7 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
|||||||
setAccumulatedSessions([]);
|
setAccumulatedSessions([]);
|
||||||
setTotalCount(null);
|
setTotalCount(null);
|
||||||
}
|
}
|
||||||
},
|
}, [data, offset, enabled]);
|
||||||
[data, offset, enabled],
|
|
||||||
);
|
|
||||||
|
|
||||||
const hasNextPage =
|
const hasNextPage =
|
||||||
totalCount !== null && accumulatedSessions.length < totalCount;
|
totalCount !== null && accumulatedSessions.length < totalCount;
|
||||||
@@ -86,17 +66,17 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
|||||||
}
|
}
|
||||||
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
|
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
|
||||||
|
|
||||||
function fetchNextPage() {
|
const fetchNextPage = () => {
|
||||||
if (hasNextPage && !isFetching) {
|
if (hasNextPage && !isFetching) {
|
||||||
setOffset((prev) => prev + PAGE_SIZE);
|
setOffset((prev) => prev + PAGE_SIZE);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
function reset() {
|
const reset = () => {
|
||||||
setOffset(0);
|
setOffset(0);
|
||||||
setAccumulatedSessions([]);
|
setAccumulatedSessions([]);
|
||||||
setTotalCount(null);
|
setTotalCount(null);
|
||||||
}
|
};
|
||||||
|
|
||||||
return {
|
return {
|
||||||
sessions: accumulatedSessions,
|
sessions: accumulatedSessions,
|
||||||
|
|||||||
@@ -104,76 +104,3 @@ export function mergeCurrentSessionIntoList(
|
|||||||
export function getCurrentSessionId(searchParams: URLSearchParams) {
|
export function getCurrentSessionId(searchParams: URLSearchParams) {
|
||||||
return searchParams.get("sessionId");
|
return searchParams.get("sessionId");
|
||||||
}
|
}
|
||||||
|
|
||||||
export function shouldAutoSelectSession(
|
|
||||||
areAllSessionsLoaded: boolean,
|
|
||||||
hasAutoSelectedSession: boolean,
|
|
||||||
paramSessionId: string | null,
|
|
||||||
visibleSessions: SessionSummaryResponse[],
|
|
||||||
accumulatedSessions: SessionSummaryResponse[],
|
|
||||||
isLoading: boolean,
|
|
||||||
totalCount: number | null,
|
|
||||||
) {
|
|
||||||
if (!areAllSessionsLoaded || hasAutoSelectedSession) {
|
|
||||||
return {
|
|
||||||
shouldSelect: false,
|
|
||||||
sessionIdToSelect: null,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (paramSessionId) {
|
|
||||||
return {
|
|
||||||
shouldSelect: false,
|
|
||||||
sessionIdToSelect: null,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (visibleSessions.length > 0) {
|
|
||||||
return {
|
|
||||||
shouldSelect: true,
|
|
||||||
sessionIdToSelect: visibleSessions[0].id,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (accumulatedSessions.length === 0 && !isLoading && totalCount === 0) {
|
|
||||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: true };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (totalCount === 0) {
|
|
||||||
return {
|
|
||||||
shouldSelect: false,
|
|
||||||
sessionIdToSelect: null,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: false };
|
|
||||||
}
|
|
||||||
|
|
||||||
export function checkReadyToShowContent(
|
|
||||||
areAllSessionsLoaded: boolean,
|
|
||||||
paramSessionId: string | null,
|
|
||||||
accumulatedSessions: SessionSummaryResponse[],
|
|
||||||
isCurrentSessionLoading: boolean,
|
|
||||||
currentSessionData: SessionDetailResponse | null | undefined,
|
|
||||||
hasAutoSelectedSession: boolean,
|
|
||||||
) {
|
|
||||||
if (!areAllSessionsLoaded) return false;
|
|
||||||
|
|
||||||
if (paramSessionId) {
|
|
||||||
const sessionFound = accumulatedSessions.some(
|
|
||||||
(s) => s.id === paramSessionId,
|
|
||||||
);
|
|
||||||
return (
|
|
||||||
sessionFound ||
|
|
||||||
(!isCurrentSessionLoading &&
|
|
||||||
currentSessionData !== undefined &&
|
|
||||||
currentSessionData !== null)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return hasAutoSelectedSession;
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,26 +1,22 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
|
getGetV2GetSessionQueryKey,
|
||||||
getGetV2ListSessionsQueryKey,
|
getGetV2ListSessionsQueryKey,
|
||||||
useGetV2GetSession,
|
useGetV2GetSession,
|
||||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
|
||||||
import { okData } from "@/app/api/helpers";
|
import { okData } from "@/app/api/helpers";
|
||||||
|
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
||||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { parseAsString, useQueryState } from "nuqs";
|
|
||||||
import { usePathname, useSearchParams } from "next/navigation";
|
import { usePathname, useSearchParams } from "next/navigation";
|
||||||
import { useEffect, useRef, useState } from "react";
|
import { useRef } from "react";
|
||||||
|
import { useCopilotStore } from "../../copilot-page-store";
|
||||||
|
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
||||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
||||||
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
|
import { getCurrentSessionId } from "./helpers";
|
||||||
import {
|
import { useShellSessionList } from "./useShellSessionList";
|
||||||
checkReadyToShowContent,
|
|
||||||
convertSessionDetailToSummary,
|
|
||||||
filterVisibleSessions,
|
|
||||||
getCurrentSessionId,
|
|
||||||
mergeCurrentSessionIntoList,
|
|
||||||
} from "./helpers";
|
|
||||||
|
|
||||||
export function useCopilotShell() {
|
export function useCopilotShell() {
|
||||||
const pathname = usePathname();
|
const pathname = usePathname();
|
||||||
@@ -31,7 +27,7 @@ export function useCopilotShell() {
|
|||||||
const isMobile =
|
const isMobile =
|
||||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||||
|
|
||||||
const [, setUrlSessionId] = useQueryState("sessionId", parseAsString);
|
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
|
||||||
|
|
||||||
const isOnHomepage = pathname === "/copilot";
|
const isOnHomepage = pathname === "/copilot";
|
||||||
const paramSessionId = searchParams.get("sessionId");
|
const paramSessionId = searchParams.get("sessionId");
|
||||||
@@ -45,123 +41,80 @@ export function useCopilotShell() {
|
|||||||
|
|
||||||
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
|
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
|
||||||
|
|
||||||
const {
|
|
||||||
sessions: accumulatedSessions,
|
|
||||||
isLoading: isSessionsLoading,
|
|
||||||
isFetching: isSessionsFetching,
|
|
||||||
hasNextPage,
|
|
||||||
areAllSessionsLoaded,
|
|
||||||
fetchNextPage,
|
|
||||||
reset: resetPagination,
|
|
||||||
} = useSessionsPagination({
|
|
||||||
enabled: paginationEnabled,
|
|
||||||
});
|
|
||||||
|
|
||||||
const currentSessionId = getCurrentSessionId(searchParams);
|
const currentSessionId = getCurrentSessionId(searchParams);
|
||||||
|
|
||||||
const { data: currentSessionData, isLoading: isCurrentSessionLoading } =
|
const { data: currentSessionData } = useGetV2GetSession(
|
||||||
useGetV2GetSession(currentSessionId || "", {
|
currentSessionId || "",
|
||||||
|
{
|
||||||
query: {
|
query: {
|
||||||
enabled: !!currentSessionId,
|
enabled: !!currentSessionId,
|
||||||
select: okData,
|
select: okData,
|
||||||
},
|
},
|
||||||
});
|
},
|
||||||
|
|
||||||
const [hasAutoSelectedSession, setHasAutoSelectedSession] = useState(false);
|
|
||||||
const hasAutoSelectedRef = useRef(false);
|
|
||||||
const recentlyCreatedSessionsRef = useRef<
|
|
||||||
Map<string, SessionSummaryResponse>
|
|
||||||
>(new Map());
|
|
||||||
|
|
||||||
// Mark as auto-selected when sessionId is in URL
|
|
||||||
useEffect(() => {
|
|
||||||
if (paramSessionId && !hasAutoSelectedRef.current) {
|
|
||||||
hasAutoSelectedRef.current = true;
|
|
||||||
setHasAutoSelectedSession(true);
|
|
||||||
}
|
|
||||||
}, [paramSessionId]);
|
|
||||||
|
|
||||||
// On homepage without sessionId, mark as ready immediately
|
|
||||||
useEffect(() => {
|
|
||||||
if (isOnHomepage && !paramSessionId && !hasAutoSelectedRef.current) {
|
|
||||||
hasAutoSelectedRef.current = true;
|
|
||||||
setHasAutoSelectedSession(true);
|
|
||||||
}
|
|
||||||
}, [isOnHomepage, paramSessionId]);
|
|
||||||
|
|
||||||
// Invalidate sessions list when navigating to homepage (to show newly created sessions)
|
|
||||||
useEffect(() => {
|
|
||||||
if (isOnHomepage && !paramSessionId) {
|
|
||||||
queryClient.invalidateQueries({
|
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [isOnHomepage, paramSessionId, queryClient]);
|
|
||||||
|
|
||||||
// Track newly created sessions to ensure they stay visible even when switching away
|
|
||||||
useEffect(() => {
|
|
||||||
if (currentSessionId && currentSessionData) {
|
|
||||||
const isNewSession =
|
|
||||||
currentSessionData.updated_at === currentSessionData.created_at;
|
|
||||||
const isNotInAccumulated = !accumulatedSessions.some(
|
|
||||||
(s) => s.id === currentSessionId,
|
|
||||||
);
|
);
|
||||||
if (isNewSession || isNotInAccumulated) {
|
|
||||||
const summary = convertSessionDetailToSummary(currentSessionData);
|
|
||||||
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, [currentSessionId, currentSessionData, accumulatedSessions]);
|
|
||||||
|
|
||||||
// Clean up recently created sessions that are now in the accumulated list
|
const {
|
||||||
useEffect(() => {
|
sessions,
|
||||||
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
|
isLoading,
|
||||||
if (accumulatedSessions.some((s) => s.id === sessionId)) {
|
isSessionsFetching,
|
||||||
recentlyCreatedSessionsRef.current.delete(sessionId);
|
hasNextPage,
|
||||||
}
|
fetchNextPage,
|
||||||
}
|
resetPagination,
|
||||||
}, [accumulatedSessions]);
|
recentlyCreatedSessionsRef,
|
||||||
|
} = useShellSessionList({
|
||||||
// Reset pagination when query becomes disabled
|
paginationEnabled,
|
||||||
const prevPaginationEnabledRef = useRef(paginationEnabled);
|
|
||||||
useEffect(() => {
|
|
||||||
if (prevPaginationEnabledRef.current && !paginationEnabled) {
|
|
||||||
resetPagination();
|
|
||||||
resetAutoSelect();
|
|
||||||
}
|
|
||||||
prevPaginationEnabledRef.current = paginationEnabled;
|
|
||||||
}, [paginationEnabled, resetPagination]);
|
|
||||||
|
|
||||||
const sessions = mergeCurrentSessionIntoList(
|
|
||||||
accumulatedSessions,
|
|
||||||
currentSessionId,
|
currentSessionId,
|
||||||
currentSessionData,
|
currentSessionData,
|
||||||
recentlyCreatedSessionsRef.current,
|
isOnHomepage,
|
||||||
);
|
|
||||||
|
|
||||||
const visibleSessions = filterVisibleSessions(sessions);
|
|
||||||
|
|
||||||
const sidebarSelectedSessionId =
|
|
||||||
isOnHomepage && !paramSessionId ? null : currentSessionId;
|
|
||||||
|
|
||||||
const isReadyToShowContent = isOnHomepage
|
|
||||||
? true
|
|
||||||
: checkReadyToShowContent(
|
|
||||||
areAllSessionsLoaded,
|
|
||||||
paramSessionId,
|
paramSessionId,
|
||||||
accumulatedSessions,
|
});
|
||||||
isCurrentSessionLoading,
|
|
||||||
currentSessionData,
|
|
||||||
hasAutoSelectedSession,
|
|
||||||
);
|
|
||||||
|
|
||||||
function handleSelectSession(sessionId: string) {
|
const stopStream = useChatStore((s) => s.stopStream);
|
||||||
|
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||||
|
const isStreaming = useCopilotStore((s) => s.isStreaming);
|
||||||
|
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
||||||
|
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
|
||||||
|
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
|
||||||
|
|
||||||
|
const pendingActionRef = useRef<(() => void) | null>(null);
|
||||||
|
|
||||||
|
async function stopCurrentStream() {
|
||||||
|
if (!currentSessionId) return;
|
||||||
|
|
||||||
|
setIsSwitchingSession(true);
|
||||||
|
await new Promise<void>((resolve) => {
|
||||||
|
const unsubscribe = onStreamComplete((completedId) => {
|
||||||
|
if (completedId === currentSessionId) {
|
||||||
|
clearTimeout(timeout);
|
||||||
|
unsubscribe();
|
||||||
|
resolve();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const timeout = setTimeout(() => {
|
||||||
|
unsubscribe();
|
||||||
|
resolve();
|
||||||
|
}, 3000);
|
||||||
|
stopStream(currentSessionId);
|
||||||
|
});
|
||||||
|
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
|
||||||
|
});
|
||||||
|
setIsSwitchingSession(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
function selectSession(sessionId: string) {
|
||||||
|
if (sessionId === currentSessionId) return;
|
||||||
|
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||||
|
});
|
||||||
|
}
|
||||||
setUrlSessionId(sessionId, { shallow: false });
|
setUrlSessionId(sessionId, { shallow: false });
|
||||||
if (isMobile) handleCloseDrawer();
|
if (isMobile) handleCloseDrawer();
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleNewChat() {
|
function startNewChat() {
|
||||||
resetAutoSelect();
|
|
||||||
resetPagination();
|
resetPagination();
|
||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
@@ -170,12 +123,31 @@ export function useCopilotShell() {
|
|||||||
if (isMobile) handleCloseDrawer();
|
if (isMobile) handleCloseDrawer();
|
||||||
}
|
}
|
||||||
|
|
||||||
function resetAutoSelect() {
|
function handleSessionClick(sessionId: string) {
|
||||||
hasAutoSelectedRef.current = false;
|
if (sessionId === currentSessionId) return;
|
||||||
setHasAutoSelectedSession(false);
|
|
||||||
|
if (isStreaming) {
|
||||||
|
pendingActionRef.current = async () => {
|
||||||
|
await stopCurrentStream();
|
||||||
|
selectSession(sessionId);
|
||||||
|
};
|
||||||
|
openInterruptModal(pendingActionRef.current);
|
||||||
|
} else {
|
||||||
|
selectSession(sessionId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
|
function handleNewChatClick() {
|
||||||
|
if (isStreaming) {
|
||||||
|
pendingActionRef.current = async () => {
|
||||||
|
await stopCurrentStream();
|
||||||
|
startNewChat();
|
||||||
|
};
|
||||||
|
openInterruptModal(pendingActionRef.current);
|
||||||
|
} else {
|
||||||
|
startNewChat();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isMobile,
|
isMobile,
|
||||||
@@ -183,17 +155,17 @@ export function useCopilotShell() {
|
|||||||
isLoggedIn,
|
isLoggedIn,
|
||||||
hasActiveSession:
|
hasActiveSession:
|
||||||
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
|
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
|
||||||
isLoading,
|
isLoading: isLoading || isCreatingSession,
|
||||||
sessions: visibleSessions,
|
isCreatingSession,
|
||||||
currentSessionId: sidebarSelectedSessionId,
|
sessions,
|
||||||
handleSelectSession,
|
currentSessionId: urlSessionId,
|
||||||
handleOpenDrawer,
|
handleOpenDrawer,
|
||||||
handleCloseDrawer,
|
handleCloseDrawer,
|
||||||
handleDrawerOpenChange,
|
handleDrawerOpenChange,
|
||||||
handleNewChat,
|
handleNewChatClick,
|
||||||
|
handleSessionClick,
|
||||||
hasNextPage,
|
hasNextPage,
|
||||||
isFetchingNextPage: isSessionsFetching,
|
isFetchingNextPage: isSessionsFetching,
|
||||||
fetchNextPage,
|
fetchNextPage,
|
||||||
isReadyToShowContent,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,113 @@
|
|||||||
|
import { getGetV2ListSessionsQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
|
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||||
|
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||||
|
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
||||||
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
|
import { useEffect, useMemo, useRef } from "react";
|
||||||
|
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
|
||||||
|
import {
|
||||||
|
convertSessionDetailToSummary,
|
||||||
|
filterVisibleSessions,
|
||||||
|
mergeCurrentSessionIntoList,
|
||||||
|
} from "./helpers";
|
||||||
|
|
||||||
|
interface UseShellSessionListArgs {
|
||||||
|
paginationEnabled: boolean;
|
||||||
|
currentSessionId: string | null;
|
||||||
|
currentSessionData: SessionDetailResponse | null | undefined;
|
||||||
|
isOnHomepage: boolean;
|
||||||
|
paramSessionId: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useShellSessionList({
|
||||||
|
paginationEnabled,
|
||||||
|
currentSessionId,
|
||||||
|
currentSessionData,
|
||||||
|
isOnHomepage,
|
||||||
|
paramSessionId,
|
||||||
|
}: UseShellSessionListArgs) {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||||
|
|
||||||
|
const {
|
||||||
|
sessions: accumulatedSessions,
|
||||||
|
isLoading: isSessionsLoading,
|
||||||
|
isFetching: isSessionsFetching,
|
||||||
|
hasNextPage,
|
||||||
|
fetchNextPage,
|
||||||
|
reset: resetPagination,
|
||||||
|
} = useSessionsPagination({
|
||||||
|
enabled: paginationEnabled,
|
||||||
|
});
|
||||||
|
|
||||||
|
const recentlyCreatedSessionsRef = useRef<
|
||||||
|
Map<string, SessionSummaryResponse>
|
||||||
|
>(new Map());
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isOnHomepage && !paramSessionId) {
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [isOnHomepage, paramSessionId, queryClient]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (currentSessionId && currentSessionData) {
|
||||||
|
const isNewSession =
|
||||||
|
currentSessionData.updated_at === currentSessionData.created_at;
|
||||||
|
const isNotInAccumulated = !accumulatedSessions.some(
|
||||||
|
(s) => s.id === currentSessionId,
|
||||||
|
);
|
||||||
|
if (isNewSession || isNotInAccumulated) {
|
||||||
|
const summary = convertSessionDetailToSummary(currentSessionData);
|
||||||
|
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [currentSessionId, currentSessionData, accumulatedSessions]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
|
||||||
|
if (accumulatedSessions.some((s) => s.id === sessionId)) {
|
||||||
|
recentlyCreatedSessionsRef.current.delete(sessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [accumulatedSessions]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const unsubscribe = onStreamComplete(() => {
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
|
});
|
||||||
|
});
|
||||||
|
return unsubscribe;
|
||||||
|
}, [onStreamComplete, queryClient]);
|
||||||
|
|
||||||
|
const sessions = useMemo(
|
||||||
|
() =>
|
||||||
|
mergeCurrentSessionIntoList(
|
||||||
|
accumulatedSessions,
|
||||||
|
currentSessionId,
|
||||||
|
currentSessionData,
|
||||||
|
recentlyCreatedSessionsRef.current,
|
||||||
|
),
|
||||||
|
[accumulatedSessions, currentSessionId, currentSessionData],
|
||||||
|
);
|
||||||
|
|
||||||
|
const visibleSessions = useMemo(
|
||||||
|
() => filterVisibleSessions(sessions),
|
||||||
|
[sessions],
|
||||||
|
);
|
||||||
|
|
||||||
|
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
|
||||||
|
|
||||||
|
return {
|
||||||
|
sessions: visibleSessions,
|
||||||
|
isLoading,
|
||||||
|
isSessionsFetching,
|
||||||
|
hasNextPage,
|
||||||
|
fetchNextPage,
|
||||||
|
resetPagination,
|
||||||
|
recentlyCreatedSessionsRef,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -4,51 +4,53 @@ import { create } from "zustand";
|
|||||||
|
|
||||||
interface CopilotStoreState {
|
interface CopilotStoreState {
|
||||||
isStreaming: boolean;
|
isStreaming: boolean;
|
||||||
isNewChatModalOpen: boolean;
|
isSwitchingSession: boolean;
|
||||||
newChatHandler: (() => void) | null;
|
isCreatingSession: boolean;
|
||||||
|
isInterruptModalOpen: boolean;
|
||||||
|
pendingAction: (() => void) | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface CopilotStoreActions {
|
interface CopilotStoreActions {
|
||||||
setIsStreaming: (isStreaming: boolean) => void;
|
setIsStreaming: (isStreaming: boolean) => void;
|
||||||
setNewChatHandler: (handler: (() => void) | null) => void;
|
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
|
||||||
requestNewChat: () => void;
|
setIsCreatingSession: (isCreating: boolean) => void;
|
||||||
confirmNewChat: () => void;
|
openInterruptModal: (onConfirm: () => void) => void;
|
||||||
cancelNewChat: () => void;
|
confirmInterrupt: () => void;
|
||||||
|
cancelInterrupt: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
type CopilotStore = CopilotStoreState & CopilotStoreActions;
|
type CopilotStore = CopilotStoreState & CopilotStoreActions;
|
||||||
|
|
||||||
export const useCopilotStore = create<CopilotStore>((set, get) => ({
|
export const useCopilotStore = create<CopilotStore>((set, get) => ({
|
||||||
isStreaming: false,
|
isStreaming: false,
|
||||||
isNewChatModalOpen: false,
|
isSwitchingSession: false,
|
||||||
newChatHandler: null,
|
isCreatingSession: false,
|
||||||
|
isInterruptModalOpen: false,
|
||||||
|
pendingAction: null,
|
||||||
|
|
||||||
setIsStreaming(isStreaming) {
|
setIsStreaming(isStreaming) {
|
||||||
set({ isStreaming });
|
set({ isStreaming });
|
||||||
},
|
},
|
||||||
|
|
||||||
setNewChatHandler(handler) {
|
setIsSwitchingSession(isSwitchingSession) {
|
||||||
set({ newChatHandler: handler });
|
set({ isSwitchingSession });
|
||||||
},
|
},
|
||||||
|
|
||||||
requestNewChat() {
|
setIsCreatingSession(isCreatingSession) {
|
||||||
const { isStreaming, newChatHandler } = get();
|
set({ isCreatingSession });
|
||||||
if (isStreaming) {
|
|
||||||
set({ isNewChatModalOpen: true });
|
|
||||||
} else if (newChatHandler) {
|
|
||||||
newChatHandler();
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
|
||||||
confirmNewChat() {
|
openInterruptModal(onConfirm) {
|
||||||
const { newChatHandler } = get();
|
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
|
||||||
set({ isNewChatModalOpen: false });
|
|
||||||
if (newChatHandler) {
|
|
||||||
newChatHandler();
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
|
||||||
cancelNewChat() {
|
confirmInterrupt() {
|
||||||
set({ isNewChatModalOpen: false });
|
const { pendingAction } = get();
|
||||||
|
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||||
|
if (pendingAction) pendingAction();
|
||||||
|
},
|
||||||
|
|
||||||
|
cancelInterrupt() {
|
||||||
|
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|||||||
@@ -1,28 +1,5 @@
|
|||||||
import type { User } from "@supabase/supabase-js";
|
import type { User } from "@supabase/supabase-js";
|
||||||
|
|
||||||
export type PageState =
|
|
||||||
| { type: "welcome" }
|
|
||||||
| { type: "newChat" }
|
|
||||||
| { type: "creating"; prompt: string }
|
|
||||||
| { type: "chat"; sessionId: string; initialPrompt?: string };
|
|
||||||
|
|
||||||
export function getInitialPromptFromState(
|
|
||||||
pageState: PageState,
|
|
||||||
storedInitialPrompt: string | undefined,
|
|
||||||
) {
|
|
||||||
if (storedInitialPrompt) return storedInitialPrompt;
|
|
||||||
if (pageState.type === "creating") return pageState.prompt;
|
|
||||||
if (pageState.type === "chat") return pageState.initialPrompt;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function shouldResetToWelcome(pageState: PageState) {
|
|
||||||
return (
|
|
||||||
pageState.type !== "newChat" &&
|
|
||||||
pageState.type !== "creating" &&
|
|
||||||
pageState.type !== "welcome"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getGreetingName(user?: User | null): string {
|
export function getGreetingName(user?: User | null): string {
|
||||||
if (!user) return "there";
|
if (!user) return "there";
|
||||||
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
||||||
|
|||||||
@@ -1,25 +1,25 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
|
|
||||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
|
||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
import { useCopilotStore } from "./copilot-page-store";
|
import { useCopilotStore } from "./copilot-page-store";
|
||||||
import { useCopilotPage } from "./useCopilotPage";
|
import { useCopilotPage } from "./useCopilotPage";
|
||||||
|
|
||||||
export default function CopilotPage() {
|
export default function CopilotPage() {
|
||||||
const { state, handlers } = useCopilotPage();
|
const { state, handlers } = useCopilotPage();
|
||||||
const confirmNewChat = useCopilotStore((s) => s.confirmNewChat);
|
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||||
|
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||||
|
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||||
const {
|
const {
|
||||||
greetingName,
|
greetingName,
|
||||||
quickActions,
|
quickActions,
|
||||||
isLoading,
|
isLoading,
|
||||||
pageState,
|
hasSession,
|
||||||
isNewChatModalOpen,
|
initialPrompt,
|
||||||
isReady,
|
isReady,
|
||||||
} = state;
|
} = state;
|
||||||
const {
|
const {
|
||||||
@@ -27,20 +27,16 @@ export default function CopilotPage() {
|
|||||||
startChatWithPrompt,
|
startChatWithPrompt,
|
||||||
handleSessionNotFound,
|
handleSessionNotFound,
|
||||||
handleStreamingChange,
|
handleStreamingChange,
|
||||||
handleCancelNewChat,
|
|
||||||
handleNewChatModalOpen,
|
|
||||||
} = handlers;
|
} = handlers;
|
||||||
|
|
||||||
if (!isReady) return null;
|
if (!isReady) return null;
|
||||||
|
|
||||||
if (pageState.type === "chat") {
|
if (hasSession) {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full flex-col">
|
<div className="flex h-full flex-col">
|
||||||
<Chat
|
<Chat
|
||||||
key={pageState.sessionId ?? "welcome"}
|
|
||||||
className="flex-1"
|
className="flex-1"
|
||||||
urlSessionId={pageState.sessionId}
|
initialPrompt={initialPrompt}
|
||||||
initialPrompt={pageState.initialPrompt}
|
|
||||||
onSessionNotFound={handleSessionNotFound}
|
onSessionNotFound={handleSessionNotFound}
|
||||||
onStreamingChange={handleStreamingChange}
|
onStreamingChange={handleStreamingChange}
|
||||||
/>
|
/>
|
||||||
@@ -48,31 +44,33 @@ export default function CopilotPage() {
|
|||||||
title="Interrupt current chat?"
|
title="Interrupt current chat?"
|
||||||
styling={{ maxWidth: 300, width: "100%" }}
|
styling={{ maxWidth: 300, width: "100%" }}
|
||||||
controlled={{
|
controlled={{
|
||||||
isOpen: isNewChatModalOpen,
|
isOpen: isInterruptModalOpen,
|
||||||
set: handleNewChatModalOpen,
|
set: (open) => {
|
||||||
|
if (!open) cancelInterrupt();
|
||||||
|
},
|
||||||
}}
|
}}
|
||||||
onClose={handleCancelNewChat}
|
onClose={cancelInterrupt}
|
||||||
>
|
>
|
||||||
<Dialog.Content>
|
<Dialog.Content>
|
||||||
<div className="flex flex-col gap-4">
|
<div className="flex flex-col gap-4">
|
||||||
<Text variant="body">
|
<Text variant="body">
|
||||||
The current chat response will be interrupted. Are you sure you
|
The current chat response will be interrupted. Are you sure you
|
||||||
want to start a new chat?
|
want to continue?
|
||||||
</Text>
|
</Text>
|
||||||
<Dialog.Footer>
|
<Dialog.Footer>
|
||||||
<Button
|
<Button
|
||||||
type="button"
|
type="button"
|
||||||
variant="outline"
|
variant="outline"
|
||||||
onClick={handleCancelNewChat}
|
onClick={cancelInterrupt}
|
||||||
>
|
>
|
||||||
Cancel
|
Cancel
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
type="button"
|
type="button"
|
||||||
variant="primary"
|
variant="primary"
|
||||||
onClick={confirmNewChat}
|
onClick={confirmInterrupt}
|
||||||
>
|
>
|
||||||
Start new chat
|
Continue
|
||||||
</Button>
|
</Button>
|
||||||
</Dialog.Footer>
|
</Dialog.Footer>
|
||||||
</div>
|
</div>
|
||||||
@@ -82,19 +80,6 @@ export default function CopilotPage() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pageState.type === "newChat" || pageState.type === "creating") {
|
|
||||||
return (
|
|
||||||
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
|
||||||
<div className="flex flex-col items-center gap-4">
|
|
||||||
<ChatLoader />
|
|
||||||
<Text variant="body" className="text-zinc-500">
|
|
||||||
Loading your chats...
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
||||||
<div className="w-full text-center">
|
<div className="w-full text-center">
|
||||||
|
|||||||
@@ -5,79 +5,40 @@ import {
|
|||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
|
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||||
import {
|
import {
|
||||||
Flag,
|
Flag,
|
||||||
type FlagValues,
|
type FlagValues,
|
||||||
useGetFlag,
|
useGetFlag,
|
||||||
} from "@/services/feature-flags/use-get-flag";
|
} from "@/services/feature-flags/use-get-flag";
|
||||||
|
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { useEffect, useReducer } from "react";
|
import { useEffect } from "react";
|
||||||
import { useCopilotStore } from "./copilot-page-store";
|
import { useCopilotStore } from "./copilot-page-store";
|
||||||
import { getGreetingName, getQuickActions, type PageState } from "./helpers";
|
import { getGreetingName, getQuickActions } from "./helpers";
|
||||||
import { useCopilotURLState } from "./useCopilotURLState";
|
import { useCopilotSessionId } from "./useCopilotSessionId";
|
||||||
|
|
||||||
type CopilotState = {
|
|
||||||
pageState: PageState;
|
|
||||||
initialPrompts: Record<string, string>;
|
|
||||||
previousSessionId: string | null;
|
|
||||||
};
|
|
||||||
|
|
||||||
type CopilotAction =
|
|
||||||
| { type: "setPageState"; pageState: PageState }
|
|
||||||
| { type: "setInitialPrompt"; sessionId: string; prompt: string }
|
|
||||||
| { type: "setPreviousSessionId"; sessionId: string | null };
|
|
||||||
|
|
||||||
function isSamePageState(next: PageState, current: PageState) {
|
|
||||||
if (next.type !== current.type) return false;
|
|
||||||
if (next.type === "creating" && current.type === "creating") {
|
|
||||||
return next.prompt === current.prompt;
|
|
||||||
}
|
|
||||||
if (next.type === "chat" && current.type === "chat") {
|
|
||||||
return (
|
|
||||||
next.sessionId === current.sessionId &&
|
|
||||||
next.initialPrompt === current.initialPrompt
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
function copilotReducer(
|
|
||||||
state: CopilotState,
|
|
||||||
action: CopilotAction,
|
|
||||||
): CopilotState {
|
|
||||||
if (action.type === "setPageState") {
|
|
||||||
if (isSamePageState(action.pageState, state.pageState)) return state;
|
|
||||||
return { ...state, pageState: action.pageState };
|
|
||||||
}
|
|
||||||
if (action.type === "setInitialPrompt") {
|
|
||||||
if (state.initialPrompts[action.sessionId] === action.prompt) return state;
|
|
||||||
return {
|
|
||||||
...state,
|
|
||||||
initialPrompts: {
|
|
||||||
...state.initialPrompts,
|
|
||||||
[action.sessionId]: action.prompt,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
if (action.type === "setPreviousSessionId") {
|
|
||||||
if (state.previousSessionId === action.sessionId) return state;
|
|
||||||
return { ...state, previousSessionId: action.sessionId };
|
|
||||||
}
|
|
||||||
return state;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useCopilotPage() {
|
export function useCopilotPage() {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
||||||
const { toast } = useToast();
|
const { toast } = useToast();
|
||||||
|
const { completeStep } = useOnboarding();
|
||||||
|
|
||||||
const isNewChatModalOpen = useCopilotStore((s) => s.isNewChatModalOpen);
|
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
|
||||||
const setIsStreaming = useCopilotStore((s) => s.setIsStreaming);
|
const setIsStreaming = useCopilotStore((s) => s.setIsStreaming);
|
||||||
const cancelNewChat = useCopilotStore((s) => s.cancelNewChat);
|
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
||||||
|
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
||||||
|
|
||||||
|
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
|
||||||
|
useEffect(() => {
|
||||||
|
if (isLoggedIn) {
|
||||||
|
completeStep("VISIT_COPILOT");
|
||||||
|
}
|
||||||
|
}, [completeStep, isLoggedIn]);
|
||||||
|
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
const flags = useFlags<FlagValues>();
|
const flags = useFlags<FlagValues>();
|
||||||
@@ -88,72 +49,27 @@ export function useCopilotPage() {
|
|||||||
const isFlagReady =
|
const isFlagReady =
|
||||||
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
||||||
|
|
||||||
const [state, dispatch] = useReducer(copilotReducer, {
|
|
||||||
pageState: { type: "welcome" },
|
|
||||||
initialPrompts: {},
|
|
||||||
previousSessionId: null,
|
|
||||||
});
|
|
||||||
|
|
||||||
const greetingName = getGreetingName(user);
|
const greetingName = getGreetingName(user);
|
||||||
const quickActions = getQuickActions();
|
const quickActions = getQuickActions();
|
||||||
|
|
||||||
function setPageState(pageState: PageState) {
|
const hasSession = Boolean(urlSessionId);
|
||||||
dispatch({ type: "setPageState", pageState });
|
const initialPrompt = urlSessionId
|
||||||
}
|
? getInitialPrompt(urlSessionId)
|
||||||
|
: undefined;
|
||||||
|
|
||||||
function setInitialPrompt(sessionId: string, prompt: string) {
|
useEffect(() => {
|
||||||
dispatch({ type: "setInitialPrompt", sessionId, prompt });
|
|
||||||
}
|
|
||||||
|
|
||||||
function setPreviousSessionId(sessionId: string | null) {
|
|
||||||
dispatch({ type: "setPreviousSessionId", sessionId });
|
|
||||||
}
|
|
||||||
|
|
||||||
const { setUrlSessionId } = useCopilotURLState({
|
|
||||||
pageState: state.pageState,
|
|
||||||
initialPrompts: state.initialPrompts,
|
|
||||||
previousSessionId: state.previousSessionId,
|
|
||||||
setPageState,
|
|
||||||
setInitialPrompt,
|
|
||||||
setPreviousSessionId,
|
|
||||||
});
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function transitionNewChatToWelcome() {
|
|
||||||
if (state.pageState.type === "newChat") {
|
|
||||||
function setWelcomeState() {
|
|
||||||
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
|
|
||||||
}
|
|
||||||
|
|
||||||
const timer = setTimeout(setWelcomeState, 300);
|
|
||||||
|
|
||||||
return function cleanup() {
|
|
||||||
clearTimeout(timer);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[state.pageState.type],
|
|
||||||
);
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function ensureAccess() {
|
|
||||||
if (!isFlagReady) return;
|
if (!isFlagReady) return;
|
||||||
if (isChatEnabled === false) {
|
if (isChatEnabled === false) {
|
||||||
router.replace(homepageRoute);
|
router.replace(homepageRoute);
|
||||||
}
|
}
|
||||||
},
|
}, [homepageRoute, isChatEnabled, isFlagReady, router]);
|
||||||
[homepageRoute, isChatEnabled, isFlagReady, router],
|
|
||||||
);
|
|
||||||
|
|
||||||
async function startChatWithPrompt(prompt: string) {
|
async function startChatWithPrompt(prompt: string) {
|
||||||
if (!prompt?.trim()) return;
|
if (!prompt?.trim()) return;
|
||||||
if (state.pageState.type === "creating") return;
|
if (isCreating) return;
|
||||||
|
|
||||||
const trimmedPrompt = prompt.trim();
|
const trimmedPrompt = prompt.trim();
|
||||||
dispatch({
|
setIsCreating(true);
|
||||||
type: "setPageState",
|
|
||||||
pageState: { type: "creating", prompt: trimmedPrompt },
|
|
||||||
});
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const sessionResponse = await postV2CreateSession({
|
const sessionResponse = await postV2CreateSession({
|
||||||
@@ -165,27 +81,19 @@ export function useCopilotPage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const sessionId = sessionResponse.data.id;
|
const sessionId = sessionResponse.data.id;
|
||||||
|
setInitialPrompt(sessionId, trimmedPrompt);
|
||||||
dispatch({
|
|
||||||
type: "setInitialPrompt",
|
|
||||||
sessionId,
|
|
||||||
prompt: trimmedPrompt,
|
|
||||||
});
|
|
||||||
|
|
||||||
await queryClient.invalidateQueries({
|
await queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
});
|
});
|
||||||
|
|
||||||
await setUrlSessionId(sessionId, { shallow: false });
|
await setUrlSessionId(sessionId, { shallow: true });
|
||||||
dispatch({
|
|
||||||
type: "setPageState",
|
|
||||||
pageState: { type: "chat", sessionId, initialPrompt: trimmedPrompt },
|
|
||||||
});
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("[CopilotPage] Failed to start chat:", error);
|
console.error("[CopilotPage] Failed to start chat:", error);
|
||||||
toast({ title: "Failed to start chat", variant: "destructive" });
|
toast({ title: "Failed to start chat", variant: "destructive" });
|
||||||
Sentry.captureException(error);
|
Sentry.captureException(error);
|
||||||
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
|
} finally {
|
||||||
|
setIsCreating(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,21 +109,13 @@ export function useCopilotPage() {
|
|||||||
setIsStreaming(isStreamingValue);
|
setIsStreaming(isStreamingValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleCancelNewChat() {
|
|
||||||
cancelNewChat();
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleNewChatModalOpen(isOpen: boolean) {
|
|
||||||
if (!isOpen) cancelNewChat();
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
state: {
|
state: {
|
||||||
greetingName,
|
greetingName,
|
||||||
quickActions,
|
quickActions,
|
||||||
isLoading: isUserLoading,
|
isLoading: isUserLoading,
|
||||||
pageState: state.pageState,
|
hasSession,
|
||||||
isNewChatModalOpen,
|
initialPrompt,
|
||||||
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
|
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
|
||||||
},
|
},
|
||||||
handlers: {
|
handlers: {
|
||||||
@@ -223,8 +123,32 @@ export function useCopilotPage() {
|
|||||||
startChatWithPrompt,
|
startChatWithPrompt,
|
||||||
handleSessionNotFound,
|
handleSessionNotFound,
|
||||||
handleStreamingChange,
|
handleStreamingChange,
|
||||||
handleCancelNewChat,
|
|
||||||
handleNewChatModalOpen,
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getInitialPrompt(sessionId: string): string | undefined {
|
||||||
|
try {
|
||||||
|
const prompts = JSON.parse(
|
||||||
|
sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}",
|
||||||
|
);
|
||||||
|
return prompts[sessionId];
|
||||||
|
} catch {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function setInitialPrompt(sessionId: string, prompt: string): void {
|
||||||
|
try {
|
||||||
|
const prompts = JSON.parse(
|
||||||
|
sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}",
|
||||||
|
);
|
||||||
|
prompts[sessionId] = prompt;
|
||||||
|
sessionStorage.set(
|
||||||
|
SessionKey.CHAT_INITIAL_PROMPTS,
|
||||||
|
JSON.stringify(prompts),
|
||||||
|
);
|
||||||
|
} catch {
|
||||||
|
// Ignore storage errors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,10 @@
|
|||||||
|
import { parseAsString, useQueryState } from "nuqs";
|
||||||
|
|
||||||
|
export function useCopilotSessionId() {
|
||||||
|
const [urlSessionId, setUrlSessionId] = useQueryState(
|
||||||
|
"sessionId",
|
||||||
|
parseAsString,
|
||||||
|
);
|
||||||
|
|
||||||
|
return { urlSessionId, setUrlSessionId };
|
||||||
|
}
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
import { parseAsString, useQueryState } from "nuqs";
|
|
||||||
import { useLayoutEffect } from "react";
|
|
||||||
import {
|
|
||||||
getInitialPromptFromState,
|
|
||||||
type PageState,
|
|
||||||
shouldResetToWelcome,
|
|
||||||
} from "./helpers";
|
|
||||||
|
|
||||||
interface UseCopilotUrlStateArgs {
|
|
||||||
pageState: PageState;
|
|
||||||
initialPrompts: Record<string, string>;
|
|
||||||
previousSessionId: string | null;
|
|
||||||
setPageState: (pageState: PageState) => void;
|
|
||||||
setInitialPrompt: (sessionId: string, prompt: string) => void;
|
|
||||||
setPreviousSessionId: (sessionId: string | null) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useCopilotURLState({
|
|
||||||
pageState,
|
|
||||||
initialPrompts,
|
|
||||||
previousSessionId,
|
|
||||||
setPageState,
|
|
||||||
setInitialPrompt,
|
|
||||||
setPreviousSessionId,
|
|
||||||
}: UseCopilotUrlStateArgs) {
|
|
||||||
const [urlSessionId, setUrlSessionId] = useQueryState(
|
|
||||||
"sessionId",
|
|
||||||
parseAsString,
|
|
||||||
);
|
|
||||||
|
|
||||||
function syncSessionFromUrl() {
|
|
||||||
if (urlSessionId) {
|
|
||||||
if (pageState.type === "chat" && pageState.sessionId === urlSessionId) {
|
|
||||||
setPreviousSessionId(urlSessionId);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const storedInitialPrompt = initialPrompts[urlSessionId];
|
|
||||||
const currentInitialPrompt = getInitialPromptFromState(
|
|
||||||
pageState,
|
|
||||||
storedInitialPrompt,
|
|
||||||
);
|
|
||||||
|
|
||||||
if (currentInitialPrompt) {
|
|
||||||
setInitialPrompt(urlSessionId, currentInitialPrompt);
|
|
||||||
}
|
|
||||||
|
|
||||||
setPageState({
|
|
||||||
type: "chat",
|
|
||||||
sessionId: urlSessionId,
|
|
||||||
initialPrompt: currentInitialPrompt,
|
|
||||||
});
|
|
||||||
setPreviousSessionId(urlSessionId);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const wasInChat = previousSessionId !== null && pageState.type === "chat";
|
|
||||||
setPreviousSessionId(null);
|
|
||||||
if (wasInChat) {
|
|
||||||
setPageState({ type: "newChat" });
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (shouldResetToWelcome(pageState)) {
|
|
||||||
setPageState({ type: "welcome" });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
useLayoutEffect(syncSessionFromUrl, [
|
|
||||||
urlSessionId,
|
|
||||||
pageState.type,
|
|
||||||
previousSessionId,
|
|
||||||
initialPrompts,
|
|
||||||
]);
|
|
||||||
|
|
||||||
return {
|
|
||||||
urlSessionId,
|
|
||||||
setUrlSessionId,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,10 +1,11 @@
|
|||||||
"use server";
|
"use server";
|
||||||
|
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import BackendAPI from "@/lib/autogpt-server-api";
|
import BackendAPI from "@/lib/autogpt-server-api";
|
||||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||||
import { loginFormSchema } from "@/types/auth";
|
import { loginFormSchema } from "@/types/auth";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
import { shouldShowOnboarding } from "../../api/helpers";
|
import { getOnboardingStatus } from "../../api/helpers";
|
||||||
|
|
||||||
export async function login(email: string, password: string) {
|
export async function login(email: string, password: string) {
|
||||||
try {
|
try {
|
||||||
@@ -36,11 +37,15 @@ export async function login(email: string, password: string) {
|
|||||||
const api = new BackendAPI();
|
const api = new BackendAPI();
|
||||||
await api.createUser();
|
await api.createUser();
|
||||||
|
|
||||||
const onboarding = await shouldShowOnboarding();
|
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||||
|
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
|
||||||
|
const next = shouldShowOnboarding
|
||||||
|
? "/onboarding"
|
||||||
|
: getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
success: true,
|
success: true,
|
||||||
onboarding,
|
next,
|
||||||
};
|
};
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
Sentry.captureException(err);
|
Sentry.captureException(err);
|
||||||
|
|||||||
@@ -97,13 +97,8 @@ export function useLoginPage() {
|
|||||||
throw new Error(result.error || "Login failed");
|
throw new Error(result.error || "Login failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nextUrl) {
|
// Prefer URL's next parameter, then use backend-determined route
|
||||||
router.replace(nextUrl);
|
router.replace(nextUrl || result.next || homepageRoute);
|
||||||
} else if (result.onboarding) {
|
|
||||||
router.replace("/onboarding");
|
|
||||||
} else {
|
|
||||||
router.replace(homepageRoute);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
toast({
|
toast({
|
||||||
title:
|
title:
|
||||||
|
|||||||
@@ -5,14 +5,13 @@ import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
|||||||
import { signupFormSchema } from "@/types/auth";
|
import { signupFormSchema } from "@/types/auth";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
import { isWaitlistError, logWaitlistError } from "../../api/auth/utils";
|
import { isWaitlistError, logWaitlistError } from "../../api/auth/utils";
|
||||||
import { shouldShowOnboarding } from "../../api/helpers";
|
import { getOnboardingStatus } from "../../api/helpers";
|
||||||
|
|
||||||
export async function signup(
|
export async function signup(
|
||||||
email: string,
|
email: string,
|
||||||
password: string,
|
password: string,
|
||||||
confirmPassword: string,
|
confirmPassword: string,
|
||||||
agreeToTerms: boolean,
|
agreeToTerms: boolean,
|
||||||
isChatEnabled: boolean,
|
|
||||||
) {
|
) {
|
||||||
try {
|
try {
|
||||||
const parsed = signupFormSchema.safeParse({
|
const parsed = signupFormSchema.safeParse({
|
||||||
@@ -59,8 +58,9 @@ export async function signup(
|
|||||||
await supabase.auth.setSession(data.session);
|
await supabase.auth.setSession(data.session);
|
||||||
}
|
}
|
||||||
|
|
||||||
const isOnboardingEnabled = await shouldShowOnboarding();
|
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||||
const next = isOnboardingEnabled
|
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
|
||||||
|
const next = shouldShowOnboarding
|
||||||
? "/onboarding"
|
? "/onboarding"
|
||||||
: getHomepageRoute(isChatEnabled);
|
: getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
|
|||||||
@@ -108,7 +108,6 @@ export function useSignupPage() {
|
|||||||
data.password,
|
data.password,
|
||||||
data.confirmPassword,
|
data.confirmPassword,
|
||||||
data.agreeToTerms,
|
data.agreeToTerms,
|
||||||
isChatEnabled === true,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
setIsLoading(false);
|
setIsLoading(false);
|
||||||
|
|||||||
@@ -175,9 +175,12 @@ export async function resolveResponse<
|
|||||||
return res.data;
|
return res.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function shouldShowOnboarding() {
|
export async function getOnboardingStatus() {
|
||||||
const isEnabled = await resolveResponse(getV1IsOnboardingEnabled());
|
const status = await resolveResponse(getV1IsOnboardingEnabled());
|
||||||
const onboarding = await resolveResponse(getV1OnboardingState());
|
const onboarding = await resolveResponse(getV1OnboardingState());
|
||||||
const isCompleted = onboarding.completedSteps.includes("CONGRATS");
|
const isCompleted = onboarding.completedSteps.includes("CONGRATS");
|
||||||
return isEnabled && !isCompleted;
|
return {
|
||||||
|
shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted,
|
||||||
|
isChatEnabled: status.is_chat_enabled,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3339,7 +3339,7 @@
|
|||||||
"get": {
|
"get": {
|
||||||
"tags": ["v2", "library", "private"],
|
"tags": ["v2", "library", "private"],
|
||||||
"summary": "List Library Agents",
|
"summary": "List Library Agents",
|
||||||
"description": "Get all agents in the user's library (both created and saved).\n\nArgs:\n user_id: ID of the authenticated user.\n search_term: Optional search term to filter agents by name/description.\n filter_by: List of filters to apply (favorites, created by user).\n sort_by: List of sorting criteria (created date, updated date).\n page: Page number to retrieve.\n page_size: Number of agents per page.\n\nReturns:\n A LibraryAgentResponse containing agents and pagination metadata.\n\nRaises:\n HTTPException: If a server/database error occurs.",
|
"description": "Get all agents in the user's library (both created and saved).",
|
||||||
"operationId": "getV2List library agents",
|
"operationId": "getV2List library agents",
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
"security": [{ "HTTPBearerJWT": [] }],
|
||||||
"parameters": [
|
"parameters": [
|
||||||
@@ -3394,7 +3394,7 @@
|
|||||||
],
|
],
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
"description": "List of library agents",
|
"description": "Successful Response",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
@@ -3413,17 +3413,13 @@
|
|||||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
|
||||||
"500": {
|
|
||||||
"description": "Server error",
|
|
||||||
"content": { "application/json": {} }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"post": {
|
"post": {
|
||||||
"tags": ["v2", "library", "private"],
|
"tags": ["v2", "library", "private"],
|
||||||
"summary": "Add Marketplace Agent",
|
"summary": "Add Marketplace Agent",
|
||||||
"description": "Add an agent from the marketplace to the user's library.\n\nArgs:\n store_listing_version_id: ID of the store listing version to add.\n user_id: ID of the authenticated user.\n\nReturns:\n library_model.LibraryAgent: Agent added to the library\n\nRaises:\n HTTPException(404): If the listing version is not found.\n HTTPException(500): If a server/database error occurs.",
|
"description": "Add an agent from the marketplace to the user's library.",
|
||||||
"operationId": "postV2Add marketplace agent",
|
"operationId": "postV2Add marketplace agent",
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
"security": [{ "HTTPBearerJWT": [] }],
|
||||||
"requestBody": {
|
"requestBody": {
|
||||||
@@ -3438,7 +3434,7 @@
|
|||||||
},
|
},
|
||||||
"responses": {
|
"responses": {
|
||||||
"201": {
|
"201": {
|
||||||
"description": "Agent added successfully",
|
"description": "Successful Response",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": { "$ref": "#/components/schemas/LibraryAgent" }
|
"schema": { "$ref": "#/components/schemas/LibraryAgent" }
|
||||||
@@ -3448,7 +3444,6 @@
|
|||||||
"401": {
|
"401": {
|
||||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||||
},
|
},
|
||||||
"404": { "description": "Store listing version not found" },
|
|
||||||
"422": {
|
"422": {
|
||||||
"description": "Validation Error",
|
"description": "Validation Error",
|
||||||
"content": {
|
"content": {
|
||||||
@@ -3456,8 +3451,7 @@
|
|||||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"500": { "description": "Server error" }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -3511,7 +3505,7 @@
|
|||||||
"get": {
|
"get": {
|
||||||
"tags": ["v2", "library", "private"],
|
"tags": ["v2", "library", "private"],
|
||||||
"summary": "List Favorite Library Agents",
|
"summary": "List Favorite Library Agents",
|
||||||
"description": "Get all favorite agents in the user's library.\n\nArgs:\n user_id: ID of the authenticated user.\n page: Page number to retrieve.\n page_size: Number of agents per page.\n\nReturns:\n A LibraryAgentResponse containing favorite agents and pagination metadata.\n\nRaises:\n HTTPException: If a server/database error occurs.",
|
"description": "Get all favorite agents in the user's library.",
|
||||||
"operationId": "getV2List favorite library agents",
|
"operationId": "getV2List favorite library agents",
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
"security": [{ "HTTPBearerJWT": [] }],
|
||||||
"parameters": [
|
"parameters": [
|
||||||
@@ -3563,10 +3557,6 @@
|
|||||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
|
||||||
"500": {
|
|
||||||
"description": "Server error",
|
|
||||||
"content": { "application/json": {} }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3588,7 +3578,7 @@
|
|||||||
],
|
],
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
"description": "Library agent found",
|
"description": "Successful Response",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
@@ -3604,7 +3594,6 @@
|
|||||||
"401": {
|
"401": {
|
||||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||||
},
|
},
|
||||||
"404": { "description": "Agent not found" },
|
|
||||||
"422": {
|
"422": {
|
||||||
"description": "Validation Error",
|
"description": "Validation Error",
|
||||||
"content": {
|
"content": {
|
||||||
@@ -3620,7 +3609,7 @@
|
|||||||
"delete": {
|
"delete": {
|
||||||
"tags": ["v2", "library", "private"],
|
"tags": ["v2", "library", "private"],
|
||||||
"summary": "Delete Library Agent",
|
"summary": "Delete Library Agent",
|
||||||
"description": "Soft-delete the specified library agent.\n\nArgs:\n library_agent_id: ID of the library agent to delete.\n user_id: ID of the authenticated user.\n\nReturns:\n 204 No Content if successful.\n\nRaises:\n HTTPException(404): If the agent does not exist.\n HTTPException(500): If a server/database error occurs.",
|
"description": "Soft-delete the specified library agent.",
|
||||||
"operationId": "deleteV2Delete library agent",
|
"operationId": "deleteV2Delete library agent",
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
"security": [{ "HTTPBearerJWT": [] }],
|
||||||
"parameters": [
|
"parameters": [
|
||||||
@@ -3636,11 +3625,9 @@
|
|||||||
"description": "Successful Response",
|
"description": "Successful Response",
|
||||||
"content": { "application/json": { "schema": {} } }
|
"content": { "application/json": { "schema": {} } }
|
||||||
},
|
},
|
||||||
"204": { "description": "Agent deleted successfully" },
|
|
||||||
"401": {
|
"401": {
|
||||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||||
},
|
},
|
||||||
"404": { "description": "Agent not found" },
|
|
||||||
"422": {
|
"422": {
|
||||||
"description": "Validation Error",
|
"description": "Validation Error",
|
||||||
"content": {
|
"content": {
|
||||||
@@ -3648,8 +3635,7 @@
|
|||||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"500": { "description": "Server error" }
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"get": {
|
"get": {
|
||||||
@@ -3690,7 +3676,7 @@
|
|||||||
"patch": {
|
"patch": {
|
||||||
"tags": ["v2", "library", "private"],
|
"tags": ["v2", "library", "private"],
|
||||||
"summary": "Update Library Agent",
|
"summary": "Update Library Agent",
|
||||||
"description": "Update the library agent with the given fields.\n\nArgs:\n library_agent_id: ID of the library agent to update.\n payload: Fields to update (auto_update_version, is_favorite, etc.).\n user_id: ID of the authenticated user.\n\nRaises:\n HTTPException(500): If a server/database error occurs.",
|
"description": "Update the library agent with the given fields.",
|
||||||
"operationId": "patchV2Update library agent",
|
"operationId": "patchV2Update library agent",
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
"security": [{ "HTTPBearerJWT": [] }],
|
||||||
"parameters": [
|
"parameters": [
|
||||||
@@ -3713,7 +3699,7 @@
|
|||||||
},
|
},
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
"description": "Agent updated successfully",
|
"description": "Successful Response",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": { "$ref": "#/components/schemas/LibraryAgent" }
|
"schema": { "$ref": "#/components/schemas/LibraryAgent" }
|
||||||
@@ -3730,8 +3716,7 @@
|
|||||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"500": { "description": "Server error" }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -4540,8 +4525,7 @@
|
|||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "boolean",
|
"$ref": "#/components/schemas/OnboardingStatusResponse"
|
||||||
"title": "Response Getv1Is Onboarding Enabled"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -4594,6 +4578,7 @@
|
|||||||
"AGENT_NEW_RUN",
|
"AGENT_NEW_RUN",
|
||||||
"AGENT_INPUT",
|
"AGENT_INPUT",
|
||||||
"CONGRATS",
|
"CONGRATS",
|
||||||
|
"VISIT_COPILOT",
|
||||||
"MARKETPLACE_VISIT",
|
"MARKETPLACE_VISIT",
|
||||||
"BUILDER_OPEN"
|
"BUILDER_OPEN"
|
||||||
],
|
],
|
||||||
@@ -8744,6 +8729,19 @@
|
|||||||
"title": "OAuthApplicationPublicInfo",
|
"title": "OAuthApplicationPublicInfo",
|
||||||
"description": "Public information about an OAuth application (for consent screen)"
|
"description": "Public information about an OAuth application (for consent screen)"
|
||||||
},
|
},
|
||||||
|
"OnboardingStatusResponse": {
|
||||||
|
"properties": {
|
||||||
|
"is_onboarding_enabled": {
|
||||||
|
"type": "boolean",
|
||||||
|
"title": "Is Onboarding Enabled"
|
||||||
|
},
|
||||||
|
"is_chat_enabled": { "type": "boolean", "title": "Is Chat Enabled" }
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"required": ["is_onboarding_enabled", "is_chat_enabled"],
|
||||||
|
"title": "OnboardingStatusResponse",
|
||||||
|
"description": "Response for onboarding status check."
|
||||||
|
},
|
||||||
"OnboardingStep": {
|
"OnboardingStep": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": [
|
"enum": [
|
||||||
@@ -8754,6 +8752,7 @@
|
|||||||
"AGENT_NEW_RUN",
|
"AGENT_NEW_RUN",
|
||||||
"AGENT_INPUT",
|
"AGENT_INPUT",
|
||||||
"CONGRATS",
|
"CONGRATS",
|
||||||
|
"VISIT_COPILOT",
|
||||||
"GET_RESULTS",
|
"GET_RESULTS",
|
||||||
"MARKETPLACE_VISIT",
|
"MARKETPLACE_VISIT",
|
||||||
"MARKETPLACE_ADD_AGENT",
|
"MARKETPLACE_ADD_AGENT",
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
|
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
|
||||||
|
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
|
||||||
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { useEffect, useRef } from "react";
|
import { useEffect, useRef } from "react";
|
||||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||||
import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState";
|
import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState";
|
||||||
import { ChatLoader } from "./components/ChatLoader/ChatLoader";
|
|
||||||
import { useChat } from "./useChat";
|
import { useChat } from "./useChat";
|
||||||
|
|
||||||
export interface ChatProps {
|
export interface ChatProps {
|
||||||
className?: string;
|
className?: string;
|
||||||
urlSessionId?: string | null;
|
|
||||||
initialPrompt?: string;
|
initialPrompt?: string;
|
||||||
onSessionNotFound?: () => void;
|
onSessionNotFound?: () => void;
|
||||||
onStreamingChange?: (isStreaming: boolean) => void;
|
onStreamingChange?: (isStreaming: boolean) => void;
|
||||||
@@ -18,12 +19,13 @@ export interface ChatProps {
|
|||||||
|
|
||||||
export function Chat({
|
export function Chat({
|
||||||
className,
|
className,
|
||||||
urlSessionId,
|
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
onSessionNotFound,
|
onSessionNotFound,
|
||||||
onStreamingChange,
|
onStreamingChange,
|
||||||
}: ChatProps) {
|
}: ChatProps) {
|
||||||
|
const { urlSessionId } = useCopilotSessionId();
|
||||||
const hasHandledNotFoundRef = useRef(false);
|
const hasHandledNotFoundRef = useRef(false);
|
||||||
|
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
|
||||||
const {
|
const {
|
||||||
messages,
|
messages,
|
||||||
isLoading,
|
isLoading,
|
||||||
@@ -33,49 +35,59 @@ export function Chat({
|
|||||||
sessionId,
|
sessionId,
|
||||||
createSession,
|
createSession,
|
||||||
showLoader,
|
showLoader,
|
||||||
|
startPollingForOperation,
|
||||||
} = useChat({ urlSessionId });
|
} = useChat({ urlSessionId });
|
||||||
|
|
||||||
useEffect(
|
useEffect(() => {
|
||||||
function handleMissingSession() {
|
|
||||||
if (!onSessionNotFound) return;
|
if (!onSessionNotFound) return;
|
||||||
if (!urlSessionId) return;
|
if (!urlSessionId) return;
|
||||||
if (!isSessionNotFound || isLoading || isCreating) return;
|
if (!isSessionNotFound || isLoading || isCreating) return;
|
||||||
if (hasHandledNotFoundRef.current) return;
|
if (hasHandledNotFoundRef.current) return;
|
||||||
hasHandledNotFoundRef.current = true;
|
hasHandledNotFoundRef.current = true;
|
||||||
onSessionNotFound();
|
onSessionNotFound();
|
||||||
},
|
}, [
|
||||||
[onSessionNotFound, urlSessionId, isSessionNotFound, isLoading, isCreating],
|
onSessionNotFound,
|
||||||
);
|
urlSessionId,
|
||||||
|
isSessionNotFound,
|
||||||
|
isLoading,
|
||||||
|
isCreating,
|
||||||
|
]);
|
||||||
|
|
||||||
|
const shouldShowLoader =
|
||||||
|
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={cn("flex h-full flex-col", className)}>
|
<div className={cn("flex h-full flex-col", className)}>
|
||||||
{/* Main Content */}
|
{/* Main Content */}
|
||||||
<main className="flex min-h-0 w-full flex-1 flex-col overflow-hidden bg-[#f8f8f9]">
|
<main className="flex min-h-0 w-full flex-1 flex-col overflow-hidden bg-[#f8f8f9]">
|
||||||
{/* Loading State */}
|
{/* Loading State */}
|
||||||
{showLoader && (isLoading || isCreating) && (
|
{shouldShowLoader && (
|
||||||
<div className="flex flex-1 items-center justify-center">
|
<div className="flex flex-1 items-center justify-center">
|
||||||
<div className="flex flex-col items-center gap-4">
|
<div className="flex flex-col items-center gap-3">
|
||||||
<ChatLoader />
|
<LoadingSpinner size="large" className="text-neutral-400" />
|
||||||
<Text variant="body" className="text-zinc-500">
|
<Text variant="body" className="text-zinc-500">
|
||||||
Loading your chats...
|
{isSwitchingSession
|
||||||
|
? "Switching chat..."
|
||||||
|
: "Loading your chat..."}
|
||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Error State */}
|
{/* Error State */}
|
||||||
{error && !isLoading && (
|
{error && !isLoading && !isSwitchingSession && (
|
||||||
<ChatErrorState error={error} onRetry={createSession} />
|
<ChatErrorState error={error} onRetry={createSession} />
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Session Content */}
|
{/* Session Content */}
|
||||||
{sessionId && !isLoading && !error && (
|
{sessionId && !isLoading && !error && !isSwitchingSession && (
|
||||||
<ChatContainer
|
<ChatContainer
|
||||||
sessionId={sessionId}
|
sessionId={sessionId}
|
||||||
initialMessages={messages}
|
initialMessages={messages}
|
||||||
initialPrompt={initialPrompt}
|
initialPrompt={initialPrompt}
|
||||||
className="flex-1"
|
className="flex-1"
|
||||||
onStreamingChange={onStreamingChange}
|
onStreamingChange={onStreamingChange}
|
||||||
|
onOperationStarted={startPollingForOperation}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</main>
|
</main>
|
||||||
|
|||||||
@@ -58,39 +58,17 @@ function notifyStreamComplete(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function cleanupCompletedStreams(completedStreams: Map<string, StreamResult>) {
|
function cleanupExpiredStreams(
|
||||||
const now = Date.now();
|
|
||||||
for (const [sessionId, result] of completedStreams) {
|
|
||||||
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
|
|
||||||
completedStreams.delete(sessionId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function moveToCompleted(
|
|
||||||
activeStreams: Map<string, ActiveStream>,
|
|
||||||
completedStreams: Map<string, StreamResult>,
|
completedStreams: Map<string, StreamResult>,
|
||||||
streamCompleteCallbacks: Set<StreamCompleteCallback>,
|
): Map<string, StreamResult> {
|
||||||
sessionId: string,
|
const now = Date.now();
|
||||||
) {
|
const cleaned = new Map(completedStreams);
|
||||||
const stream = activeStreams.get(sessionId);
|
for (const [sessionId, result] of cleaned) {
|
||||||
if (!stream) return;
|
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
|
||||||
|
cleaned.delete(sessionId);
|
||||||
const result: StreamResult = {
|
|
||||||
sessionId,
|
|
||||||
status: stream.status,
|
|
||||||
chunks: stream.chunks,
|
|
||||||
completedAt: Date.now(),
|
|
||||||
error: stream.error,
|
|
||||||
};
|
|
||||||
|
|
||||||
completedStreams.set(sessionId, result);
|
|
||||||
activeStreams.delete(sessionId);
|
|
||||||
cleanupCompletedStreams(completedStreams);
|
|
||||||
|
|
||||||
if (stream.status === "completed" || stream.status === "error") {
|
|
||||||
notifyStreamComplete(streamCompleteCallbacks, sessionId);
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
return cleaned;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const useChatStore = create<ChatStore>((set, get) => ({
|
export const useChatStore = create<ChatStore>((set, get) => ({
|
||||||
@@ -106,17 +84,31 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
context,
|
context,
|
||||||
onChunk,
|
onChunk,
|
||||||
) {
|
) {
|
||||||
const { activeStreams, completedStreams, streamCompleteCallbacks } = get();
|
const state = get();
|
||||||
|
const newActiveStreams = new Map(state.activeStreams);
|
||||||
|
let newCompletedStreams = new Map(state.completedStreams);
|
||||||
|
const callbacks = state.streamCompleteCallbacks;
|
||||||
|
|
||||||
const existingStream = activeStreams.get(sessionId);
|
const existingStream = newActiveStreams.get(sessionId);
|
||||||
if (existingStream) {
|
if (existingStream) {
|
||||||
existingStream.abortController.abort();
|
existingStream.abortController.abort();
|
||||||
moveToCompleted(
|
const normalizedStatus =
|
||||||
activeStreams,
|
existingStream.status === "streaming"
|
||||||
completedStreams,
|
? "completed"
|
||||||
streamCompleteCallbacks,
|
: existingStream.status;
|
||||||
|
const result: StreamResult = {
|
||||||
sessionId,
|
sessionId,
|
||||||
);
|
status: normalizedStatus,
|
||||||
|
chunks: existingStream.chunks,
|
||||||
|
completedAt: Date.now(),
|
||||||
|
error: existingStream.error,
|
||||||
|
};
|
||||||
|
newCompletedStreams.set(sessionId, result);
|
||||||
|
newActiveStreams.delete(sessionId);
|
||||||
|
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||||
|
if (normalizedStatus === "completed" || normalizedStatus === "error") {
|
||||||
|
notifyStreamComplete(callbacks, sessionId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const abortController = new AbortController();
|
const abortController = new AbortController();
|
||||||
@@ -132,36 +124,76 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
onChunkCallbacks: initialCallbacks,
|
onChunkCallbacks: initialCallbacks,
|
||||||
};
|
};
|
||||||
|
|
||||||
activeStreams.set(sessionId, stream);
|
newActiveStreams.set(sessionId, stream);
|
||||||
|
set({
|
||||||
|
activeStreams: newActiveStreams,
|
||||||
|
completedStreams: newCompletedStreams,
|
||||||
|
});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await executeStream(stream, message, isUserMessage, context);
|
await executeStream(stream, message, isUserMessage, context);
|
||||||
} finally {
|
} finally {
|
||||||
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
||||||
if (stream.status !== "streaming") {
|
if (stream.status !== "streaming") {
|
||||||
moveToCompleted(
|
const currentState = get();
|
||||||
activeStreams,
|
const finalActiveStreams = new Map(currentState.activeStreams);
|
||||||
completedStreams,
|
let finalCompletedStreams = new Map(currentState.completedStreams);
|
||||||
streamCompleteCallbacks,
|
|
||||||
|
const storedStream = finalActiveStreams.get(sessionId);
|
||||||
|
if (storedStream === stream) {
|
||||||
|
const result: StreamResult = {
|
||||||
|
sessionId,
|
||||||
|
status: stream.status,
|
||||||
|
chunks: stream.chunks,
|
||||||
|
completedAt: Date.now(),
|
||||||
|
error: stream.error,
|
||||||
|
};
|
||||||
|
finalCompletedStreams.set(sessionId, result);
|
||||||
|
finalActiveStreams.delete(sessionId);
|
||||||
|
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
|
||||||
|
set({
|
||||||
|
activeStreams: finalActiveStreams,
|
||||||
|
completedStreams: finalCompletedStreams,
|
||||||
|
});
|
||||||
|
if (stream.status === "completed" || stream.status === "error") {
|
||||||
|
notifyStreamComplete(
|
||||||
|
currentState.streamCompleteCallbacks,
|
||||||
sessionId,
|
sessionId,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
stopStream: function stopStream(sessionId) {
|
stopStream: function stopStream(sessionId) {
|
||||||
const { activeStreams, completedStreams, streamCompleteCallbacks } = get();
|
const state = get();
|
||||||
const stream = activeStreams.get(sessionId);
|
const stream = state.activeStreams.get(sessionId);
|
||||||
if (stream) {
|
if (!stream) return;
|
||||||
|
|
||||||
stream.abortController.abort();
|
stream.abortController.abort();
|
||||||
stream.status = "completed";
|
stream.status = "completed";
|
||||||
moveToCompleted(
|
|
||||||
activeStreams,
|
const newActiveStreams = new Map(state.activeStreams);
|
||||||
completedStreams,
|
let newCompletedStreams = new Map(state.completedStreams);
|
||||||
streamCompleteCallbacks,
|
|
||||||
|
const result: StreamResult = {
|
||||||
sessionId,
|
sessionId,
|
||||||
);
|
status: stream.status,
|
||||||
}
|
chunks: stream.chunks,
|
||||||
|
completedAt: Date.now(),
|
||||||
|
error: stream.error,
|
||||||
|
};
|
||||||
|
newCompletedStreams.set(sessionId, result);
|
||||||
|
newActiveStreams.delete(sessionId);
|
||||||
|
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||||
|
|
||||||
|
set({
|
||||||
|
activeStreams: newActiveStreams,
|
||||||
|
completedStreams: newCompletedStreams,
|
||||||
|
});
|
||||||
|
|
||||||
|
notifyStreamComplete(state.streamCompleteCallbacks, sessionId);
|
||||||
},
|
},
|
||||||
|
|
||||||
subscribeToStream: function subscribeToStream(
|
subscribeToStream: function subscribeToStream(
|
||||||
@@ -169,16 +201,18 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
onChunk,
|
onChunk,
|
||||||
skipReplay = false,
|
skipReplay = false,
|
||||||
) {
|
) {
|
||||||
const { activeStreams } = get();
|
const state = get();
|
||||||
|
const stream = state.activeStreams.get(sessionId);
|
||||||
|
|
||||||
const stream = activeStreams.get(sessionId);
|
|
||||||
if (stream) {
|
if (stream) {
|
||||||
if (!skipReplay) {
|
if (!skipReplay) {
|
||||||
for (const chunk of stream.chunks) {
|
for (const chunk of stream.chunks) {
|
||||||
onChunk(chunk);
|
onChunk(chunk);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stream.onChunkCallbacks.add(onChunk);
|
stream.onChunkCallbacks.add(onChunk);
|
||||||
|
|
||||||
return function unsubscribe() {
|
return function unsubscribe() {
|
||||||
stream.onChunkCallbacks.delete(onChunk);
|
stream.onChunkCallbacks.delete(onChunk);
|
||||||
};
|
};
|
||||||
@@ -204,7 +238,12 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
},
|
},
|
||||||
|
|
||||||
clearCompletedStream: function clearCompletedStream(sessionId) {
|
clearCompletedStream: function clearCompletedStream(sessionId) {
|
||||||
get().completedStreams.delete(sessionId);
|
const state = get();
|
||||||
|
if (!state.completedStreams.has(sessionId)) return;
|
||||||
|
|
||||||
|
const newCompletedStreams = new Map(state.completedStreams);
|
||||||
|
newCompletedStreams.delete(sessionId);
|
||||||
|
set({ completedStreams: newCompletedStreams });
|
||||||
},
|
},
|
||||||
|
|
||||||
isStreaming: function isStreaming(sessionId) {
|
isStreaming: function isStreaming(sessionId) {
|
||||||
@@ -213,11 +252,21 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
},
|
},
|
||||||
|
|
||||||
registerActiveSession: function registerActiveSession(sessionId) {
|
registerActiveSession: function registerActiveSession(sessionId) {
|
||||||
get().activeSessions.add(sessionId);
|
const state = get();
|
||||||
|
if (state.activeSessions.has(sessionId)) return;
|
||||||
|
|
||||||
|
const newActiveSessions = new Set(state.activeSessions);
|
||||||
|
newActiveSessions.add(sessionId);
|
||||||
|
set({ activeSessions: newActiveSessions });
|
||||||
},
|
},
|
||||||
|
|
||||||
unregisterActiveSession: function unregisterActiveSession(sessionId) {
|
unregisterActiveSession: function unregisterActiveSession(sessionId) {
|
||||||
get().activeSessions.delete(sessionId);
|
const state = get();
|
||||||
|
if (!state.activeSessions.has(sessionId)) return;
|
||||||
|
|
||||||
|
const newActiveSessions = new Set(state.activeSessions);
|
||||||
|
newActiveSessions.delete(sessionId);
|
||||||
|
set({ activeSessions: newActiveSessions });
|
||||||
},
|
},
|
||||||
|
|
||||||
isSessionActive: function isSessionActive(sessionId) {
|
isSessionActive: function isSessionActive(sessionId) {
|
||||||
@@ -225,10 +274,16 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
},
|
},
|
||||||
|
|
||||||
onStreamComplete: function onStreamComplete(callback) {
|
onStreamComplete: function onStreamComplete(callback) {
|
||||||
const { streamCompleteCallbacks } = get();
|
const state = get();
|
||||||
streamCompleteCallbacks.add(callback);
|
const newCallbacks = new Set(state.streamCompleteCallbacks);
|
||||||
|
newCallbacks.add(callback);
|
||||||
|
set({ streamCompleteCallbacks: newCallbacks });
|
||||||
|
|
||||||
return function unsubscribe() {
|
return function unsubscribe() {
|
||||||
streamCompleteCallbacks.delete(callback);
|
const currentState = get();
|
||||||
|
const cleanedCallbacks = new Set(currentState.streamCompleteCallbacks);
|
||||||
|
cleanedCallbacks.delete(callback);
|
||||||
|
set({ streamCompleteCallbacks: cleanedCallbacks });
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ export interface ChatContainerProps {
|
|||||||
initialPrompt?: string;
|
initialPrompt?: string;
|
||||||
className?: string;
|
className?: string;
|
||||||
onStreamingChange?: (isStreaming: boolean) => void;
|
onStreamingChange?: (isStreaming: boolean) => void;
|
||||||
|
onOperationStarted?: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ChatContainer({
|
export function ChatContainer({
|
||||||
@@ -24,6 +25,7 @@ export function ChatContainer({
|
|||||||
initialPrompt,
|
initialPrompt,
|
||||||
className,
|
className,
|
||||||
onStreamingChange,
|
onStreamingChange,
|
||||||
|
onOperationStarted,
|
||||||
}: ChatContainerProps) {
|
}: ChatContainerProps) {
|
||||||
const {
|
const {
|
||||||
messages,
|
messages,
|
||||||
@@ -38,6 +40,7 @@ export function ChatContainer({
|
|||||||
sessionId,
|
sessionId,
|
||||||
initialMessages,
|
initialMessages,
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
|
onOperationStarted,
|
||||||
});
|
});
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ export interface HandlerDependencies {
|
|||||||
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
||||||
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
|
onOperationStarted?: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
||||||
@@ -48,6 +49,15 @@ export function handleTextEnded(
|
|||||||
const completedText = deps.streamingChunksRef.current.join("");
|
const completedText = deps.streamingChunksRef.current.join("");
|
||||||
if (completedText.trim()) {
|
if (completedText.trim()) {
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => {
|
||||||
|
// Check if this exact message already exists to prevent duplicates
|
||||||
|
const exists = prev.some(
|
||||||
|
(msg) =>
|
||||||
|
msg.type === "message" &&
|
||||||
|
msg.role === "assistant" &&
|
||||||
|
msg.content === completedText,
|
||||||
|
);
|
||||||
|
if (exists) return prev;
|
||||||
|
|
||||||
const assistantMessage: ChatMessageData = {
|
const assistantMessage: ChatMessageData = {
|
||||||
type: "message",
|
type: "message",
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
@@ -154,6 +164,11 @@ export function handleToolResponse(
|
|||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// Trigger polling when operation_started is received
|
||||||
|
if (responseMessage.type === "operation_started") {
|
||||||
|
deps.onOperationStarted?.();
|
||||||
|
}
|
||||||
|
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => {
|
||||||
const toolCallIndex = prev.findIndex(
|
const toolCallIndex = prev.findIndex(
|
||||||
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
||||||
@@ -203,13 +218,24 @@ export function handleStreamEnd(
|
|||||||
]);
|
]);
|
||||||
}
|
}
|
||||||
if (completedContent.trim()) {
|
if (completedContent.trim()) {
|
||||||
|
deps.setMessages((prev) => {
|
||||||
|
// Check if this exact message already exists to prevent duplicates
|
||||||
|
const exists = prev.some(
|
||||||
|
(msg) =>
|
||||||
|
msg.type === "message" &&
|
||||||
|
msg.role === "assistant" &&
|
||||||
|
msg.content === completedContent,
|
||||||
|
);
|
||||||
|
if (exists) return prev;
|
||||||
|
|
||||||
const assistantMessage: ChatMessageData = {
|
const assistantMessage: ChatMessageData = {
|
||||||
type: "message",
|
type: "message",
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: completedContent,
|
content: completedContent,
|
||||||
timestamp: new Date(),
|
timestamp: new Date(),
|
||||||
};
|
};
|
||||||
deps.setMessages((prev) => [...prev, assistantMessage]);
|
return [...prev, assistantMessage];
|
||||||
|
});
|
||||||
}
|
}
|
||||||
deps.setStreamingChunks([]);
|
deps.setStreamingChunks([]);
|
||||||
deps.streamingChunksRef.current = [];
|
deps.streamingChunksRef.current = [];
|
||||||
|
|||||||
@@ -304,6 +304,7 @@ export function parseToolResponse(
|
|||||||
if (isAgentArray(agentsData)) {
|
if (isAgentArray(agentsData)) {
|
||||||
return {
|
return {
|
||||||
type: "agent_carousel",
|
type: "agent_carousel",
|
||||||
|
toolId,
|
||||||
toolName: "agent_carousel",
|
toolName: "agent_carousel",
|
||||||
agents: agentsData,
|
agents: agentsData,
|
||||||
totalCount: parsedResult.total_count as number | undefined,
|
totalCount: parsedResult.total_count as number | undefined,
|
||||||
@@ -316,6 +317,7 @@ export function parseToolResponse(
|
|||||||
if (responseType === "execution_started") {
|
if (responseType === "execution_started") {
|
||||||
return {
|
return {
|
||||||
type: "execution_started",
|
type: "execution_started",
|
||||||
|
toolId,
|
||||||
toolName: "execution_started",
|
toolName: "execution_started",
|
||||||
executionId: (parsedResult.execution_id as string) || "",
|
executionId: (parsedResult.execution_id as string) || "",
|
||||||
agentName: (parsedResult.graph_name as string) || undefined,
|
agentName: (parsedResult.graph_name as string) || undefined,
|
||||||
@@ -341,6 +343,41 @@ export function parseToolResponse(
|
|||||||
timestamp: timestamp || new Date(),
|
timestamp: timestamp || new Date(),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
if (responseType === "operation_started") {
|
||||||
|
return {
|
||||||
|
type: "operation_started",
|
||||||
|
toolName: (parsedResult.tool_name as string) || toolName,
|
||||||
|
toolId,
|
||||||
|
operationId: (parsedResult.operation_id as string) || "",
|
||||||
|
message:
|
||||||
|
(parsedResult.message as string) ||
|
||||||
|
"Operation started. You can close this tab.",
|
||||||
|
timestamp: timestamp || new Date(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if (responseType === "operation_pending") {
|
||||||
|
return {
|
||||||
|
type: "operation_pending",
|
||||||
|
toolName: (parsedResult.tool_name as string) || toolName,
|
||||||
|
toolId,
|
||||||
|
operationId: (parsedResult.operation_id as string) || "",
|
||||||
|
message:
|
||||||
|
(parsedResult.message as string) ||
|
||||||
|
"Operation in progress. Please wait...",
|
||||||
|
timestamp: timestamp || new Date(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if (responseType === "operation_in_progress") {
|
||||||
|
return {
|
||||||
|
type: "operation_in_progress",
|
||||||
|
toolName: (parsedResult.tool_name as string) || toolName,
|
||||||
|
toolCallId: (parsedResult.tool_call_id as string) || toolId,
|
||||||
|
message:
|
||||||
|
(parsedResult.message as string) ||
|
||||||
|
"Operation already in progress. Please wait...",
|
||||||
|
timestamp: timestamp || new Date(),
|
||||||
|
};
|
||||||
|
}
|
||||||
if (responseType === "need_login") {
|
if (responseType === "need_login") {
|
||||||
return {
|
return {
|
||||||
type: "login_needed",
|
type: "login_needed",
|
||||||
|
|||||||
@@ -14,16 +14,40 @@ import {
|
|||||||
processInitialMessages,
|
processInitialMessages,
|
||||||
} from "./helpers";
|
} from "./helpers";
|
||||||
|
|
||||||
|
// Helper to generate deduplication key for a message
|
||||||
|
function getMessageKey(msg: ChatMessageData): string {
|
||||||
|
if (msg.type === "message") {
|
||||||
|
// Don't include timestamp - dedupe by role + content only
|
||||||
|
// This handles the case where local and server timestamps differ
|
||||||
|
// Server messages are authoritative, so duplicates from local state are filtered
|
||||||
|
return `msg:${msg.role}:${msg.content}`;
|
||||||
|
} else if (msg.type === "tool_call") {
|
||||||
|
return `toolcall:${msg.toolId}`;
|
||||||
|
} else if (msg.type === "tool_response") {
|
||||||
|
return `toolresponse:${(msg as any).toolId}`;
|
||||||
|
} else if (
|
||||||
|
msg.type === "operation_started" ||
|
||||||
|
msg.type === "operation_pending" ||
|
||||||
|
msg.type === "operation_in_progress"
|
||||||
|
) {
|
||||||
|
return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`;
|
||||||
|
} else {
|
||||||
|
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
interface Args {
|
interface Args {
|
||||||
sessionId: string | null;
|
sessionId: string | null;
|
||||||
initialMessages: SessionDetailResponse["messages"];
|
initialMessages: SessionDetailResponse["messages"];
|
||||||
initialPrompt?: string;
|
initialPrompt?: string;
|
||||||
|
onOperationStarted?: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useChatContainer({
|
export function useChatContainer({
|
||||||
sessionId,
|
sessionId,
|
||||||
initialMessages,
|
initialMessages,
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
|
onOperationStarted,
|
||||||
}: Args) {
|
}: Args) {
|
||||||
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
||||||
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
||||||
@@ -73,20 +97,102 @@ export function useChatContainer({
|
|||||||
setIsRegionBlockedModalOpen,
|
setIsRegionBlockedModalOpen,
|
||||||
sessionId,
|
sessionId,
|
||||||
setIsStreamingInitiated,
|
setIsStreamingInitiated,
|
||||||
|
onOperationStarted,
|
||||||
});
|
});
|
||||||
|
|
||||||
setIsStreamingInitiated(true);
|
setIsStreamingInitiated(true);
|
||||||
const skipReplay = initialMessages.length > 0;
|
const skipReplay = initialMessages.length > 0;
|
||||||
return subscribeToStream(sessionId, dispatcher, skipReplay);
|
return subscribeToStream(sessionId, dispatcher, skipReplay);
|
||||||
},
|
},
|
||||||
[sessionId, stopStreaming, activeStreams, subscribeToStream],
|
[
|
||||||
|
sessionId,
|
||||||
|
stopStreaming,
|
||||||
|
activeStreams,
|
||||||
|
subscribeToStream,
|
||||||
|
onOperationStarted,
|
||||||
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
const allMessages = useMemo(
|
// Collect toolIds from completed tool results in initialMessages
|
||||||
() => [...processInitialMessages(initialMessages), ...messages],
|
// Used to filter out operation messages when their results arrive
|
||||||
[initialMessages, messages],
|
const completedToolIds = useMemo(() => {
|
||||||
|
const processedInitial = processInitialMessages(initialMessages);
|
||||||
|
const ids = new Set<string>();
|
||||||
|
for (const msg of processedInitial) {
|
||||||
|
if (
|
||||||
|
msg.type === "tool_response" ||
|
||||||
|
msg.type === "agent_carousel" ||
|
||||||
|
msg.type === "execution_started"
|
||||||
|
) {
|
||||||
|
const toolId = (msg as any).toolId;
|
||||||
|
if (toolId) {
|
||||||
|
ids.add(toolId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ids;
|
||||||
|
}, [initialMessages]);
|
||||||
|
|
||||||
|
// Clean up local operation messages when their completed results arrive from polling
|
||||||
|
// This effect runs when completedToolIds changes (i.e., when polling brings new results)
|
||||||
|
useEffect(
|
||||||
|
function cleanupCompletedOperations() {
|
||||||
|
if (completedToolIds.size === 0) return;
|
||||||
|
|
||||||
|
setMessages((prev) => {
|
||||||
|
const filtered = prev.filter((msg) => {
|
||||||
|
if (
|
||||||
|
msg.type === "operation_started" ||
|
||||||
|
msg.type === "operation_pending" ||
|
||||||
|
msg.type === "operation_in_progress"
|
||||||
|
) {
|
||||||
|
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
||||||
|
if (toolId && completedToolIds.has(toolId)) {
|
||||||
|
return false; // Remove - operation completed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
// Only update state if something was actually filtered
|
||||||
|
return filtered.length === prev.length ? prev : filtered;
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[completedToolIds],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Combine initial messages from backend with local streaming messages,
|
||||||
|
// Server messages maintain correct order; only append truly new local messages
|
||||||
|
const allMessages = useMemo(() => {
|
||||||
|
const processedInitial = processInitialMessages(initialMessages);
|
||||||
|
|
||||||
|
// Build a set of keys from server messages for deduplication
|
||||||
|
const serverKeys = new Set<string>();
|
||||||
|
for (const msg of processedInitial) {
|
||||||
|
serverKeys.add(getMessageKey(msg));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter local messages: remove duplicates and completed operation messages
|
||||||
|
const newLocalMessages = messages.filter((msg) => {
|
||||||
|
// Remove operation messages for completed tools
|
||||||
|
if (
|
||||||
|
msg.type === "operation_started" ||
|
||||||
|
msg.type === "operation_pending" ||
|
||||||
|
msg.type === "operation_in_progress"
|
||||||
|
) {
|
||||||
|
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
||||||
|
if (toolId && completedToolIds.has(toolId)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Remove messages that already exist in server data
|
||||||
|
const key = getMessageKey(msg);
|
||||||
|
return !serverKeys.has(key);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Server messages first (correct order), then new local messages
|
||||||
|
return [...processedInitial, ...newLocalMessages];
|
||||||
|
}, [initialMessages, messages, completedToolIds]);
|
||||||
|
|
||||||
async function sendMessage(
|
async function sendMessage(
|
||||||
content: string,
|
content: string,
|
||||||
isUserMessage: boolean = true,
|
isUserMessage: boolean = true,
|
||||||
@@ -118,6 +224,7 @@ export function useChatContainer({
|
|||||||
setIsRegionBlockedModalOpen,
|
setIsRegionBlockedModalOpen,
|
||||||
sessionId,
|
sessionId,
|
||||||
setIsStreamingInitiated,
|
setIsStreamingInitiated,
|
||||||
|
onOperationStarted,
|
||||||
});
|
});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import { AuthPromptWidget } from "../AuthPromptWidget/AuthPromptWidget";
|
|||||||
import { ChatCredentialsSetup } from "../ChatCredentialsSetup/ChatCredentialsSetup";
|
import { ChatCredentialsSetup } from "../ChatCredentialsSetup/ChatCredentialsSetup";
|
||||||
import { ClarificationQuestionsWidget } from "../ClarificationQuestionsWidget/ClarificationQuestionsWidget";
|
import { ClarificationQuestionsWidget } from "../ClarificationQuestionsWidget/ClarificationQuestionsWidget";
|
||||||
import { ExecutionStartedMessage } from "../ExecutionStartedMessage/ExecutionStartedMessage";
|
import { ExecutionStartedMessage } from "../ExecutionStartedMessage/ExecutionStartedMessage";
|
||||||
|
import { PendingOperationWidget } from "../PendingOperationWidget/PendingOperationWidget";
|
||||||
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
||||||
import { NoResultsMessage } from "../NoResultsMessage/NoResultsMessage";
|
import { NoResultsMessage } from "../NoResultsMessage/NoResultsMessage";
|
||||||
import { ToolCallMessage } from "../ToolCallMessage/ToolCallMessage";
|
import { ToolCallMessage } from "../ToolCallMessage/ToolCallMessage";
|
||||||
@@ -71,6 +72,9 @@ export function ChatMessage({
|
|||||||
isLoginNeeded,
|
isLoginNeeded,
|
||||||
isCredentialsNeeded,
|
isCredentialsNeeded,
|
||||||
isClarificationNeeded,
|
isClarificationNeeded,
|
||||||
|
isOperationStarted,
|
||||||
|
isOperationPending,
|
||||||
|
isOperationInProgress,
|
||||||
} = useChatMessage(message);
|
} = useChatMessage(message);
|
||||||
const displayContent = getDisplayContent(message, isUser);
|
const displayContent = getDisplayContent(message, isUser);
|
||||||
|
|
||||||
@@ -126,10 +130,6 @@ export function ChatMessage({
|
|||||||
[displayContent, message],
|
[displayContent, message],
|
||||||
);
|
);
|
||||||
|
|
||||||
function isLongResponse(content: string): boolean {
|
|
||||||
return content.split("\n").length > 5;
|
|
||||||
}
|
|
||||||
|
|
||||||
const handleTryAgain = useCallback(() => {
|
const handleTryAgain = useCallback(() => {
|
||||||
if (message.type !== "message" || !onSendMessage) return;
|
if (message.type !== "message" || !onSendMessage) return;
|
||||||
onSendMessage(message.content, message.role === "user");
|
onSendMessage(message.content, message.role === "user");
|
||||||
@@ -294,6 +294,42 @@ export function ChatMessage({
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Render operation_started messages (long-running background operations)
|
||||||
|
if (isOperationStarted && message.type === "operation_started") {
|
||||||
|
return (
|
||||||
|
<PendingOperationWidget
|
||||||
|
status="started"
|
||||||
|
message={message.message}
|
||||||
|
toolName={message.toolName}
|
||||||
|
className={className}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render operation_pending messages (operations in progress when refreshing)
|
||||||
|
if (isOperationPending && message.type === "operation_pending") {
|
||||||
|
return (
|
||||||
|
<PendingOperationWidget
|
||||||
|
status="pending"
|
||||||
|
message={message.message}
|
||||||
|
toolName={message.toolName}
|
||||||
|
className={className}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render operation_in_progress messages (duplicate request while operation running)
|
||||||
|
if (isOperationInProgress && message.type === "operation_in_progress") {
|
||||||
|
return (
|
||||||
|
<PendingOperationWidget
|
||||||
|
status="in_progress"
|
||||||
|
message={message.message}
|
||||||
|
toolName={message.toolName}
|
||||||
|
className={className}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Render tool response messages (but skip agent_output if it's being rendered inside assistant message)
|
// Render tool response messages (but skip agent_output if it's being rendered inside assistant message)
|
||||||
if (isToolResponse && message.type === "tool_response") {
|
if (isToolResponse && message.type === "tool_response") {
|
||||||
return (
|
return (
|
||||||
@@ -358,7 +394,7 @@ export function ChatMessage({
|
|||||||
<ArrowsClockwiseIcon className="size-4 text-zinc-600" />
|
<ArrowsClockwiseIcon className="size-4 text-zinc-600" />
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
{!isUser && isFinalMessage && isLongResponse(displayContent) && (
|
{!isUser && isFinalMessage && !isStreaming && (
|
||||||
<Button
|
<Button
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
size="icon"
|
size="icon"
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ export type ChatMessageData =
|
|||||||
}
|
}
|
||||||
| {
|
| {
|
||||||
type: "agent_carousel";
|
type: "agent_carousel";
|
||||||
|
toolId: string;
|
||||||
toolName: string;
|
toolName: string;
|
||||||
agents: Array<{
|
agents: Array<{
|
||||||
id: string;
|
id: string;
|
||||||
@@ -74,6 +75,7 @@ export type ChatMessageData =
|
|||||||
}
|
}
|
||||||
| {
|
| {
|
||||||
type: "execution_started";
|
type: "execution_started";
|
||||||
|
toolId: string;
|
||||||
toolName: string;
|
toolName: string;
|
||||||
executionId: string;
|
executionId: string;
|
||||||
agentName?: string;
|
agentName?: string;
|
||||||
@@ -103,6 +105,29 @@ export type ChatMessageData =
|
|||||||
message: string;
|
message: string;
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
timestamp?: string | Date;
|
timestamp?: string | Date;
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: "operation_started";
|
||||||
|
toolName: string;
|
||||||
|
toolId: string;
|
||||||
|
operationId: string;
|
||||||
|
message: string;
|
||||||
|
timestamp?: string | Date;
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: "operation_pending";
|
||||||
|
toolName: string;
|
||||||
|
toolId: string;
|
||||||
|
operationId: string;
|
||||||
|
message: string;
|
||||||
|
timestamp?: string | Date;
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: "operation_in_progress";
|
||||||
|
toolName: string;
|
||||||
|
toolCallId: string;
|
||||||
|
message: string;
|
||||||
|
timestamp?: string | Date;
|
||||||
};
|
};
|
||||||
|
|
||||||
export function useChatMessage(message: ChatMessageData) {
|
export function useChatMessage(message: ChatMessageData) {
|
||||||
@@ -124,5 +149,8 @@ export function useChatMessage(message: ChatMessageData) {
|
|||||||
isExecutionStarted: message.type === "execution_started",
|
isExecutionStarted: message.type === "execution_started",
|
||||||
isInputsNeeded: message.type === "inputs_needed",
|
isInputsNeeded: message.type === "inputs_needed",
|
||||||
isClarificationNeeded: message.type === "clarification_needed",
|
isClarificationNeeded: message.type === "clarification_needed",
|
||||||
|
isOperationStarted: message.type === "operation_started",
|
||||||
|
isOperationPending: message.type === "operation_pending",
|
||||||
|
isOperationInProgress: message.type === "operation_in_progress",
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ export function ClarificationQuestionsWidget({
|
|||||||
className,
|
className,
|
||||||
}: Props) {
|
}: Props) {
|
||||||
const [answers, setAnswers] = useState<Record<string, string>>({});
|
const [answers, setAnswers] = useState<Record<string, string>>({});
|
||||||
|
const [isSubmitted, setIsSubmitted] = useState(false);
|
||||||
|
|
||||||
function handleAnswerChange(keyword: string, value: string) {
|
function handleAnswerChange(keyword: string, value: string) {
|
||||||
setAnswers((prev) => ({ ...prev, [keyword]: value }));
|
setAnswers((prev) => ({ ...prev, [keyword]: value }));
|
||||||
@@ -41,11 +42,42 @@ export function ClarificationQuestionsWidget({
|
|||||||
if (!allAnswered) {
|
if (!allAnswered) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
setIsSubmitted(true);
|
||||||
onSubmitAnswers(answers);
|
onSubmitAnswers(answers);
|
||||||
}
|
}
|
||||||
|
|
||||||
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
|
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
|
||||||
|
|
||||||
|
// Show submitted state after answers are submitted
|
||||||
|
if (isSubmitted) {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"group relative flex w-full justify-start gap-3 px-4 py-3",
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<div className="flex w-full max-w-3xl gap-3">
|
||||||
|
<div className="flex-shrink-0">
|
||||||
|
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-green-500">
|
||||||
|
<CheckCircleIcon className="h-4 w-4 text-white" weight="bold" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex min-w-0 flex-1 flex-col">
|
||||||
|
<Card className="p-4">
|
||||||
|
<Text variant="h4" className="mb-1 text-slate-900">
|
||||||
|
Answers submitted
|
||||||
|
</Text>
|
||||||
|
<Text variant="small" className="text-slate-600">
|
||||||
|
Processing your responses...
|
||||||
|
</Text>
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
|
|||||||
@@ -0,0 +1,109 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { Card } from "@/components/atoms/Card/Card";
|
||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { CircleNotch, CheckCircle, XCircle } from "@phosphor-icons/react";
|
||||||
|
|
||||||
|
type OperationStatus =
|
||||||
|
| "pending"
|
||||||
|
| "started"
|
||||||
|
| "in_progress"
|
||||||
|
| "completed"
|
||||||
|
| "error";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
status: OperationStatus;
|
||||||
|
message: string;
|
||||||
|
toolName?: string;
|
||||||
|
className?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getOperationTitle(toolName?: string): string {
|
||||||
|
if (!toolName) return "Operation";
|
||||||
|
// Convert tool name to human-readable format
|
||||||
|
// e.g., "create_agent" -> "Creating Agent", "edit_agent" -> "Editing Agent"
|
||||||
|
if (toolName === "create_agent") return "Creating Agent";
|
||||||
|
if (toolName === "edit_agent") return "Editing Agent";
|
||||||
|
// Default: capitalize and format tool name
|
||||||
|
return toolName
|
||||||
|
.split("_")
|
||||||
|
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
|
||||||
|
.join(" ");
|
||||||
|
}
|
||||||
|
|
||||||
|
export function PendingOperationWidget({
|
||||||
|
status,
|
||||||
|
message,
|
||||||
|
toolName,
|
||||||
|
className,
|
||||||
|
}: Props) {
|
||||||
|
const isPending =
|
||||||
|
status === "pending" || status === "started" || status === "in_progress";
|
||||||
|
const isCompleted = status === "completed";
|
||||||
|
const isError = status === "error";
|
||||||
|
|
||||||
|
const operationTitle = getOperationTitle(toolName);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"group relative flex w-full justify-start gap-3 px-4 py-3",
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<div className="flex w-full max-w-3xl gap-3">
|
||||||
|
<div className="flex-shrink-0">
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex h-7 w-7 items-center justify-center rounded-lg",
|
||||||
|
isPending && "bg-blue-500",
|
||||||
|
isCompleted && "bg-green-500",
|
||||||
|
isError && "bg-red-500",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{isPending && (
|
||||||
|
<CircleNotch
|
||||||
|
className="h-4 w-4 animate-spin text-white"
|
||||||
|
weight="bold"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{isCompleted && (
|
||||||
|
<CheckCircle className="h-4 w-4 text-white" weight="bold" />
|
||||||
|
)}
|
||||||
|
{isError && (
|
||||||
|
<XCircle className="h-4 w-4 text-white" weight="bold" />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex min-w-0 flex-1 flex-col">
|
||||||
|
<Card className="space-y-2 p-4">
|
||||||
|
<div>
|
||||||
|
<Text variant="h4" className="mb-1 text-slate-900">
|
||||||
|
{isPending && operationTitle}
|
||||||
|
{isCompleted && `${operationTitle} Complete`}
|
||||||
|
{isError && `${operationTitle} Failed`}
|
||||||
|
</Text>
|
||||||
|
<Text variant="small" className="text-slate-600">
|
||||||
|
{message}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{isPending && (
|
||||||
|
<Text variant="small" className="italic text-slate-500">
|
||||||
|
Check your library in a few minutes.
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{toolName && (
|
||||||
|
<Text variant="small" className="text-slate-400">
|
||||||
|
Tool: {toolName}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ export function UserChatBubble({ children, className }: UserChatBubbleProps) {
|
|||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
"group relative min-w-20 overflow-hidden rounded-xl bg-purple-100 px-3 text-right text-[1rem] leading-relaxed transition-all duration-500 ease-in-out",
|
"group relative min-w-20 overflow-hidden rounded-xl bg-purple-100 px-3 text-left text-[1rem] leading-relaxed transition-all duration-500 ease-in-out",
|
||||||
className,
|
className,
|
||||||
)}
|
)}
|
||||||
style={{
|
style={{
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
|||||||
claimSession,
|
claimSession,
|
||||||
clearSession: clearSessionBase,
|
clearSession: clearSessionBase,
|
||||||
loadSession,
|
loadSession,
|
||||||
|
startPollingForOperation,
|
||||||
} = useChatSession({
|
} = useChatSession({
|
||||||
urlSessionId,
|
urlSessionId,
|
||||||
autoCreate: false,
|
autoCreate: false,
|
||||||
@@ -94,5 +95,6 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
|||||||
loadSession,
|
loadSession,
|
||||||
sessionId: sessionIdFromHook,
|
sessionId: sessionIdFromHook,
|
||||||
showLoader,
|
showLoader,
|
||||||
|
startPollingForOperation,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ export function useChatSession({
|
|||||||
query: {
|
query: {
|
||||||
enabled: !!sessionId,
|
enabled: !!sessionId,
|
||||||
select: okData,
|
select: okData,
|
||||||
|
staleTime: 0,
|
||||||
retry: shouldRetrySessionLoad,
|
retry: shouldRetrySessionLoad,
|
||||||
retryDelay: getSessionRetryDelay,
|
retryDelay: getSessionRetryDelay,
|
||||||
},
|
},
|
||||||
@@ -102,15 +103,123 @@ export function useChatSession({
|
|||||||
}
|
}
|
||||||
}, [createError, loadError]);
|
}, [createError, loadError]);
|
||||||
|
|
||||||
|
// Track if we should be polling (set by external callers when they receive operation_started via SSE)
|
||||||
|
const [forcePolling, setForcePolling] = useState(false);
|
||||||
|
// Track if we've seen server acknowledge the pending operation (to avoid clearing forcePolling prematurely)
|
||||||
|
const hasSeenServerPendingRef = useRef(false);
|
||||||
|
|
||||||
|
// Check if there are any pending operations in the messages
|
||||||
|
// Must check all operation types: operation_pending, operation_started, operation_in_progress
|
||||||
|
const hasPendingOperationsFromServer = useMemo(() => {
|
||||||
|
if (!messages || messages.length === 0) return false;
|
||||||
|
const pendingTypes = new Set([
|
||||||
|
"operation_pending",
|
||||||
|
"operation_in_progress",
|
||||||
|
"operation_started",
|
||||||
|
]);
|
||||||
|
return messages.some((msg) => {
|
||||||
|
if (msg.role !== "tool" || !msg.content) return false;
|
||||||
|
try {
|
||||||
|
const content =
|
||||||
|
typeof msg.content === "string"
|
||||||
|
? JSON.parse(msg.content)
|
||||||
|
: msg.content;
|
||||||
|
return pendingTypes.has(content?.type);
|
||||||
|
} catch {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}, [messages]);
|
||||||
|
|
||||||
|
// Track when server has acknowledged the pending operation
|
||||||
|
useEffect(() => {
|
||||||
|
if (hasPendingOperationsFromServer) {
|
||||||
|
hasSeenServerPendingRef.current = true;
|
||||||
|
}
|
||||||
|
}, [hasPendingOperationsFromServer]);
|
||||||
|
|
||||||
|
// Combined: poll if server has pending ops OR if we received operation_started via SSE
|
||||||
|
const hasPendingOperations = hasPendingOperationsFromServer || forcePolling;
|
||||||
|
|
||||||
|
// Clear forcePolling only after server has acknowledged AND completed the operation
|
||||||
|
useEffect(() => {
|
||||||
|
if (
|
||||||
|
forcePolling &&
|
||||||
|
!hasPendingOperationsFromServer &&
|
||||||
|
hasSeenServerPendingRef.current
|
||||||
|
) {
|
||||||
|
// Server acknowledged the operation and it's now complete
|
||||||
|
setForcePolling(false);
|
||||||
|
hasSeenServerPendingRef.current = false;
|
||||||
|
}
|
||||||
|
}, [forcePolling, hasPendingOperationsFromServer]);
|
||||||
|
|
||||||
|
// Function to trigger polling (called when operation_started is received via SSE)
|
||||||
|
function startPollingForOperation() {
|
||||||
|
setForcePolling(true);
|
||||||
|
hasSeenServerPendingRef.current = false; // Reset for new operation
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh sessions list when a pending operation completes
|
||||||
|
// (hasPendingOperations transitions from true to false)
|
||||||
|
const prevHasPendingOperationsRef = useRef(hasPendingOperations);
|
||||||
useEffect(
|
useEffect(
|
||||||
function refreshSessionsListOnLoad() {
|
function refreshSessionsListOnOperationComplete() {
|
||||||
if (sessionId && sessionData && !isLoadingSession) {
|
const wasHasPending = prevHasPendingOperationsRef.current;
|
||||||
|
prevHasPendingOperationsRef.current = hasPendingOperations;
|
||||||
|
|
||||||
|
// Only invalidate when transitioning from pending to not pending
|
||||||
|
if (wasHasPending && !hasPendingOperations && sessionId) {
|
||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[sessionId, sessionData, isLoadingSession, queryClient],
|
[hasPendingOperations, sessionId, queryClient],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Poll for updates when there are pending operations
|
||||||
|
// Backoff: 2s, 4s, 6s, 8s, 10s, ... up to 30s max
|
||||||
|
const pollAttemptRef = useRef(0);
|
||||||
|
const hasPendingOperationsRef = useRef(hasPendingOperations);
|
||||||
|
hasPendingOperationsRef.current = hasPendingOperations;
|
||||||
|
|
||||||
|
useEffect(
|
||||||
|
function pollForPendingOperations() {
|
||||||
|
if (!sessionId || !hasPendingOperations) {
|
||||||
|
pollAttemptRef.current = 0;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let cancelled = false;
|
||||||
|
let timeoutId: ReturnType<typeof setTimeout> | null = null;
|
||||||
|
|
||||||
|
function schedule() {
|
||||||
|
// 2s, 4s, 6s, 8s, 10s, ... 30s (max)
|
||||||
|
const delay = Math.min((pollAttemptRef.current + 1) * 2000, 30000);
|
||||||
|
timeoutId = setTimeout(async () => {
|
||||||
|
if (cancelled) return;
|
||||||
|
pollAttemptRef.current += 1;
|
||||||
|
try {
|
||||||
|
await refetch();
|
||||||
|
} catch (err) {
|
||||||
|
console.error("[useChatSession] Poll failed:", err);
|
||||||
|
} finally {
|
||||||
|
if (!cancelled && hasPendingOperationsRef.current) {
|
||||||
|
schedule();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, delay);
|
||||||
|
}
|
||||||
|
|
||||||
|
schedule();
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
cancelled = true;
|
||||||
|
if (timeoutId) clearTimeout(timeoutId);
|
||||||
|
};
|
||||||
|
},
|
||||||
|
[sessionId, hasPendingOperations, refetch],
|
||||||
);
|
);
|
||||||
|
|
||||||
async function createSession() {
|
async function createSession() {
|
||||||
@@ -239,11 +348,13 @@ export function useChatSession({
|
|||||||
isCreating,
|
isCreating,
|
||||||
error,
|
error,
|
||||||
isSessionNotFound: isNotFoundError(loadError),
|
isSessionNotFound: isNotFoundError(loadError),
|
||||||
|
hasPendingOperations,
|
||||||
createSession,
|
createSession,
|
||||||
loadSession,
|
loadSession,
|
||||||
refreshSession,
|
refreshSession,
|
||||||
claimSession,
|
claimSession,
|
||||||
clearSession,
|
clearSession,
|
||||||
|
startPollingForOperation,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -255,13 +255,18 @@ export function Wallet() {
|
|||||||
(notification: WebSocketNotification) => {
|
(notification: WebSocketNotification) => {
|
||||||
if (
|
if (
|
||||||
notification.type !== "onboarding" ||
|
notification.type !== "onboarding" ||
|
||||||
notification.event !== "step_completed" ||
|
notification.event !== "step_completed"
|
||||||
!walletRef.current
|
|
||||||
) {
|
) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only trigger confetti for tasks that are in groups
|
// Always refresh credits when any onboarding step completes
|
||||||
|
fetchCredits();
|
||||||
|
|
||||||
|
// Only trigger confetti for tasks that are in displayed groups
|
||||||
|
if (!walletRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const taskIds = groups
|
const taskIds = groups
|
||||||
.flatMap((group) => group.tasks)
|
.flatMap((group) => group.tasks)
|
||||||
.map((task) => task.id);
|
.map((task) => task.id);
|
||||||
@@ -274,7 +279,6 @@ export function Wallet() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
fetchCredits();
|
|
||||||
party.confetti(walletRef.current, {
|
party.confetti(walletRef.current, {
|
||||||
count: 30,
|
count: 30,
|
||||||
spread: 120,
|
spread: 120,
|
||||||
@@ -284,7 +288,7 @@ export function Wallet() {
|
|||||||
modules: [fadeOut],
|
modules: [fadeOut],
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[fetchCredits, fadeOut],
|
[fetchCredits, fadeOut, groups],
|
||||||
);
|
);
|
||||||
|
|
||||||
// WebSocket setup for onboarding notifications
|
// WebSocket setup for onboarding notifications
|
||||||
|
|||||||
@@ -1003,6 +1003,7 @@ export type OnboardingStep =
|
|||||||
| "AGENT_INPUT"
|
| "AGENT_INPUT"
|
||||||
| "CONGRATS"
|
| "CONGRATS"
|
||||||
// First Wins
|
// First Wins
|
||||||
|
| "VISIT_COPILOT"
|
||||||
| "GET_RESULTS"
|
| "GET_RESULTS"
|
||||||
| "MARKETPLACE_VISIT"
|
| "MARKETPLACE_VISIT"
|
||||||
| "MARKETPLACE_ADD_AGENT"
|
| "MARKETPLACE_ADD_AGENT"
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import { environment } from "../environment";
|
|||||||
|
|
||||||
export enum SessionKey {
|
export enum SessionKey {
|
||||||
CHAT_SENT_INITIAL_PROMPTS = "chat_sent_initial_prompts",
|
CHAT_SENT_INITIAL_PROMPTS = "chat_sent_initial_prompts",
|
||||||
|
CHAT_INITIAL_PROMPTS = "chat_initial_prompts",
|
||||||
}
|
}
|
||||||
|
|
||||||
function get(key: SessionKey) {
|
function get(key: SessionKey) {
|
||||||
|
|||||||
@@ -37,9 +37,13 @@ export class LoginPage {
|
|||||||
this.page.on("load", (page) => console.log(`ℹ️ Now at URL: ${page.url()}`));
|
this.page.on("load", (page) => console.log(`ℹ️ Now at URL: ${page.url()}`));
|
||||||
|
|
||||||
// Start waiting for navigation before clicking
|
// Start waiting for navigation before clicking
|
||||||
|
// Wait for redirect to marketplace, onboarding, library, or copilot (new landing pages)
|
||||||
const leaveLoginPage = this.page
|
const leaveLoginPage = this.page
|
||||||
.waitForURL(
|
.waitForURL(
|
||||||
(url) => /^\/(marketplace|onboarding(\/.*)?)?$/.test(url.pathname),
|
(url: URL) =>
|
||||||
|
/^\/(marketplace|onboarding(\/.*)?|library|copilot)?$/.test(
|
||||||
|
url.pathname,
|
||||||
|
),
|
||||||
{ timeout: 10_000 },
|
{ timeout: 10_000 },
|
||||||
)
|
)
|
||||||
.catch((reason) => {
|
.catch((reason) => {
|
||||||
|
|||||||
@@ -36,14 +36,16 @@ export async function signupTestUser(
|
|||||||
const signupButton = getButton("Sign up");
|
const signupButton = getButton("Sign up");
|
||||||
await signupButton.click();
|
await signupButton.click();
|
||||||
|
|
||||||
// Wait for successful signup - could redirect to onboarding or marketplace
|
// Wait for successful signup - could redirect to various pages depending on onboarding state
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Wait for either onboarding or marketplace redirect
|
// Wait for redirect to onboarding, marketplace, copilot, or library
|
||||||
await Promise.race([
|
// Use a single waitForURL with a callback to avoid Promise.race race conditions
|
||||||
page.waitForURL(/\/onboarding/, { timeout: 15000 }),
|
await page.waitForURL(
|
||||||
page.waitForURL(/\/marketplace/, { timeout: 15000 }),
|
(url: URL) =>
|
||||||
]);
|
/\/(onboarding|marketplace|copilot|library)/.test(url.pathname),
|
||||||
|
{ timeout: 15000 },
|
||||||
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(
|
console.error(
|
||||||
"❌ Timeout waiting for redirect, current URL:",
|
"❌ Timeout waiting for redirect, current URL:",
|
||||||
@@ -54,14 +56,19 @@ export async function signupTestUser(
|
|||||||
|
|
||||||
const currentUrl = page.url();
|
const currentUrl = page.url();
|
||||||
|
|
||||||
// Handle onboarding or marketplace redirect
|
// Handle onboarding redirect if needed
|
||||||
if (currentUrl.includes("/onboarding") && ignoreOnboarding) {
|
if (currentUrl.includes("/onboarding") && ignoreOnboarding) {
|
||||||
await page.goto("http://localhost:3000/marketplace");
|
await page.goto("http://localhost:3000/marketplace");
|
||||||
await page.waitForLoadState("domcontentloaded", { timeout: 10000 });
|
await page.waitForLoadState("domcontentloaded", { timeout: 10000 });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify we're on the expected final page
|
// Verify we're on an expected final page and user is authenticated
|
||||||
if (ignoreOnboarding || currentUrl.includes("/marketplace")) {
|
if (currentUrl.includes("/copilot") || currentUrl.includes("/library")) {
|
||||||
|
// For copilot/library landing pages, just verify user is authenticated
|
||||||
|
await page
|
||||||
|
.getByTestId("profile-popout-menu-trigger")
|
||||||
|
.waitFor({ state: "visible", timeout: 10000 });
|
||||||
|
} else if (ignoreOnboarding || currentUrl.includes("/marketplace")) {
|
||||||
// Verify we're on marketplace
|
// Verify we're on marketplace
|
||||||
await page
|
await page
|
||||||
.getByText(
|
.getByText(
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
# Video editing blocks
|
|
||||||
Reference in New Issue
Block a user