mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-28 00:18:25 -05:00
Compare commits
8 Commits
feat/text-
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0058cd3ba6 | ||
|
|
ea035224bc | ||
|
|
62813a1ea6 | ||
|
|
67405f7eb9 | ||
|
|
171ff6e776 | ||
|
|
349b1f9c79 | ||
|
|
277b0537e9 | ||
|
|
071b3bb5cd |
@@ -33,9 +33,15 @@ class ChatConfig(BaseSettings):
|
|||||||
|
|
||||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||||
max_agent_runs: int = Field(default=3, description="Maximum number of agent runs")
|
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||||
max_agent_schedules: int = Field(
|
max_agent_schedules: int = Field(
|
||||||
default=3, description="Maximum number of agent schedules"
|
default=30, description="Maximum number of agent schedules"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Long-running operation configuration
|
||||||
|
long_running_operation_ttl: int = Field(
|
||||||
|
default=600,
|
||||||
|
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Langfuse Prompt Management Configuration
|
# Langfuse Prompt Management Configuration
|
||||||
|
|||||||
@@ -247,3 +247,45 @@ async def get_chat_session_message_count(session_id: str) -> int:
|
|||||||
"""Get the number of messages in a chat session."""
|
"""Get the number of messages in a chat session."""
|
||||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
async def update_tool_message_content(
|
||||||
|
session_id: str,
|
||||||
|
tool_call_id: str,
|
||||||
|
new_content: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Update the content of a tool message in chat history.
|
||||||
|
|
||||||
|
Used by background tasks to update pending operation messages with final results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The chat session ID.
|
||||||
|
tool_call_id: The tool call ID to find the message.
|
||||||
|
new_content: The new content to set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a message was updated, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await PrismaChatMessage.prisma().update_many(
|
||||||
|
where={
|
||||||
|
"sessionId": session_id,
|
||||||
|
"toolCallId": tool_call_id,
|
||||||
|
},
|
||||||
|
data={
|
||||||
|
"content": new_content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if result == 0:
|
||||||
|
logger.warning(
|
||||||
|
f"No message found to update for session {session_id}, "
|
||||||
|
f"tool_call_id {tool_call_id}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to update tool message for session {session_id}, "
|
||||||
|
f"tool_call_id {tool_call_id}: {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|||||||
@@ -295,6 +295,21 @@ async def cache_chat_session(session: ChatSession) -> None:
|
|||||||
await _cache_session(session)
|
await _cache_session(session)
|
||||||
|
|
||||||
|
|
||||||
|
async def invalidate_session_cache(session_id: str) -> None:
|
||||||
|
"""Invalidate a chat session from Redis cache.
|
||||||
|
|
||||||
|
Used by background tasks to ensure fresh data is loaded on next access.
|
||||||
|
This is best-effort - Redis failures are logged but don't fail the operation.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis_key = _get_session_cache_key(session_id)
|
||||||
|
async_redis = await get_redis_async()
|
||||||
|
await async_redis.delete(redis_key)
|
||||||
|
except Exception as e:
|
||||||
|
# Best-effort: log but don't fail - cache will expire naturally
|
||||||
|
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||||
"""Get a chat session from the database."""
|
"""Get a chat session from the database."""
|
||||||
prisma_session = await chat_db.get_chat_session(session_id)
|
prisma_session = await chat_db.get_chat_session(session_id)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from openai import (
|
|||||||
)
|
)
|
||||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||||
|
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.data.understanding import (
|
from backend.data.understanding import (
|
||||||
format_understanding_for_prompt,
|
format_understanding_for_prompt,
|
||||||
get_business_understanding,
|
get_business_understanding,
|
||||||
@@ -24,6 +25,7 @@ from backend.data.understanding import (
|
|||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
from . import db as chat_db
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import (
|
from .model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@@ -31,6 +33,7 @@ from .model import (
|
|||||||
Usage,
|
Usage,
|
||||||
cache_chat_session,
|
cache_chat_session,
|
||||||
get_chat_session,
|
get_chat_session,
|
||||||
|
invalidate_session_cache,
|
||||||
update_session_title,
|
update_session_title,
|
||||||
upsert_chat_session,
|
upsert_chat_session,
|
||||||
)
|
)
|
||||||
@@ -48,8 +51,13 @@ from .response_model import (
|
|||||||
StreamToolOutputAvailable,
|
StreamToolOutputAvailable,
|
||||||
StreamUsage,
|
StreamUsage,
|
||||||
)
|
)
|
||||||
from .tools import execute_tool, tools
|
from .tools import execute_tool, get_tool, tools
|
||||||
from .tools.models import ErrorResponse
|
from .tools.models import (
|
||||||
|
ErrorResponse,
|
||||||
|
OperationInProgressResponse,
|
||||||
|
OperationPendingResponse,
|
||||||
|
OperationStartedResponse,
|
||||||
|
)
|
||||||
from .tracking import track_user_message
|
from .tracking import track_user_message
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -61,11 +69,126 @@ client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|||||||
|
|
||||||
langfuse = get_client()
|
langfuse = get_client()
|
||||||
|
|
||||||
|
# Redis key prefix for tracking running long-running operations
|
||||||
|
# Used for idempotency across Kubernetes pods - prevents duplicate executions on browser refresh
|
||||||
|
RUNNING_OPERATION_PREFIX = "chat:running_operation:"
|
||||||
|
|
||||||
class LangfuseNotConfiguredError(Exception):
|
# Default system prompt used when Langfuse is not configured
|
||||||
"""Raised when Langfuse is required but not configured."""
|
# This is a snapshot of the "CoPilot Prompt" from Langfuse (version 11)
|
||||||
|
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
|
||||||
|
|
||||||
pass
|
Here is everything you know about the current user from previous interactions:
|
||||||
|
|
||||||
|
<users_information>
|
||||||
|
{users_information}
|
||||||
|
</users_information>
|
||||||
|
|
||||||
|
## YOUR CORE MANDATE
|
||||||
|
|
||||||
|
You are action-oriented. Your success is measured by:
|
||||||
|
- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"?
|
||||||
|
- **Demonstrable Proof**: Show working automations, not descriptions of what's possible
|
||||||
|
- **Time Saved**: Focus on tangible efficiency gains
|
||||||
|
- **Quality Output**: Deliver results that meet or exceed expectations
|
||||||
|
|
||||||
|
## YOUR WORKFLOW
|
||||||
|
|
||||||
|
Adapt flexibly to the conversation context. Not every interaction requires all stages:
|
||||||
|
|
||||||
|
1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations.
|
||||||
|
|
||||||
|
2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task.
|
||||||
|
|
||||||
|
3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.).
|
||||||
|
|
||||||
|
4. **Discover or Create Agents**:
|
||||||
|
- **Always check the user's library first** with `find_library_agent` (these may be customized to their needs)
|
||||||
|
- Search the marketplace with `find_agent` for pre-built automations
|
||||||
|
- Find reusable components with `find_block`
|
||||||
|
- Create custom solutions with `create_agent` if nothing suitable exists
|
||||||
|
- Modify existing library agents with `edit_agent`
|
||||||
|
|
||||||
|
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
|
||||||
|
|
||||||
|
6. **Show Results**: Display outputs using `agent_output`.
|
||||||
|
|
||||||
|
## AVAILABLE TOOLS
|
||||||
|
|
||||||
|
**Understanding & Discovery:**
|
||||||
|
- `add_understanding`: Create a memory about the user's business or use cases for future sessions
|
||||||
|
- `search_docs`: Search platform documentation for specific technical information
|
||||||
|
- `get_doc_page`: Retrieve full text of a specific documentation page
|
||||||
|
|
||||||
|
**Agent Discovery:**
|
||||||
|
- `find_library_agent`: Search the user's existing agents (CHECK HERE FIRST—these may be customized)
|
||||||
|
- `find_agent`: Search the marketplace for pre-built automations
|
||||||
|
- `find_block`: Find pre-written code units that perform specific tasks (agents are built from blocks)
|
||||||
|
|
||||||
|
**Agent Creation & Editing:**
|
||||||
|
- `create_agent`: Create a new automation agent
|
||||||
|
- `edit_agent`: Modify an agent in the user's library
|
||||||
|
|
||||||
|
**Execution & Output:**
|
||||||
|
- `run_agent`: Run an agent now, schedule it, or set up a webhook trigger
|
||||||
|
- `run_block`: Test or run a specific block independently
|
||||||
|
- `agent_output`: View results from previous agent runs
|
||||||
|
|
||||||
|
## BEHAVIORAL GUIDELINES
|
||||||
|
|
||||||
|
**Be Concise:**
|
||||||
|
- Target 2-5 short lines maximum
|
||||||
|
- Make every word count—no repetition or filler
|
||||||
|
- Use lightweight structure for scannability (bullets, numbered lists, short prompts)
|
||||||
|
- Avoid jargon (blocks, slugs, cron) unless the user asks
|
||||||
|
|
||||||
|
**Be Proactive:**
|
||||||
|
- Suggest next steps before being asked
|
||||||
|
- Anticipate needs based on conversation context and user information
|
||||||
|
- Look for opportunities to expand scope when relevant
|
||||||
|
- Reveal capabilities through action, not explanation
|
||||||
|
|
||||||
|
**Use Tools Effectively:**
|
||||||
|
- Select the right tool for each task
|
||||||
|
- **Always check `find_library_agent` before searching the marketplace**
|
||||||
|
- Use `add_understanding` to capture valuable business context
|
||||||
|
- When tool calls fail, try alternative approaches
|
||||||
|
|
||||||
|
## CRITICAL REMINDER
|
||||||
|
|
||||||
|
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
|
||||||
|
|
||||||
|
# Module-level set to hold strong references to background tasks.
|
||||||
|
# This prevents asyncio from garbage collecting tasks before they complete.
|
||||||
|
# Tasks are automatically removed on completion via done_callback.
|
||||||
|
_background_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
|
||||||
|
async def _mark_operation_started(tool_call_id: str) -> bool:
|
||||||
|
"""Mark a long-running operation as started (Redis-based).
|
||||||
|
|
||||||
|
Returns True if successfully marked (operation was not already running),
|
||||||
|
False if operation was already running (lost race condition).
|
||||||
|
Raises exception if Redis is unavailable (fail-closed).
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
key = f"{RUNNING_OPERATION_PREFIX}{tool_call_id}"
|
||||||
|
# SETNX with TTL - atomic "set if not exists"
|
||||||
|
result = await redis.set(key, "1", ex=config.long_running_operation_ttl, nx=True)
|
||||||
|
return result is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def _mark_operation_completed(tool_call_id: str) -> None:
|
||||||
|
"""Mark a long-running operation as completed (remove Redis key).
|
||||||
|
|
||||||
|
This is best-effort - if Redis fails, the TTL will eventually clean up.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
key = f"{RUNNING_OPERATION_PREFIX}{tool_call_id}"
|
||||||
|
await redis.delete(key)
|
||||||
|
except Exception as e:
|
||||||
|
# Non-critical: TTL will clean up eventually
|
||||||
|
logger.warning(f"Failed to delete running operation key {tool_call_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _is_langfuse_configured() -> bool:
|
def _is_langfuse_configured() -> bool:
|
||||||
@@ -75,6 +198,30 @@ def _is_langfuse_configured() -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_system_prompt_template(context: str) -> str:
|
||||||
|
"""Get the system prompt, trying Langfuse first with fallback to default.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: The user context/information to compile into the prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The compiled system prompt string.
|
||||||
|
"""
|
||||||
|
if _is_langfuse_configured():
|
||||||
|
try:
|
||||||
|
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||||
|
# Use asyncio.to_thread to avoid blocking the event loop
|
||||||
|
prompt = await asyncio.to_thread(
|
||||||
|
langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0
|
||||||
|
)
|
||||||
|
return prompt.compile(users_information=context)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch prompt from Langfuse, using default: {e}")
|
||||||
|
|
||||||
|
# Fallback to default prompt
|
||||||
|
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
||||||
|
|
||||||
|
|
||||||
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||||
"""Build the full system prompt including business understanding if available.
|
"""Build the full system prompt including business understanding if available.
|
||||||
|
|
||||||
@@ -83,12 +230,8 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
|||||||
If "default" and this is the user's first session, will use "onboarding" instead.
|
If "default" and this is the user's first session, will use "onboarding" instead.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (compiled prompt string, Langfuse prompt object for tracing)
|
Tuple of (compiled prompt string, business understanding object)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
|
||||||
prompt = langfuse.get_prompt(config.langfuse_prompt_name, cache_ttl_seconds=0)
|
|
||||||
|
|
||||||
# If user is authenticated, try to fetch their business understanding
|
# If user is authenticated, try to fetch their business understanding
|
||||||
understanding = None
|
understanding = None
|
||||||
if user_id:
|
if user_id:
|
||||||
@@ -97,12 +240,13 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to fetch business understanding: {e}")
|
logger.warning(f"Failed to fetch business understanding: {e}")
|
||||||
understanding = None
|
understanding = None
|
||||||
|
|
||||||
if understanding:
|
if understanding:
|
||||||
context = format_understanding_for_prompt(understanding)
|
context = format_understanding_for_prompt(understanding)
|
||||||
else:
|
else:
|
||||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||||
|
|
||||||
compiled = prompt.compile(users_information=context)
|
compiled = await _get_system_prompt_template(context)
|
||||||
return compiled, understanding
|
return compiled, understanding
|
||||||
|
|
||||||
|
|
||||||
@@ -210,16 +354,6 @@ async def stream_chat_completion(
|
|||||||
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if Langfuse is configured - required for chat functionality
|
|
||||||
if not _is_langfuse_configured():
|
|
||||||
logger.error("Chat request failed: Langfuse is not configured")
|
|
||||||
yield StreamError(
|
|
||||||
errorText="Chat service is not available. Langfuse must be configured "
|
|
||||||
"with LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables."
|
|
||||||
)
|
|
||||||
yield StreamFinish()
|
|
||||||
return
|
|
||||||
|
|
||||||
# Only fetch from Redis if session not provided (initial call)
|
# Only fetch from Redis if session not provided (initial call)
|
||||||
if session is None:
|
if session is None:
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
@@ -315,6 +449,7 @@ async def stream_chat_completion(
|
|||||||
has_yielded_end = False
|
has_yielded_end = False
|
||||||
has_yielded_error = False
|
has_yielded_error = False
|
||||||
has_done_tool_call = False
|
has_done_tool_call = False
|
||||||
|
has_long_running_tool_call = False # Track if we had a long-running tool call
|
||||||
has_received_text = False
|
has_received_text = False
|
||||||
text_streaming_ended = False
|
text_streaming_ended = False
|
||||||
tool_response_messages: list[ChatMessage] = []
|
tool_response_messages: list[ChatMessage] = []
|
||||||
@@ -336,7 +471,6 @@ async def stream_chat_completion(
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
text_block_id=text_block_id,
|
text_block_id=text_block_id,
|
||||||
):
|
):
|
||||||
|
|
||||||
if isinstance(chunk, StreamTextStart):
|
if isinstance(chunk, StreamTextStart):
|
||||||
# Emit text-start before first text delta
|
# Emit text-start before first text delta
|
||||||
if not has_received_text:
|
if not has_received_text:
|
||||||
@@ -394,13 +528,34 @@ async def stream_chat_completion(
|
|||||||
if isinstance(chunk.output, str)
|
if isinstance(chunk.output, str)
|
||||||
else orjson.dumps(chunk.output).decode("utf-8")
|
else orjson.dumps(chunk.output).decode("utf-8")
|
||||||
)
|
)
|
||||||
tool_response_messages.append(
|
# Skip saving long-running operation responses - messages already saved in _yield_tool_call
|
||||||
ChatMessage(
|
# Use JSON parsing instead of substring matching to avoid false positives
|
||||||
role="tool",
|
is_long_running_response = False
|
||||||
content=result_content,
|
try:
|
||||||
tool_call_id=chunk.toolCallId,
|
parsed = orjson.loads(result_content)
|
||||||
|
if isinstance(parsed, dict) and parsed.get("type") in (
|
||||||
|
"operation_started",
|
||||||
|
"operation_in_progress",
|
||||||
|
):
|
||||||
|
is_long_running_response = True
|
||||||
|
except (orjson.JSONDecodeError, TypeError):
|
||||||
|
pass # Not JSON or not a dict - treat as regular response
|
||||||
|
if is_long_running_response:
|
||||||
|
# Remove from accumulated_tool_calls since assistant message was already saved
|
||||||
|
accumulated_tool_calls[:] = [
|
||||||
|
tc
|
||||||
|
for tc in accumulated_tool_calls
|
||||||
|
if tc["id"] != chunk.toolCallId
|
||||||
|
]
|
||||||
|
has_long_running_tool_call = True
|
||||||
|
else:
|
||||||
|
tool_response_messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="tool",
|
||||||
|
content=result_content,
|
||||||
|
tool_call_id=chunk.toolCallId,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
has_done_tool_call = True
|
has_done_tool_call = True
|
||||||
# Track if any tool execution failed
|
# Track if any tool execution failed
|
||||||
if not chunk.success:
|
if not chunk.success:
|
||||||
@@ -576,7 +731,14 @@ async def stream_chat_completion(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Extended session messages, new message_count={len(session.messages)}"
|
f"Extended session messages, new message_count={len(session.messages)}"
|
||||||
)
|
)
|
||||||
if messages_to_save or has_appended_streaming_message:
|
# Save if there are regular (non-long-running) tool responses or streaming message.
|
||||||
|
# Long-running tools save their own state, but we still need to save regular tools
|
||||||
|
# that may be in the same response.
|
||||||
|
has_regular_tool_responses = len(tool_response_messages) > 0
|
||||||
|
if has_regular_tool_responses or (
|
||||||
|
not has_long_running_tool_call
|
||||||
|
and (messages_to_save or has_appended_streaming_message)
|
||||||
|
):
|
||||||
await upsert_chat_session(session)
|
await upsert_chat_session(session)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -585,7 +747,9 @@ async def stream_chat_completion(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# If we did a tool call, stream the chat completion again to get the next response
|
# If we did a tool call, stream the chat completion again to get the next response
|
||||||
if has_done_tool_call:
|
# Skip only if ALL tools were long-running (they handle their own completion)
|
||||||
|
has_regular_tools = len(tool_response_messages) > 0
|
||||||
|
if has_done_tool_call and (has_regular_tools or not has_long_running_tool_call):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Tool call executed, streaming chat completion again to get assistant response"
|
"Tool call executed, streaming chat completion again to get assistant response"
|
||||||
)
|
)
|
||||||
@@ -725,6 +889,114 @@ async def _summarize_messages(
|
|||||||
return summary or "No summary available."
|
return summary or "No summary available."
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_tool_pairs_intact(
|
||||||
|
recent_messages: list[dict],
|
||||||
|
all_messages: list[dict],
|
||||||
|
start_index: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Ensure tool_call/tool_response pairs stay together after slicing.
|
||||||
|
|
||||||
|
When slicing messages for context compaction, a naive slice can separate
|
||||||
|
an assistant message containing tool_calls from its corresponding tool
|
||||||
|
response messages. This causes API validation errors (e.g., Anthropic's
|
||||||
|
"unexpected tool_use_id found in tool_result blocks").
|
||||||
|
|
||||||
|
This function checks for orphan tool responses in the slice and extends
|
||||||
|
backwards to include their corresponding assistant messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recent_messages: The sliced messages to validate
|
||||||
|
all_messages: The complete message list (for looking up missing assistants)
|
||||||
|
start_index: The index in all_messages where recent_messages begins
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A potentially extended list of messages with tool pairs intact
|
||||||
|
"""
|
||||||
|
if not recent_messages:
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
# Collect all tool_call_ids from assistant messages in the slice
|
||||||
|
available_tool_call_ids: set[str] = set()
|
||||||
|
for msg in recent_messages:
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||||
|
for tc in msg["tool_calls"]:
|
||||||
|
tc_id = tc.get("id")
|
||||||
|
if tc_id:
|
||||||
|
available_tool_call_ids.add(tc_id)
|
||||||
|
|
||||||
|
# Find orphan tool responses (tool messages whose tool_call_id is missing)
|
||||||
|
orphan_tool_call_ids: set[str] = set()
|
||||||
|
for msg in recent_messages:
|
||||||
|
if msg.get("role") == "tool":
|
||||||
|
tc_id = msg.get("tool_call_id")
|
||||||
|
if tc_id and tc_id not in available_tool_call_ids:
|
||||||
|
orphan_tool_call_ids.add(tc_id)
|
||||||
|
|
||||||
|
if not orphan_tool_call_ids:
|
||||||
|
# No orphans, slice is valid
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
# Find the assistant messages that contain the orphan tool_call_ids
|
||||||
|
# Search backwards from start_index in all_messages
|
||||||
|
messages_to_prepend: list[dict] = []
|
||||||
|
for i in range(start_index - 1, -1, -1):
|
||||||
|
msg = all_messages[i]
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||||
|
msg_tool_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")}
|
||||||
|
if msg_tool_ids & orphan_tool_call_ids:
|
||||||
|
# This assistant message has tool_calls we need
|
||||||
|
# Also collect its contiguous tool responses that follow it
|
||||||
|
assistant_and_responses: list[dict] = [msg]
|
||||||
|
|
||||||
|
# Scan forward from this assistant to collect tool responses
|
||||||
|
for j in range(i + 1, start_index):
|
||||||
|
following_msg = all_messages[j]
|
||||||
|
if following_msg.get("role") == "tool":
|
||||||
|
tool_id = following_msg.get("tool_call_id")
|
||||||
|
if tool_id and tool_id in msg_tool_ids:
|
||||||
|
assistant_and_responses.append(following_msg)
|
||||||
|
else:
|
||||||
|
# Stop at first non-tool message
|
||||||
|
break
|
||||||
|
|
||||||
|
# Prepend the assistant and its tool responses (maintain order)
|
||||||
|
messages_to_prepend = assistant_and_responses + messages_to_prepend
|
||||||
|
# Mark these as found
|
||||||
|
orphan_tool_call_ids -= msg_tool_ids
|
||||||
|
# Also add this assistant's tool_call_ids to available set
|
||||||
|
available_tool_call_ids |= msg_tool_ids
|
||||||
|
|
||||||
|
if not orphan_tool_call_ids:
|
||||||
|
# Found all missing assistants
|
||||||
|
break
|
||||||
|
|
||||||
|
if orphan_tool_call_ids:
|
||||||
|
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
||||||
|
# This shouldn't happen in normal operation but handles edge cases
|
||||||
|
logger.warning(
|
||||||
|
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
||||||
|
"Removing orphan tool responses."
|
||||||
|
)
|
||||||
|
recent_messages = [
|
||||||
|
msg
|
||||||
|
for msg in recent_messages
|
||||||
|
if not (
|
||||||
|
msg.get("role") == "tool"
|
||||||
|
and msg.get("tool_call_id") in orphan_tool_call_ids
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if messages_to_prepend:
|
||||||
|
logger.info(
|
||||||
|
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
||||||
|
f"tool_call/tool_response pairs"
|
||||||
|
)
|
||||||
|
return messages_to_prepend + recent_messages
|
||||||
|
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
|
||||||
async def _stream_chat_chunks(
|
async def _stream_chat_chunks(
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
tools: list[ChatCompletionToolParam],
|
tools: list[ChatCompletionToolParam],
|
||||||
@@ -816,7 +1088,15 @@ async def _stream_chat_chunks(
|
|||||||
# Always attempt mitigation when over limit, even with few messages
|
# Always attempt mitigation when over limit, even with few messages
|
||||||
if messages:
|
if messages:
|
||||||
# Split messages based on whether system prompt exists
|
# Split messages based on whether system prompt exists
|
||||||
recent_messages = messages[-KEEP_RECENT:]
|
# Calculate start index for the slice
|
||||||
|
slice_start = max(0, len(messages_dict) - KEEP_RECENT)
|
||||||
|
recent_messages = messages_dict[-KEEP_RECENT:]
|
||||||
|
|
||||||
|
# Ensure tool_call/tool_response pairs stay together
|
||||||
|
# This prevents API errors from orphan tool responses
|
||||||
|
recent_messages = _ensure_tool_pairs_intact(
|
||||||
|
recent_messages, messages_dict, slice_start
|
||||||
|
)
|
||||||
|
|
||||||
if has_system_prompt:
|
if has_system_prompt:
|
||||||
# Keep system prompt separate, summarize everything between system and recent
|
# Keep system prompt separate, summarize everything between system and recent
|
||||||
@@ -903,6 +1183,13 @@ async def _stream_chat_chunks(
|
|||||||
if len(recent_messages) >= keep_count
|
if len(recent_messages) >= keep_count
|
||||||
else recent_messages
|
else recent_messages
|
||||||
)
|
)
|
||||||
|
# Ensure tool pairs stay intact in the reduced slice
|
||||||
|
reduced_slice_start = max(
|
||||||
|
0, len(recent_messages) - keep_count
|
||||||
|
)
|
||||||
|
reduced_recent = _ensure_tool_pairs_intact(
|
||||||
|
reduced_recent, recent_messages, reduced_slice_start
|
||||||
|
)
|
||||||
if has_system_prompt:
|
if has_system_prompt:
|
||||||
messages = [
|
messages = [
|
||||||
system_msg,
|
system_msg,
|
||||||
@@ -961,7 +1248,10 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
# Create a base list excluding system prompt to avoid duplication
|
# Create a base list excluding system prompt to avoid duplication
|
||||||
# This is the pool of messages we'll slice from in the loop
|
# This is the pool of messages we'll slice from in the loop
|
||||||
base_msgs = messages[1:] if has_system_prompt else messages
|
# Use messages_dict for type consistency with _ensure_tool_pairs_intact
|
||||||
|
base_msgs = (
|
||||||
|
messages_dict[1:] if has_system_prompt else messages_dict
|
||||||
|
)
|
||||||
|
|
||||||
# Try progressively smaller keep counts
|
# Try progressively smaller keep counts
|
||||||
new_token_count = token_count # Initialize with current count
|
new_token_count = token_count # Initialize with current count
|
||||||
@@ -984,6 +1274,12 @@ async def _stream_chat_chunks(
|
|||||||
# Slice from base_msgs to get recent messages (without system prompt)
|
# Slice from base_msgs to get recent messages (without system prompt)
|
||||||
recent_messages = base_msgs[-keep_count:]
|
recent_messages = base_msgs[-keep_count:]
|
||||||
|
|
||||||
|
# Ensure tool pairs stay intact in the reduced slice
|
||||||
|
reduced_slice_start = max(0, len(base_msgs) - keep_count)
|
||||||
|
recent_messages = _ensure_tool_pairs_intact(
|
||||||
|
recent_messages, base_msgs, reduced_slice_start
|
||||||
|
)
|
||||||
|
|
||||||
if has_system_prompt:
|
if has_system_prompt:
|
||||||
messages = [system_msg] + recent_messages
|
messages = [system_msg] + recent_messages
|
||||||
else:
|
else:
|
||||||
@@ -1260,17 +1556,19 @@ async def _yield_tool_call(
|
|||||||
"""
|
"""
|
||||||
Yield a tool call and its execution result.
|
Yield a tool call and its execution result.
|
||||||
|
|
||||||
For long-running tools, yields heartbeat events every 15 seconds to keep
|
For tools marked with `is_long_running=True` (like agent generation), spawns a
|
||||||
the SSE connection alive through proxies and load balancers.
|
background task so the operation survives SSE disconnections. For other tools,
|
||||||
|
yields heartbeat events every 15 seconds to keep the SSE connection alive.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
orjson.JSONDecodeError: If tool call arguments cannot be parsed as JSON
|
orjson.JSONDecodeError: If tool call arguments cannot be parsed as JSON
|
||||||
KeyError: If expected tool call fields are missing
|
KeyError: If expected tool call fields are missing
|
||||||
TypeError: If tool call structure is invalid
|
TypeError: If tool call structure is invalid
|
||||||
"""
|
"""
|
||||||
|
import uuid as uuid_module
|
||||||
|
|
||||||
tool_name = tool_calls[yield_idx]["function"]["name"]
|
tool_name = tool_calls[yield_idx]["function"]["name"]
|
||||||
tool_call_id = tool_calls[yield_idx]["id"]
|
tool_call_id = tool_calls[yield_idx]["id"]
|
||||||
logger.info(f"Yielding tool call: {tool_calls[yield_idx]}")
|
|
||||||
|
|
||||||
# Parse tool call arguments - handle empty arguments gracefully
|
# Parse tool call arguments - handle empty arguments gracefully
|
||||||
raw_arguments = tool_calls[yield_idx]["function"]["arguments"]
|
raw_arguments = tool_calls[yield_idx]["function"]["arguments"]
|
||||||
@@ -1285,7 +1583,151 @@ async def _yield_tool_call(
|
|||||||
input=arguments,
|
input=arguments,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run tool execution in background task with heartbeats to keep connection alive
|
# Check if this tool is long-running (survives SSE disconnection)
|
||||||
|
tool = get_tool(tool_name)
|
||||||
|
if tool and tool.is_long_running:
|
||||||
|
# Atomic check-and-set: returns False if operation already running (lost race)
|
||||||
|
if not await _mark_operation_started(tool_call_id):
|
||||||
|
logger.info(
|
||||||
|
f"Tool call {tool_call_id} already in progress, returning status"
|
||||||
|
)
|
||||||
|
# Build dynamic message based on tool name
|
||||||
|
if tool_name == "create_agent":
|
||||||
|
in_progress_msg = "Agent creation already in progress. Please wait..."
|
||||||
|
elif tool_name == "edit_agent":
|
||||||
|
in_progress_msg = "Agent edit already in progress. Please wait..."
|
||||||
|
else:
|
||||||
|
in_progress_msg = f"{tool_name} already in progress. Please wait..."
|
||||||
|
|
||||||
|
yield StreamToolOutputAvailable(
|
||||||
|
toolCallId=tool_call_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
output=OperationInProgressResponse(
|
||||||
|
message=in_progress_msg,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
).model_dump_json(),
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Generate operation ID
|
||||||
|
operation_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
|
# Build a user-friendly message based on tool and arguments
|
||||||
|
if tool_name == "create_agent":
|
||||||
|
agent_desc = arguments.get("description", "")
|
||||||
|
# Truncate long descriptions for the message
|
||||||
|
desc_preview = (
|
||||||
|
(agent_desc[:100] + "...") if len(agent_desc) > 100 else agent_desc
|
||||||
|
)
|
||||||
|
pending_msg = (
|
||||||
|
f"Creating your agent: {desc_preview}"
|
||||||
|
if desc_preview
|
||||||
|
else "Creating agent... This may take a few minutes."
|
||||||
|
)
|
||||||
|
started_msg = (
|
||||||
|
"Agent creation started. You can close this tab - "
|
||||||
|
"check your library in a few minutes."
|
||||||
|
)
|
||||||
|
elif tool_name == "edit_agent":
|
||||||
|
changes = arguments.get("changes", "")
|
||||||
|
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
|
||||||
|
pending_msg = (
|
||||||
|
f"Editing agent: {changes_preview}"
|
||||||
|
if changes_preview
|
||||||
|
else "Editing agent... This may take a few minutes."
|
||||||
|
)
|
||||||
|
started_msg = (
|
||||||
|
"Agent edit started. You can close this tab - "
|
||||||
|
"check your library in a few minutes."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pending_msg = f"Running {tool_name}... This may take a few minutes."
|
||||||
|
started_msg = (
|
||||||
|
f"{tool_name} started. You can close this tab - "
|
||||||
|
"check back in a few minutes."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track appended messages for rollback on failure
|
||||||
|
assistant_message: ChatMessage | None = None
|
||||||
|
pending_message: ChatMessage | None = None
|
||||||
|
|
||||||
|
# Wrap session save and task creation in try-except to release lock on failure
|
||||||
|
try:
|
||||||
|
# Save assistant message with tool_call FIRST (required by LLM)
|
||||||
|
assistant_message = ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[tool_calls[yield_idx]],
|
||||||
|
)
|
||||||
|
session.messages.append(assistant_message)
|
||||||
|
|
||||||
|
# Then save pending tool result
|
||||||
|
pending_message = ChatMessage(
|
||||||
|
role="tool",
|
||||||
|
content=OperationPendingResponse(
|
||||||
|
message=pending_msg,
|
||||||
|
operation_id=operation_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
).model_dump_json(),
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
session.messages.append(pending_message)
|
||||||
|
await upsert_chat_session(session)
|
||||||
|
logger.info(
|
||||||
|
f"Saved pending operation {operation_id} for tool {tool_name} "
|
||||||
|
f"in session {session.session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store task reference in module-level set to prevent GC before completion
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_execute_long_running_tool(
|
||||||
|
tool_name=tool_name,
|
||||||
|
parameters=arguments,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
operation_id=operation_id,
|
||||||
|
session_id=session.session_id,
|
||||||
|
user_id=session.user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
except Exception as e:
|
||||||
|
# Roll back appended messages to prevent data corruption on subsequent saves
|
||||||
|
if (
|
||||||
|
pending_message
|
||||||
|
and session.messages
|
||||||
|
and session.messages[-1] == pending_message
|
||||||
|
):
|
||||||
|
session.messages.pop()
|
||||||
|
if (
|
||||||
|
assistant_message
|
||||||
|
and session.messages
|
||||||
|
and session.messages[-1] == assistant_message
|
||||||
|
):
|
||||||
|
session.messages.pop()
|
||||||
|
|
||||||
|
# Release the Redis lock since the background task won't be spawned
|
||||||
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
logger.error(
|
||||||
|
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Return immediately - don't wait for completion
|
||||||
|
yield StreamToolOutputAvailable(
|
||||||
|
toolCallId=tool_call_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
output=OperationStartedResponse(
|
||||||
|
message=started_msg,
|
||||||
|
operation_id=operation_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
).model_dump_json(),
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Normal flow: Run tool execution in background task with heartbeats
|
||||||
tool_task = asyncio.create_task(
|
tool_task = asyncio.create_task(
|
||||||
execute_tool(
|
execute_tool(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
@@ -1335,3 +1777,190 @@ async def _yield_tool_call(
|
|||||||
)
|
)
|
||||||
|
|
||||||
yield tool_execution_response
|
yield tool_execution_response
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_long_running_tool(
|
||||||
|
tool_name: str,
|
||||||
|
parameters: dict[str, Any],
|
||||||
|
tool_call_id: str,
|
||||||
|
operation_id: str,
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a long-running tool in background and update chat history with result.
|
||||||
|
|
||||||
|
This function runs independently of the SSE connection, so the operation
|
||||||
|
survives if the user closes their browser tab.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load fresh session (not stale reference)
|
||||||
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
if not session:
|
||||||
|
logger.error(f"Session {session_id} not found for background tool")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Execute the actual tool
|
||||||
|
result = await execute_tool(
|
||||||
|
tool_name=tool_name,
|
||||||
|
parameters=parameters,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the pending message with result
|
||||||
|
await _update_pending_operation(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
result=(
|
||||||
|
result.output
|
||||||
|
if isinstance(result.output, str)
|
||||||
|
else orjson.dumps(result.output).decode("utf-8")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Background tool {tool_name} completed for session {session_id}")
|
||||||
|
|
||||||
|
# Generate LLM continuation so user sees response when they poll/refresh
|
||||||
|
await _generate_llm_continuation(session_id=session_id, user_id=user_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message=f"Tool {tool_name} failed: {str(e)}",
|
||||||
|
)
|
||||||
|
await _update_pending_operation(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
result=error_response.model_dump_json(),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_pending_operation(
|
||||||
|
session_id: str,
|
||||||
|
tool_call_id: str,
|
||||||
|
result: str,
|
||||||
|
) -> None:
|
||||||
|
"""Update the pending tool message with final result.
|
||||||
|
|
||||||
|
This is called by background tasks when long-running operations complete.
|
||||||
|
"""
|
||||||
|
# Update the message in database
|
||||||
|
updated = await chat_db.update_tool_message_content(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
new_content=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
if updated:
|
||||||
|
# Invalidate Redis cache so next load gets fresh data
|
||||||
|
# Wrap in try/except to prevent cache failures from triggering error handling
|
||||||
|
# that would overwrite our successful DB update
|
||||||
|
try:
|
||||||
|
await invalidate_session_cache(session_id)
|
||||||
|
except Exception as e:
|
||||||
|
# Non-critical: cache will eventually be refreshed on next load
|
||||||
|
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||||
|
logger.info(
|
||||||
|
f"Updated pending operation for tool_call_id {tool_call_id} "
|
||||||
|
f"in session {session_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to update pending operation for tool_call_id {tool_call_id} "
|
||||||
|
f"in session {session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_llm_continuation(
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Generate an LLM response after a long-running tool completes.
|
||||||
|
|
||||||
|
This is called by background tasks to continue the conversation
|
||||||
|
after a tool result is saved. The response is saved to the database
|
||||||
|
so users see it when they refresh or poll.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load fresh session from DB (bypass cache to get the updated tool result)
|
||||||
|
await invalidate_session_cache(session_id)
|
||||||
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
if not session:
|
||||||
|
logger.error(f"Session {session_id} not found for LLM continuation")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build system prompt
|
||||||
|
system_prompt, _ = await _build_system_prompt(user_id)
|
||||||
|
|
||||||
|
# Build messages in OpenAI format
|
||||||
|
messages = session.to_openai_messages()
|
||||||
|
if system_prompt:
|
||||||
|
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||||
|
|
||||||
|
system_message = ChatCompletionSystemMessageParam(
|
||||||
|
role="system",
|
||||||
|
content=system_prompt,
|
||||||
|
)
|
||||||
|
messages = [system_message] + messages
|
||||||
|
|
||||||
|
# Build extra_body for tracing
|
||||||
|
extra_body: dict[str, Any] = {
|
||||||
|
"posthogProperties": {
|
||||||
|
"environment": settings.config.app_env.value,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if user_id:
|
||||||
|
extra_body["user"] = user_id[:128]
|
||||||
|
extra_body["posthogDistinctId"] = user_id
|
||||||
|
if session_id:
|
||||||
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
|
# Make non-streaming LLM call (no tools - just text response)
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
# No tools parameter = text-only response (no tool calls)
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=config.model,
|
||||||
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
|
extra_body=extra_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.choices and response.choices[0].message.content:
|
||||||
|
assistant_content = response.choices[0].message.content
|
||||||
|
|
||||||
|
# Reload session from DB to avoid race condition with user messages
|
||||||
|
# that may have been sent while we were generating the LLM response
|
||||||
|
fresh_session = await get_chat_session(session_id, user_id)
|
||||||
|
if not fresh_session:
|
||||||
|
logger.error(
|
||||||
|
f"Session {session_id} disappeared during LLM continuation"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save assistant message to database
|
||||||
|
assistant_message = ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=assistant_content,
|
||||||
|
)
|
||||||
|
fresh_session.messages.append(assistant_message)
|
||||||
|
|
||||||
|
# Save to database (not cache) to persist the response
|
||||||
|
await upsert_chat_session(fresh_session)
|
||||||
|
|
||||||
|
# Invalidate cache so next poll/refresh gets fresh data
|
||||||
|
await invalidate_session_cache(session_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generated LLM continuation for session {session_id}, "
|
||||||
|
f"response length: {len(assistant_content)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"LLM continuation returned empty response for {session_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||||
|
|||||||
@@ -49,6 +49,11 @@ tools: list[ChatCompletionToolParam] = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool(tool_name: str) -> BaseTool | None:
|
||||||
|
"""Get a tool instance by name."""
|
||||||
|
return TOOL_REGISTRY.get(tool_name)
|
||||||
|
|
||||||
|
|
||||||
async def execute_tool(
|
async def execute_tool(
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
parameters: dict[str, Any],
|
parameters: dict[str, Any],
|
||||||
@@ -57,7 +62,7 @@ async def execute_tool(
|
|||||||
tool_call_id: str,
|
tool_call_id: str,
|
||||||
) -> "StreamToolOutputAvailable":
|
) -> "StreamToolOutputAvailable":
|
||||||
"""Execute a tool by name."""
|
"""Execute a tool by name."""
|
||||||
tool = TOOL_REGISTRY.get(tool_name)
|
tool = get_tool(tool_name)
|
||||||
if not tool:
|
if not tool:
|
||||||
raise ValueError(f"Tool {tool_name} not found")
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,16 @@ class BaseTool:
|
|||||||
"""Whether this tool requires authentication."""
|
"""Whether this tool requires authentication."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
"""Whether this tool is long-running and should execute in background.
|
||||||
|
|
||||||
|
Long-running tools (like agent generation) are executed via background
|
||||||
|
tasks to survive SSE disconnections. The result is persisted to chat
|
||||||
|
history and visible when the user refreshes.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
def as_openai_tool(self) -> ChatCompletionToolParam:
|
def as_openai_tool(self) -> ChatCompletionToolParam:
|
||||||
"""Convert to OpenAI tool format."""
|
"""Convert to OpenAI tool format."""
|
||||||
return ChatCompletionToolParam(
|
return ChatCompletionToolParam(
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ class CreateAgentTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ class EditAgentTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ class ResponseType(str, Enum):
|
|||||||
BLOCK_OUTPUT = "block_output"
|
BLOCK_OUTPUT = "block_output"
|
||||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||||
DOC_PAGE = "doc_page"
|
DOC_PAGE = "doc_page"
|
||||||
|
# Long-running operation types
|
||||||
|
OPERATION_STARTED = "operation_started"
|
||||||
|
OPERATION_PENDING = "operation_pending"
|
||||||
|
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||||
|
|
||||||
|
|
||||||
# Base response model
|
# Base response model
|
||||||
@@ -334,3 +338,39 @@ class BlockOutputResponse(ToolResponseBase):
|
|||||||
block_name: str
|
block_name: str
|
||||||
outputs: dict[str, list[Any]]
|
outputs: dict[str, list[Any]]
|
||||||
success: bool = True
|
success: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
# Long-running operation models
|
||||||
|
class OperationStartedResponse(ToolResponseBase):
|
||||||
|
"""Response when a long-running operation has been started in the background.
|
||||||
|
|
||||||
|
This is returned immediately to the client while the operation continues
|
||||||
|
to execute. The user can close the tab and check back later.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||||
|
operation_id: str
|
||||||
|
tool_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class OperationPendingResponse(ToolResponseBase):
|
||||||
|
"""Response stored in chat history while a long-running operation is executing.
|
||||||
|
|
||||||
|
This is persisted to the database so users see a pending state when they
|
||||||
|
refresh before the operation completes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_PENDING
|
||||||
|
operation_id: str
|
||||||
|
tool_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class OperationInProgressResponse(ToolResponseBase):
|
||||||
|
"""Response when an operation is already in progress.
|
||||||
|
|
||||||
|
Returned for idempotency when the same tool_call_id is requested again
|
||||||
|
while the background task is still running.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||||
|
tool_call_id: str
|
||||||
|
|||||||
@@ -359,8 +359,8 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
description="The port for the Agent Generator service",
|
description="The port for the Agent Generator service",
|
||||||
)
|
)
|
||||||
agentgenerator_timeout: int = Field(
|
agentgenerator_timeout: int = Field(
|
||||||
default=120,
|
default=600,
|
||||||
description="The timeout in seconds for Agent Generator service requests",
|
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
|
||||||
)
|
)
|
||||||
|
|
||||||
enable_example_blocks: bool = Field(
|
enable_example_blocks: bool = Field(
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
||||||
import type { ReactNode } from "react";
|
import type { ReactNode } from "react";
|
||||||
import { useEffect } from "react";
|
|
||||||
import { useCopilotStore } from "../../copilot-page-store";
|
|
||||||
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
|
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
|
||||||
import { LoadingState } from "./components/LoadingState/LoadingState";
|
|
||||||
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
||||||
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
||||||
import { useCopilotShell } from "./useCopilotShell";
|
import { useCopilotShell } from "./useCopilotShell";
|
||||||
@@ -20,38 +18,21 @@ export function CopilotShell({ children }: Props) {
|
|||||||
isMobile,
|
isMobile,
|
||||||
isDrawerOpen,
|
isDrawerOpen,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
isCreatingSession,
|
||||||
isLoggedIn,
|
isLoggedIn,
|
||||||
hasActiveSession,
|
hasActiveSession,
|
||||||
sessions,
|
sessions,
|
||||||
currentSessionId,
|
currentSessionId,
|
||||||
handleSelectSession,
|
|
||||||
handleOpenDrawer,
|
handleOpenDrawer,
|
||||||
handleCloseDrawer,
|
handleCloseDrawer,
|
||||||
handleDrawerOpenChange,
|
handleDrawerOpenChange,
|
||||||
handleNewChat,
|
handleNewChatClick,
|
||||||
|
handleSessionClick,
|
||||||
hasNextPage,
|
hasNextPage,
|
||||||
isFetchingNextPage,
|
isFetchingNextPage,
|
||||||
fetchNextPage,
|
fetchNextPage,
|
||||||
isReadyToShowContent,
|
|
||||||
} = useCopilotShell();
|
} = useCopilotShell();
|
||||||
|
|
||||||
const setNewChatHandler = useCopilotStore((s) => s.setNewChatHandler);
|
|
||||||
const requestNewChat = useCopilotStore((s) => s.requestNewChat);
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function registerNewChatHandler() {
|
|
||||||
setNewChatHandler(handleNewChat);
|
|
||||||
return function cleanup() {
|
|
||||||
setNewChatHandler(null);
|
|
||||||
};
|
|
||||||
},
|
|
||||||
[handleNewChat],
|
|
||||||
);
|
|
||||||
|
|
||||||
function handleNewChatClick() {
|
|
||||||
requestNewChat();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isLoggedIn) {
|
if (!isLoggedIn) {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full items-center justify-center">
|
<div className="flex h-full items-center justify-center">
|
||||||
@@ -72,7 +53,7 @@ export function CopilotShell({ children }: Props) {
|
|||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
hasNextPage={hasNextPage}
|
hasNextPage={hasNextPage}
|
||||||
isFetchingNextPage={isFetchingNextPage}
|
isFetchingNextPage={isFetchingNextPage}
|
||||||
onSelectSession={handleSelectSession}
|
onSelectSession={handleSessionClick}
|
||||||
onFetchNextPage={fetchNextPage}
|
onFetchNextPage={fetchNextPage}
|
||||||
onNewChat={handleNewChatClick}
|
onNewChat={handleNewChatClick}
|
||||||
hasActiveSession={Boolean(hasActiveSession)}
|
hasActiveSession={Boolean(hasActiveSession)}
|
||||||
@@ -82,7 +63,18 @@ export function CopilotShell({ children }: Props) {
|
|||||||
<div className="relative flex min-h-0 flex-1 flex-col">
|
<div className="relative flex min-h-0 flex-1 flex-col">
|
||||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||||
<div className="flex min-h-0 flex-1 flex-col">
|
<div className="flex min-h-0 flex-1 flex-col">
|
||||||
{isReadyToShowContent ? children : <LoadingState />}
|
{isCreatingSession ? (
|
||||||
|
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
||||||
|
<div className="flex flex-col items-center gap-4">
|
||||||
|
<ChatLoader />
|
||||||
|
<Text variant="body" className="text-zinc-500">
|
||||||
|
Creating your chat...
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
children
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -94,7 +86,7 @@ export function CopilotShell({ children }: Props) {
|
|||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
hasNextPage={hasNextPage}
|
hasNextPage={hasNextPage}
|
||||||
isFetchingNextPage={isFetchingNextPage}
|
isFetchingNextPage={isFetchingNextPage}
|
||||||
onSelectSession={handleSelectSession}
|
onSelectSession={handleSessionClick}
|
||||||
onFetchNextPage={fetchNextPage}
|
onFetchNextPage={fetchNextPage}
|
||||||
onNewChat={handleNewChatClick}
|
onNewChat={handleNewChatClick}
|
||||||
onClose={handleCloseDrawer}
|
onClose={handleCloseDrawer}
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
|
||||||
|
|
||||||
export function LoadingState() {
|
|
||||||
return (
|
|
||||||
<div className="flex flex-1 items-center justify-center">
|
|
||||||
<div className="flex flex-col items-center gap-4">
|
|
||||||
<ChatLoader />
|
|
||||||
<Text variant="body" className="text-zinc-500">
|
|
||||||
Loading your chats...
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -3,17 +3,17 @@ import { useState } from "react";
|
|||||||
export function useMobileDrawer() {
|
export function useMobileDrawer() {
|
||||||
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
||||||
|
|
||||||
function handleOpenDrawer() {
|
const handleOpenDrawer = () => {
|
||||||
setIsDrawerOpen(true);
|
setIsDrawerOpen(true);
|
||||||
}
|
};
|
||||||
|
|
||||||
function handleCloseDrawer() {
|
const handleCloseDrawer = () => {
|
||||||
setIsDrawerOpen(false);
|
setIsDrawerOpen(false);
|
||||||
}
|
};
|
||||||
|
|
||||||
function handleDrawerOpenChange(open: boolean) {
|
const handleDrawerOpenChange = (open: boolean) => {
|
||||||
setIsDrawerOpen(open);
|
setIsDrawerOpen(open);
|
||||||
}
|
};
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isDrawerOpen,
|
isDrawerOpen,
|
||||||
|
|||||||
@@ -1,11 +1,6 @@
|
|||||||
import {
|
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
getGetV2ListSessionsQueryKey,
|
|
||||||
useGetV2ListSessions,
|
|
||||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
|
||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||||
import { okData } from "@/app/api/helpers";
|
import { okData } from "@/app/api/helpers";
|
||||||
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
|
|
||||||
const PAGE_SIZE = 50;
|
const PAGE_SIZE = 50;
|
||||||
@@ -16,12 +11,12 @@ export interface UseSessionsPaginationArgs {
|
|||||||
|
|
||||||
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
||||||
const [offset, setOffset] = useState(0);
|
const [offset, setOffset] = useState(0);
|
||||||
|
|
||||||
const [accumulatedSessions, setAccumulatedSessions] = useState<
|
const [accumulatedSessions, setAccumulatedSessions] = useState<
|
||||||
SessionSummaryResponse[]
|
SessionSummaryResponse[]
|
||||||
>([]);
|
>([]);
|
||||||
|
|
||||||
const [totalCount, setTotalCount] = useState<number | null>(null);
|
const [totalCount, setTotalCount] = useState<number | null>(null);
|
||||||
const queryClient = useQueryClient();
|
|
||||||
const onStreamComplete = useChatStore((state) => state.onStreamComplete);
|
|
||||||
|
|
||||||
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
|
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
|
||||||
{ limit: PAGE_SIZE, offset },
|
{ limit: PAGE_SIZE, offset },
|
||||||
@@ -32,38 +27,23 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
useEffect(function refreshOnStreamComplete() {
|
useEffect(() => {
|
||||||
const unsubscribe = onStreamComplete(function handleStreamComplete() {
|
const responseData = okData(data);
|
||||||
setOffset(0);
|
if (responseData) {
|
||||||
|
const newSessions = responseData.sessions;
|
||||||
|
const total = responseData.total;
|
||||||
|
setTotalCount(total);
|
||||||
|
|
||||||
|
if (offset === 0) {
|
||||||
|
setAccumulatedSessions(newSessions);
|
||||||
|
} else {
|
||||||
|
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
|
||||||
|
}
|
||||||
|
} else if (!enabled) {
|
||||||
setAccumulatedSessions([]);
|
setAccumulatedSessions([]);
|
||||||
setTotalCount(null);
|
setTotalCount(null);
|
||||||
queryClient.invalidateQueries({
|
}
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
}, [data, offset, enabled]);
|
||||||
});
|
|
||||||
});
|
|
||||||
return unsubscribe;
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function updateSessionsFromResponse() {
|
|
||||||
const responseData = okData(data);
|
|
||||||
if (responseData) {
|
|
||||||
const newSessions = responseData.sessions;
|
|
||||||
const total = responseData.total;
|
|
||||||
setTotalCount(total);
|
|
||||||
|
|
||||||
if (offset === 0) {
|
|
||||||
setAccumulatedSessions(newSessions);
|
|
||||||
} else {
|
|
||||||
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
|
|
||||||
}
|
|
||||||
} else if (!enabled) {
|
|
||||||
setAccumulatedSessions([]);
|
|
||||||
setTotalCount(null);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[data, offset, enabled],
|
|
||||||
);
|
|
||||||
|
|
||||||
const hasNextPage =
|
const hasNextPage =
|
||||||
totalCount !== null && accumulatedSessions.length < totalCount;
|
totalCount !== null && accumulatedSessions.length < totalCount;
|
||||||
@@ -86,17 +66,17 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
|||||||
}
|
}
|
||||||
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
|
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
|
||||||
|
|
||||||
function fetchNextPage() {
|
const fetchNextPage = () => {
|
||||||
if (hasNextPage && !isFetching) {
|
if (hasNextPage && !isFetching) {
|
||||||
setOffset((prev) => prev + PAGE_SIZE);
|
setOffset((prev) => prev + PAGE_SIZE);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
function reset() {
|
const reset = () => {
|
||||||
setOffset(0);
|
setOffset(0);
|
||||||
setAccumulatedSessions([]);
|
setAccumulatedSessions([]);
|
||||||
setTotalCount(null);
|
setTotalCount(null);
|
||||||
}
|
};
|
||||||
|
|
||||||
return {
|
return {
|
||||||
sessions: accumulatedSessions,
|
sessions: accumulatedSessions,
|
||||||
|
|||||||
@@ -104,76 +104,3 @@ export function mergeCurrentSessionIntoList(
|
|||||||
export function getCurrentSessionId(searchParams: URLSearchParams) {
|
export function getCurrentSessionId(searchParams: URLSearchParams) {
|
||||||
return searchParams.get("sessionId");
|
return searchParams.get("sessionId");
|
||||||
}
|
}
|
||||||
|
|
||||||
export function shouldAutoSelectSession(
|
|
||||||
areAllSessionsLoaded: boolean,
|
|
||||||
hasAutoSelectedSession: boolean,
|
|
||||||
paramSessionId: string | null,
|
|
||||||
visibleSessions: SessionSummaryResponse[],
|
|
||||||
accumulatedSessions: SessionSummaryResponse[],
|
|
||||||
isLoading: boolean,
|
|
||||||
totalCount: number | null,
|
|
||||||
) {
|
|
||||||
if (!areAllSessionsLoaded || hasAutoSelectedSession) {
|
|
||||||
return {
|
|
||||||
shouldSelect: false,
|
|
||||||
sessionIdToSelect: null,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (paramSessionId) {
|
|
||||||
return {
|
|
||||||
shouldSelect: false,
|
|
||||||
sessionIdToSelect: null,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (visibleSessions.length > 0) {
|
|
||||||
return {
|
|
||||||
shouldSelect: true,
|
|
||||||
sessionIdToSelect: visibleSessions[0].id,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (accumulatedSessions.length === 0 && !isLoading && totalCount === 0) {
|
|
||||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: true };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (totalCount === 0) {
|
|
||||||
return {
|
|
||||||
shouldSelect: false,
|
|
||||||
sessionIdToSelect: null,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: false };
|
|
||||||
}
|
|
||||||
|
|
||||||
export function checkReadyToShowContent(
|
|
||||||
areAllSessionsLoaded: boolean,
|
|
||||||
paramSessionId: string | null,
|
|
||||||
accumulatedSessions: SessionSummaryResponse[],
|
|
||||||
isCurrentSessionLoading: boolean,
|
|
||||||
currentSessionData: SessionDetailResponse | null | undefined,
|
|
||||||
hasAutoSelectedSession: boolean,
|
|
||||||
) {
|
|
||||||
if (!areAllSessionsLoaded) return false;
|
|
||||||
|
|
||||||
if (paramSessionId) {
|
|
||||||
const sessionFound = accumulatedSessions.some(
|
|
||||||
(s) => s.id === paramSessionId,
|
|
||||||
);
|
|
||||||
return (
|
|
||||||
sessionFound ||
|
|
||||||
(!isCurrentSessionLoading &&
|
|
||||||
currentSessionData !== undefined &&
|
|
||||||
currentSessionData !== null)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return hasAutoSelectedSession;
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,26 +1,22 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
|
getGetV2GetSessionQueryKey,
|
||||||
getGetV2ListSessionsQueryKey,
|
getGetV2ListSessionsQueryKey,
|
||||||
useGetV2GetSession,
|
useGetV2GetSession,
|
||||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
|
||||||
import { okData } from "@/app/api/helpers";
|
import { okData } from "@/app/api/helpers";
|
||||||
|
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
||||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { parseAsString, useQueryState } from "nuqs";
|
|
||||||
import { usePathname, useSearchParams } from "next/navigation";
|
import { usePathname, useSearchParams } from "next/navigation";
|
||||||
import { useEffect, useRef, useState } from "react";
|
import { useRef } from "react";
|
||||||
|
import { useCopilotStore } from "../../copilot-page-store";
|
||||||
|
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
||||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
||||||
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
|
import { getCurrentSessionId } from "./helpers";
|
||||||
import {
|
import { useShellSessionList } from "./useShellSessionList";
|
||||||
checkReadyToShowContent,
|
|
||||||
convertSessionDetailToSummary,
|
|
||||||
filterVisibleSessions,
|
|
||||||
getCurrentSessionId,
|
|
||||||
mergeCurrentSessionIntoList,
|
|
||||||
} from "./helpers";
|
|
||||||
|
|
||||||
export function useCopilotShell() {
|
export function useCopilotShell() {
|
||||||
const pathname = usePathname();
|
const pathname = usePathname();
|
||||||
@@ -31,7 +27,7 @@ export function useCopilotShell() {
|
|||||||
const isMobile =
|
const isMobile =
|
||||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||||
|
|
||||||
const [, setUrlSessionId] = useQueryState("sessionId", parseAsString);
|
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
|
||||||
|
|
||||||
const isOnHomepage = pathname === "/copilot";
|
const isOnHomepage = pathname === "/copilot";
|
||||||
const paramSessionId = searchParams.get("sessionId");
|
const paramSessionId = searchParams.get("sessionId");
|
||||||
@@ -45,123 +41,80 @@ export function useCopilotShell() {
|
|||||||
|
|
||||||
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
|
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
|
||||||
|
|
||||||
const {
|
|
||||||
sessions: accumulatedSessions,
|
|
||||||
isLoading: isSessionsLoading,
|
|
||||||
isFetching: isSessionsFetching,
|
|
||||||
hasNextPage,
|
|
||||||
areAllSessionsLoaded,
|
|
||||||
fetchNextPage,
|
|
||||||
reset: resetPagination,
|
|
||||||
} = useSessionsPagination({
|
|
||||||
enabled: paginationEnabled,
|
|
||||||
});
|
|
||||||
|
|
||||||
const currentSessionId = getCurrentSessionId(searchParams);
|
const currentSessionId = getCurrentSessionId(searchParams);
|
||||||
|
|
||||||
const { data: currentSessionData, isLoading: isCurrentSessionLoading } =
|
const { data: currentSessionData } = useGetV2GetSession(
|
||||||
useGetV2GetSession(currentSessionId || "", {
|
currentSessionId || "",
|
||||||
|
{
|
||||||
query: {
|
query: {
|
||||||
enabled: !!currentSessionId,
|
enabled: !!currentSessionId,
|
||||||
select: okData,
|
select: okData,
|
||||||
},
|
},
|
||||||
});
|
},
|
||||||
|
|
||||||
const [hasAutoSelectedSession, setHasAutoSelectedSession] = useState(false);
|
|
||||||
const hasAutoSelectedRef = useRef(false);
|
|
||||||
const recentlyCreatedSessionsRef = useRef<
|
|
||||||
Map<string, SessionSummaryResponse>
|
|
||||||
>(new Map());
|
|
||||||
|
|
||||||
// Mark as auto-selected when sessionId is in URL
|
|
||||||
useEffect(() => {
|
|
||||||
if (paramSessionId && !hasAutoSelectedRef.current) {
|
|
||||||
hasAutoSelectedRef.current = true;
|
|
||||||
setHasAutoSelectedSession(true);
|
|
||||||
}
|
|
||||||
}, [paramSessionId]);
|
|
||||||
|
|
||||||
// On homepage without sessionId, mark as ready immediately
|
|
||||||
useEffect(() => {
|
|
||||||
if (isOnHomepage && !paramSessionId && !hasAutoSelectedRef.current) {
|
|
||||||
hasAutoSelectedRef.current = true;
|
|
||||||
setHasAutoSelectedSession(true);
|
|
||||||
}
|
|
||||||
}, [isOnHomepage, paramSessionId]);
|
|
||||||
|
|
||||||
// Invalidate sessions list when navigating to homepage (to show newly created sessions)
|
|
||||||
useEffect(() => {
|
|
||||||
if (isOnHomepage && !paramSessionId) {
|
|
||||||
queryClient.invalidateQueries({
|
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [isOnHomepage, paramSessionId, queryClient]);
|
|
||||||
|
|
||||||
// Track newly created sessions to ensure they stay visible even when switching away
|
|
||||||
useEffect(() => {
|
|
||||||
if (currentSessionId && currentSessionData) {
|
|
||||||
const isNewSession =
|
|
||||||
currentSessionData.updated_at === currentSessionData.created_at;
|
|
||||||
const isNotInAccumulated = !accumulatedSessions.some(
|
|
||||||
(s) => s.id === currentSessionId,
|
|
||||||
);
|
|
||||||
if (isNewSession || isNotInAccumulated) {
|
|
||||||
const summary = convertSessionDetailToSummary(currentSessionData);
|
|
||||||
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, [currentSessionId, currentSessionData, accumulatedSessions]);
|
|
||||||
|
|
||||||
// Clean up recently created sessions that are now in the accumulated list
|
|
||||||
useEffect(() => {
|
|
||||||
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
|
|
||||||
if (accumulatedSessions.some((s) => s.id === sessionId)) {
|
|
||||||
recentlyCreatedSessionsRef.current.delete(sessionId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, [accumulatedSessions]);
|
|
||||||
|
|
||||||
// Reset pagination when query becomes disabled
|
|
||||||
const prevPaginationEnabledRef = useRef(paginationEnabled);
|
|
||||||
useEffect(() => {
|
|
||||||
if (prevPaginationEnabledRef.current && !paginationEnabled) {
|
|
||||||
resetPagination();
|
|
||||||
resetAutoSelect();
|
|
||||||
}
|
|
||||||
prevPaginationEnabledRef.current = paginationEnabled;
|
|
||||||
}, [paginationEnabled, resetPagination]);
|
|
||||||
|
|
||||||
const sessions = mergeCurrentSessionIntoList(
|
|
||||||
accumulatedSessions,
|
|
||||||
currentSessionId,
|
|
||||||
currentSessionData,
|
|
||||||
recentlyCreatedSessionsRef.current,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const visibleSessions = filterVisibleSessions(sessions);
|
const {
|
||||||
|
sessions,
|
||||||
|
isLoading,
|
||||||
|
isSessionsFetching,
|
||||||
|
hasNextPage,
|
||||||
|
fetchNextPage,
|
||||||
|
resetPagination,
|
||||||
|
recentlyCreatedSessionsRef,
|
||||||
|
} = useShellSessionList({
|
||||||
|
paginationEnabled,
|
||||||
|
currentSessionId,
|
||||||
|
currentSessionData,
|
||||||
|
isOnHomepage,
|
||||||
|
paramSessionId,
|
||||||
|
});
|
||||||
|
|
||||||
const sidebarSelectedSessionId =
|
const stopStream = useChatStore((s) => s.stopStream);
|
||||||
isOnHomepage && !paramSessionId ? null : currentSessionId;
|
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||||
|
const isStreaming = useCopilotStore((s) => s.isStreaming);
|
||||||
|
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
||||||
|
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
|
||||||
|
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
|
||||||
|
|
||||||
const isReadyToShowContent = isOnHomepage
|
const pendingActionRef = useRef<(() => void) | null>(null);
|
||||||
? true
|
|
||||||
: checkReadyToShowContent(
|
|
||||||
areAllSessionsLoaded,
|
|
||||||
paramSessionId,
|
|
||||||
accumulatedSessions,
|
|
||||||
isCurrentSessionLoading,
|
|
||||||
currentSessionData,
|
|
||||||
hasAutoSelectedSession,
|
|
||||||
);
|
|
||||||
|
|
||||||
function handleSelectSession(sessionId: string) {
|
async function stopCurrentStream() {
|
||||||
|
if (!currentSessionId) return;
|
||||||
|
|
||||||
|
setIsSwitchingSession(true);
|
||||||
|
await new Promise<void>((resolve) => {
|
||||||
|
const unsubscribe = onStreamComplete((completedId) => {
|
||||||
|
if (completedId === currentSessionId) {
|
||||||
|
clearTimeout(timeout);
|
||||||
|
unsubscribe();
|
||||||
|
resolve();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const timeout = setTimeout(() => {
|
||||||
|
unsubscribe();
|
||||||
|
resolve();
|
||||||
|
}, 3000);
|
||||||
|
stopStream(currentSessionId);
|
||||||
|
});
|
||||||
|
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
|
||||||
|
});
|
||||||
|
setIsSwitchingSession(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
function selectSession(sessionId: string) {
|
||||||
|
if (sessionId === currentSessionId) return;
|
||||||
|
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||||
|
});
|
||||||
|
}
|
||||||
setUrlSessionId(sessionId, { shallow: false });
|
setUrlSessionId(sessionId, { shallow: false });
|
||||||
if (isMobile) handleCloseDrawer();
|
if (isMobile) handleCloseDrawer();
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleNewChat() {
|
function startNewChat() {
|
||||||
resetAutoSelect();
|
|
||||||
resetPagination();
|
resetPagination();
|
||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
@@ -170,12 +123,31 @@ export function useCopilotShell() {
|
|||||||
if (isMobile) handleCloseDrawer();
|
if (isMobile) handleCloseDrawer();
|
||||||
}
|
}
|
||||||
|
|
||||||
function resetAutoSelect() {
|
function handleSessionClick(sessionId: string) {
|
||||||
hasAutoSelectedRef.current = false;
|
if (sessionId === currentSessionId) return;
|
||||||
setHasAutoSelectedSession(false);
|
|
||||||
|
if (isStreaming) {
|
||||||
|
pendingActionRef.current = async () => {
|
||||||
|
await stopCurrentStream();
|
||||||
|
selectSession(sessionId);
|
||||||
|
};
|
||||||
|
openInterruptModal(pendingActionRef.current);
|
||||||
|
} else {
|
||||||
|
selectSession(sessionId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
|
function handleNewChatClick() {
|
||||||
|
if (isStreaming) {
|
||||||
|
pendingActionRef.current = async () => {
|
||||||
|
await stopCurrentStream();
|
||||||
|
startNewChat();
|
||||||
|
};
|
||||||
|
openInterruptModal(pendingActionRef.current);
|
||||||
|
} else {
|
||||||
|
startNewChat();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isMobile,
|
isMobile,
|
||||||
@@ -183,17 +155,17 @@ export function useCopilotShell() {
|
|||||||
isLoggedIn,
|
isLoggedIn,
|
||||||
hasActiveSession:
|
hasActiveSession:
|
||||||
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
|
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
|
||||||
isLoading,
|
isLoading: isLoading || isCreatingSession,
|
||||||
sessions: visibleSessions,
|
isCreatingSession,
|
||||||
currentSessionId: sidebarSelectedSessionId,
|
sessions,
|
||||||
handleSelectSession,
|
currentSessionId: urlSessionId,
|
||||||
handleOpenDrawer,
|
handleOpenDrawer,
|
||||||
handleCloseDrawer,
|
handleCloseDrawer,
|
||||||
handleDrawerOpenChange,
|
handleDrawerOpenChange,
|
||||||
handleNewChat,
|
handleNewChatClick,
|
||||||
|
handleSessionClick,
|
||||||
hasNextPage,
|
hasNextPage,
|
||||||
isFetchingNextPage: isSessionsFetching,
|
isFetchingNextPage: isSessionsFetching,
|
||||||
fetchNextPage,
|
fetchNextPage,
|
||||||
isReadyToShowContent,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,113 @@
|
|||||||
|
import { getGetV2ListSessionsQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
|
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||||
|
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||||
|
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
||||||
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
|
import { useEffect, useMemo, useRef } from "react";
|
||||||
|
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
|
||||||
|
import {
|
||||||
|
convertSessionDetailToSummary,
|
||||||
|
filterVisibleSessions,
|
||||||
|
mergeCurrentSessionIntoList,
|
||||||
|
} from "./helpers";
|
||||||
|
|
||||||
|
interface UseShellSessionListArgs {
|
||||||
|
paginationEnabled: boolean;
|
||||||
|
currentSessionId: string | null;
|
||||||
|
currentSessionData: SessionDetailResponse | null | undefined;
|
||||||
|
isOnHomepage: boolean;
|
||||||
|
paramSessionId: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useShellSessionList({
|
||||||
|
paginationEnabled,
|
||||||
|
currentSessionId,
|
||||||
|
currentSessionData,
|
||||||
|
isOnHomepage,
|
||||||
|
paramSessionId,
|
||||||
|
}: UseShellSessionListArgs) {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||||
|
|
||||||
|
const {
|
||||||
|
sessions: accumulatedSessions,
|
||||||
|
isLoading: isSessionsLoading,
|
||||||
|
isFetching: isSessionsFetching,
|
||||||
|
hasNextPage,
|
||||||
|
fetchNextPage,
|
||||||
|
reset: resetPagination,
|
||||||
|
} = useSessionsPagination({
|
||||||
|
enabled: paginationEnabled,
|
||||||
|
});
|
||||||
|
|
||||||
|
const recentlyCreatedSessionsRef = useRef<
|
||||||
|
Map<string, SessionSummaryResponse>
|
||||||
|
>(new Map());
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isOnHomepage && !paramSessionId) {
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [isOnHomepage, paramSessionId, queryClient]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (currentSessionId && currentSessionData) {
|
||||||
|
const isNewSession =
|
||||||
|
currentSessionData.updated_at === currentSessionData.created_at;
|
||||||
|
const isNotInAccumulated = !accumulatedSessions.some(
|
||||||
|
(s) => s.id === currentSessionId,
|
||||||
|
);
|
||||||
|
if (isNewSession || isNotInAccumulated) {
|
||||||
|
const summary = convertSessionDetailToSummary(currentSessionData);
|
||||||
|
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [currentSessionId, currentSessionData, accumulatedSessions]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
|
||||||
|
if (accumulatedSessions.some((s) => s.id === sessionId)) {
|
||||||
|
recentlyCreatedSessionsRef.current.delete(sessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [accumulatedSessions]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const unsubscribe = onStreamComplete(() => {
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
|
});
|
||||||
|
});
|
||||||
|
return unsubscribe;
|
||||||
|
}, [onStreamComplete, queryClient]);
|
||||||
|
|
||||||
|
const sessions = useMemo(
|
||||||
|
() =>
|
||||||
|
mergeCurrentSessionIntoList(
|
||||||
|
accumulatedSessions,
|
||||||
|
currentSessionId,
|
||||||
|
currentSessionData,
|
||||||
|
recentlyCreatedSessionsRef.current,
|
||||||
|
),
|
||||||
|
[accumulatedSessions, currentSessionId, currentSessionData],
|
||||||
|
);
|
||||||
|
|
||||||
|
const visibleSessions = useMemo(
|
||||||
|
() => filterVisibleSessions(sessions),
|
||||||
|
[sessions],
|
||||||
|
);
|
||||||
|
|
||||||
|
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
|
||||||
|
|
||||||
|
return {
|
||||||
|
sessions: visibleSessions,
|
||||||
|
isLoading,
|
||||||
|
isSessionsFetching,
|
||||||
|
hasNextPage,
|
||||||
|
fetchNextPage,
|
||||||
|
resetPagination,
|
||||||
|
recentlyCreatedSessionsRef,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -4,51 +4,53 @@ import { create } from "zustand";
|
|||||||
|
|
||||||
interface CopilotStoreState {
|
interface CopilotStoreState {
|
||||||
isStreaming: boolean;
|
isStreaming: boolean;
|
||||||
isNewChatModalOpen: boolean;
|
isSwitchingSession: boolean;
|
||||||
newChatHandler: (() => void) | null;
|
isCreatingSession: boolean;
|
||||||
|
isInterruptModalOpen: boolean;
|
||||||
|
pendingAction: (() => void) | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface CopilotStoreActions {
|
interface CopilotStoreActions {
|
||||||
setIsStreaming: (isStreaming: boolean) => void;
|
setIsStreaming: (isStreaming: boolean) => void;
|
||||||
setNewChatHandler: (handler: (() => void) | null) => void;
|
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
|
||||||
requestNewChat: () => void;
|
setIsCreatingSession: (isCreating: boolean) => void;
|
||||||
confirmNewChat: () => void;
|
openInterruptModal: (onConfirm: () => void) => void;
|
||||||
cancelNewChat: () => void;
|
confirmInterrupt: () => void;
|
||||||
|
cancelInterrupt: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
type CopilotStore = CopilotStoreState & CopilotStoreActions;
|
type CopilotStore = CopilotStoreState & CopilotStoreActions;
|
||||||
|
|
||||||
export const useCopilotStore = create<CopilotStore>((set, get) => ({
|
export const useCopilotStore = create<CopilotStore>((set, get) => ({
|
||||||
isStreaming: false,
|
isStreaming: false,
|
||||||
isNewChatModalOpen: false,
|
isSwitchingSession: false,
|
||||||
newChatHandler: null,
|
isCreatingSession: false,
|
||||||
|
isInterruptModalOpen: false,
|
||||||
|
pendingAction: null,
|
||||||
|
|
||||||
setIsStreaming(isStreaming) {
|
setIsStreaming(isStreaming) {
|
||||||
set({ isStreaming });
|
set({ isStreaming });
|
||||||
},
|
},
|
||||||
|
|
||||||
setNewChatHandler(handler) {
|
setIsSwitchingSession(isSwitchingSession) {
|
||||||
set({ newChatHandler: handler });
|
set({ isSwitchingSession });
|
||||||
},
|
},
|
||||||
|
|
||||||
requestNewChat() {
|
setIsCreatingSession(isCreatingSession) {
|
||||||
const { isStreaming, newChatHandler } = get();
|
set({ isCreatingSession });
|
||||||
if (isStreaming) {
|
|
||||||
set({ isNewChatModalOpen: true });
|
|
||||||
} else if (newChatHandler) {
|
|
||||||
newChatHandler();
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
|
||||||
confirmNewChat() {
|
openInterruptModal(onConfirm) {
|
||||||
const { newChatHandler } = get();
|
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
|
||||||
set({ isNewChatModalOpen: false });
|
|
||||||
if (newChatHandler) {
|
|
||||||
newChatHandler();
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
|
||||||
cancelNewChat() {
|
confirmInterrupt() {
|
||||||
set({ isNewChatModalOpen: false });
|
const { pendingAction } = get();
|
||||||
|
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||||
|
if (pendingAction) pendingAction();
|
||||||
|
},
|
||||||
|
|
||||||
|
cancelInterrupt() {
|
||||||
|
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|||||||
@@ -1,28 +1,5 @@
|
|||||||
import type { User } from "@supabase/supabase-js";
|
import type { User } from "@supabase/supabase-js";
|
||||||
|
|
||||||
export type PageState =
|
|
||||||
| { type: "welcome" }
|
|
||||||
| { type: "newChat" }
|
|
||||||
| { type: "creating"; prompt: string }
|
|
||||||
| { type: "chat"; sessionId: string; initialPrompt?: string };
|
|
||||||
|
|
||||||
export function getInitialPromptFromState(
|
|
||||||
pageState: PageState,
|
|
||||||
storedInitialPrompt: string | undefined,
|
|
||||||
) {
|
|
||||||
if (storedInitialPrompt) return storedInitialPrompt;
|
|
||||||
if (pageState.type === "creating") return pageState.prompt;
|
|
||||||
if (pageState.type === "chat") return pageState.initialPrompt;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function shouldResetToWelcome(pageState: PageState) {
|
|
||||||
return (
|
|
||||||
pageState.type !== "newChat" &&
|
|
||||||
pageState.type !== "creating" &&
|
|
||||||
pageState.type !== "welcome"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getGreetingName(user?: User | null): string {
|
export function getGreetingName(user?: User | null): string {
|
||||||
if (!user) return "there";
|
if (!user) return "there";
|
||||||
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
||||||
|
|||||||
@@ -1,25 +1,25 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
|
|
||||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
|
||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
import { useCopilotStore } from "./copilot-page-store";
|
import { useCopilotStore } from "./copilot-page-store";
|
||||||
import { useCopilotPage } from "./useCopilotPage";
|
import { useCopilotPage } from "./useCopilotPage";
|
||||||
|
|
||||||
export default function CopilotPage() {
|
export default function CopilotPage() {
|
||||||
const { state, handlers } = useCopilotPage();
|
const { state, handlers } = useCopilotPage();
|
||||||
const confirmNewChat = useCopilotStore((s) => s.confirmNewChat);
|
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||||
|
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||||
|
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||||
const {
|
const {
|
||||||
greetingName,
|
greetingName,
|
||||||
quickActions,
|
quickActions,
|
||||||
isLoading,
|
isLoading,
|
||||||
pageState,
|
hasSession,
|
||||||
isNewChatModalOpen,
|
initialPrompt,
|
||||||
isReady,
|
isReady,
|
||||||
} = state;
|
} = state;
|
||||||
const {
|
const {
|
||||||
@@ -27,20 +27,16 @@ export default function CopilotPage() {
|
|||||||
startChatWithPrompt,
|
startChatWithPrompt,
|
||||||
handleSessionNotFound,
|
handleSessionNotFound,
|
||||||
handleStreamingChange,
|
handleStreamingChange,
|
||||||
handleCancelNewChat,
|
|
||||||
handleNewChatModalOpen,
|
|
||||||
} = handlers;
|
} = handlers;
|
||||||
|
|
||||||
if (!isReady) return null;
|
if (!isReady) return null;
|
||||||
|
|
||||||
if (pageState.type === "chat") {
|
if (hasSession) {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full flex-col">
|
<div className="flex h-full flex-col">
|
||||||
<Chat
|
<Chat
|
||||||
key={pageState.sessionId ?? "welcome"}
|
|
||||||
className="flex-1"
|
className="flex-1"
|
||||||
urlSessionId={pageState.sessionId}
|
initialPrompt={initialPrompt}
|
||||||
initialPrompt={pageState.initialPrompt}
|
|
||||||
onSessionNotFound={handleSessionNotFound}
|
onSessionNotFound={handleSessionNotFound}
|
||||||
onStreamingChange={handleStreamingChange}
|
onStreamingChange={handleStreamingChange}
|
||||||
/>
|
/>
|
||||||
@@ -48,31 +44,33 @@ export default function CopilotPage() {
|
|||||||
title="Interrupt current chat?"
|
title="Interrupt current chat?"
|
||||||
styling={{ maxWidth: 300, width: "100%" }}
|
styling={{ maxWidth: 300, width: "100%" }}
|
||||||
controlled={{
|
controlled={{
|
||||||
isOpen: isNewChatModalOpen,
|
isOpen: isInterruptModalOpen,
|
||||||
set: handleNewChatModalOpen,
|
set: (open) => {
|
||||||
|
if (!open) cancelInterrupt();
|
||||||
|
},
|
||||||
}}
|
}}
|
||||||
onClose={handleCancelNewChat}
|
onClose={cancelInterrupt}
|
||||||
>
|
>
|
||||||
<Dialog.Content>
|
<Dialog.Content>
|
||||||
<div className="flex flex-col gap-4">
|
<div className="flex flex-col gap-4">
|
||||||
<Text variant="body">
|
<Text variant="body">
|
||||||
The current chat response will be interrupted. Are you sure you
|
The current chat response will be interrupted. Are you sure you
|
||||||
want to start a new chat?
|
want to continue?
|
||||||
</Text>
|
</Text>
|
||||||
<Dialog.Footer>
|
<Dialog.Footer>
|
||||||
<Button
|
<Button
|
||||||
type="button"
|
type="button"
|
||||||
variant="outline"
|
variant="outline"
|
||||||
onClick={handleCancelNewChat}
|
onClick={cancelInterrupt}
|
||||||
>
|
>
|
||||||
Cancel
|
Cancel
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
type="button"
|
type="button"
|
||||||
variant="primary"
|
variant="primary"
|
||||||
onClick={confirmNewChat}
|
onClick={confirmInterrupt}
|
||||||
>
|
>
|
||||||
Start new chat
|
Continue
|
||||||
</Button>
|
</Button>
|
||||||
</Dialog.Footer>
|
</Dialog.Footer>
|
||||||
</div>
|
</div>
|
||||||
@@ -82,19 +80,6 @@ export default function CopilotPage() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pageState.type === "newChat" || pageState.type === "creating") {
|
|
||||||
return (
|
|
||||||
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
|
||||||
<div className="flex flex-col items-center gap-4">
|
|
||||||
<ChatLoader />
|
|
||||||
<Text variant="body" className="text-zinc-500">
|
|
||||||
Loading your chats...
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
||||||
<div className="w-full text-center">
|
<div className="w-full text-center">
|
||||||
|
|||||||
@@ -10,64 +10,15 @@ import {
|
|||||||
type FlagValues,
|
type FlagValues,
|
||||||
useGetFlag,
|
useGetFlag,
|
||||||
} from "@/services/feature-flags/use-get-flag";
|
} from "@/services/feature-flags/use-get-flag";
|
||||||
|
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { useEffect, useReducer } from "react";
|
import { useEffect } from "react";
|
||||||
import { useCopilotStore } from "./copilot-page-store";
|
import { useCopilotStore } from "./copilot-page-store";
|
||||||
import { getGreetingName, getQuickActions, type PageState } from "./helpers";
|
import { getGreetingName, getQuickActions } from "./helpers";
|
||||||
import { useCopilotURLState } from "./useCopilotURLState";
|
import { useCopilotSessionId } from "./useCopilotSessionId";
|
||||||
|
|
||||||
type CopilotState = {
|
|
||||||
pageState: PageState;
|
|
||||||
initialPrompts: Record<string, string>;
|
|
||||||
previousSessionId: string | null;
|
|
||||||
};
|
|
||||||
|
|
||||||
type CopilotAction =
|
|
||||||
| { type: "setPageState"; pageState: PageState }
|
|
||||||
| { type: "setInitialPrompt"; sessionId: string; prompt: string }
|
|
||||||
| { type: "setPreviousSessionId"; sessionId: string | null };
|
|
||||||
|
|
||||||
function isSamePageState(next: PageState, current: PageState) {
|
|
||||||
if (next.type !== current.type) return false;
|
|
||||||
if (next.type === "creating" && current.type === "creating") {
|
|
||||||
return next.prompt === current.prompt;
|
|
||||||
}
|
|
||||||
if (next.type === "chat" && current.type === "chat") {
|
|
||||||
return (
|
|
||||||
next.sessionId === current.sessionId &&
|
|
||||||
next.initialPrompt === current.initialPrompt
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
function copilotReducer(
|
|
||||||
state: CopilotState,
|
|
||||||
action: CopilotAction,
|
|
||||||
): CopilotState {
|
|
||||||
if (action.type === "setPageState") {
|
|
||||||
if (isSamePageState(action.pageState, state.pageState)) return state;
|
|
||||||
return { ...state, pageState: action.pageState };
|
|
||||||
}
|
|
||||||
if (action.type === "setInitialPrompt") {
|
|
||||||
if (state.initialPrompts[action.sessionId] === action.prompt) return state;
|
|
||||||
return {
|
|
||||||
...state,
|
|
||||||
initialPrompts: {
|
|
||||||
...state.initialPrompts,
|
|
||||||
[action.sessionId]: action.prompt,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
if (action.type === "setPreviousSessionId") {
|
|
||||||
if (state.previousSessionId === action.sessionId) return state;
|
|
||||||
return { ...state, previousSessionId: action.sessionId };
|
|
||||||
}
|
|
||||||
return state;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useCopilotPage() {
|
export function useCopilotPage() {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
@@ -75,9 +26,10 @@ export function useCopilotPage() {
|
|||||||
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
||||||
const { toast } = useToast();
|
const { toast } = useToast();
|
||||||
|
|
||||||
const isNewChatModalOpen = useCopilotStore((s) => s.isNewChatModalOpen);
|
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
|
||||||
const setIsStreaming = useCopilotStore((s) => s.setIsStreaming);
|
const setIsStreaming = useCopilotStore((s) => s.setIsStreaming);
|
||||||
const cancelNewChat = useCopilotStore((s) => s.cancelNewChat);
|
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
||||||
|
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
||||||
|
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
const flags = useFlags<FlagValues>();
|
const flags = useFlags<FlagValues>();
|
||||||
@@ -88,72 +40,27 @@ export function useCopilotPage() {
|
|||||||
const isFlagReady =
|
const isFlagReady =
|
||||||
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
||||||
|
|
||||||
const [state, dispatch] = useReducer(copilotReducer, {
|
|
||||||
pageState: { type: "welcome" },
|
|
||||||
initialPrompts: {},
|
|
||||||
previousSessionId: null,
|
|
||||||
});
|
|
||||||
|
|
||||||
const greetingName = getGreetingName(user);
|
const greetingName = getGreetingName(user);
|
||||||
const quickActions = getQuickActions();
|
const quickActions = getQuickActions();
|
||||||
|
|
||||||
function setPageState(pageState: PageState) {
|
const hasSession = Boolean(urlSessionId);
|
||||||
dispatch({ type: "setPageState", pageState });
|
const initialPrompt = urlSessionId
|
||||||
}
|
? getInitialPrompt(urlSessionId)
|
||||||
|
: undefined;
|
||||||
|
|
||||||
function setInitialPrompt(sessionId: string, prompt: string) {
|
useEffect(() => {
|
||||||
dispatch({ type: "setInitialPrompt", sessionId, prompt });
|
if (!isFlagReady) return;
|
||||||
}
|
if (isChatEnabled === false) {
|
||||||
|
router.replace(homepageRoute);
|
||||||
function setPreviousSessionId(sessionId: string | null) {
|
}
|
||||||
dispatch({ type: "setPreviousSessionId", sessionId });
|
}, [homepageRoute, isChatEnabled, isFlagReady, router]);
|
||||||
}
|
|
||||||
|
|
||||||
const { setUrlSessionId } = useCopilotURLState({
|
|
||||||
pageState: state.pageState,
|
|
||||||
initialPrompts: state.initialPrompts,
|
|
||||||
previousSessionId: state.previousSessionId,
|
|
||||||
setPageState,
|
|
||||||
setInitialPrompt,
|
|
||||||
setPreviousSessionId,
|
|
||||||
});
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function transitionNewChatToWelcome() {
|
|
||||||
if (state.pageState.type === "newChat") {
|
|
||||||
function setWelcomeState() {
|
|
||||||
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
|
|
||||||
}
|
|
||||||
|
|
||||||
const timer = setTimeout(setWelcomeState, 300);
|
|
||||||
|
|
||||||
return function cleanup() {
|
|
||||||
clearTimeout(timer);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[state.pageState.type],
|
|
||||||
);
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function ensureAccess() {
|
|
||||||
if (!isFlagReady) return;
|
|
||||||
if (isChatEnabled === false) {
|
|
||||||
router.replace(homepageRoute);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[homepageRoute, isChatEnabled, isFlagReady, router],
|
|
||||||
);
|
|
||||||
|
|
||||||
async function startChatWithPrompt(prompt: string) {
|
async function startChatWithPrompt(prompt: string) {
|
||||||
if (!prompt?.trim()) return;
|
if (!prompt?.trim()) return;
|
||||||
if (state.pageState.type === "creating") return;
|
if (isCreating) return;
|
||||||
|
|
||||||
const trimmedPrompt = prompt.trim();
|
const trimmedPrompt = prompt.trim();
|
||||||
dispatch({
|
setIsCreating(true);
|
||||||
type: "setPageState",
|
|
||||||
pageState: { type: "creating", prompt: trimmedPrompt },
|
|
||||||
});
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const sessionResponse = await postV2CreateSession({
|
const sessionResponse = await postV2CreateSession({
|
||||||
@@ -165,27 +72,19 @@ export function useCopilotPage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const sessionId = sessionResponse.data.id;
|
const sessionId = sessionResponse.data.id;
|
||||||
|
setInitialPrompt(sessionId, trimmedPrompt);
|
||||||
dispatch({
|
|
||||||
type: "setInitialPrompt",
|
|
||||||
sessionId,
|
|
||||||
prompt: trimmedPrompt,
|
|
||||||
});
|
|
||||||
|
|
||||||
await queryClient.invalidateQueries({
|
await queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
});
|
});
|
||||||
|
|
||||||
await setUrlSessionId(sessionId, { shallow: false });
|
await setUrlSessionId(sessionId, { shallow: true });
|
||||||
dispatch({
|
|
||||||
type: "setPageState",
|
|
||||||
pageState: { type: "chat", sessionId, initialPrompt: trimmedPrompt },
|
|
||||||
});
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("[CopilotPage] Failed to start chat:", error);
|
console.error("[CopilotPage] Failed to start chat:", error);
|
||||||
toast({ title: "Failed to start chat", variant: "destructive" });
|
toast({ title: "Failed to start chat", variant: "destructive" });
|
||||||
Sentry.captureException(error);
|
Sentry.captureException(error);
|
||||||
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
|
} finally {
|
||||||
|
setIsCreating(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,21 +100,13 @@ export function useCopilotPage() {
|
|||||||
setIsStreaming(isStreamingValue);
|
setIsStreaming(isStreamingValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleCancelNewChat() {
|
|
||||||
cancelNewChat();
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleNewChatModalOpen(isOpen: boolean) {
|
|
||||||
if (!isOpen) cancelNewChat();
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
state: {
|
state: {
|
||||||
greetingName,
|
greetingName,
|
||||||
quickActions,
|
quickActions,
|
||||||
isLoading: isUserLoading,
|
isLoading: isUserLoading,
|
||||||
pageState: state.pageState,
|
hasSession,
|
||||||
isNewChatModalOpen,
|
initialPrompt,
|
||||||
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
|
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
|
||||||
},
|
},
|
||||||
handlers: {
|
handlers: {
|
||||||
@@ -223,8 +114,32 @@ export function useCopilotPage() {
|
|||||||
startChatWithPrompt,
|
startChatWithPrompt,
|
||||||
handleSessionNotFound,
|
handleSessionNotFound,
|
||||||
handleStreamingChange,
|
handleStreamingChange,
|
||||||
handleCancelNewChat,
|
|
||||||
handleNewChatModalOpen,
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getInitialPrompt(sessionId: string): string | undefined {
|
||||||
|
try {
|
||||||
|
const prompts = JSON.parse(
|
||||||
|
sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}",
|
||||||
|
);
|
||||||
|
return prompts[sessionId];
|
||||||
|
} catch {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function setInitialPrompt(sessionId: string, prompt: string): void {
|
||||||
|
try {
|
||||||
|
const prompts = JSON.parse(
|
||||||
|
sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}",
|
||||||
|
);
|
||||||
|
prompts[sessionId] = prompt;
|
||||||
|
sessionStorage.set(
|
||||||
|
SessionKey.CHAT_INITIAL_PROMPTS,
|
||||||
|
JSON.stringify(prompts),
|
||||||
|
);
|
||||||
|
} catch {
|
||||||
|
// Ignore storage errors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,10 @@
|
|||||||
|
import { parseAsString, useQueryState } from "nuqs";
|
||||||
|
|
||||||
|
export function useCopilotSessionId() {
|
||||||
|
const [urlSessionId, setUrlSessionId] = useQueryState(
|
||||||
|
"sessionId",
|
||||||
|
parseAsString,
|
||||||
|
);
|
||||||
|
|
||||||
|
return { urlSessionId, setUrlSessionId };
|
||||||
|
}
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
import { parseAsString, useQueryState } from "nuqs";
|
|
||||||
import { useLayoutEffect } from "react";
|
|
||||||
import {
|
|
||||||
getInitialPromptFromState,
|
|
||||||
type PageState,
|
|
||||||
shouldResetToWelcome,
|
|
||||||
} from "./helpers";
|
|
||||||
|
|
||||||
interface UseCopilotUrlStateArgs {
|
|
||||||
pageState: PageState;
|
|
||||||
initialPrompts: Record<string, string>;
|
|
||||||
previousSessionId: string | null;
|
|
||||||
setPageState: (pageState: PageState) => void;
|
|
||||||
setInitialPrompt: (sessionId: string, prompt: string) => void;
|
|
||||||
setPreviousSessionId: (sessionId: string | null) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useCopilotURLState({
|
|
||||||
pageState,
|
|
||||||
initialPrompts,
|
|
||||||
previousSessionId,
|
|
||||||
setPageState,
|
|
||||||
setInitialPrompt,
|
|
||||||
setPreviousSessionId,
|
|
||||||
}: UseCopilotUrlStateArgs) {
|
|
||||||
const [urlSessionId, setUrlSessionId] = useQueryState(
|
|
||||||
"sessionId",
|
|
||||||
parseAsString,
|
|
||||||
);
|
|
||||||
|
|
||||||
function syncSessionFromUrl() {
|
|
||||||
if (urlSessionId) {
|
|
||||||
if (pageState.type === "chat" && pageState.sessionId === urlSessionId) {
|
|
||||||
setPreviousSessionId(urlSessionId);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const storedInitialPrompt = initialPrompts[urlSessionId];
|
|
||||||
const currentInitialPrompt = getInitialPromptFromState(
|
|
||||||
pageState,
|
|
||||||
storedInitialPrompt,
|
|
||||||
);
|
|
||||||
|
|
||||||
if (currentInitialPrompt) {
|
|
||||||
setInitialPrompt(urlSessionId, currentInitialPrompt);
|
|
||||||
}
|
|
||||||
|
|
||||||
setPageState({
|
|
||||||
type: "chat",
|
|
||||||
sessionId: urlSessionId,
|
|
||||||
initialPrompt: currentInitialPrompt,
|
|
||||||
});
|
|
||||||
setPreviousSessionId(urlSessionId);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const wasInChat = previousSessionId !== null && pageState.type === "chat";
|
|
||||||
setPreviousSessionId(null);
|
|
||||||
if (wasInChat) {
|
|
||||||
setPageState({ type: "newChat" });
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (shouldResetToWelcome(pageState)) {
|
|
||||||
setPageState({ type: "welcome" });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
useLayoutEffect(syncSessionFromUrl, [
|
|
||||||
urlSessionId,
|
|
||||||
pageState.type,
|
|
||||||
previousSessionId,
|
|
||||||
initialPrompts,
|
|
||||||
]);
|
|
||||||
|
|
||||||
return {
|
|
||||||
urlSessionId,
|
|
||||||
setUrlSessionId,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,16 +1,17 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
|
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
|
||||||
|
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
|
||||||
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { useEffect, useRef } from "react";
|
import { useEffect, useRef } from "react";
|
||||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||||
import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState";
|
import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState";
|
||||||
import { ChatLoader } from "./components/ChatLoader/ChatLoader";
|
|
||||||
import { useChat } from "./useChat";
|
import { useChat } from "./useChat";
|
||||||
|
|
||||||
export interface ChatProps {
|
export interface ChatProps {
|
||||||
className?: string;
|
className?: string;
|
||||||
urlSessionId?: string | null;
|
|
||||||
initialPrompt?: string;
|
initialPrompt?: string;
|
||||||
onSessionNotFound?: () => void;
|
onSessionNotFound?: () => void;
|
||||||
onStreamingChange?: (isStreaming: boolean) => void;
|
onStreamingChange?: (isStreaming: boolean) => void;
|
||||||
@@ -18,12 +19,13 @@ export interface ChatProps {
|
|||||||
|
|
||||||
export function Chat({
|
export function Chat({
|
||||||
className,
|
className,
|
||||||
urlSessionId,
|
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
onSessionNotFound,
|
onSessionNotFound,
|
||||||
onStreamingChange,
|
onStreamingChange,
|
||||||
}: ChatProps) {
|
}: ChatProps) {
|
||||||
|
const { urlSessionId } = useCopilotSessionId();
|
||||||
const hasHandledNotFoundRef = useRef(false);
|
const hasHandledNotFoundRef = useRef(false);
|
||||||
|
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
|
||||||
const {
|
const {
|
||||||
messages,
|
messages,
|
||||||
isLoading,
|
isLoading,
|
||||||
@@ -33,49 +35,59 @@ export function Chat({
|
|||||||
sessionId,
|
sessionId,
|
||||||
createSession,
|
createSession,
|
||||||
showLoader,
|
showLoader,
|
||||||
|
startPollingForOperation,
|
||||||
} = useChat({ urlSessionId });
|
} = useChat({ urlSessionId });
|
||||||
|
|
||||||
useEffect(
|
useEffect(() => {
|
||||||
function handleMissingSession() {
|
if (!onSessionNotFound) return;
|
||||||
if (!onSessionNotFound) return;
|
if (!urlSessionId) return;
|
||||||
if (!urlSessionId) return;
|
if (!isSessionNotFound || isLoading || isCreating) return;
|
||||||
if (!isSessionNotFound || isLoading || isCreating) return;
|
if (hasHandledNotFoundRef.current) return;
|
||||||
if (hasHandledNotFoundRef.current) return;
|
hasHandledNotFoundRef.current = true;
|
||||||
hasHandledNotFoundRef.current = true;
|
onSessionNotFound();
|
||||||
onSessionNotFound();
|
}, [
|
||||||
},
|
onSessionNotFound,
|
||||||
[onSessionNotFound, urlSessionId, isSessionNotFound, isLoading, isCreating],
|
urlSessionId,
|
||||||
);
|
isSessionNotFound,
|
||||||
|
isLoading,
|
||||||
|
isCreating,
|
||||||
|
]);
|
||||||
|
|
||||||
|
const shouldShowLoader =
|
||||||
|
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={cn("flex h-full flex-col", className)}>
|
<div className={cn("flex h-full flex-col", className)}>
|
||||||
{/* Main Content */}
|
{/* Main Content */}
|
||||||
<main className="flex min-h-0 w-full flex-1 flex-col overflow-hidden bg-[#f8f8f9]">
|
<main className="flex min-h-0 w-full flex-1 flex-col overflow-hidden bg-[#f8f8f9]">
|
||||||
{/* Loading State */}
|
{/* Loading State */}
|
||||||
{showLoader && (isLoading || isCreating) && (
|
{shouldShowLoader && (
|
||||||
<div className="flex flex-1 items-center justify-center">
|
<div className="flex flex-1 items-center justify-center">
|
||||||
<div className="flex flex-col items-center gap-4">
|
<div className="flex flex-col items-center gap-3">
|
||||||
<ChatLoader />
|
<LoadingSpinner size="large" className="text-neutral-400" />
|
||||||
<Text variant="body" className="text-zinc-500">
|
<Text variant="body" className="text-zinc-500">
|
||||||
Loading your chats...
|
{isSwitchingSession
|
||||||
|
? "Switching chat..."
|
||||||
|
: "Loading your chat..."}
|
||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Error State */}
|
{/* Error State */}
|
||||||
{error && !isLoading && (
|
{error && !isLoading && !isSwitchingSession && (
|
||||||
<ChatErrorState error={error} onRetry={createSession} />
|
<ChatErrorState error={error} onRetry={createSession} />
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Session Content */}
|
{/* Session Content */}
|
||||||
{sessionId && !isLoading && !error && (
|
{sessionId && !isLoading && !error && !isSwitchingSession && (
|
||||||
<ChatContainer
|
<ChatContainer
|
||||||
sessionId={sessionId}
|
sessionId={sessionId}
|
||||||
initialMessages={messages}
|
initialMessages={messages}
|
||||||
initialPrompt={initialPrompt}
|
initialPrompt={initialPrompt}
|
||||||
className="flex-1"
|
className="flex-1"
|
||||||
onStreamingChange={onStreamingChange}
|
onStreamingChange={onStreamingChange}
|
||||||
|
onOperationStarted={startPollingForOperation}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</main>
|
</main>
|
||||||
|
|||||||
@@ -58,39 +58,17 @@ function notifyStreamComplete(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function cleanupCompletedStreams(completedStreams: Map<string, StreamResult>) {
|
function cleanupExpiredStreams(
|
||||||
|
completedStreams: Map<string, StreamResult>,
|
||||||
|
): Map<string, StreamResult> {
|
||||||
const now = Date.now();
|
const now = Date.now();
|
||||||
for (const [sessionId, result] of completedStreams) {
|
const cleaned = new Map(completedStreams);
|
||||||
|
for (const [sessionId, result] of cleaned) {
|
||||||
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
|
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
|
||||||
completedStreams.delete(sessionId);
|
cleaned.delete(sessionId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
return cleaned;
|
||||||
|
|
||||||
function moveToCompleted(
|
|
||||||
activeStreams: Map<string, ActiveStream>,
|
|
||||||
completedStreams: Map<string, StreamResult>,
|
|
||||||
streamCompleteCallbacks: Set<StreamCompleteCallback>,
|
|
||||||
sessionId: string,
|
|
||||||
) {
|
|
||||||
const stream = activeStreams.get(sessionId);
|
|
||||||
if (!stream) return;
|
|
||||||
|
|
||||||
const result: StreamResult = {
|
|
||||||
sessionId,
|
|
||||||
status: stream.status,
|
|
||||||
chunks: stream.chunks,
|
|
||||||
completedAt: Date.now(),
|
|
||||||
error: stream.error,
|
|
||||||
};
|
|
||||||
|
|
||||||
completedStreams.set(sessionId, result);
|
|
||||||
activeStreams.delete(sessionId);
|
|
||||||
cleanupCompletedStreams(completedStreams);
|
|
||||||
|
|
||||||
if (stream.status === "completed" || stream.status === "error") {
|
|
||||||
notifyStreamComplete(streamCompleteCallbacks, sessionId);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export const useChatStore = create<ChatStore>((set, get) => ({
|
export const useChatStore = create<ChatStore>((set, get) => ({
|
||||||
@@ -106,17 +84,31 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
context,
|
context,
|
||||||
onChunk,
|
onChunk,
|
||||||
) {
|
) {
|
||||||
const { activeStreams, completedStreams, streamCompleteCallbacks } = get();
|
const state = get();
|
||||||
|
const newActiveStreams = new Map(state.activeStreams);
|
||||||
|
let newCompletedStreams = new Map(state.completedStreams);
|
||||||
|
const callbacks = state.streamCompleteCallbacks;
|
||||||
|
|
||||||
const existingStream = activeStreams.get(sessionId);
|
const existingStream = newActiveStreams.get(sessionId);
|
||||||
if (existingStream) {
|
if (existingStream) {
|
||||||
existingStream.abortController.abort();
|
existingStream.abortController.abort();
|
||||||
moveToCompleted(
|
const normalizedStatus =
|
||||||
activeStreams,
|
existingStream.status === "streaming"
|
||||||
completedStreams,
|
? "completed"
|
||||||
streamCompleteCallbacks,
|
: existingStream.status;
|
||||||
|
const result: StreamResult = {
|
||||||
sessionId,
|
sessionId,
|
||||||
);
|
status: normalizedStatus,
|
||||||
|
chunks: existingStream.chunks,
|
||||||
|
completedAt: Date.now(),
|
||||||
|
error: existingStream.error,
|
||||||
|
};
|
||||||
|
newCompletedStreams.set(sessionId, result);
|
||||||
|
newActiveStreams.delete(sessionId);
|
||||||
|
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||||
|
if (normalizedStatus === "completed" || normalizedStatus === "error") {
|
||||||
|
notifyStreamComplete(callbacks, sessionId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const abortController = new AbortController();
|
const abortController = new AbortController();
|
||||||
@@ -132,36 +124,76 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
onChunkCallbacks: initialCallbacks,
|
onChunkCallbacks: initialCallbacks,
|
||||||
};
|
};
|
||||||
|
|
||||||
activeStreams.set(sessionId, stream);
|
newActiveStreams.set(sessionId, stream);
|
||||||
|
set({
|
||||||
|
activeStreams: newActiveStreams,
|
||||||
|
completedStreams: newCompletedStreams,
|
||||||
|
});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await executeStream(stream, message, isUserMessage, context);
|
await executeStream(stream, message, isUserMessage, context);
|
||||||
} finally {
|
} finally {
|
||||||
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
||||||
if (stream.status !== "streaming") {
|
if (stream.status !== "streaming") {
|
||||||
moveToCompleted(
|
const currentState = get();
|
||||||
activeStreams,
|
const finalActiveStreams = new Map(currentState.activeStreams);
|
||||||
completedStreams,
|
let finalCompletedStreams = new Map(currentState.completedStreams);
|
||||||
streamCompleteCallbacks,
|
|
||||||
sessionId,
|
const storedStream = finalActiveStreams.get(sessionId);
|
||||||
);
|
if (storedStream === stream) {
|
||||||
|
const result: StreamResult = {
|
||||||
|
sessionId,
|
||||||
|
status: stream.status,
|
||||||
|
chunks: stream.chunks,
|
||||||
|
completedAt: Date.now(),
|
||||||
|
error: stream.error,
|
||||||
|
};
|
||||||
|
finalCompletedStreams.set(sessionId, result);
|
||||||
|
finalActiveStreams.delete(sessionId);
|
||||||
|
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
|
||||||
|
set({
|
||||||
|
activeStreams: finalActiveStreams,
|
||||||
|
completedStreams: finalCompletedStreams,
|
||||||
|
});
|
||||||
|
if (stream.status === "completed" || stream.status === "error") {
|
||||||
|
notifyStreamComplete(
|
||||||
|
currentState.streamCompleteCallbacks,
|
||||||
|
sessionId,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
stopStream: function stopStream(sessionId) {
|
stopStream: function stopStream(sessionId) {
|
||||||
const { activeStreams, completedStreams, streamCompleteCallbacks } = get();
|
const state = get();
|
||||||
const stream = activeStreams.get(sessionId);
|
const stream = state.activeStreams.get(sessionId);
|
||||||
if (stream) {
|
if (!stream) return;
|
||||||
stream.abortController.abort();
|
|
||||||
stream.status = "completed";
|
stream.abortController.abort();
|
||||||
moveToCompleted(
|
stream.status = "completed";
|
||||||
activeStreams,
|
|
||||||
completedStreams,
|
const newActiveStreams = new Map(state.activeStreams);
|
||||||
streamCompleteCallbacks,
|
let newCompletedStreams = new Map(state.completedStreams);
|
||||||
sessionId,
|
|
||||||
);
|
const result: StreamResult = {
|
||||||
}
|
sessionId,
|
||||||
|
status: stream.status,
|
||||||
|
chunks: stream.chunks,
|
||||||
|
completedAt: Date.now(),
|
||||||
|
error: stream.error,
|
||||||
|
};
|
||||||
|
newCompletedStreams.set(sessionId, result);
|
||||||
|
newActiveStreams.delete(sessionId);
|
||||||
|
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||||
|
|
||||||
|
set({
|
||||||
|
activeStreams: newActiveStreams,
|
||||||
|
completedStreams: newCompletedStreams,
|
||||||
|
});
|
||||||
|
|
||||||
|
notifyStreamComplete(state.streamCompleteCallbacks, sessionId);
|
||||||
},
|
},
|
||||||
|
|
||||||
subscribeToStream: function subscribeToStream(
|
subscribeToStream: function subscribeToStream(
|
||||||
@@ -169,16 +201,18 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
onChunk,
|
onChunk,
|
||||||
skipReplay = false,
|
skipReplay = false,
|
||||||
) {
|
) {
|
||||||
const { activeStreams } = get();
|
const state = get();
|
||||||
|
const stream = state.activeStreams.get(sessionId);
|
||||||
|
|
||||||
const stream = activeStreams.get(sessionId);
|
|
||||||
if (stream) {
|
if (stream) {
|
||||||
if (!skipReplay) {
|
if (!skipReplay) {
|
||||||
for (const chunk of stream.chunks) {
|
for (const chunk of stream.chunks) {
|
||||||
onChunk(chunk);
|
onChunk(chunk);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stream.onChunkCallbacks.add(onChunk);
|
stream.onChunkCallbacks.add(onChunk);
|
||||||
|
|
||||||
return function unsubscribe() {
|
return function unsubscribe() {
|
||||||
stream.onChunkCallbacks.delete(onChunk);
|
stream.onChunkCallbacks.delete(onChunk);
|
||||||
};
|
};
|
||||||
@@ -204,7 +238,12 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
},
|
},
|
||||||
|
|
||||||
clearCompletedStream: function clearCompletedStream(sessionId) {
|
clearCompletedStream: function clearCompletedStream(sessionId) {
|
||||||
get().completedStreams.delete(sessionId);
|
const state = get();
|
||||||
|
if (!state.completedStreams.has(sessionId)) return;
|
||||||
|
|
||||||
|
const newCompletedStreams = new Map(state.completedStreams);
|
||||||
|
newCompletedStreams.delete(sessionId);
|
||||||
|
set({ completedStreams: newCompletedStreams });
|
||||||
},
|
},
|
||||||
|
|
||||||
isStreaming: function isStreaming(sessionId) {
|
isStreaming: function isStreaming(sessionId) {
|
||||||
@@ -213,11 +252,21 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
},
|
},
|
||||||
|
|
||||||
registerActiveSession: function registerActiveSession(sessionId) {
|
registerActiveSession: function registerActiveSession(sessionId) {
|
||||||
get().activeSessions.add(sessionId);
|
const state = get();
|
||||||
|
if (state.activeSessions.has(sessionId)) return;
|
||||||
|
|
||||||
|
const newActiveSessions = new Set(state.activeSessions);
|
||||||
|
newActiveSessions.add(sessionId);
|
||||||
|
set({ activeSessions: newActiveSessions });
|
||||||
},
|
},
|
||||||
|
|
||||||
unregisterActiveSession: function unregisterActiveSession(sessionId) {
|
unregisterActiveSession: function unregisterActiveSession(sessionId) {
|
||||||
get().activeSessions.delete(sessionId);
|
const state = get();
|
||||||
|
if (!state.activeSessions.has(sessionId)) return;
|
||||||
|
|
||||||
|
const newActiveSessions = new Set(state.activeSessions);
|
||||||
|
newActiveSessions.delete(sessionId);
|
||||||
|
set({ activeSessions: newActiveSessions });
|
||||||
},
|
},
|
||||||
|
|
||||||
isSessionActive: function isSessionActive(sessionId) {
|
isSessionActive: function isSessionActive(sessionId) {
|
||||||
@@ -225,10 +274,16 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
},
|
},
|
||||||
|
|
||||||
onStreamComplete: function onStreamComplete(callback) {
|
onStreamComplete: function onStreamComplete(callback) {
|
||||||
const { streamCompleteCallbacks } = get();
|
const state = get();
|
||||||
streamCompleteCallbacks.add(callback);
|
const newCallbacks = new Set(state.streamCompleteCallbacks);
|
||||||
|
newCallbacks.add(callback);
|
||||||
|
set({ streamCompleteCallbacks: newCallbacks });
|
||||||
|
|
||||||
return function unsubscribe() {
|
return function unsubscribe() {
|
||||||
streamCompleteCallbacks.delete(callback);
|
const currentState = get();
|
||||||
|
const cleanedCallbacks = new Set(currentState.streamCompleteCallbacks);
|
||||||
|
cleanedCallbacks.delete(callback);
|
||||||
|
set({ streamCompleteCallbacks: cleanedCallbacks });
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ export interface ChatContainerProps {
|
|||||||
initialPrompt?: string;
|
initialPrompt?: string;
|
||||||
className?: string;
|
className?: string;
|
||||||
onStreamingChange?: (isStreaming: boolean) => void;
|
onStreamingChange?: (isStreaming: boolean) => void;
|
||||||
|
onOperationStarted?: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ChatContainer({
|
export function ChatContainer({
|
||||||
@@ -24,6 +25,7 @@ export function ChatContainer({
|
|||||||
initialPrompt,
|
initialPrompt,
|
||||||
className,
|
className,
|
||||||
onStreamingChange,
|
onStreamingChange,
|
||||||
|
onOperationStarted,
|
||||||
}: ChatContainerProps) {
|
}: ChatContainerProps) {
|
||||||
const {
|
const {
|
||||||
messages,
|
messages,
|
||||||
@@ -38,6 +40,7 @@ export function ChatContainer({
|
|||||||
sessionId,
|
sessionId,
|
||||||
initialMessages,
|
initialMessages,
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
|
onOperationStarted,
|
||||||
});
|
});
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ export interface HandlerDependencies {
|
|||||||
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
||||||
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
|
onOperationStarted?: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
||||||
@@ -48,6 +49,15 @@ export function handleTextEnded(
|
|||||||
const completedText = deps.streamingChunksRef.current.join("");
|
const completedText = deps.streamingChunksRef.current.join("");
|
||||||
if (completedText.trim()) {
|
if (completedText.trim()) {
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => {
|
||||||
|
// Check if this exact message already exists to prevent duplicates
|
||||||
|
const exists = prev.some(
|
||||||
|
(msg) =>
|
||||||
|
msg.type === "message" &&
|
||||||
|
msg.role === "assistant" &&
|
||||||
|
msg.content === completedText,
|
||||||
|
);
|
||||||
|
if (exists) return prev;
|
||||||
|
|
||||||
const assistantMessage: ChatMessageData = {
|
const assistantMessage: ChatMessageData = {
|
||||||
type: "message",
|
type: "message",
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
@@ -154,6 +164,11 @@ export function handleToolResponse(
|
|||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// Trigger polling when operation_started is received
|
||||||
|
if (responseMessage.type === "operation_started") {
|
||||||
|
deps.onOperationStarted?.();
|
||||||
|
}
|
||||||
|
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => {
|
||||||
const toolCallIndex = prev.findIndex(
|
const toolCallIndex = prev.findIndex(
|
||||||
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
||||||
@@ -203,13 +218,24 @@ export function handleStreamEnd(
|
|||||||
]);
|
]);
|
||||||
}
|
}
|
||||||
if (completedContent.trim()) {
|
if (completedContent.trim()) {
|
||||||
const assistantMessage: ChatMessageData = {
|
deps.setMessages((prev) => {
|
||||||
type: "message",
|
// Check if this exact message already exists to prevent duplicates
|
||||||
role: "assistant",
|
const exists = prev.some(
|
||||||
content: completedContent,
|
(msg) =>
|
||||||
timestamp: new Date(),
|
msg.type === "message" &&
|
||||||
};
|
msg.role === "assistant" &&
|
||||||
deps.setMessages((prev) => [...prev, assistantMessage]);
|
msg.content === completedContent,
|
||||||
|
);
|
||||||
|
if (exists) return prev;
|
||||||
|
|
||||||
|
const assistantMessage: ChatMessageData = {
|
||||||
|
type: "message",
|
||||||
|
role: "assistant",
|
||||||
|
content: completedContent,
|
||||||
|
timestamp: new Date(),
|
||||||
|
};
|
||||||
|
return [...prev, assistantMessage];
|
||||||
|
});
|
||||||
}
|
}
|
||||||
deps.setStreamingChunks([]);
|
deps.setStreamingChunks([]);
|
||||||
deps.streamingChunksRef.current = [];
|
deps.streamingChunksRef.current = [];
|
||||||
|
|||||||
@@ -304,6 +304,7 @@ export function parseToolResponse(
|
|||||||
if (isAgentArray(agentsData)) {
|
if (isAgentArray(agentsData)) {
|
||||||
return {
|
return {
|
||||||
type: "agent_carousel",
|
type: "agent_carousel",
|
||||||
|
toolId,
|
||||||
toolName: "agent_carousel",
|
toolName: "agent_carousel",
|
||||||
agents: agentsData,
|
agents: agentsData,
|
||||||
totalCount: parsedResult.total_count as number | undefined,
|
totalCount: parsedResult.total_count as number | undefined,
|
||||||
@@ -316,6 +317,7 @@ export function parseToolResponse(
|
|||||||
if (responseType === "execution_started") {
|
if (responseType === "execution_started") {
|
||||||
return {
|
return {
|
||||||
type: "execution_started",
|
type: "execution_started",
|
||||||
|
toolId,
|
||||||
toolName: "execution_started",
|
toolName: "execution_started",
|
||||||
executionId: (parsedResult.execution_id as string) || "",
|
executionId: (parsedResult.execution_id as string) || "",
|
||||||
agentName: (parsedResult.graph_name as string) || undefined,
|
agentName: (parsedResult.graph_name as string) || undefined,
|
||||||
@@ -341,6 +343,41 @@ export function parseToolResponse(
|
|||||||
timestamp: timestamp || new Date(),
|
timestamp: timestamp || new Date(),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
if (responseType === "operation_started") {
|
||||||
|
return {
|
||||||
|
type: "operation_started",
|
||||||
|
toolName: (parsedResult.tool_name as string) || toolName,
|
||||||
|
toolId,
|
||||||
|
operationId: (parsedResult.operation_id as string) || "",
|
||||||
|
message:
|
||||||
|
(parsedResult.message as string) ||
|
||||||
|
"Operation started. You can close this tab.",
|
||||||
|
timestamp: timestamp || new Date(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if (responseType === "operation_pending") {
|
||||||
|
return {
|
||||||
|
type: "operation_pending",
|
||||||
|
toolName: (parsedResult.tool_name as string) || toolName,
|
||||||
|
toolId,
|
||||||
|
operationId: (parsedResult.operation_id as string) || "",
|
||||||
|
message:
|
||||||
|
(parsedResult.message as string) ||
|
||||||
|
"Operation in progress. Please wait...",
|
||||||
|
timestamp: timestamp || new Date(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if (responseType === "operation_in_progress") {
|
||||||
|
return {
|
||||||
|
type: "operation_in_progress",
|
||||||
|
toolName: (parsedResult.tool_name as string) || toolName,
|
||||||
|
toolCallId: (parsedResult.tool_call_id as string) || toolId,
|
||||||
|
message:
|
||||||
|
(parsedResult.message as string) ||
|
||||||
|
"Operation already in progress. Please wait...",
|
||||||
|
timestamp: timestamp || new Date(),
|
||||||
|
};
|
||||||
|
}
|
||||||
if (responseType === "need_login") {
|
if (responseType === "need_login") {
|
||||||
return {
|
return {
|
||||||
type: "login_needed",
|
type: "login_needed",
|
||||||
|
|||||||
@@ -14,16 +14,40 @@ import {
|
|||||||
processInitialMessages,
|
processInitialMessages,
|
||||||
} from "./helpers";
|
} from "./helpers";
|
||||||
|
|
||||||
|
// Helper to generate deduplication key for a message
|
||||||
|
function getMessageKey(msg: ChatMessageData): string {
|
||||||
|
if (msg.type === "message") {
|
||||||
|
// Don't include timestamp - dedupe by role + content only
|
||||||
|
// This handles the case where local and server timestamps differ
|
||||||
|
// Server messages are authoritative, so duplicates from local state are filtered
|
||||||
|
return `msg:${msg.role}:${msg.content}`;
|
||||||
|
} else if (msg.type === "tool_call") {
|
||||||
|
return `toolcall:${msg.toolId}`;
|
||||||
|
} else if (msg.type === "tool_response") {
|
||||||
|
return `toolresponse:${(msg as any).toolId}`;
|
||||||
|
} else if (
|
||||||
|
msg.type === "operation_started" ||
|
||||||
|
msg.type === "operation_pending" ||
|
||||||
|
msg.type === "operation_in_progress"
|
||||||
|
) {
|
||||||
|
return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`;
|
||||||
|
} else {
|
||||||
|
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
interface Args {
|
interface Args {
|
||||||
sessionId: string | null;
|
sessionId: string | null;
|
||||||
initialMessages: SessionDetailResponse["messages"];
|
initialMessages: SessionDetailResponse["messages"];
|
||||||
initialPrompt?: string;
|
initialPrompt?: string;
|
||||||
|
onOperationStarted?: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useChatContainer({
|
export function useChatContainer({
|
||||||
sessionId,
|
sessionId,
|
||||||
initialMessages,
|
initialMessages,
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
|
onOperationStarted,
|
||||||
}: Args) {
|
}: Args) {
|
||||||
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
||||||
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
||||||
@@ -73,20 +97,102 @@ export function useChatContainer({
|
|||||||
setIsRegionBlockedModalOpen,
|
setIsRegionBlockedModalOpen,
|
||||||
sessionId,
|
sessionId,
|
||||||
setIsStreamingInitiated,
|
setIsStreamingInitiated,
|
||||||
|
onOperationStarted,
|
||||||
});
|
});
|
||||||
|
|
||||||
setIsStreamingInitiated(true);
|
setIsStreamingInitiated(true);
|
||||||
const skipReplay = initialMessages.length > 0;
|
const skipReplay = initialMessages.length > 0;
|
||||||
return subscribeToStream(sessionId, dispatcher, skipReplay);
|
return subscribeToStream(sessionId, dispatcher, skipReplay);
|
||||||
},
|
},
|
||||||
[sessionId, stopStreaming, activeStreams, subscribeToStream],
|
[
|
||||||
|
sessionId,
|
||||||
|
stopStreaming,
|
||||||
|
activeStreams,
|
||||||
|
subscribeToStream,
|
||||||
|
onOperationStarted,
|
||||||
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
const allMessages = useMemo(
|
// Collect toolIds from completed tool results in initialMessages
|
||||||
() => [...processInitialMessages(initialMessages), ...messages],
|
// Used to filter out operation messages when their results arrive
|
||||||
[initialMessages, messages],
|
const completedToolIds = useMemo(() => {
|
||||||
|
const processedInitial = processInitialMessages(initialMessages);
|
||||||
|
const ids = new Set<string>();
|
||||||
|
for (const msg of processedInitial) {
|
||||||
|
if (
|
||||||
|
msg.type === "tool_response" ||
|
||||||
|
msg.type === "agent_carousel" ||
|
||||||
|
msg.type === "execution_started"
|
||||||
|
) {
|
||||||
|
const toolId = (msg as any).toolId;
|
||||||
|
if (toolId) {
|
||||||
|
ids.add(toolId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ids;
|
||||||
|
}, [initialMessages]);
|
||||||
|
|
||||||
|
// Clean up local operation messages when their completed results arrive from polling
|
||||||
|
// This effect runs when completedToolIds changes (i.e., when polling brings new results)
|
||||||
|
useEffect(
|
||||||
|
function cleanupCompletedOperations() {
|
||||||
|
if (completedToolIds.size === 0) return;
|
||||||
|
|
||||||
|
setMessages((prev) => {
|
||||||
|
const filtered = prev.filter((msg) => {
|
||||||
|
if (
|
||||||
|
msg.type === "operation_started" ||
|
||||||
|
msg.type === "operation_pending" ||
|
||||||
|
msg.type === "operation_in_progress"
|
||||||
|
) {
|
||||||
|
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
||||||
|
if (toolId && completedToolIds.has(toolId)) {
|
||||||
|
return false; // Remove - operation completed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
// Only update state if something was actually filtered
|
||||||
|
return filtered.length === prev.length ? prev : filtered;
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[completedToolIds],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Combine initial messages from backend with local streaming messages,
|
||||||
|
// Server messages maintain correct order; only append truly new local messages
|
||||||
|
const allMessages = useMemo(() => {
|
||||||
|
const processedInitial = processInitialMessages(initialMessages);
|
||||||
|
|
||||||
|
// Build a set of keys from server messages for deduplication
|
||||||
|
const serverKeys = new Set<string>();
|
||||||
|
for (const msg of processedInitial) {
|
||||||
|
serverKeys.add(getMessageKey(msg));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter local messages: remove duplicates and completed operation messages
|
||||||
|
const newLocalMessages = messages.filter((msg) => {
|
||||||
|
// Remove operation messages for completed tools
|
||||||
|
if (
|
||||||
|
msg.type === "operation_started" ||
|
||||||
|
msg.type === "operation_pending" ||
|
||||||
|
msg.type === "operation_in_progress"
|
||||||
|
) {
|
||||||
|
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
||||||
|
if (toolId && completedToolIds.has(toolId)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Remove messages that already exist in server data
|
||||||
|
const key = getMessageKey(msg);
|
||||||
|
return !serverKeys.has(key);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Server messages first (correct order), then new local messages
|
||||||
|
return [...processedInitial, ...newLocalMessages];
|
||||||
|
}, [initialMessages, messages, completedToolIds]);
|
||||||
|
|
||||||
async function sendMessage(
|
async function sendMessage(
|
||||||
content: string,
|
content: string,
|
||||||
isUserMessage: boolean = true,
|
isUserMessage: boolean = true,
|
||||||
@@ -118,6 +224,7 @@ export function useChatContainer({
|
|||||||
setIsRegionBlockedModalOpen,
|
setIsRegionBlockedModalOpen,
|
||||||
sessionId,
|
sessionId,
|
||||||
setIsStreamingInitiated,
|
setIsStreamingInitiated,
|
||||||
|
onOperationStarted,
|
||||||
});
|
});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import { AuthPromptWidget } from "../AuthPromptWidget/AuthPromptWidget";
|
|||||||
import { ChatCredentialsSetup } from "../ChatCredentialsSetup/ChatCredentialsSetup";
|
import { ChatCredentialsSetup } from "../ChatCredentialsSetup/ChatCredentialsSetup";
|
||||||
import { ClarificationQuestionsWidget } from "../ClarificationQuestionsWidget/ClarificationQuestionsWidget";
|
import { ClarificationQuestionsWidget } from "../ClarificationQuestionsWidget/ClarificationQuestionsWidget";
|
||||||
import { ExecutionStartedMessage } from "../ExecutionStartedMessage/ExecutionStartedMessage";
|
import { ExecutionStartedMessage } from "../ExecutionStartedMessage/ExecutionStartedMessage";
|
||||||
|
import { PendingOperationWidget } from "../PendingOperationWidget/PendingOperationWidget";
|
||||||
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
||||||
import { NoResultsMessage } from "../NoResultsMessage/NoResultsMessage";
|
import { NoResultsMessage } from "../NoResultsMessage/NoResultsMessage";
|
||||||
import { ToolCallMessage } from "../ToolCallMessage/ToolCallMessage";
|
import { ToolCallMessage } from "../ToolCallMessage/ToolCallMessage";
|
||||||
@@ -71,6 +72,9 @@ export function ChatMessage({
|
|||||||
isLoginNeeded,
|
isLoginNeeded,
|
||||||
isCredentialsNeeded,
|
isCredentialsNeeded,
|
||||||
isClarificationNeeded,
|
isClarificationNeeded,
|
||||||
|
isOperationStarted,
|
||||||
|
isOperationPending,
|
||||||
|
isOperationInProgress,
|
||||||
} = useChatMessage(message);
|
} = useChatMessage(message);
|
||||||
const displayContent = getDisplayContent(message, isUser);
|
const displayContent = getDisplayContent(message, isUser);
|
||||||
|
|
||||||
@@ -126,10 +130,6 @@ export function ChatMessage({
|
|||||||
[displayContent, message],
|
[displayContent, message],
|
||||||
);
|
);
|
||||||
|
|
||||||
function isLongResponse(content: string): boolean {
|
|
||||||
return content.split("\n").length > 5;
|
|
||||||
}
|
|
||||||
|
|
||||||
const handleTryAgain = useCallback(() => {
|
const handleTryAgain = useCallback(() => {
|
||||||
if (message.type !== "message" || !onSendMessage) return;
|
if (message.type !== "message" || !onSendMessage) return;
|
||||||
onSendMessage(message.content, message.role === "user");
|
onSendMessage(message.content, message.role === "user");
|
||||||
@@ -294,6 +294,42 @@ export function ChatMessage({
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Render operation_started messages (long-running background operations)
|
||||||
|
if (isOperationStarted && message.type === "operation_started") {
|
||||||
|
return (
|
||||||
|
<PendingOperationWidget
|
||||||
|
status="started"
|
||||||
|
message={message.message}
|
||||||
|
toolName={message.toolName}
|
||||||
|
className={className}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render operation_pending messages (operations in progress when refreshing)
|
||||||
|
if (isOperationPending && message.type === "operation_pending") {
|
||||||
|
return (
|
||||||
|
<PendingOperationWidget
|
||||||
|
status="pending"
|
||||||
|
message={message.message}
|
||||||
|
toolName={message.toolName}
|
||||||
|
className={className}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render operation_in_progress messages (duplicate request while operation running)
|
||||||
|
if (isOperationInProgress && message.type === "operation_in_progress") {
|
||||||
|
return (
|
||||||
|
<PendingOperationWidget
|
||||||
|
status="in_progress"
|
||||||
|
message={message.message}
|
||||||
|
toolName={message.toolName}
|
||||||
|
className={className}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Render tool response messages (but skip agent_output if it's being rendered inside assistant message)
|
// Render tool response messages (but skip agent_output if it's being rendered inside assistant message)
|
||||||
if (isToolResponse && message.type === "tool_response") {
|
if (isToolResponse && message.type === "tool_response") {
|
||||||
return (
|
return (
|
||||||
@@ -358,7 +394,7 @@ export function ChatMessage({
|
|||||||
<ArrowsClockwiseIcon className="size-4 text-zinc-600" />
|
<ArrowsClockwiseIcon className="size-4 text-zinc-600" />
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
{!isUser && isFinalMessage && isLongResponse(displayContent) && (
|
{!isUser && isFinalMessage && !isStreaming && (
|
||||||
<Button
|
<Button
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
size="icon"
|
size="icon"
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ export type ChatMessageData =
|
|||||||
}
|
}
|
||||||
| {
|
| {
|
||||||
type: "agent_carousel";
|
type: "agent_carousel";
|
||||||
|
toolId: string;
|
||||||
toolName: string;
|
toolName: string;
|
||||||
agents: Array<{
|
agents: Array<{
|
||||||
id: string;
|
id: string;
|
||||||
@@ -74,6 +75,7 @@ export type ChatMessageData =
|
|||||||
}
|
}
|
||||||
| {
|
| {
|
||||||
type: "execution_started";
|
type: "execution_started";
|
||||||
|
toolId: string;
|
||||||
toolName: string;
|
toolName: string;
|
||||||
executionId: string;
|
executionId: string;
|
||||||
agentName?: string;
|
agentName?: string;
|
||||||
@@ -103,6 +105,29 @@ export type ChatMessageData =
|
|||||||
message: string;
|
message: string;
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
timestamp?: string | Date;
|
timestamp?: string | Date;
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: "operation_started";
|
||||||
|
toolName: string;
|
||||||
|
toolId: string;
|
||||||
|
operationId: string;
|
||||||
|
message: string;
|
||||||
|
timestamp?: string | Date;
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: "operation_pending";
|
||||||
|
toolName: string;
|
||||||
|
toolId: string;
|
||||||
|
operationId: string;
|
||||||
|
message: string;
|
||||||
|
timestamp?: string | Date;
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: "operation_in_progress";
|
||||||
|
toolName: string;
|
||||||
|
toolCallId: string;
|
||||||
|
message: string;
|
||||||
|
timestamp?: string | Date;
|
||||||
};
|
};
|
||||||
|
|
||||||
export function useChatMessage(message: ChatMessageData) {
|
export function useChatMessage(message: ChatMessageData) {
|
||||||
@@ -124,5 +149,8 @@ export function useChatMessage(message: ChatMessageData) {
|
|||||||
isExecutionStarted: message.type === "execution_started",
|
isExecutionStarted: message.type === "execution_started",
|
||||||
isInputsNeeded: message.type === "inputs_needed",
|
isInputsNeeded: message.type === "inputs_needed",
|
||||||
isClarificationNeeded: message.type === "clarification_needed",
|
isClarificationNeeded: message.type === "clarification_needed",
|
||||||
|
isOperationStarted: message.type === "operation_started",
|
||||||
|
isOperationPending: message.type === "operation_pending",
|
||||||
|
isOperationInProgress: message.type === "operation_in_progress",
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ export function ClarificationQuestionsWidget({
|
|||||||
className,
|
className,
|
||||||
}: Props) {
|
}: Props) {
|
||||||
const [answers, setAnswers] = useState<Record<string, string>>({});
|
const [answers, setAnswers] = useState<Record<string, string>>({});
|
||||||
|
const [isSubmitted, setIsSubmitted] = useState(false);
|
||||||
|
|
||||||
function handleAnswerChange(keyword: string, value: string) {
|
function handleAnswerChange(keyword: string, value: string) {
|
||||||
setAnswers((prev) => ({ ...prev, [keyword]: value }));
|
setAnswers((prev) => ({ ...prev, [keyword]: value }));
|
||||||
@@ -41,11 +42,42 @@ export function ClarificationQuestionsWidget({
|
|||||||
if (!allAnswered) {
|
if (!allAnswered) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
setIsSubmitted(true);
|
||||||
onSubmitAnswers(answers);
|
onSubmitAnswers(answers);
|
||||||
}
|
}
|
||||||
|
|
||||||
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
|
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
|
||||||
|
|
||||||
|
// Show submitted state after answers are submitted
|
||||||
|
if (isSubmitted) {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"group relative flex w-full justify-start gap-3 px-4 py-3",
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<div className="flex w-full max-w-3xl gap-3">
|
||||||
|
<div className="flex-shrink-0">
|
||||||
|
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-green-500">
|
||||||
|
<CheckCircleIcon className="h-4 w-4 text-white" weight="bold" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex min-w-0 flex-1 flex-col">
|
||||||
|
<Card className="p-4">
|
||||||
|
<Text variant="h4" className="mb-1 text-slate-900">
|
||||||
|
Answers submitted
|
||||||
|
</Text>
|
||||||
|
<Text variant="small" className="text-slate-600">
|
||||||
|
Processing your responses...
|
||||||
|
</Text>
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
|
|||||||
@@ -0,0 +1,109 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { Card } from "@/components/atoms/Card/Card";
|
||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { CircleNotch, CheckCircle, XCircle } from "@phosphor-icons/react";
|
||||||
|
|
||||||
|
type OperationStatus =
|
||||||
|
| "pending"
|
||||||
|
| "started"
|
||||||
|
| "in_progress"
|
||||||
|
| "completed"
|
||||||
|
| "error";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
status: OperationStatus;
|
||||||
|
message: string;
|
||||||
|
toolName?: string;
|
||||||
|
className?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getOperationTitle(toolName?: string): string {
|
||||||
|
if (!toolName) return "Operation";
|
||||||
|
// Convert tool name to human-readable format
|
||||||
|
// e.g., "create_agent" -> "Creating Agent", "edit_agent" -> "Editing Agent"
|
||||||
|
if (toolName === "create_agent") return "Creating Agent";
|
||||||
|
if (toolName === "edit_agent") return "Editing Agent";
|
||||||
|
// Default: capitalize and format tool name
|
||||||
|
return toolName
|
||||||
|
.split("_")
|
||||||
|
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
|
||||||
|
.join(" ");
|
||||||
|
}
|
||||||
|
|
||||||
|
export function PendingOperationWidget({
|
||||||
|
status,
|
||||||
|
message,
|
||||||
|
toolName,
|
||||||
|
className,
|
||||||
|
}: Props) {
|
||||||
|
const isPending =
|
||||||
|
status === "pending" || status === "started" || status === "in_progress";
|
||||||
|
const isCompleted = status === "completed";
|
||||||
|
const isError = status === "error";
|
||||||
|
|
||||||
|
const operationTitle = getOperationTitle(toolName);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"group relative flex w-full justify-start gap-3 px-4 py-3",
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<div className="flex w-full max-w-3xl gap-3">
|
||||||
|
<div className="flex-shrink-0">
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex h-7 w-7 items-center justify-center rounded-lg",
|
||||||
|
isPending && "bg-blue-500",
|
||||||
|
isCompleted && "bg-green-500",
|
||||||
|
isError && "bg-red-500",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{isPending && (
|
||||||
|
<CircleNotch
|
||||||
|
className="h-4 w-4 animate-spin text-white"
|
||||||
|
weight="bold"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{isCompleted && (
|
||||||
|
<CheckCircle className="h-4 w-4 text-white" weight="bold" />
|
||||||
|
)}
|
||||||
|
{isError && (
|
||||||
|
<XCircle className="h-4 w-4 text-white" weight="bold" />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex min-w-0 flex-1 flex-col">
|
||||||
|
<Card className="space-y-2 p-4">
|
||||||
|
<div>
|
||||||
|
<Text variant="h4" className="mb-1 text-slate-900">
|
||||||
|
{isPending && operationTitle}
|
||||||
|
{isCompleted && `${operationTitle} Complete`}
|
||||||
|
{isError && `${operationTitle} Failed`}
|
||||||
|
</Text>
|
||||||
|
<Text variant="small" className="text-slate-600">
|
||||||
|
{message}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{isPending && (
|
||||||
|
<Text variant="small" className="italic text-slate-500">
|
||||||
|
Check your library in a few minutes.
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{toolName && (
|
||||||
|
<Text variant="small" className="text-slate-400">
|
||||||
|
Tool: {toolName}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ export function UserChatBubble({ children, className }: UserChatBubbleProps) {
|
|||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
"group relative min-w-20 overflow-hidden rounded-xl bg-purple-100 px-3 text-right text-[1rem] leading-relaxed transition-all duration-500 ease-in-out",
|
"group relative min-w-20 overflow-hidden rounded-xl bg-purple-100 px-3 text-left text-[1rem] leading-relaxed transition-all duration-500 ease-in-out",
|
||||||
className,
|
className,
|
||||||
)}
|
)}
|
||||||
style={{
|
style={{
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
|||||||
claimSession,
|
claimSession,
|
||||||
clearSession: clearSessionBase,
|
clearSession: clearSessionBase,
|
||||||
loadSession,
|
loadSession,
|
||||||
|
startPollingForOperation,
|
||||||
} = useChatSession({
|
} = useChatSession({
|
||||||
urlSessionId,
|
urlSessionId,
|
||||||
autoCreate: false,
|
autoCreate: false,
|
||||||
@@ -94,5 +95,6 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
|||||||
loadSession,
|
loadSession,
|
||||||
sessionId: sessionIdFromHook,
|
sessionId: sessionIdFromHook,
|
||||||
showLoader,
|
showLoader,
|
||||||
|
startPollingForOperation,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ export function useChatSession({
|
|||||||
query: {
|
query: {
|
||||||
enabled: !!sessionId,
|
enabled: !!sessionId,
|
||||||
select: okData,
|
select: okData,
|
||||||
|
staleTime: 0,
|
||||||
retry: shouldRetrySessionLoad,
|
retry: shouldRetrySessionLoad,
|
||||||
retryDelay: getSessionRetryDelay,
|
retryDelay: getSessionRetryDelay,
|
||||||
},
|
},
|
||||||
@@ -102,15 +103,123 @@ export function useChatSession({
|
|||||||
}
|
}
|
||||||
}, [createError, loadError]);
|
}, [createError, loadError]);
|
||||||
|
|
||||||
|
// Track if we should be polling (set by external callers when they receive operation_started via SSE)
|
||||||
|
const [forcePolling, setForcePolling] = useState(false);
|
||||||
|
// Track if we've seen server acknowledge the pending operation (to avoid clearing forcePolling prematurely)
|
||||||
|
const hasSeenServerPendingRef = useRef(false);
|
||||||
|
|
||||||
|
// Check if there are any pending operations in the messages
|
||||||
|
// Must check all operation types: operation_pending, operation_started, operation_in_progress
|
||||||
|
const hasPendingOperationsFromServer = useMemo(() => {
|
||||||
|
if (!messages || messages.length === 0) return false;
|
||||||
|
const pendingTypes = new Set([
|
||||||
|
"operation_pending",
|
||||||
|
"operation_in_progress",
|
||||||
|
"operation_started",
|
||||||
|
]);
|
||||||
|
return messages.some((msg) => {
|
||||||
|
if (msg.role !== "tool" || !msg.content) return false;
|
||||||
|
try {
|
||||||
|
const content =
|
||||||
|
typeof msg.content === "string"
|
||||||
|
? JSON.parse(msg.content)
|
||||||
|
: msg.content;
|
||||||
|
return pendingTypes.has(content?.type);
|
||||||
|
} catch {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}, [messages]);
|
||||||
|
|
||||||
|
// Track when server has acknowledged the pending operation
|
||||||
|
useEffect(() => {
|
||||||
|
if (hasPendingOperationsFromServer) {
|
||||||
|
hasSeenServerPendingRef.current = true;
|
||||||
|
}
|
||||||
|
}, [hasPendingOperationsFromServer]);
|
||||||
|
|
||||||
|
// Combined: poll if server has pending ops OR if we received operation_started via SSE
|
||||||
|
const hasPendingOperations = hasPendingOperationsFromServer || forcePolling;
|
||||||
|
|
||||||
|
// Clear forcePolling only after server has acknowledged AND completed the operation
|
||||||
|
useEffect(() => {
|
||||||
|
if (
|
||||||
|
forcePolling &&
|
||||||
|
!hasPendingOperationsFromServer &&
|
||||||
|
hasSeenServerPendingRef.current
|
||||||
|
) {
|
||||||
|
// Server acknowledged the operation and it's now complete
|
||||||
|
setForcePolling(false);
|
||||||
|
hasSeenServerPendingRef.current = false;
|
||||||
|
}
|
||||||
|
}, [forcePolling, hasPendingOperationsFromServer]);
|
||||||
|
|
||||||
|
// Function to trigger polling (called when operation_started is received via SSE)
|
||||||
|
function startPollingForOperation() {
|
||||||
|
setForcePolling(true);
|
||||||
|
hasSeenServerPendingRef.current = false; // Reset for new operation
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh sessions list when a pending operation completes
|
||||||
|
// (hasPendingOperations transitions from true to false)
|
||||||
|
const prevHasPendingOperationsRef = useRef(hasPendingOperations);
|
||||||
useEffect(
|
useEffect(
|
||||||
function refreshSessionsListOnLoad() {
|
function refreshSessionsListOnOperationComplete() {
|
||||||
if (sessionId && sessionData && !isLoadingSession) {
|
const wasHasPending = prevHasPendingOperationsRef.current;
|
||||||
|
prevHasPendingOperationsRef.current = hasPendingOperations;
|
||||||
|
|
||||||
|
// Only invalidate when transitioning from pending to not pending
|
||||||
|
if (wasHasPending && !hasPendingOperations && sessionId) {
|
||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[sessionId, sessionData, isLoadingSession, queryClient],
|
[hasPendingOperations, sessionId, queryClient],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Poll for updates when there are pending operations
|
||||||
|
// Backoff: 2s, 4s, 6s, 8s, 10s, ... up to 30s max
|
||||||
|
const pollAttemptRef = useRef(0);
|
||||||
|
const hasPendingOperationsRef = useRef(hasPendingOperations);
|
||||||
|
hasPendingOperationsRef.current = hasPendingOperations;
|
||||||
|
|
||||||
|
useEffect(
|
||||||
|
function pollForPendingOperations() {
|
||||||
|
if (!sessionId || !hasPendingOperations) {
|
||||||
|
pollAttemptRef.current = 0;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let cancelled = false;
|
||||||
|
let timeoutId: ReturnType<typeof setTimeout> | null = null;
|
||||||
|
|
||||||
|
function schedule() {
|
||||||
|
// 2s, 4s, 6s, 8s, 10s, ... 30s (max)
|
||||||
|
const delay = Math.min((pollAttemptRef.current + 1) * 2000, 30000);
|
||||||
|
timeoutId = setTimeout(async () => {
|
||||||
|
if (cancelled) return;
|
||||||
|
pollAttemptRef.current += 1;
|
||||||
|
try {
|
||||||
|
await refetch();
|
||||||
|
} catch (err) {
|
||||||
|
console.error("[useChatSession] Poll failed:", err);
|
||||||
|
} finally {
|
||||||
|
if (!cancelled && hasPendingOperationsRef.current) {
|
||||||
|
schedule();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, delay);
|
||||||
|
}
|
||||||
|
|
||||||
|
schedule();
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
cancelled = true;
|
||||||
|
if (timeoutId) clearTimeout(timeoutId);
|
||||||
|
};
|
||||||
|
},
|
||||||
|
[sessionId, hasPendingOperations, refetch],
|
||||||
);
|
);
|
||||||
|
|
||||||
async function createSession() {
|
async function createSession() {
|
||||||
@@ -239,11 +348,13 @@ export function useChatSession({
|
|||||||
isCreating,
|
isCreating,
|
||||||
error,
|
error,
|
||||||
isSessionNotFound: isNotFoundError(loadError),
|
isSessionNotFound: isNotFoundError(loadError),
|
||||||
|
hasPendingOperations,
|
||||||
createSession,
|
createSession,
|
||||||
loadSession,
|
loadSession,
|
||||||
refreshSession,
|
refreshSession,
|
||||||
claimSession,
|
claimSession,
|
||||||
clearSession,
|
clearSession,
|
||||||
|
startPollingForOperation,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import { environment } from "../environment";
|
|||||||
|
|
||||||
export enum SessionKey {
|
export enum SessionKey {
|
||||||
CHAT_SENT_INITIAL_PROMPTS = "chat_sent_initial_prompts",
|
CHAT_SENT_INITIAL_PROMPTS = "chat_sent_initial_prompts",
|
||||||
|
CHAT_INITIAL_PROMPTS = "chat_initial_prompts",
|
||||||
}
|
}
|
||||||
|
|
||||||
function get(key: SessionKey) {
|
function get(key: SessionKey) {
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
# Video editing blocks
|
|
||||||
Reference in New Issue
Block a user