mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-23 14:08:02 -05:00
Compare commits
47 Commits
swiftyos/r
...
feat/sensi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e95a9273e6 | ||
|
|
abd4a2b0b3 | ||
|
|
3a96455db5 | ||
|
|
4c68b87e00 | ||
|
|
f4296f8764 | ||
|
|
c0547bb1ed | ||
|
|
a396155d63 | ||
|
|
4c47b4df76 | ||
|
|
c46e563e6a | ||
|
|
a80f452ffe | ||
|
|
fd970c800c | ||
|
|
5fc1ec0ece | ||
|
|
9be3ec58ae | ||
|
|
e6ca904326 | ||
|
|
dbb56fa7aa | ||
|
|
0111820f61 | ||
|
|
1a1b1aa26d | ||
|
|
614ed8cf82 | ||
|
|
edd4c96aa6 | ||
|
|
cd231e2d69 | ||
|
|
399c472623 | ||
|
|
554e2beddf | ||
|
|
29fdda3fa8 | ||
|
|
67e6a8841c | ||
|
|
aea97db485 | ||
|
|
71a6969bbd | ||
|
|
e4c3f9995b | ||
|
|
3b58684abc | ||
|
|
e8d44a62fd | ||
|
|
be024da2a8 | ||
|
|
0df917e243 | ||
|
|
8688805a8c | ||
|
|
9bdda7dab0 | ||
|
|
7d377aabaa | ||
|
|
dfd7c64068 | ||
|
|
02089bc047 | ||
|
|
bed7b356bb | ||
|
|
4efc0ff502 | ||
|
|
4ad0528257 | ||
|
|
2f440ee80a | ||
|
|
2a55923ec0 | ||
|
|
ad50f57a2b | ||
|
|
aebd961ef5 | ||
|
|
bcccaa16cc | ||
|
|
d5ddc41b18 | ||
|
|
95eab5b7eb | ||
|
|
832d6e1696 |
38
.github/workflows/platform-frontend-ci.yml
vendored
38
.github/workflows/platform-frontend-ci.yml
vendored
@@ -128,7 +128,7 @@ jobs:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
exitOnceUploaded: true
|
||||
|
||||
e2e_test:
|
||||
test:
|
||||
runs-on: big-boi
|
||||
needs: setup
|
||||
strategy:
|
||||
@@ -258,39 +258,3 @@ jobs:
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.yml logs
|
||||
|
||||
integration_test:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Generate API client
|
||||
run: pnpm generate:api
|
||||
|
||||
- name: Run Integration Tests
|
||||
run: pnpm test:unit
|
||||
|
||||
26
AGENTS.md
26
AGENTS.md
@@ -16,32 +16,6 @@ See `docs/content/platform/getting-started.md` for setup instructions.
|
||||
- Format Python code with `poetry run format`.
|
||||
- Format frontend code using `pnpm format`.
|
||||
|
||||
|
||||
## Frontend guidelines:
|
||||
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Add `usePageName.ts` hook for logic
|
||||
- Put sub-components in local `components/` folder
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
||||
- Use function declarations for components, arrow functions only for callbacks
|
||||
- No barrel files or `index.ts` re-exports
|
||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
||||
- Avoid comments at all times unless the code is very complex
|
||||
|
||||
## Testing
|
||||
|
||||
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
|
||||
|
||||
@@ -201,7 +201,7 @@ If you get any pushback or hit complex block conditions check the new_blocks gui
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
### Frontend guidelines:
|
||||
**Frontend feature development:**
|
||||
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
@@ -217,14 +217,6 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
||||
- Use function declarations for components, arrow functions only for callbacks
|
||||
- No barrel files or `index.ts` re-exports
|
||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
||||
- Avoid comments at all times unless the code is very complex
|
||||
|
||||
### Security Implementation
|
||||
|
||||
|
||||
@@ -290,11 +290,6 @@ async def _cache_session(session: ChatSession) -> None:
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def cache_chat_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session without persisting to the database."""
|
||||
await _cache_session(session)
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
|
||||
@@ -172,12 +172,12 @@ async def get_session(
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, or None if not found.
|
||||
SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.
|
||||
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
logger.info(
|
||||
@@ -222,8 +222,6 @@ async def stream_chat_post(
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
@@ -232,26 +230,7 @@ async def stream_chat_post(
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
@@ -296,8 +275,6 @@ async def stream_chat_get(
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
message,
|
||||
@@ -305,26 +282,7 @@ async def stream_chat_get(
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@@ -1,20 +1,12 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from asyncio import CancelledError
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
import orjson
|
||||
from langfuse import get_client
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
APIStatusError,
|
||||
PermissionDeniedError,
|
||||
RateLimitError,
|
||||
)
|
||||
from langfuse import get_client, propagate_attributes
|
||||
from langfuse.openai import openai # type: ignore
|
||||
from openai import APIConnectionError, APIError, APIStatusError, RateLimitError
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||
|
||||
from backend.data.understanding import (
|
||||
@@ -29,7 +21,6 @@ from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
cache_chat_session,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
@@ -276,301 +267,257 @@ async def stream_chat_completion(
|
||||
# Build system prompt with business understanding
|
||||
system_prompt, understanding = await _build_system_prompt(user_id)
|
||||
|
||||
# Initialize variables for streaming
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
)
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
has_saved_assistant_message = False
|
||||
has_appended_streaming_message = False
|
||||
last_cache_time = 0.0
|
||||
last_cache_content_len = 0
|
||||
# Create Langfuse trace for this LLM call (each call gets its own trace, grouped by session_id)
|
||||
# Using v3 SDK: start_observation creates a root span, update_trace sets trace-level attributes
|
||||
input = message
|
||||
if not message and tool_call_response:
|
||||
input = tool_call_response
|
||||
|
||||
has_yielded_end = False
|
||||
has_yielded_error = False
|
||||
has_done_tool_call = False
|
||||
has_received_text = False
|
||||
text_streaming_ended = False
|
||||
tool_response_messages: list[ChatMessage] = []
|
||||
should_retry = False
|
||||
|
||||
# Generate unique IDs for AI SDK protocol
|
||||
import uuid as uuid_module
|
||||
|
||||
message_id = str(uuid_module.uuid4())
|
||||
text_block_id = str(uuid_module.uuid4())
|
||||
|
||||
# Yield message start
|
||||
yield StreamStart(messageId=message_id)
|
||||
|
||||
try:
|
||||
async for chunk in _stream_chat_chunks(
|
||||
session=session,
|
||||
tools=tools,
|
||||
system_prompt=system_prompt,
|
||||
text_block_id=text_block_id,
|
||||
langfuse = get_client()
|
||||
with langfuse.start_as_current_observation(
|
||||
as_type="span",
|
||||
name="user-copilot-request",
|
||||
input=input,
|
||||
) as span:
|
||||
with propagate_attributes(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tags=["copilot"],
|
||||
metadata={
|
||||
"users_information": format_understanding_for_prompt(understanding)[
|
||||
:200
|
||||
] # langfuse only accepts upto to 200 chars
|
||||
},
|
||||
):
|
||||
|
||||
if isinstance(chunk, StreamTextStart):
|
||||
# Emit text-start before first text delta
|
||||
if not has_received_text:
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamTextDelta):
|
||||
delta = chunk.delta or ""
|
||||
assert assistant_response.content is not None
|
||||
assistant_response.content += delta
|
||||
has_received_text = True
|
||||
if not has_appended_streaming_message:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_streaming_message = True
|
||||
current_time = time.monotonic()
|
||||
content_len = len(assistant_response.content)
|
||||
if (
|
||||
current_time - last_cache_time >= 1.0
|
||||
and content_len > last_cache_content_len
|
||||
):
|
||||
try:
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to cache partial session {session.session_id}: {e}"
|
||||
)
|
||||
last_cache_time = current_time
|
||||
last_cache_content_len = content_len
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamTextEnd):
|
||||
# Emit text-end after text completes
|
||||
if has_received_text and not text_streaming_ended:
|
||||
text_streaming_ended = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamToolInputStart):
|
||||
# Emit text-end before first tool call, but only if we've received text
|
||||
if has_received_text and not text_streaming_ended:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_streaming_ended = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamToolInputAvailable):
|
||||
# Accumulate tool calls in OpenAI format
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": chunk.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": chunk.toolName,
|
||||
"arguments": orjson.dumps(chunk.input).decode("utf-8"),
|
||||
},
|
||||
}
|
||||
)
|
||||
elif isinstance(chunk, StreamToolOutputAvailable):
|
||||
result_content = (
|
||||
chunk.output
|
||||
if isinstance(chunk.output, str)
|
||||
else orjson.dumps(chunk.output).decode("utf-8")
|
||||
)
|
||||
tool_response_messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=result_content,
|
||||
tool_call_id=chunk.toolCallId,
|
||||
)
|
||||
)
|
||||
has_done_tool_call = True
|
||||
# Track if any tool execution failed
|
||||
if not chunk.success:
|
||||
logger.warning(
|
||||
f"Tool {chunk.toolName} (ID: {chunk.toolCallId}) execution failed"
|
||||
)
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamFinish):
|
||||
if not has_done_tool_call:
|
||||
# Emit text-end before finish if we received text but haven't closed it
|
||||
if has_received_text and not text_streaming_ended:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_streaming_ended = True
|
||||
|
||||
# Save assistant message before yielding finish to ensure it's persisted
|
||||
# even if client disconnects immediately after receiving StreamFinish
|
||||
if not has_saved_assistant_message:
|
||||
messages_to_save_early: list[ChatMessage] = []
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if not has_appended_streaming_message and (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
):
|
||||
messages_to_save_early.append(assistant_response)
|
||||
messages_to_save_early.extend(tool_response_messages)
|
||||
|
||||
if messages_to_save_early:
|
||||
session.messages.extend(messages_to_save_early)
|
||||
logger.info(
|
||||
f"Saving assistant message before StreamFinish: "
|
||||
f"content_len={len(assistant_response.content or '')}, "
|
||||
f"tool_calls={len(assistant_response.tool_calls or [])}, "
|
||||
f"tool_responses={len(tool_response_messages)}"
|
||||
)
|
||||
if messages_to_save_early or has_appended_streaming_message:
|
||||
await upsert_chat_session(session)
|
||||
has_saved_assistant_message = True
|
||||
|
||||
has_yielded_end = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamError):
|
||||
has_yielded_error = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamUsage):
|
||||
session.usage.append(
|
||||
Usage(
|
||||
prompt_tokens=chunk.promptTokens,
|
||||
completion_tokens=chunk.completionTokens,
|
||||
total_tokens=chunk.totalTokens,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
||||
|
||||
except CancelledError:
|
||||
if not has_saved_assistant_message:
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if assistant_response.content:
|
||||
assistant_response.content = (
|
||||
f"{assistant_response.content}\n\n[interrupted]"
|
||||
)
|
||||
else:
|
||||
assistant_response.content = "[interrupted]"
|
||||
if not has_appended_streaming_message:
|
||||
session.messages.append(assistant_response)
|
||||
if tool_response_messages:
|
||||
session.messages.extend(tool_response_messages)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to save interrupted session {session.session_id}: {e}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error during stream: {e!s}", exc_info=True)
|
||||
|
||||
# Check if this is a retryable error (JSON parsing, incomplete tool calls, etc.)
|
||||
is_retryable = isinstance(e, (orjson.JSONDecodeError, KeyError, TypeError))
|
||||
|
||||
if is_retryable and retry_count < config.max_retries:
|
||||
logger.info(
|
||||
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
||||
# Initialize variables that will be used in finally block (must be defined before try)
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
)
|
||||
should_retry = True
|
||||
else:
|
||||
# Non-retryable error or max retries exceeded
|
||||
# Save any partial progress before reporting error
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
# Wrap main logic in try/finally to ensure Langfuse observations are always ended
|
||||
has_yielded_end = False
|
||||
has_yielded_error = False
|
||||
has_done_tool_call = False
|
||||
has_received_text = False
|
||||
text_streaming_ended = False
|
||||
tool_response_messages: list[ChatMessage] = []
|
||||
should_retry = False
|
||||
|
||||
# Generate unique IDs for AI SDK protocol
|
||||
import uuid as uuid_module
|
||||
|
||||
message_id = str(uuid_module.uuid4())
|
||||
text_block_id = str(uuid_module.uuid4())
|
||||
|
||||
# Yield message start
|
||||
yield StreamStart(messageId=message_id)
|
||||
|
||||
try:
|
||||
async for chunk in _stream_chat_chunks(
|
||||
session=session,
|
||||
tools=tools,
|
||||
system_prompt=system_prompt,
|
||||
text_block_id=text_block_id,
|
||||
):
|
||||
|
||||
if isinstance(chunk, StreamTextStart):
|
||||
# Emit text-start before first text delta
|
||||
if not has_received_text:
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamTextDelta):
|
||||
delta = chunk.delta or ""
|
||||
assert assistant_response.content is not None
|
||||
assistant_response.content += delta
|
||||
has_received_text = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamTextEnd):
|
||||
# Emit text-end after text completes
|
||||
if has_received_text and not text_streaming_ended:
|
||||
text_streaming_ended = True
|
||||
if assistant_response.content:
|
||||
logger.warn(
|
||||
f"StreamTextEnd: Attempting to set output {assistant_response.content}"
|
||||
)
|
||||
span.update_trace(output=assistant_response.content)
|
||||
span.update(output=assistant_response.content)
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamToolInputStart):
|
||||
# Emit text-end before first tool call, but only if we've received text
|
||||
if has_received_text and not text_streaming_ended:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_streaming_ended = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamToolInputAvailable):
|
||||
# Accumulate tool calls in OpenAI format
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": chunk.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": chunk.toolName,
|
||||
"arguments": orjson.dumps(chunk.input).decode(
|
||||
"utf-8"
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
elif isinstance(chunk, StreamToolOutputAvailable):
|
||||
result_content = (
|
||||
chunk.output
|
||||
if isinstance(chunk.output, str)
|
||||
else orjson.dumps(chunk.output).decode("utf-8")
|
||||
)
|
||||
tool_response_messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=result_content,
|
||||
tool_call_id=chunk.toolCallId,
|
||||
)
|
||||
)
|
||||
has_done_tool_call = True
|
||||
# Track if any tool execution failed
|
||||
if not chunk.success:
|
||||
logger.warning(
|
||||
f"Tool {chunk.toolName} (ID: {chunk.toolCallId}) execution failed"
|
||||
)
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamFinish):
|
||||
if not has_done_tool_call:
|
||||
# Emit text-end before finish if we received text but haven't closed it
|
||||
if has_received_text and not text_streaming_ended:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_streaming_ended = True
|
||||
has_yielded_end = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamError):
|
||||
has_yielded_error = True
|
||||
elif isinstance(chunk, StreamUsage):
|
||||
session.usage.append(
|
||||
Usage(
|
||||
prompt_tokens=chunk.promptTokens,
|
||||
completion_tokens=chunk.completionTokens,
|
||||
total_tokens=chunk.totalTokens,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Unknown chunk type: {type(chunk)}", exc_info=True
|
||||
)
|
||||
if assistant_response.content:
|
||||
langfuse.update_current_trace(output=assistant_response.content)
|
||||
langfuse.update_current_span(output=assistant_response.content)
|
||||
elif tool_response_messages:
|
||||
langfuse.update_current_trace(output=str(tool_response_messages))
|
||||
langfuse.update_current_span(output=str(tool_response_messages))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during stream: {e!s}", exc_info=True)
|
||||
|
||||
# Check if this is a retryable error (JSON parsing, incomplete tool calls, etc.)
|
||||
is_retryable = isinstance(
|
||||
e, (orjson.JSONDecodeError, KeyError, TypeError)
|
||||
)
|
||||
|
||||
if is_retryable and retry_count < config.max_retries:
|
||||
logger.info(
|
||||
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
||||
)
|
||||
should_retry = True
|
||||
else:
|
||||
# Non-retryable error or max retries exceeded
|
||||
# Save any partial progress before reporting error
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
|
||||
# Add assistant message if it has content or tool calls
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_save.append(assistant_response)
|
||||
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
|
||||
session.messages.extend(messages_to_save)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
if not has_yielded_error:
|
||||
error_message = str(e)
|
||||
if not is_retryable:
|
||||
error_message = f"Non-retryable error: {error_message}"
|
||||
elif retry_count >= config.max_retries:
|
||||
error_message = f"Max retries ({config.max_retries}) exceeded: {error_message}"
|
||||
|
||||
error_response = StreamError(errorText=error_message)
|
||||
yield error_response
|
||||
if not has_yielded_end:
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Handle retry outside of exception handler to avoid nesting
|
||||
if should_retry and retry_count < config.max_retries:
|
||||
logger.info(
|
||||
f"Retrying stream_chat_completion for session {session_id}, attempt {retry_count + 1}"
|
||||
)
|
||||
async for chunk in stream_chat_completion(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
retry_count=retry_count + 1,
|
||||
session=session,
|
||||
context=context,
|
||||
):
|
||||
yield chunk
|
||||
return # Exit after retry to avoid double-saving in finally block
|
||||
|
||||
# Normal completion path - save session and handle tool call continuation
|
||||
logger.info(
|
||||
f"Normal completion path: session={session.session_id}, "
|
||||
f"current message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
# Build the messages list in the correct order
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
|
||||
# Add assistant message if it has content or tool calls
|
||||
# Add assistant message with tool_calls if any
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if not has_appended_streaming_message and (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
):
|
||||
logger.info(
|
||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||
)
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_save.append(assistant_response)
|
||||
logger.info(
|
||||
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
|
||||
)
|
||||
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
|
||||
if not has_saved_assistant_message:
|
||||
if messages_to_save:
|
||||
session.messages.extend(messages_to_save)
|
||||
if messages_to_save or has_appended_streaming_message:
|
||||
await upsert_chat_session(session)
|
||||
|
||||
if not has_yielded_error:
|
||||
error_message = str(e)
|
||||
if not is_retryable:
|
||||
error_message = f"Non-retryable error: {error_message}"
|
||||
elif retry_count >= config.max_retries:
|
||||
error_message = (
|
||||
f"Max retries ({config.max_retries}) exceeded: {error_message}"
|
||||
)
|
||||
|
||||
error_response = StreamError(errorText=error_message)
|
||||
yield error_response
|
||||
if not has_yielded_end:
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Handle retry outside of exception handler to avoid nesting
|
||||
if should_retry and retry_count < config.max_retries:
|
||||
logger.info(
|
||||
f"Retrying stream_chat_completion for session {session_id}, attempt {retry_count + 1}"
|
||||
)
|
||||
async for chunk in stream_chat_completion(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
retry_count=retry_count + 1,
|
||||
session=session,
|
||||
context=context,
|
||||
):
|
||||
yield chunk
|
||||
return # Exit after retry to avoid double-saving in finally block
|
||||
|
||||
# Normal completion path - save session and handle tool call continuation
|
||||
# Only save if we haven't already saved when StreamFinish was received
|
||||
if not has_saved_assistant_message:
|
||||
logger.info(
|
||||
f"Normal completion path: session={session.session_id}, "
|
||||
f"current message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
# Build the messages list in the correct order
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
|
||||
# Add assistant message with tool_calls if any
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
logger.info(
|
||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||
)
|
||||
if not has_appended_streaming_message and (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
):
|
||||
messages_to_save.append(assistant_response)
|
||||
logger.info(
|
||||
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
|
||||
f"Saving {len(tool_response_messages)} tool response messages, "
|
||||
f"total_to_save={len(messages_to_save)}"
|
||||
)
|
||||
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
logger.info(
|
||||
f"Saving {len(tool_response_messages)} tool response messages, "
|
||||
f"total_to_save={len(messages_to_save)}"
|
||||
)
|
||||
|
||||
if messages_to_save:
|
||||
session.messages.extend(messages_to_save)
|
||||
logger.info(
|
||||
f"Extended session messages, new message_count={len(session.messages)}"
|
||||
)
|
||||
if messages_to_save or has_appended_streaming_message:
|
||||
await upsert_chat_session(session)
|
||||
else:
|
||||
logger.info(
|
||||
"Assistant message already saved when StreamFinish was received, "
|
||||
"skipping duplicate save"
|
||||
)
|
||||
|
||||
# If we did a tool call, stream the chat completion again to get the next response
|
||||
if has_done_tool_call:
|
||||
logger.info(
|
||||
"Tool call executed, streaming chat completion again to get assistant response"
|
||||
)
|
||||
async for chunk in stream_chat_completion(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
session=session, # Pass session object to avoid Redis refetch
|
||||
context=context,
|
||||
tool_call_response=str(tool_response_messages),
|
||||
):
|
||||
yield chunk
|
||||
# If we did a tool call, stream the chat completion again to get the next response
|
||||
if has_done_tool_call:
|
||||
logger.info(
|
||||
"Tool call executed, streaming chat completion again to get assistant response"
|
||||
)
|
||||
async for chunk in stream_chat_completion(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
session=session, # Pass session object to avoid Redis refetch
|
||||
context=context,
|
||||
tool_call_response=str(tool_response_messages),
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
# Retry configuration for OpenAI API calls
|
||||
@@ -598,12 +545,6 @@ def _is_retryable_error(error: Exception) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _is_region_blocked_error(error: Exception) -> bool:
|
||||
if isinstance(error, PermissionDeniedError):
|
||||
return "not available in your region" in str(error).lower()
|
||||
return "not available in your region" in str(error).lower()
|
||||
|
||||
|
||||
async def _stream_chat_chunks(
|
||||
session: ChatSession,
|
||||
tools: list[ChatCompletionToolParam],
|
||||
@@ -796,18 +737,7 @@ async def _stream_chat_chunks(
|
||||
f"Error in stream (not retrying): {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
error_code = None
|
||||
error_text = str(e)
|
||||
if _is_region_blocked_error(e):
|
||||
error_code = "MODEL_NOT_AVAILABLE_REGION"
|
||||
error_text = (
|
||||
"This model is not available in your region. "
|
||||
"Please connect via VPN and try again."
|
||||
)
|
||||
error_response = StreamError(
|
||||
errorText=error_text,
|
||||
code=error_code,
|
||||
)
|
||||
error_response = StreamError(errorText=str(e))
|
||||
yield error_response
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
@@ -59,6 +61,7 @@ and automations for the user's specific needs."""
|
||||
"""Requires authentication to store user-specific data."""
|
||||
return True
|
||||
|
||||
@observe(as_type="tool", name="add_understanding")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
|
||||
@@ -5,6 +5,7 @@ import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
@@ -328,6 +329,7 @@ class AgentOutputTool(BaseTool):
|
||||
total_executions=len(available_executions) if available_executions else 1,
|
||||
)
|
||||
|
||||
@observe(as_type="tool", name="view_agent_output")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
@@ -78,6 +80,7 @@ class CreateAgentTool(BaseTool):
|
||||
"required": ["description"],
|
||||
}
|
||||
|
||||
@observe(as_type="tool", name="create_agent")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
@@ -85,6 +87,7 @@ class EditAgentTool(BaseTool):
|
||||
"required": ["agent_id", "changes"],
|
||||
}
|
||||
|
||||
@observe(as_type="tool", name="edit_agent")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
@@ -35,6 +37,7 @@ class FindAgentTool(BaseTool):
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@observe(as_type="tool", name="find_agent")
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
@@ -55,6 +56,7 @@ class FindBlockTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@observe(as_type="tool", name="find_block")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
@@ -41,6 +43,7 @@ class FindLibraryAgentTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@observe(as_type="tool", name="find_library_agent")
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
|
||||
@@ -4,6 +4,8 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
@@ -71,6 +73,7 @@ class GetDocPageTool(BaseTool):
|
||||
url_path = path.rsplit(".", 1)[0] if "." in path else path
|
||||
return f"{DOCS_BASE_URL}/{url_path}"
|
||||
|
||||
@observe(as_type="tool", name="get_doc_page")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.api.features.chat.config import ChatConfig
|
||||
@@ -154,6 +155,7 @@ class RunAgentTool(BaseTool):
|
||||
"""All operations require authentication."""
|
||||
return True
|
||||
|
||||
@observe(as_type="tool", name="run_agent")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
|
||||
@@ -4,6 +4,8 @@ import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
@@ -128,6 +130,7 @@ class RunBlockTool(BaseTool):
|
||||
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
@observe(as_type="tool", name="run_block")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
@@ -87,6 +88,7 @@ class SearchDocsTool(BaseTool):
|
||||
url_path = path.rsplit(".", 1)[0] if "." in path else path
|
||||
return f"{DOCS_BASE_URL}/{url_path}"
|
||||
|
||||
@observe(as_type="tool", name="search_docs")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
|
||||
@@ -107,6 +107,13 @@ class ReviewItem(BaseModel):
|
||||
reviewed_data: SafeJsonData | None = Field(
|
||||
None, description="Optional edited data (ignored if approved=False)"
|
||||
)
|
||||
auto_approve_future: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true and this review is approved, future executions of this same "
|
||||
"block (node) will be automatically approved. This only affects approved reviews."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("reviewed_data")
|
||||
@classmethod
|
||||
@@ -174,6 +181,9 @@ class ReviewRequest(BaseModel):
|
||||
This request must include ALL pending reviews for a graph execution.
|
||||
Each review will be either approved (with optional data modifications)
|
||||
or rejected (data ignored). The execution will resume only after ALL reviews are processed.
|
||||
|
||||
Each review item can individually specify whether to auto-approve future executions
|
||||
of the same block via the `auto_approve_future` field on ReviewItem.
|
||||
"""
|
||||
|
||||
reviews: List[ReviewItem] = Field(
|
||||
|
||||
@@ -8,6 +8,12 @@ from prisma.enums import ReviewStatus
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.api.rest_api import handle_internal_http_error
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
ExecutionStatus,
|
||||
NodeExecutionResult,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
|
||||
from .model import PendingHumanReviewModel
|
||||
from .routes import router
|
||||
@@ -15,20 +21,24 @@ from .routes import router
|
||||
# Using a fixed timestamp for reproducible tests
|
||||
FIXED_NOW = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router, prefix="/api/review")
|
||||
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create FastAPI app for testing"""
|
||||
test_app = fastapi.FastAPI()
|
||||
test_app.include_router(router, prefix="/api/review")
|
||||
test_app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
return test_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
@pytest.fixture
|
||||
def client(app, mock_jwt_user):
|
||||
"""Create test client with auth overrides"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
with fastapi.testclient.TestClient(app) as test_client:
|
||||
yield test_client
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@@ -55,6 +65,7 @@ def sample_pending_review(test_user_id: str) -> PendingHumanReviewModel:
|
||||
|
||||
|
||||
def test_get_pending_reviews_empty(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
@@ -73,6 +84,7 @@ def test_get_pending_reviews_empty(
|
||||
|
||||
|
||||
def test_get_pending_reviews_with_data(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
snapshot: Snapshot,
|
||||
@@ -95,6 +107,7 @@ def test_get_pending_reviews_with_data(
|
||||
|
||||
|
||||
def test_get_pending_reviews_for_execution_success(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
snapshot: Snapshot,
|
||||
@@ -123,6 +136,7 @@ def test_get_pending_reviews_for_execution_success(
|
||||
|
||||
|
||||
def test_get_pending_reviews_for_execution_not_available(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Test access denied when user doesn't own the execution"""
|
||||
@@ -138,6 +152,7 @@ def test_get_pending_reviews_for_execution_not_available(
|
||||
|
||||
|
||||
def test_process_review_action_approve_success(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
@@ -145,6 +160,12 @@ def test_process_review_action_approve_success(
|
||||
"""Test successful review approval"""
|
||||
# Mock the route functions
|
||||
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = sample_pending_review
|
||||
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
@@ -173,6 +194,14 @@ def test_process_review_action_approve_success(
|
||||
)
|
||||
mock_process_all_reviews.return_value = {"test_node_123": approved_review}
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
@@ -202,6 +231,7 @@ def test_process_review_action_approve_success(
|
||||
|
||||
|
||||
def test_process_review_action_reject_success(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
@@ -209,6 +239,20 @@ def test_process_review_action_reject_success(
|
||||
"""Test successful review rejection"""
|
||||
# Mock the route functions
|
||||
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = sample_pending_review
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
@@ -262,6 +306,7 @@ def test_process_review_action_reject_success(
|
||||
|
||||
|
||||
def test_process_review_action_mixed_success(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
@@ -288,6 +333,12 @@ def test_process_review_action_mixed_success(
|
||||
|
||||
# Mock the route functions
|
||||
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = sample_pending_review
|
||||
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
@@ -337,6 +388,14 @@ def test_process_review_action_mixed_success(
|
||||
"test_node_456": rejected_review,
|
||||
}
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
@@ -369,6 +428,7 @@ def test_process_review_action_mixed_success(
|
||||
|
||||
|
||||
def test_process_review_action_empty_request(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
@@ -386,10 +446,45 @@ def test_process_review_action_empty_request(
|
||||
|
||||
|
||||
def test_process_review_action_review_not_found(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error when review is not found"""
|
||||
# Create a review with the nonexistent_node ID so the route can find the graph_exec_id
|
||||
nonexistent_review = PendingHumanReviewModel(
|
||||
node_exec_id="nonexistent_node",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
payload={"data": "test"},
|
||||
instructions="Review",
|
||||
editable=True,
|
||||
status=ReviewStatus.WAITING,
|
||||
review_message=None,
|
||||
was_edited=None,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=None,
|
||||
reviewed_at=None,
|
||||
)
|
||||
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = nonexistent_review
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
# Mock the functions that extract graph execution ID from the request
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
@@ -422,11 +517,26 @@ def test_process_review_action_review_not_found(
|
||||
|
||||
|
||||
def test_process_review_action_partial_failure(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test handling of partial failures in review processing"""
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = sample_pending_review
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
# Mock the route functions
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
@@ -456,16 +566,50 @@ def test_process_review_action_partial_failure(
|
||||
|
||||
|
||||
def test_process_review_action_invalid_node_exec_id(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test failure when trying to process review with invalid node execution ID"""
|
||||
# Create a review with the invalid-node-format ID so the route can find the graph_exec_id
|
||||
invalid_review = PendingHumanReviewModel(
|
||||
node_exec_id="invalid-node-format",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
payload={"data": "test"},
|
||||
instructions="Review",
|
||||
editable=True,
|
||||
status=ReviewStatus.WAITING,
|
||||
review_message=None,
|
||||
was_edited=None,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=None,
|
||||
reviewed_at=None,
|
||||
)
|
||||
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = invalid_review
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
# Mock the route functions
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||
mock_get_reviews_for_execution.return_value = [invalid_review]
|
||||
|
||||
# Mock validation failure - this should return 400, not 500
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
@@ -490,3 +634,595 @@ def test_process_review_action_invalid_node_exec_id(
|
||||
# Should be a 400 Bad Request, not 500 Internal Server Error
|
||||
assert response.status_code == 400
|
||||
assert "Invalid node execution ID format" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_process_review_action_auto_approve_creates_auto_approval_records(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test that auto_approve_future_actions flag creates auto-approval records"""
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = sample_pending_review
|
||||
|
||||
# Mock process_all_reviews
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
approved_review = PendingHumanReviewModel(
|
||||
node_exec_id="test_node_123",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
payload={"data": "test payload"},
|
||||
instructions="Please review",
|
||||
editable=True,
|
||||
status=ReviewStatus.APPROVED,
|
||||
review_message="Approved",
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=FIXED_NOW,
|
||||
reviewed_at=FIXED_NOW,
|
||||
)
|
||||
mock_process_all_reviews.return_value = {"test_node_123": approved_review}
|
||||
|
||||
# Mock get_node_execution to return node_id
|
||||
mock_get_node_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_node_execution"
|
||||
)
|
||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||
mock_node_exec.node_id = "test_node_def_456"
|
||||
mock_get_node_execution.return_value = mock_node_exec
|
||||
|
||||
# Mock create_auto_approval_record
|
||||
mock_create_auto_approval = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.create_auto_approval_record"
|
||||
)
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
# Mock has_pending_reviews_for_graph_exec
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
# Mock get_graph_settings to return custom settings
|
||||
mock_get_settings = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_settings"
|
||||
)
|
||||
mock_get_settings.return_value = GraphSettings(
|
||||
human_in_the_loop_safe_mode=True,
|
||||
sensitive_action_safe_mode=True,
|
||||
)
|
||||
|
||||
# Mock get_user_by_id to prevent database access
|
||||
mock_get_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_user_by_id"
|
||||
)
|
||||
mock_user = mocker.Mock()
|
||||
mock_user.timezone = "UTC"
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock add_graph_execution
|
||||
mock_add_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.add_graph_execution"
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"reviews": [
|
||||
{
|
||||
"node_exec_id": "test_node_123",
|
||||
"approved": True,
|
||||
"message": "Approved",
|
||||
"auto_approve_future": True,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
response = client.post("/api/review/action", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify process_all_reviews_for_execution was called (without auto_approve param)
|
||||
mock_process_all_reviews.assert_called_once()
|
||||
|
||||
# Verify create_auto_approval_record was called for the approved review
|
||||
mock_create_auto_approval.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
node_id="test_node_def_456",
|
||||
payload={"data": "test payload"},
|
||||
)
|
||||
|
||||
# Verify get_graph_settings was called with correct parameters
|
||||
mock_get_settings.assert_called_once_with(
|
||||
user_id=test_user_id, graph_id="test_graph_789"
|
||||
)
|
||||
|
||||
# Verify add_graph_execution was called with proper ExecutionContext
|
||||
mock_add_execution.assert_called_once()
|
||||
call_kwargs = mock_add_execution.call_args.kwargs
|
||||
execution_context = call_kwargs["execution_context"]
|
||||
|
||||
assert isinstance(execution_context, ExecutionContext)
|
||||
assert execution_context.human_in_the_loop_safe_mode is True
|
||||
assert execution_context.sensitive_action_safe_mode is True
|
||||
|
||||
|
||||
def test_process_review_action_without_auto_approve_still_loads_settings(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test that execution context is created with settings even without auto-approve"""
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = sample_pending_review
|
||||
|
||||
# Mock process_all_reviews
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
approved_review = PendingHumanReviewModel(
|
||||
node_exec_id="test_node_123",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
payload={"data": "test payload"},
|
||||
instructions="Please review",
|
||||
editable=True,
|
||||
status=ReviewStatus.APPROVED,
|
||||
review_message="Approved",
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=FIXED_NOW,
|
||||
reviewed_at=FIXED_NOW,
|
||||
)
|
||||
mock_process_all_reviews.return_value = {"test_node_123": approved_review}
|
||||
|
||||
# Mock create_auto_approval_record - should NOT be called when auto_approve is False
|
||||
mock_create_auto_approval = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.create_auto_approval_record"
|
||||
)
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
# Mock has_pending_reviews_for_graph_exec
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
# Mock get_graph_settings with sensitive_action_safe_mode enabled
|
||||
mock_get_settings = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_settings"
|
||||
)
|
||||
mock_get_settings.return_value = GraphSettings(
|
||||
human_in_the_loop_safe_mode=False,
|
||||
sensitive_action_safe_mode=True,
|
||||
)
|
||||
|
||||
# Mock get_user_by_id to prevent database access
|
||||
mock_get_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_user_by_id"
|
||||
)
|
||||
mock_user = mocker.Mock()
|
||||
mock_user.timezone = "UTC"
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock add_graph_execution
|
||||
mock_add_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.add_graph_execution"
|
||||
)
|
||||
|
||||
# Request WITHOUT auto_approve_future (defaults to False)
|
||||
request_data = {
|
||||
"reviews": [
|
||||
{
|
||||
"node_exec_id": "test_node_123",
|
||||
"approved": True,
|
||||
"message": "Approved",
|
||||
# auto_approve_future defaults to False
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
response = client.post("/api/review/action", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify process_all_reviews_for_execution was called
|
||||
mock_process_all_reviews.assert_called_once()
|
||||
|
||||
# Verify create_auto_approval_record was NOT called (auto_approve_future=False)
|
||||
mock_create_auto_approval.assert_not_called()
|
||||
|
||||
# Verify settings were loaded
|
||||
mock_get_settings.assert_called_once()
|
||||
|
||||
# Verify ExecutionContext has proper settings
|
||||
mock_add_execution.assert_called_once()
|
||||
call_kwargs = mock_add_execution.call_args.kwargs
|
||||
execution_context = call_kwargs["execution_context"]
|
||||
|
||||
assert isinstance(execution_context, ExecutionContext)
|
||||
assert execution_context.human_in_the_loop_safe_mode is False
|
||||
assert execution_context.sensitive_action_safe_mode is True
|
||||
|
||||
|
||||
def test_process_review_action_auto_approve_only_applies_to_approved_reviews(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test that auto_approve record is created only for approved reviews"""
|
||||
# Create two reviews - one approved, one rejected
|
||||
approved_review = PendingHumanReviewModel(
|
||||
node_exec_id="node_exec_approved",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
payload={"data": "approved"},
|
||||
instructions="Review",
|
||||
editable=True,
|
||||
status=ReviewStatus.APPROVED,
|
||||
review_message=None,
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=FIXED_NOW,
|
||||
reviewed_at=FIXED_NOW,
|
||||
)
|
||||
rejected_review = PendingHumanReviewModel(
|
||||
node_exec_id="node_exec_rejected",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
payload={"data": "rejected"},
|
||||
instructions="Review",
|
||||
editable=True,
|
||||
status=ReviewStatus.REJECTED,
|
||||
review_message="Rejected",
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=FIXED_NOW,
|
||||
reviewed_at=FIXED_NOW,
|
||||
)
|
||||
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = approved_review
|
||||
|
||||
# Mock process_all_reviews
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
mock_process_all_reviews.return_value = {
|
||||
"node_exec_approved": approved_review,
|
||||
"node_exec_rejected": rejected_review,
|
||||
}
|
||||
|
||||
# Mock get_node_execution to return node_id (only called for approved review)
|
||||
mock_get_node_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_node_execution"
|
||||
)
|
||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||
mock_node_exec.node_id = "test_node_def_approved"
|
||||
mock_get_node_execution.return_value = mock_node_exec
|
||||
|
||||
# Mock create_auto_approval_record
|
||||
mock_create_auto_approval = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.create_auto_approval_record"
|
||||
)
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
# Mock has_pending_reviews_for_graph_exec
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
# Mock get_graph_settings
|
||||
mock_get_settings = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_settings"
|
||||
)
|
||||
mock_get_settings.return_value = GraphSettings()
|
||||
|
||||
# Mock get_user_by_id to prevent database access
|
||||
mock_get_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_user_by_id"
|
||||
)
|
||||
mock_user = mocker.Mock()
|
||||
mock_user.timezone = "UTC"
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock add_graph_execution
|
||||
mock_add_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.add_graph_execution"
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"reviews": [
|
||||
{
|
||||
"node_exec_id": "node_exec_approved",
|
||||
"approved": True,
|
||||
"auto_approve_future": True,
|
||||
},
|
||||
{
|
||||
"node_exec_id": "node_exec_rejected",
|
||||
"approved": False,
|
||||
"auto_approve_future": True, # Should be ignored since rejected
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
response = client.post("/api/review/action", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify process_all_reviews_for_execution was called
|
||||
mock_process_all_reviews.assert_called_once()
|
||||
|
||||
# Verify create_auto_approval_record was called ONLY for the approved review
|
||||
# (not for the rejected one)
|
||||
mock_create_auto_approval.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
node_id="test_node_def_approved",
|
||||
payload={"data": "approved"},
|
||||
)
|
||||
|
||||
# Verify get_node_execution was called only for approved review
|
||||
mock_get_node_execution.assert_called_once_with("node_exec_approved")
|
||||
|
||||
# Verify ExecutionContext was created (auto-approval is now DB-based)
|
||||
call_kwargs = mock_add_execution.call_args.kwargs
|
||||
execution_context = call_kwargs["execution_context"]
|
||||
assert isinstance(execution_context, ExecutionContext)
|
||||
|
||||
|
||||
def test_process_review_action_per_review_auto_approve_granularity(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test that auto-approval can be set per-review (granular control)"""
|
||||
# Mock get_pending_review_by_node_exec_id - return different reviews based on node_exec_id
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
|
||||
# Create a mapping of node_exec_id to review
|
||||
review_map = {
|
||||
"node_1_auto": PendingHumanReviewModel(
|
||||
node_exec_id="node_1_auto",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
payload={"data": "node1"},
|
||||
instructions="Review 1",
|
||||
editable=True,
|
||||
status=ReviewStatus.WAITING,
|
||||
review_message=None,
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
),
|
||||
"node_2_manual": PendingHumanReviewModel(
|
||||
node_exec_id="node_2_manual",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
payload={"data": "node2"},
|
||||
instructions="Review 2",
|
||||
editable=True,
|
||||
status=ReviewStatus.WAITING,
|
||||
review_message=None,
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
),
|
||||
"node_3_auto": PendingHumanReviewModel(
|
||||
node_exec_id="node_3_auto",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
payload={"data": "node3"},
|
||||
instructions="Review 3",
|
||||
editable=True,
|
||||
status=ReviewStatus.WAITING,
|
||||
review_message=None,
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
),
|
||||
}
|
||||
|
||||
# Use side_effect to return different reviews based on node_exec_id parameter
|
||||
def mock_get_review_by_id(node_exec_id: str, _user_id: str):
|
||||
return review_map.get(node_exec_id)
|
||||
|
||||
mock_get_reviews_for_user.side_effect = mock_get_review_by_id
|
||||
|
||||
# Mock process_all_reviews - return 3 approved reviews
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
mock_process_all_reviews.return_value = {
|
||||
"node_1_auto": PendingHumanReviewModel(
|
||||
node_exec_id="node_1_auto",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
payload={"data": "node1"},
|
||||
instructions="Review 1",
|
||||
editable=True,
|
||||
status=ReviewStatus.APPROVED,
|
||||
review_message=None,
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=FIXED_NOW,
|
||||
reviewed_at=FIXED_NOW,
|
||||
),
|
||||
"node_2_manual": PendingHumanReviewModel(
|
||||
node_exec_id="node_2_manual",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
payload={"data": "node2"},
|
||||
instructions="Review 2",
|
||||
editable=True,
|
||||
status=ReviewStatus.APPROVED,
|
||||
review_message=None,
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=FIXED_NOW,
|
||||
reviewed_at=FIXED_NOW,
|
||||
),
|
||||
"node_3_auto": PendingHumanReviewModel(
|
||||
node_exec_id="node_3_auto",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
payload={"data": "node3"},
|
||||
instructions="Review 3",
|
||||
editable=True,
|
||||
status=ReviewStatus.APPROVED,
|
||||
review_message=None,
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=FIXED_NOW,
|
||||
reviewed_at=FIXED_NOW,
|
||||
),
|
||||
}
|
||||
|
||||
# Mock get_node_execution
|
||||
mock_get_node_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_node_execution"
|
||||
)
|
||||
|
||||
def mock_get_node(node_exec_id: str):
|
||||
mock_node = mocker.Mock(spec=NodeExecutionResult)
|
||||
mock_node.node_id = f"node_def_{node_exec_id}"
|
||||
return mock_node
|
||||
|
||||
mock_get_node_execution.side_effect = mock_get_node
|
||||
|
||||
# Mock create_auto_approval_record
|
||||
mock_create_auto_approval = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.create_auto_approval_record"
|
||||
)
|
||||
|
||||
# Mock get_graph_execution_meta
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
# Mock has_pending_reviews_for_graph_exec
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
# Mock settings and execution
|
||||
mock_get_settings = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_settings"
|
||||
)
|
||||
mock_get_settings.return_value = GraphSettings(
|
||||
human_in_the_loop_safe_mode=False, sensitive_action_safe_mode=False
|
||||
)
|
||||
|
||||
mocker.patch("backend.api.features.executions.review.routes.add_graph_execution")
|
||||
mocker.patch("backend.api.features.executions.review.routes.get_user_by_id")
|
||||
|
||||
# Request with granular auto-approval:
|
||||
# - node_1_auto: auto_approve_future=True
|
||||
# - node_2_manual: auto_approve_future=False (explicit)
|
||||
# - node_3_auto: auto_approve_future=True
|
||||
request_data = {
|
||||
"reviews": [
|
||||
{
|
||||
"node_exec_id": "node_1_auto",
|
||||
"approved": True,
|
||||
"auto_approve_future": True,
|
||||
},
|
||||
{
|
||||
"node_exec_id": "node_2_manual",
|
||||
"approved": True,
|
||||
"auto_approve_future": False, # Don't auto-approve this one
|
||||
},
|
||||
{
|
||||
"node_exec_id": "node_3_auto",
|
||||
"approved": True,
|
||||
"auto_approve_future": True,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
response = client.post("/api/review/action", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify create_auto_approval_record was called ONLY for reviews with auto_approve_future=True
|
||||
assert mock_create_auto_approval.call_count == 2
|
||||
|
||||
# Check that it was called for node_1 and node_3, but NOT node_2
|
||||
call_args_list = [call.kwargs for call in mock_create_auto_approval.call_args_list]
|
||||
node_ids_with_auto_approval = [args["node_id"] for args in call_args_list]
|
||||
|
||||
assert "node_def_node_1_auto" in node_ids_with_auto_approval
|
||||
assert "node_def_node_3_auto" in node_ids_with_auto_approval
|
||||
assert "node_def_node_2_manual" not in node_ids_with_auto_approval
|
||||
|
||||
@@ -5,13 +5,23 @@ import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, HTTPException, Query, Security, status
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.data.execution import get_graph_execution_meta
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
ExecutionStatus,
|
||||
get_graph_execution_meta,
|
||||
get_node_execution,
|
||||
)
|
||||
from backend.data.graph import get_graph_settings
|
||||
from backend.data.human_review import (
|
||||
create_auto_approval_record,
|
||||
get_pending_review_by_node_exec_id,
|
||||
get_pending_reviews_for_execution,
|
||||
get_pending_reviews_for_user,
|
||||
has_pending_reviews_for_graph_exec,
|
||||
process_all_reviews_for_execution,
|
||||
)
|
||||
from backend.data.model import USER_TIMEZONE_NOT_SET
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor.utils import add_graph_execution
|
||||
|
||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||
@@ -127,17 +137,80 @@ async def process_review_action(
|
||||
detail="At least one review must be provided",
|
||||
)
|
||||
|
||||
# Build review decisions map
|
||||
# Get graph execution ID by looking up all requested reviews
|
||||
# Use direct lookup to avoid pagination issues (can't miss reviews beyond first page)
|
||||
# Also validate that all reviews belong to the same execution
|
||||
matching_review = None
|
||||
graph_exec_ids: set[str] = set()
|
||||
|
||||
for node_exec_id in all_request_node_ids:
|
||||
review = await get_pending_review_by_node_exec_id(node_exec_id, user_id)
|
||||
if not review:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No pending review found for node execution {node_exec_id}",
|
||||
)
|
||||
if matching_review is None:
|
||||
matching_review = review
|
||||
graph_exec_ids.add(review.graph_exec_id)
|
||||
|
||||
# Ensure all reviews belong to the same execution
|
||||
if len(graph_exec_ids) > 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="All reviews in a single request must belong to the same execution.",
|
||||
)
|
||||
|
||||
# Safety check (matching_review should never be None here due to validation above)
|
||||
if matching_review is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal error: No matching review found despite validation",
|
||||
)
|
||||
|
||||
graph_exec_id = matching_review.graph_exec_id
|
||||
|
||||
# Validate execution status before processing reviews
|
||||
graph_exec_meta = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
|
||||
if not graph_exec_meta:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
# Only allow processing reviews if execution is paused for review
|
||||
# or incomplete (partial execution with some reviews already processed)
|
||||
if graph_exec_meta.status not in (
|
||||
ExecutionStatus.REVIEW,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
|
||||
f"Reviews can only be processed when execution is paused (REVIEW status). "
|
||||
f"Current status: {graph_exec_meta.status}",
|
||||
)
|
||||
|
||||
# Build review decisions map and track which reviews requested auto-approval
|
||||
# Auto-approved reviews use original data (no modifications allowed)
|
||||
review_decisions = {}
|
||||
auto_approve_requests = {} # Map node_exec_id -> auto_approve_future flag
|
||||
|
||||
for review in request.reviews:
|
||||
review_status = (
|
||||
ReviewStatus.APPROVED if review.approved else ReviewStatus.REJECTED
|
||||
)
|
||||
# If this review requested auto-approval, don't allow data modifications
|
||||
reviewed_data = None if review.auto_approve_future else review.reviewed_data
|
||||
review_decisions[review.node_exec_id] = (
|
||||
review_status,
|
||||
review.reviewed_data,
|
||||
reviewed_data,
|
||||
review.message,
|
||||
)
|
||||
auto_approve_requests[review.node_exec_id] = review.auto_approve_future
|
||||
|
||||
# Process all reviews
|
||||
updated_reviews = await process_all_reviews_for_execution(
|
||||
@@ -145,6 +218,32 @@ async def process_review_action(
|
||||
review_decisions=review_decisions,
|
||||
)
|
||||
|
||||
# Create auto-approval records for approved reviews that requested it
|
||||
# Note: Processing sequentially to avoid event loop issues in tests
|
||||
for node_exec_id, review_result in updated_reviews.items():
|
||||
# Only create auto-approval if:
|
||||
# 1. This review was approved
|
||||
# 2. The review requested auto-approval
|
||||
if review_result.status == ReviewStatus.APPROVED and auto_approve_requests.get(
|
||||
node_exec_id, False
|
||||
):
|
||||
try:
|
||||
node_exec = await get_node_execution(node_exec_id)
|
||||
if node_exec:
|
||||
await create_auto_approval_record(
|
||||
user_id=user_id,
|
||||
graph_exec_id=review_result.graph_exec_id,
|
||||
graph_id=review_result.graph_id,
|
||||
graph_version=review_result.graph_version,
|
||||
node_id=node_exec.node_id,
|
||||
payload=review_result.payload,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create auto-approval record for {node_exec_id}",
|
||||
exc_info=e,
|
||||
)
|
||||
|
||||
# Count results
|
||||
approved_count = sum(
|
||||
1
|
||||
@@ -157,22 +256,37 @@ async def process_review_action(
|
||||
if review.status == ReviewStatus.REJECTED
|
||||
)
|
||||
|
||||
# Resume execution if we processed some reviews
|
||||
# Resume execution only if ALL pending reviews for this execution have been processed
|
||||
if updated_reviews:
|
||||
# Get graph execution ID from any processed review
|
||||
first_review = next(iter(updated_reviews.values()))
|
||||
graph_exec_id = first_review.graph_exec_id
|
||||
|
||||
# Check if any pending reviews remain for this execution
|
||||
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
||||
|
||||
if not still_has_pending:
|
||||
# Resume execution
|
||||
# Get the graph_id from any processed review
|
||||
first_review = next(iter(updated_reviews.values()))
|
||||
|
||||
try:
|
||||
# Fetch user and settings to build complete execution context
|
||||
user = await get_user_by_id(user_id)
|
||||
settings = await get_graph_settings(
|
||||
user_id=user_id, graph_id=first_review.graph_id
|
||||
)
|
||||
|
||||
# Preserve user's timezone preference when resuming execution
|
||||
user_timezone = (
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||
user_timezone=user_timezone,
|
||||
)
|
||||
|
||||
await add_graph_execution(
|
||||
graph_id=first_review.graph_id,
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
logger.info(f"Resumed execution {graph_exec_id}")
|
||||
except Exception as e:
|
||||
|
||||
@@ -6,6 +6,7 @@ Handles generation and storage of OpenAI embeddings for all content types
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
@@ -21,6 +22,11 @@ from backend.util.json import dumps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Context variable to track errors logged in the current task/operation
|
||||
# This prevents spamming the same error multiple times when processing batches
|
||||
_logged_errors: contextvars.ContextVar[set[str]] = contextvars.ContextVar(
|
||||
"_logged_errors"
|
||||
)
|
||||
|
||||
# OpenAI embedding model configuration
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
@@ -31,6 +37,42 @@ EMBEDDING_DIM = 1536
|
||||
EMBEDDING_MAX_TOKENS = 8191
|
||||
|
||||
|
||||
def log_once_per_task(error_key: str, log_fn, message: str, **kwargs) -> bool:
|
||||
"""
|
||||
Log an error/warning only once per task/operation to avoid log spam.
|
||||
|
||||
Uses contextvars to track what has been logged in the current async context.
|
||||
Useful when processing batches where the same error might occur for many items.
|
||||
|
||||
Args:
|
||||
error_key: Unique identifier for this error type
|
||||
log_fn: Logger function to call (e.g., logger.error, logger.warning)
|
||||
message: Message to log
|
||||
**kwargs: Additional arguments to pass to log_fn
|
||||
|
||||
Returns:
|
||||
True if the message was logged, False if it was suppressed (already logged)
|
||||
|
||||
Example:
|
||||
log_once_per_task("missing_api_key", logger.error, "API key not set")
|
||||
"""
|
||||
# Get current logged errors, or create a new set if this is the first call in this context
|
||||
logged = _logged_errors.get(None)
|
||||
if logged is None:
|
||||
logged = set()
|
||||
_logged_errors.set(logged)
|
||||
|
||||
if error_key in logged:
|
||||
return False
|
||||
|
||||
# Log the message with a note that it will only appear once
|
||||
log_fn(f"{message} (This message will only be shown once per task.)", **kwargs)
|
||||
|
||||
# Mark as logged
|
||||
logged.add(error_key)
|
||||
return True
|
||||
|
||||
|
||||
def build_searchable_text(
|
||||
name: str,
|
||||
description: str,
|
||||
@@ -73,7 +115,11 @@ async def generate_embedding(text: str) -> list[float] | None:
|
||||
try:
|
||||
client = get_openai_client()
|
||||
if not client:
|
||||
logger.error("openai_internal_api_key not set, cannot generate embedding")
|
||||
log_once_per_task(
|
||||
"openai_api_key_missing",
|
||||
logger.error,
|
||||
"openai_internal_api_key not set, cannot generate embeddings",
|
||||
)
|
||||
return None
|
||||
|
||||
# Truncate text to token limit using tiktoken
|
||||
@@ -290,7 +336,12 @@ async def ensure_embedding(
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(f"Could not generate embedding for version {version_id}")
|
||||
log_once_per_task(
|
||||
"embedding_generation_failed",
|
||||
logger.warning,
|
||||
"Could not generate embeddings (missing API key or service unavailable). "
|
||||
"Embedding generation is disabled for this task.",
|
||||
)
|
||||
return False
|
||||
|
||||
# Store the embedding with metadata using new function
|
||||
@@ -609,8 +660,11 @@ async def ensure_content_embedding(
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(
|
||||
f"Could not generate embedding for {content_type}:{content_id}"
|
||||
log_once_per_task(
|
||||
"embedding_generation_failed",
|
||||
logger.warning,
|
||||
"Could not generate embeddings (missing API key or service unavailable). "
|
||||
"Embedding generation is disabled for this task.",
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@@ -116,6 +116,7 @@ class PrintToConsoleBlock(Block):
|
||||
input_schema=PrintToConsoleBlock.Input,
|
||||
output_schema=PrintToConsoleBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
is_sensitive_action=True,
|
||||
test_output=[
|
||||
("output", "Hello, World!"),
|
||||
("status", "printed"),
|
||||
|
||||
@@ -1,659 +0,0 @@
|
||||
import json
|
||||
import shlex
|
||||
import uuid
|
||||
from typing import Literal, Optional
|
||||
|
||||
from e2b import AsyncSandbox as BaseAsyncSandbox
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
|
||||
class ClaudeCodeExecutionError(Exception):
|
||||
"""Exception raised when Claude Code execution fails.
|
||||
|
||||
Carries the sandbox_id so it can be returned to the user for cleanup
|
||||
when dispose_sandbox=False.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, sandbox_id: str = ""):
|
||||
super().__init__(message)
|
||||
self.sandbox_id = sandbox_id
|
||||
|
||||
|
||||
# Test credentials for E2B
|
||||
TEST_E2B_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="e2b",
|
||||
api_key=SecretStr("mock-e2b-api-key"),
|
||||
title="Mock E2B API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_E2B_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_E2B_CREDENTIALS.provider,
|
||||
"id": TEST_E2B_CREDENTIALS.id,
|
||||
"type": TEST_E2B_CREDENTIALS.type,
|
||||
"title": TEST_E2B_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
# Test credentials for Anthropic
|
||||
TEST_ANTHROPIC_CREDENTIALS = APIKeyCredentials(
|
||||
id="2e568a2b-b2ea-475a-8564-9a676bf31c56",
|
||||
provider="anthropic",
|
||||
api_key=SecretStr("mock-anthropic-api-key"),
|
||||
title="Mock Anthropic API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_ANTHROPIC_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_ANTHROPIC_CREDENTIALS.provider,
|
||||
"id": TEST_ANTHROPIC_CREDENTIALS.id,
|
||||
"type": TEST_ANTHROPIC_CREDENTIALS.type,
|
||||
"title": TEST_ANTHROPIC_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class ClaudeCodeBlock(Block):
|
||||
"""
|
||||
Execute tasks using Claude Code (Anthropic's AI coding assistant) in an E2B sandbox.
|
||||
|
||||
Claude Code can create files, install tools, run commands, and perform complex
|
||||
coding tasks autonomously within a secure sandbox environment.
|
||||
"""
|
||||
|
||||
# Use base template - we'll install Claude Code ourselves for latest version
|
||||
DEFAULT_TEMPLATE = "base"
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
e2b_credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description=(
|
||||
"API key for the E2B platform to create the sandbox. "
|
||||
"Get one on the [e2b website](https://e2b.dev/docs)"
|
||||
),
|
||||
)
|
||||
|
||||
anthropic_credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.ANTHROPIC], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description=(
|
||||
"API key for Anthropic to power Claude Code. "
|
||||
"Get one at [Anthropic's website](https://console.anthropic.com)"
|
||||
),
|
||||
)
|
||||
|
||||
prompt: str = SchemaField(
|
||||
description=(
|
||||
"The task or instruction for Claude Code to execute. "
|
||||
"Claude Code can create files, install packages, run commands, "
|
||||
"and perform complex coding tasks."
|
||||
),
|
||||
placeholder="Create a hello world index.html file",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
timeout: int = SchemaField(
|
||||
description=(
|
||||
"Sandbox timeout in seconds. Claude Code tasks can take "
|
||||
"a while, so set this appropriately for your task complexity. "
|
||||
"Note: This only applies when creating a new sandbox. "
|
||||
"When reconnecting to an existing sandbox via sandbox_id, "
|
||||
"the original timeout is retained."
|
||||
),
|
||||
default=300, # 5 minutes default
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
setup_commands: list[str] = SchemaField(
|
||||
description=(
|
||||
"Optional shell commands to run before executing Claude Code. "
|
||||
"Useful for installing dependencies or setting up the environment."
|
||||
),
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
working_directory: str = SchemaField(
|
||||
description="Working directory for Claude Code to operate in.",
|
||||
default="/home/user",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Session/continuation support
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID to resume a previous conversation. "
|
||||
"Leave empty for a new conversation. "
|
||||
"Use the session_id from a previous run to continue that conversation."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
sandbox_id: str = SchemaField(
|
||||
description=(
|
||||
"Sandbox ID to reconnect to an existing sandbox. "
|
||||
"Required when resuming a session (along with session_id). "
|
||||
"Use the sandbox_id from a previous run where dispose_sandbox was False."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
conversation_history: str = SchemaField(
|
||||
description=(
|
||||
"Previous conversation history to continue from. "
|
||||
"Use this to restore context on a fresh sandbox if the previous one timed out. "
|
||||
"Pass the conversation_history output from a previous run."
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
dispose_sandbox: bool = SchemaField(
|
||||
description=(
|
||||
"Whether to dispose of the sandbox immediately after execution. "
|
||||
"Set to False if you want to continue the conversation later "
|
||||
"(you'll need both sandbox_id and session_id from the output)."
|
||||
),
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class FileOutput(BaseModel):
|
||||
"""A file extracted from the sandbox."""
|
||||
|
||||
path: str
|
||||
relative_path: str # Path relative to working directory (for GitHub, etc.)
|
||||
name: str
|
||||
content: str
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
response: str = SchemaField(
|
||||
description="The output/response from Claude Code execution"
|
||||
)
|
||||
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
|
||||
description=(
|
||||
"List of text files created/modified by Claude Code during this execution. "
|
||||
"Each file has 'path', 'relative_path', 'name', and 'content' fields."
|
||||
)
|
||||
)
|
||||
conversation_history: str = SchemaField(
|
||||
description=(
|
||||
"Full conversation history including this turn. "
|
||||
"Pass this to conversation_history input to continue on a fresh sandbox "
|
||||
"if the previous sandbox timed out."
|
||||
)
|
||||
)
|
||||
session_id: str = SchemaField(
|
||||
description=(
|
||||
"Session ID for this conversation. "
|
||||
"Pass this back along with sandbox_id to continue the conversation."
|
||||
)
|
||||
)
|
||||
sandbox_id: Optional[str] = SchemaField(
|
||||
description=(
|
||||
"ID of the sandbox instance. "
|
||||
"Pass this back along with session_id to continue the conversation. "
|
||||
"This is None if dispose_sandbox was True (sandbox was disposed)."
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4e34f4a5-9b89-4326-ba77-2dd6750b7194",
|
||||
description=(
|
||||
"Execute tasks using Claude Code in an E2B sandbox. "
|
||||
"Claude Code can create files, install tools, run commands, "
|
||||
"and perform complex coding tasks autonomously."
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.AI},
|
||||
input_schema=ClaudeCodeBlock.Input,
|
||||
output_schema=ClaudeCodeBlock.Output,
|
||||
test_credentials={
|
||||
"e2b_credentials": TEST_E2B_CREDENTIALS,
|
||||
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS,
|
||||
},
|
||||
test_input={
|
||||
"e2b_credentials": TEST_E2B_CREDENTIALS_INPUT,
|
||||
"anthropic_credentials": TEST_ANTHROPIC_CREDENTIALS_INPUT,
|
||||
"prompt": "Create a hello world HTML file",
|
||||
"timeout": 300,
|
||||
"setup_commands": [],
|
||||
"working_directory": "/home/user",
|
||||
"session_id": "",
|
||||
"sandbox_id": "",
|
||||
"conversation_history": "",
|
||||
"dispose_sandbox": True,
|
||||
},
|
||||
test_output=[
|
||||
("response", "Created index.html with hello world content"),
|
||||
(
|
||||
"files",
|
||||
[
|
||||
{
|
||||
"path": "/home/user/index.html",
|
||||
"relative_path": "index.html",
|
||||
"name": "index.html",
|
||||
"content": "<html>Hello World</html>",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"conversation_history",
|
||||
"User: Create a hello world HTML file\n"
|
||||
"Claude: Created index.html with hello world content",
|
||||
),
|
||||
("session_id", str),
|
||||
("sandbox_id", None), # None because dispose_sandbox=True in test_input
|
||||
],
|
||||
test_mock={
|
||||
"execute_claude_code": lambda *args, **kwargs: (
|
||||
"Created index.html with hello world content", # response
|
||||
[
|
||||
ClaudeCodeBlock.FileOutput(
|
||||
path="/home/user/index.html",
|
||||
relative_path="index.html",
|
||||
name="index.html",
|
||||
content="<html>Hello World</html>",
|
||||
)
|
||||
], # files
|
||||
"User: Create a hello world HTML file\n"
|
||||
"Claude: Created index.html with hello world content", # conversation_history
|
||||
"test-session-id", # session_id
|
||||
"sandbox_id", # sandbox_id
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_claude_code(
|
||||
self,
|
||||
e2b_api_key: str,
|
||||
anthropic_api_key: str,
|
||||
prompt: str,
|
||||
timeout: int,
|
||||
setup_commands: list[str],
|
||||
working_directory: str,
|
||||
session_id: str,
|
||||
existing_sandbox_id: str,
|
||||
conversation_history: str,
|
||||
dispose_sandbox: bool,
|
||||
) -> tuple[str, list["ClaudeCodeBlock.FileOutput"], str, str, str]:
|
||||
"""
|
||||
Execute Claude Code in an E2B sandbox.
|
||||
|
||||
Returns:
|
||||
Tuple of (response, files, conversation_history, session_id, sandbox_id)
|
||||
"""
|
||||
|
||||
# Validate that sandbox_id is provided when resuming a session
|
||||
if session_id and not existing_sandbox_id:
|
||||
raise ValueError(
|
||||
"sandbox_id is required when resuming a session with session_id. "
|
||||
"The session state is stored in the original sandbox. "
|
||||
"If the sandbox has timed out, use conversation_history instead "
|
||||
"to restore context on a fresh sandbox."
|
||||
)
|
||||
|
||||
sandbox = None
|
||||
sandbox_id = ""
|
||||
|
||||
try:
|
||||
# Either reconnect to existing sandbox or create a new one
|
||||
if existing_sandbox_id:
|
||||
# Reconnect to existing sandbox for conversation continuation
|
||||
sandbox = await BaseAsyncSandbox.connect(
|
||||
sandbox_id=existing_sandbox_id,
|
||||
api_key=e2b_api_key,
|
||||
)
|
||||
else:
|
||||
# Create new sandbox
|
||||
sandbox = await BaseAsyncSandbox.create(
|
||||
template=self.DEFAULT_TEMPLATE,
|
||||
api_key=e2b_api_key,
|
||||
timeout=timeout,
|
||||
envs={"ANTHROPIC_API_KEY": anthropic_api_key},
|
||||
)
|
||||
|
||||
# Install Claude Code from npm (ensures we get the latest version)
|
||||
install_result = await sandbox.commands.run(
|
||||
"npm install -g @anthropic-ai/claude-code@latest",
|
||||
timeout=120, # 2 min timeout for install
|
||||
)
|
||||
if install_result.exit_code != 0:
|
||||
raise Exception(
|
||||
f"Failed to install Claude Code: {install_result.stderr}"
|
||||
)
|
||||
|
||||
# Run any user-provided setup commands
|
||||
for cmd in setup_commands:
|
||||
setup_result = await sandbox.commands.run(cmd)
|
||||
if setup_result.exit_code != 0:
|
||||
raise Exception(
|
||||
f"Setup command failed: {cmd}\n"
|
||||
f"Exit code: {setup_result.exit_code}\n"
|
||||
f"Stdout: {setup_result.stdout}\n"
|
||||
f"Stderr: {setup_result.stderr}"
|
||||
)
|
||||
|
||||
# Capture sandbox_id immediately after creation/connection
|
||||
# so it's available for error recovery if dispose_sandbox=False
|
||||
sandbox_id = sandbox.sandbox_id
|
||||
|
||||
# Generate or use provided session ID
|
||||
current_session_id = session_id if session_id else str(uuid.uuid4())
|
||||
|
||||
# Build base Claude flags
|
||||
base_flags = "-p --dangerously-skip-permissions --output-format json"
|
||||
|
||||
# Add conversation history context if provided (for fresh sandbox continuation)
|
||||
history_flag = ""
|
||||
if conversation_history and not session_id:
|
||||
# Inject previous conversation as context via system prompt
|
||||
# Use consistent escaping via _escape_prompt helper
|
||||
escaped_history = self._escape_prompt(
|
||||
f"Previous conversation context: {conversation_history}"
|
||||
)
|
||||
history_flag = f" --append-system-prompt {escaped_history}"
|
||||
|
||||
# Build Claude command based on whether we're resuming or starting new
|
||||
# Use shlex.quote for working_directory and session IDs to prevent injection
|
||||
safe_working_dir = shlex.quote(working_directory)
|
||||
if session_id:
|
||||
# Resuming existing session (sandbox still alive)
|
||||
safe_session_id = shlex.quote(session_id)
|
||||
claude_command = (
|
||||
f"cd {safe_working_dir} && "
|
||||
f"echo {self._escape_prompt(prompt)} | "
|
||||
f"claude --resume {safe_session_id} {base_flags}"
|
||||
)
|
||||
else:
|
||||
# New session with specific ID
|
||||
safe_current_session_id = shlex.quote(current_session_id)
|
||||
claude_command = (
|
||||
f"cd {safe_working_dir} && "
|
||||
f"echo {self._escape_prompt(prompt)} | "
|
||||
f"claude --session-id {safe_current_session_id} {base_flags}{history_flag}"
|
||||
)
|
||||
|
||||
# Capture timestamp before running Claude Code to filter files later
|
||||
# Capture timestamp 1 second in the past to avoid race condition with file creation
|
||||
timestamp_result = await sandbox.commands.run(
|
||||
"date -u -d '1 second ago' +%Y-%m-%dT%H:%M:%S"
|
||||
)
|
||||
if timestamp_result.exit_code != 0:
|
||||
raise RuntimeError(
|
||||
f"Failed to capture timestamp: {timestamp_result.stderr}"
|
||||
)
|
||||
start_timestamp = (
|
||||
timestamp_result.stdout.strip() if timestamp_result.stdout else None
|
||||
)
|
||||
|
||||
result = await sandbox.commands.run(
|
||||
claude_command,
|
||||
timeout=0, # No command timeout - let sandbox timeout handle it
|
||||
)
|
||||
|
||||
# Check for command failure
|
||||
if result.exit_code != 0:
|
||||
error_msg = result.stderr or result.stdout or "Unknown error"
|
||||
raise Exception(
|
||||
f"Claude Code command failed with exit code {result.exit_code}:\n"
|
||||
f"{error_msg}"
|
||||
)
|
||||
|
||||
raw_output = result.stdout or ""
|
||||
|
||||
# Parse JSON output to extract response and build conversation history
|
||||
response = ""
|
||||
new_conversation_history = conversation_history or ""
|
||||
|
||||
try:
|
||||
# The JSON output contains the result
|
||||
output_data = json.loads(raw_output)
|
||||
response = output_data.get("result", raw_output)
|
||||
|
||||
# Build conversation history entry
|
||||
turn_entry = f"User: {prompt}\nClaude: {response}"
|
||||
if new_conversation_history:
|
||||
new_conversation_history = (
|
||||
f"{new_conversation_history}\n\n{turn_entry}"
|
||||
)
|
||||
else:
|
||||
new_conversation_history = turn_entry
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# If not valid JSON, use raw output
|
||||
response = raw_output
|
||||
turn_entry = f"User: {prompt}\nClaude: {response}"
|
||||
if new_conversation_history:
|
||||
new_conversation_history = (
|
||||
f"{new_conversation_history}\n\n{turn_entry}"
|
||||
)
|
||||
else:
|
||||
new_conversation_history = turn_entry
|
||||
|
||||
# Extract files created/modified during this run
|
||||
files = await self._extract_files(
|
||||
sandbox, working_directory, start_timestamp
|
||||
)
|
||||
|
||||
return (
|
||||
response,
|
||||
files,
|
||||
new_conversation_history,
|
||||
current_session_id,
|
||||
sandbox_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Wrap exception with sandbox_id so caller can access/cleanup
|
||||
# the preserved sandbox when dispose_sandbox=False
|
||||
raise ClaudeCodeExecutionError(str(e), sandbox_id) from e
|
||||
|
||||
finally:
|
||||
if dispose_sandbox and sandbox:
|
||||
await sandbox.kill()
|
||||
|
||||
async def _extract_files(
|
||||
self,
|
||||
sandbox: BaseAsyncSandbox,
|
||||
working_directory: str,
|
||||
since_timestamp: str | None = None,
|
||||
) -> list["ClaudeCodeBlock.FileOutput"]:
|
||||
"""
|
||||
Extract text files created/modified during this Claude Code execution.
|
||||
|
||||
Args:
|
||||
sandbox: The E2B sandbox instance
|
||||
working_directory: Directory to search for files
|
||||
since_timestamp: ISO timestamp - only return files modified after this time
|
||||
|
||||
Returns:
|
||||
List of FileOutput objects with path, relative_path, name, and content
|
||||
"""
|
||||
files: list[ClaudeCodeBlock.FileOutput] = []
|
||||
|
||||
# Text file extensions we can safely read as text
|
||||
text_extensions = {
|
||||
".txt",
|
||||
".md",
|
||||
".html",
|
||||
".htm",
|
||||
".css",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
".py",
|
||||
".rb",
|
||||
".php",
|
||||
".java",
|
||||
".c",
|
||||
".cpp",
|
||||
".h",
|
||||
".hpp",
|
||||
".cs",
|
||||
".go",
|
||||
".rs",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".sql",
|
||||
".graphql",
|
||||
".env",
|
||||
".gitignore",
|
||||
".dockerfile",
|
||||
"Dockerfile",
|
||||
".vue",
|
||||
".svelte",
|
||||
".astro",
|
||||
".mdx",
|
||||
".rst",
|
||||
".tex",
|
||||
".csv",
|
||||
".log",
|
||||
}
|
||||
|
||||
try:
|
||||
# List files recursively using find command
|
||||
# Exclude node_modules and .git directories, but allow hidden files
|
||||
# like .env and .gitignore (they're filtered by text_extensions later)
|
||||
# Filter by timestamp to only get files created/modified during this run
|
||||
safe_working_dir = shlex.quote(working_directory)
|
||||
timestamp_filter = ""
|
||||
if since_timestamp:
|
||||
timestamp_filter = f"-newermt {shlex.quote(since_timestamp)} "
|
||||
find_result = await sandbox.commands.run(
|
||||
f"find {safe_working_dir} -type f "
|
||||
f"{timestamp_filter}"
|
||||
f"-not -path '*/node_modules/*' "
|
||||
f"-not -path '*/.git/*' "
|
||||
f"2>/dev/null"
|
||||
)
|
||||
|
||||
if find_result.stdout:
|
||||
for file_path in find_result.stdout.strip().split("\n"):
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
# Check if it's a text file we can read
|
||||
is_text = any(
|
||||
file_path.endswith(ext) for ext in text_extensions
|
||||
) or file_path.endswith("Dockerfile")
|
||||
|
||||
if is_text:
|
||||
try:
|
||||
content = await sandbox.files.read(file_path)
|
||||
# Handle bytes or string
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8", errors="replace")
|
||||
|
||||
# Extract filename from path
|
||||
file_name = file_path.split("/")[-1]
|
||||
|
||||
# Calculate relative path by stripping working directory
|
||||
relative_path = file_path
|
||||
if file_path.startswith(working_directory):
|
||||
relative_path = file_path[len(working_directory) :]
|
||||
# Remove leading slash if present
|
||||
if relative_path.startswith("/"):
|
||||
relative_path = relative_path[1:]
|
||||
|
||||
files.append(
|
||||
ClaudeCodeBlock.FileOutput(
|
||||
path=file_path,
|
||||
relative_path=relative_path,
|
||||
name=file_name,
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Skip files that can't be read
|
||||
pass
|
||||
|
||||
except Exception:
|
||||
# If file extraction fails, return empty results
|
||||
pass
|
||||
|
||||
return files
|
||||
|
||||
def _escape_prompt(self, prompt: str) -> str:
|
||||
"""Escape the prompt for safe shell execution."""
|
||||
# Use single quotes and escape any single quotes in the prompt
|
||||
escaped = prompt.replace("'", "'\"'\"'")
|
||||
return f"'{escaped}'"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
e2b_credentials: APIKeyCredentials,
|
||||
anthropic_credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
(
|
||||
response,
|
||||
files,
|
||||
conversation_history,
|
||||
session_id,
|
||||
sandbox_id,
|
||||
) = await self.execute_claude_code(
|
||||
e2b_api_key=e2b_credentials.api_key.get_secret_value(),
|
||||
anthropic_api_key=anthropic_credentials.api_key.get_secret_value(),
|
||||
prompt=input_data.prompt,
|
||||
timeout=input_data.timeout,
|
||||
setup_commands=input_data.setup_commands,
|
||||
working_directory=input_data.working_directory,
|
||||
session_id=input_data.session_id,
|
||||
existing_sandbox_id=input_data.sandbox_id,
|
||||
conversation_history=input_data.conversation_history,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
)
|
||||
|
||||
yield "response", response
|
||||
# Always yield files (empty list if none) to match Output schema
|
||||
yield "files", [f.model_dump() for f in files]
|
||||
# Always yield conversation_history so user can restore context on fresh sandbox
|
||||
yield "conversation_history", conversation_history
|
||||
# Always yield session_id so user can continue conversation
|
||||
yield "session_id", session_id
|
||||
# Always yield sandbox_id (None if disposed) to match Output schema
|
||||
yield "sandbox_id", sandbox_id if not input_data.dispose_sandbox else None
|
||||
|
||||
except ClaudeCodeExecutionError as e:
|
||||
yield "error", str(e)
|
||||
# If sandbox was preserved (dispose_sandbox=False), yield sandbox_id
|
||||
# so user can reconnect to or clean up the orphaned sandbox
|
||||
if not input_data.dispose_sandbox and e.sandbox_id:
|
||||
yield "sandbox_id", e.sandbox_id
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -9,7 +9,7 @@ from typing import Any, Optional
|
||||
from prisma.enums import ReviewStatus
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.execution import ExecutionContext, ExecutionStatus
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.human_review import ReviewResult
|
||||
from backend.executor.manager import async_update_node_execution_status
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
@@ -28,6 +28,11 @@ class ReviewDecision(BaseModel):
|
||||
class HITLReviewHelper:
|
||||
"""Helper class for Human-In-The-Loop review operations."""
|
||||
|
||||
@staticmethod
|
||||
async def check_approval(**kwargs) -> Optional[ReviewResult]:
|
||||
"""Check if there's an existing approval for this node execution."""
|
||||
return await get_database_manager_async_client().check_approval(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
|
||||
"""Create or retrieve a human review from the database."""
|
||||
@@ -55,11 +60,11 @@ class HITLReviewHelper:
|
||||
async def _handle_review_request(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
) -> Optional[ReviewResult]:
|
||||
@@ -69,11 +74,11 @@ class HITLReviewHelper:
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_id: ID of the node in the graph definition
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
graph_version: Version of the graph
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
|
||||
@@ -83,15 +88,41 @@ class HITLReviewHelper:
|
||||
Raises:
|
||||
Exception: If review creation or status update fails
|
||||
"""
|
||||
# Skip review if safe mode is disabled - return auto-approved result
|
||||
if not execution_context.human_in_the_loop_safe_mode:
|
||||
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
|
||||
# are handled by the caller:
|
||||
# - HITL blocks check human_in_the_loop_safe_mode in their run() method
|
||||
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
|
||||
# This function only handles checking for existing approvals.
|
||||
|
||||
# Check if this node has already been approved (normal or auto-approval)
|
||||
if approval_result := await HITLReviewHelper.check_approval(
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
node_id=node_id,
|
||||
user_id=user_id,
|
||||
input_data=input_data,
|
||||
):
|
||||
logger.info(
|
||||
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
|
||||
f"Block {block_name} skipping review for node {node_exec_id} - "
|
||||
f"found existing approval"
|
||||
)
|
||||
# Return a new ReviewResult with the current node_exec_id but approved status
|
||||
# For auto-approvals, always use current input_data
|
||||
# For normal approvals, use approval_result.data unless it's None
|
||||
is_auto_approval = approval_result.node_exec_id != node_exec_id
|
||||
approved_data = (
|
||||
input_data
|
||||
if is_auto_approval
|
||||
else (
|
||||
approval_result.data
|
||||
if approval_result.data is not None
|
||||
else input_data
|
||||
)
|
||||
)
|
||||
return ReviewResult(
|
||||
data=input_data,
|
||||
data=approved_data,
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="Auto-approved (safe mode disabled)",
|
||||
message=approval_result.message,
|
||||
processed=True,
|
||||
node_exec_id=node_exec_id,
|
||||
)
|
||||
@@ -129,11 +160,11 @@ class HITLReviewHelper:
|
||||
async def handle_review_decision(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
) -> Optional[ReviewDecision]:
|
||||
@@ -143,11 +174,11 @@ class HITLReviewHelper:
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_id: ID of the node in the graph definition
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
graph_version: Version of the graph
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
|
||||
@@ -158,11 +189,11 @@ class HITLReviewHelper:
|
||||
review_result = await HITLReviewHelper._handle_review_request(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=block_name,
|
||||
editable=editable,
|
||||
)
|
||||
|
||||
@@ -97,6 +97,7 @@ class HumanInTheLoopBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
@@ -115,11 +116,11 @@ class HumanInTheLoopBlock(Block):
|
||||
decision = await self.handle_review_decision(
|
||||
input_data=input_data.data,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=self.name,
|
||||
editable=input_data.editable,
|
||||
)
|
||||
|
||||
@@ -441,6 +441,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
static_output: bool = False,
|
||||
block_type: BlockType = BlockType.STANDARD,
|
||||
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
|
||||
is_sensitive_action: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the block with the given schema.
|
||||
@@ -473,8 +474,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.static_output = static_output
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.is_sensitive_action = is_sensitive_action
|
||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||
self.is_sensitive_action: bool = False
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
@@ -622,6 +623,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
input_data: BlockInput,
|
||||
*,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
@@ -648,11 +650,11 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
decision = await HITLReviewHelper.handle_review_decision(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=self.name,
|
||||
editable=True,
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ from backend.api.features.executions.review.model import (
|
||||
PendingHumanReviewModel,
|
||||
SafeJsonData,
|
||||
)
|
||||
from backend.data.execution import get_graph_execution_meta
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -32,6 +33,125 @@ class ReviewResult(BaseModel):
|
||||
node_exec_id: str
|
||||
|
||||
|
||||
def get_auto_approve_key(graph_exec_id: str, node_id: str) -> str:
|
||||
"""Generate the special nodeExecId key for auto-approval records."""
|
||||
return f"auto_approve_{graph_exec_id}_{node_id}"
|
||||
|
||||
|
||||
async def check_approval(
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
input_data: SafeJsonData | None = None,
|
||||
) -> Optional[ReviewResult]:
|
||||
"""
|
||||
Check if there's an existing approval for this node execution.
|
||||
|
||||
Checks both:
|
||||
1. Normal approval by node_exec_id (previous run of the same node execution)
|
||||
2. Auto-approval by special key pattern "auto_approve_{graph_exec_id}_{node_id}"
|
||||
|
||||
Args:
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
node_id: ID of the node definition (not execution)
|
||||
user_id: ID of the user (for data isolation)
|
||||
input_data: Current input data (used for auto-approvals to avoid stale data)
|
||||
|
||||
Returns:
|
||||
ReviewResult if approval found (either normal or auto), None otherwise
|
||||
"""
|
||||
auto_approve_key = get_auto_approve_key(graph_exec_id, node_id)
|
||||
|
||||
# Check for either normal approval or auto-approval in a single query
|
||||
existing_review = await PendingHumanReview.prisma().find_first(
|
||||
where={
|
||||
"OR": [
|
||||
{"nodeExecId": node_exec_id},
|
||||
{"nodeExecId": auto_approve_key},
|
||||
],
|
||||
"status": ReviewStatus.APPROVED,
|
||||
"userId": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
if existing_review:
|
||||
is_auto_approval = existing_review.nodeExecId == auto_approve_key
|
||||
logger.info(
|
||||
f"Found {'auto-' if is_auto_approval else ''}approval for node {node_id} "
|
||||
f"(exec: {node_exec_id}) in execution {graph_exec_id}"
|
||||
)
|
||||
# For auto-approvals, use current input_data to avoid replaying stale payload
|
||||
# For normal approvals, use the stored payload (which may have been edited)
|
||||
return ReviewResult(
|
||||
data=(
|
||||
input_data
|
||||
if is_auto_approval and input_data is not None
|
||||
else existing_review.payload
|
||||
),
|
||||
status=ReviewStatus.APPROVED,
|
||||
message=(
|
||||
"Auto-approved (user approved all future actions for this node)"
|
||||
if is_auto_approval
|
||||
else existing_review.reviewMessage or ""
|
||||
),
|
||||
processed=True,
|
||||
node_exec_id=existing_review.nodeExecId,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def create_auto_approval_record(
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_id: str,
|
||||
payload: SafeJsonData,
|
||||
) -> None:
|
||||
"""
|
||||
Create an auto-approval record for a node in this execution.
|
||||
|
||||
This is stored as a PendingHumanReview with a special nodeExecId pattern
|
||||
and status=APPROVED, so future executions of the same node can skip review.
|
||||
|
||||
Raises:
|
||||
ValueError: If the graph execution doesn't belong to the user
|
||||
"""
|
||||
# Validate that the graph execution belongs to this user (defense in depth)
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec:
|
||||
raise ValueError(
|
||||
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
|
||||
)
|
||||
|
||||
auto_approve_key = get_auto_approve_key(graph_exec_id, node_id)
|
||||
|
||||
await PendingHumanReview.prisma().upsert(
|
||||
where={"nodeExecId": auto_approve_key},
|
||||
data={
|
||||
"create": {
|
||||
"nodeExecId": auto_approve_key,
|
||||
"userId": user_id,
|
||||
"graphExecId": graph_exec_id,
|
||||
"graphId": graph_id,
|
||||
"graphVersion": graph_version,
|
||||
"payload": SafeJson(payload),
|
||||
"instructions": "Auto-approval record",
|
||||
"editable": False,
|
||||
"status": ReviewStatus.APPROVED,
|
||||
"processed": True,
|
||||
"reviewedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
"update": {}, # Already exists, no update needed
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_or_create_human_review(
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
@@ -108,6 +228,33 @@ async def get_or_create_human_review(
|
||||
)
|
||||
|
||||
|
||||
async def get_pending_review_by_node_exec_id(
|
||||
node_exec_id: str, user_id: str
|
||||
) -> Optional["PendingHumanReviewModel"]:
|
||||
"""
|
||||
Get a pending review by its node execution ID.
|
||||
|
||||
Args:
|
||||
node_exec_id: The node execution ID to look up
|
||||
user_id: User ID for authorization (only returns if review belongs to this user)
|
||||
|
||||
Returns:
|
||||
The pending review if found and belongs to user, None otherwise
|
||||
"""
|
||||
review = await PendingHumanReview.prisma().find_first(
|
||||
where={
|
||||
"nodeExecId": node_exec_id,
|
||||
"userId": user_id,
|
||||
"status": ReviewStatus.WAITING,
|
||||
}
|
||||
)
|
||||
|
||||
if not review:
|
||||
return None
|
||||
|
||||
return PendingHumanReviewModel.from_db(review)
|
||||
|
||||
|
||||
async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool:
|
||||
"""
|
||||
Check if a graph execution has any pending reviews.
|
||||
@@ -256,3 +403,44 @@ async def update_review_processed_status(node_exec_id: str, processed: bool) ->
|
||||
await PendingHumanReview.prisma().update(
|
||||
where={"nodeExecId": node_exec_id}, data={"processed": processed}
|
||||
)
|
||||
|
||||
|
||||
async def cancel_pending_reviews_for_execution(graph_exec_id: str, user_id: str) -> int:
|
||||
"""
|
||||
Cancel all pending reviews for a graph execution (e.g., when execution is stopped).
|
||||
|
||||
Marks all WAITING reviews as REJECTED with a message indicating the execution was stopped.
|
||||
|
||||
Args:
|
||||
graph_exec_id: The graph execution ID
|
||||
user_id: User ID who owns the execution (for security validation)
|
||||
|
||||
Returns:
|
||||
Number of reviews cancelled
|
||||
|
||||
Raises:
|
||||
ValueError: If the graph execution doesn't belong to the user
|
||||
"""
|
||||
# Validate user ownership before cancelling reviews
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec:
|
||||
raise ValueError(
|
||||
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
|
||||
)
|
||||
|
||||
result = await PendingHumanReview.prisma().update_many(
|
||||
where={
|
||||
"graphExecId": graph_exec_id,
|
||||
"userId": user_id,
|
||||
"status": ReviewStatus.WAITING,
|
||||
},
|
||||
data={
|
||||
"status": ReviewStatus.REJECTED,
|
||||
"reviewMessage": "Execution was stopped by user",
|
||||
"processed": True,
|
||||
"reviewedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -46,8 +46,8 @@ async def test_get_or_create_human_review_new(
|
||||
sample_db_review.status = ReviewStatus.WAITING
|
||||
sample_db_review.processed = False
|
||||
|
||||
mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||
mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||
mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||
mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||
|
||||
result = await get_or_create_human_review(
|
||||
user_id="test-user-123",
|
||||
@@ -75,8 +75,8 @@ async def test_get_or_create_human_review_approved(
|
||||
sample_db_review.processed = False
|
||||
sample_db_review.reviewMessage = "Looks good"
|
||||
|
||||
mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||
mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||
mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||
mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||
|
||||
result = await get_or_create_human_review(
|
||||
user_id="test-user-123",
|
||||
|
||||
@@ -50,6 +50,8 @@ from backend.data.graph import (
|
||||
validate_graph_execution_permissions,
|
||||
)
|
||||
from backend.data.human_review import (
|
||||
cancel_pending_reviews_for_execution,
|
||||
check_approval,
|
||||
get_or_create_human_review,
|
||||
has_pending_reviews_for_graph_exec,
|
||||
update_review_processed_status,
|
||||
@@ -190,6 +192,8 @@ class DatabaseManager(AppService):
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
|
||||
# Human In The Loop
|
||||
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
||||
check_approval = _(check_approval)
|
||||
get_or_create_human_review = _(get_or_create_human_review)
|
||||
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
||||
update_review_processed_status = _(update_review_processed_status)
|
||||
@@ -313,6 +317,8 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
set_execution_kv_data = d.set_execution_kv_data
|
||||
|
||||
# Human In The Loop
|
||||
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
||||
check_approval = d.check_approval
|
||||
get_or_create_human_review = d.get_or_create_human_review
|
||||
update_review_processed_status = d.update_review_processed_status
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import human_review as human_review_db
|
||||
from backend.data import onboarding as onboarding_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data.block import (
|
||||
@@ -749,9 +750,27 @@ async def stop_graph_execution(
|
||||
if graph_exec.status in [
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.REVIEW,
|
||||
]:
|
||||
# If the graph is still on the queue, we can prevent them from being executed
|
||||
# by setting the status to TERMINATED.
|
||||
# If the graph is queued/incomplete/paused for review, terminate immediately
|
||||
# No need to wait for executor since it's not actively running
|
||||
|
||||
# If graph is in REVIEW status, clean up pending reviews before terminating
|
||||
if graph_exec.status == ExecutionStatus.REVIEW:
|
||||
# Use human_review_db if Prisma connected, else database manager
|
||||
review_db = (
|
||||
human_review_db
|
||||
if prisma.is_connected()
|
||||
else get_database_manager_async_client()
|
||||
)
|
||||
# Mark all pending reviews as rejected/cancelled
|
||||
cancelled_count = await review_db.cancel_pending_reviews_for_execution(
|
||||
graph_exec_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"Cancelled {cancelled_count} pending review(s) for stopped execution {graph_exec_id}"
|
||||
)
|
||||
|
||||
graph_exec.status = ExecutionStatus.TERMINATED
|
||||
|
||||
await asyncio.gather(
|
||||
@@ -887,9 +906,28 @@ async def add_graph_execution(
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
||||
logger.info(f"Queueing execution {graph_exec.id}")
|
||||
|
||||
# Update execution status to QUEUED BEFORE publishing to prevent race condition
|
||||
# where two concurrent requests could both publish the same execution
|
||||
updated_exec = await edb.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.QUEUED,
|
||||
)
|
||||
|
||||
# Verify the status update succeeded (prevents duplicate queueing in race conditions)
|
||||
# If another request already updated the status, this execution will not be QUEUED
|
||||
if not updated_exec or updated_exec.status != ExecutionStatus.QUEUED:
|
||||
logger.warning(
|
||||
f"Skipping queue publish for execution {graph_exec.id} - "
|
||||
f"status update failed or execution already queued by another request"
|
||||
)
|
||||
return graph_exec
|
||||
|
||||
graph_exec.status = ExecutionStatus.QUEUED
|
||||
|
||||
# Publish to execution queue for executor to pick up
|
||||
# This happens AFTER status update to ensure only one request publishes
|
||||
exec_queue = await get_async_execution_queue()
|
||||
await exec_queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
@@ -897,13 +935,6 @@ async def add_graph_execution(
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
||||
|
||||
# Update execution status to QUEUED
|
||||
graph_exec.status = ExecutionStatus.QUEUED
|
||||
await edb.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=graph_exec.status,
|
||||
)
|
||||
except BaseException as e:
|
||||
err = str(e) or type(e).__name__
|
||||
if not graph_exec:
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from backend.data.dynamic_fields import merge_execution_input, parse_execution_output
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.util.mock import MockObject
|
||||
|
||||
|
||||
@@ -346,6 +347,7 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Mock the queue and event bus
|
||||
@@ -611,6 +613,7 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = []
|
||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||
|
||||
# Track what's passed to to_graph_execution_entry
|
||||
captured_kwargs = {}
|
||||
@@ -670,3 +673,232 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
# Verify nodes_to_skip was passed to to_graph_execution_entry
|
||||
assert "nodes_to_skip" in captured_kwargs
|
||||
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_graph_execution_in_review_status_cancels_pending_reviews(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""Test that stopping an execution in REVIEW status cancels pending reviews."""
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
from backend.executor.utils import stop_graph_execution
|
||||
|
||||
user_id = "test-user"
|
||||
graph_exec_id = "test-exec-123"
|
||||
|
||||
# Mock graph execution in REVIEW status
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||
mock_graph_exec.id = graph_exec_id
|
||||
mock_graph_exec.status = ExecutionStatus.REVIEW
|
||||
|
||||
# Mock dependencies
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_queue_client = mocker.AsyncMock()
|
||||
mock_get_queue.return_value = mock_queue_client
|
||||
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_prisma.is_connected.return_value = True
|
||||
|
||||
mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db")
|
||||
mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock(
|
||||
return_value=2 # 2 reviews cancelled
|
||||
)
|
||||
|
||||
mock_execution_db = mocker.patch("backend.executor.utils.execution_db")
|
||||
mock_execution_db.get_graph_execution_meta = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_execution_db.update_graph_execution_stats = mocker.AsyncMock()
|
||||
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_event_bus = mocker.MagicMock()
|
||||
mock_event_bus.publish = mocker.AsyncMock()
|
||||
mock_get_event_bus.return_value = mock_event_bus
|
||||
|
||||
mock_get_child_executions = mocker.patch(
|
||||
"backend.executor.utils._get_child_executions"
|
||||
)
|
||||
mock_get_child_executions.return_value = [] # No children
|
||||
|
||||
# Call stop_graph_execution with timeout to allow status check
|
||||
await stop_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
wait_timeout=1.0, # Wait to allow status check
|
||||
cascade=True,
|
||||
)
|
||||
|
||||
# Verify pending reviews were cancelled
|
||||
mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with(
|
||||
graph_exec_id, user_id
|
||||
)
|
||||
|
||||
# Verify execution status was updated to TERMINATED
|
||||
mock_execution_db.update_graph_execution_stats.assert_called_once()
|
||||
call_kwargs = mock_execution_db.update_graph_execution_stats.call_args[1]
|
||||
assert call_kwargs["graph_exec_id"] == graph_exec_id
|
||||
assert call_kwargs["status"] == ExecutionStatus.TERMINATED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_graph_execution_with_database_manager_when_prisma_disconnected(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""Test that stop uses database manager when Prisma is not connected."""
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
from backend.executor.utils import stop_graph_execution
|
||||
|
||||
user_id = "test-user"
|
||||
graph_exec_id = "test-exec-456"
|
||||
|
||||
# Mock graph execution in REVIEW status
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||
mock_graph_exec.id = graph_exec_id
|
||||
mock_graph_exec.status = ExecutionStatus.REVIEW
|
||||
|
||||
# Mock dependencies
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_queue_client = mocker.AsyncMock()
|
||||
mock_get_queue.return_value = mock_queue_client
|
||||
|
||||
# Prisma is NOT connected
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_prisma.is_connected.return_value = False
|
||||
|
||||
# Mock database manager client
|
||||
mock_get_db_manager = mocker.patch(
|
||||
"backend.executor.utils.get_database_manager_async_client"
|
||||
)
|
||||
mock_db_manager = mocker.AsyncMock()
|
||||
mock_db_manager.get_graph_execution_meta = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_db_manager.cancel_pending_reviews_for_execution = mocker.AsyncMock(
|
||||
return_value=3 # 3 reviews cancelled
|
||||
)
|
||||
mock_db_manager.update_graph_execution_stats = mocker.AsyncMock()
|
||||
mock_get_db_manager.return_value = mock_db_manager
|
||||
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_event_bus = mocker.MagicMock()
|
||||
mock_event_bus.publish = mocker.AsyncMock()
|
||||
mock_get_event_bus.return_value = mock_event_bus
|
||||
|
||||
mock_get_child_executions = mocker.patch(
|
||||
"backend.executor.utils._get_child_executions"
|
||||
)
|
||||
mock_get_child_executions.return_value = [] # No children
|
||||
|
||||
# Call stop_graph_execution with timeout
|
||||
await stop_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
wait_timeout=1.0,
|
||||
cascade=True,
|
||||
)
|
||||
|
||||
# Verify database manager was used for cancel_pending_reviews
|
||||
mock_db_manager.cancel_pending_reviews_for_execution.assert_called_once_with(
|
||||
graph_exec_id, user_id
|
||||
)
|
||||
|
||||
# Verify execution status was updated via database manager
|
||||
mock_db_manager.update_graph_execution_stats.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_graph_execution_cascades_to_child_with_reviews(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""Test that stopping parent execution cascades to children and cancels their reviews."""
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
from backend.executor.utils import stop_graph_execution
|
||||
|
||||
user_id = "test-user"
|
||||
parent_exec_id = "parent-exec"
|
||||
child_exec_id = "child-exec"
|
||||
|
||||
# Mock parent execution in RUNNING status
|
||||
mock_parent_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||
mock_parent_exec.id = parent_exec_id
|
||||
mock_parent_exec.status = ExecutionStatus.RUNNING
|
||||
|
||||
# Mock child execution in REVIEW status
|
||||
mock_child_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||
mock_child_exec.id = child_exec_id
|
||||
mock_child_exec.status = ExecutionStatus.REVIEW
|
||||
|
||||
# Mock dependencies
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_queue_client = mocker.AsyncMock()
|
||||
mock_get_queue.return_value = mock_queue_client
|
||||
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_prisma.is_connected.return_value = True
|
||||
|
||||
mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db")
|
||||
mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock(
|
||||
return_value=1 # 1 child review cancelled
|
||||
)
|
||||
|
||||
# Mock execution_db to return different status based on which execution is queried
|
||||
mock_execution_db = mocker.patch("backend.executor.utils.execution_db")
|
||||
|
||||
# Track call count to simulate status transition
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def get_exec_meta_side_effect(execution_id, user_id):
|
||||
call_count["count"] += 1
|
||||
if execution_id == parent_exec_id:
|
||||
# After a few calls (child processing happens), transition parent to TERMINATED
|
||||
# This simulates the executor service processing the stop request
|
||||
if call_count["count"] > 3:
|
||||
mock_parent_exec.status = ExecutionStatus.TERMINATED
|
||||
return mock_parent_exec
|
||||
elif execution_id == child_exec_id:
|
||||
return mock_child_exec
|
||||
return None
|
||||
|
||||
mock_execution_db.get_graph_execution_meta = mocker.AsyncMock(
|
||||
side_effect=get_exec_meta_side_effect
|
||||
)
|
||||
mock_execution_db.update_graph_execution_stats = mocker.AsyncMock()
|
||||
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_event_bus = mocker.MagicMock()
|
||||
mock_event_bus.publish = mocker.AsyncMock()
|
||||
mock_get_event_bus.return_value = mock_event_bus
|
||||
|
||||
# Mock _get_child_executions to return the child
|
||||
mock_get_child_executions = mocker.patch(
|
||||
"backend.executor.utils._get_child_executions"
|
||||
)
|
||||
|
||||
def get_children_side_effect(parent_id):
|
||||
if parent_id == parent_exec_id:
|
||||
return [mock_child_exec]
|
||||
return []
|
||||
|
||||
mock_get_child_executions.side_effect = get_children_side_effect
|
||||
|
||||
# Call stop_graph_execution on parent with cascade=True
|
||||
await stop_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_exec_id=parent_exec_id,
|
||||
wait_timeout=1.0,
|
||||
cascade=True,
|
||||
)
|
||||
|
||||
# Verify child reviews were cancelled
|
||||
mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with(
|
||||
child_exec_id, user_id
|
||||
)
|
||||
|
||||
# Verify both parent and child status updates
|
||||
assert mock_execution_db.update_graph_execution_stats.call_count >= 1
|
||||
|
||||
@@ -1,37 +1,12 @@
|
||||
-- CreateExtension
|
||||
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
||||
-- Ensures vector extension is in the current schema (from DATABASE_URL ?schema= param)
|
||||
-- If it exists in a different schema (e.g., public), we drop and recreate it in the current schema
|
||||
-- Creates extension in current schema (determined by search_path from DATABASE_URL ?schema= param)
|
||||
-- This ensures vector type is in the same schema as tables, making ::vector work without explicit qualification
|
||||
DO $$
|
||||
DECLARE
|
||||
current_schema_name text;
|
||||
vector_schema text;
|
||||
BEGIN
|
||||
-- Get the current schema from search_path
|
||||
SELECT current_schema() INTO current_schema_name;
|
||||
|
||||
-- Check if vector extension exists and which schema it's in
|
||||
SELECT n.nspname INTO vector_schema
|
||||
FROM pg_extension e
|
||||
JOIN pg_namespace n ON e.extnamespace = n.oid
|
||||
WHERE e.extname = 'vector';
|
||||
|
||||
-- Handle removal if in wrong schema
|
||||
IF vector_schema IS NOT NULL AND vector_schema != current_schema_name THEN
|
||||
BEGIN
|
||||
-- Vector exists in a different schema, drop it first
|
||||
RAISE WARNING 'pgvector found in schema "%" but need it in "%". Dropping and reinstalling...',
|
||||
vector_schema, current_schema_name;
|
||||
EXECUTE 'DROP EXTENSION IF EXISTS vector CASCADE';
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE EXCEPTION 'Failed to drop pgvector from schema "%": %. You may need to drop it manually.',
|
||||
vector_schema, SQLERRM;
|
||||
END;
|
||||
END IF;
|
||||
|
||||
-- Create extension in current schema (let it fail naturally if not available)
|
||||
EXECUTE format('CREATE EXTENSION IF NOT EXISTS vector SCHEMA %I', current_schema_name);
|
||||
CREATE EXTENSION IF NOT EXISTS "vector";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'vector extension not available or already exists, skipping';
|
||||
END $$;
|
||||
|
||||
-- CreateEnum
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
-- Acknowledge Supabase-managed extensions to prevent drift warnings
|
||||
-- These extensions are pre-installed by Supabase in specific schemas
|
||||
-- This migration ensures they exist where available (Supabase) or skips gracefully (CI)
|
||||
|
||||
-- Create schemas (safe in both CI and Supabase)
|
||||
CREATE SCHEMA IF NOT EXISTS "extensions";
|
||||
|
||||
-- Extensions that exist in both CI and Supabase
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgcrypto extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'uuid-ossp extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
-- Supabase-specific extensions (skip gracefully in CI)
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_stat_statements extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_net" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_net extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgjwt extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "graphql";
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_graphql extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "pgsodium";
|
||||
CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgsodium extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "vault";
|
||||
CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'supabase_vault extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
|
||||
-- Return to platform
|
||||
CREATE SCHEMA IF NOT EXISTS "platform";
|
||||
@@ -0,0 +1,7 @@
|
||||
-- Remove NodeExecution foreign key from PendingHumanReview
|
||||
-- The nodeExecId column remains as the primary key, but we remove the FK constraint
|
||||
-- to AgentNodeExecution since PendingHumanReview records can persist after node
|
||||
-- execution records are deleted.
|
||||
|
||||
-- Drop foreign key constraint that linked PendingHumanReview.nodeExecId to AgentNodeExecution.id
|
||||
ALTER TABLE "PendingHumanReview" DROP CONSTRAINT IF EXISTS "PendingHumanReview_nodeExecId_fkey";
|
||||
@@ -517,8 +517,6 @@ model AgentNodeExecution {
|
||||
|
||||
stats Json?
|
||||
|
||||
PendingHumanReview PendingHumanReview?
|
||||
|
||||
@@index([agentGraphExecutionId, agentNodeId, executionStatus])
|
||||
@@index([agentNodeId, executionStatus])
|
||||
@@index([addedTime, queuedTime])
|
||||
@@ -567,6 +565,7 @@ enum ReviewStatus {
|
||||
}
|
||||
|
||||
// Pending human reviews for Human-in-the-loop blocks
|
||||
// Also stores auto-approval records with special nodeExecId patterns (e.g., "auto_approve_{graph_exec_id}_{node_id}")
|
||||
model PendingHumanReview {
|
||||
nodeExecId String @id
|
||||
userId String
|
||||
@@ -585,7 +584,6 @@ model PendingHumanReview {
|
||||
reviewedAt DateTime?
|
||||
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
NodeExecution AgentNodeExecution @relation(fields: [nodeExecId], references: [id], onDelete: Cascade)
|
||||
GraphExecution AgentGraphExecution @relation(fields: [graphExecId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@unique([nodeExecId]) // One pending review per node execution
|
||||
|
||||
@@ -34,10 +34,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Default output directory relative to repo root
|
||||
DEFAULT_OUTPUT_DIR = (
|
||||
Path(__file__).parent.parent.parent.parent
|
||||
/ "docs"
|
||||
/ "integrations"
|
||||
/ "block-integrations"
|
||||
Path(__file__).parent.parent.parent.parent / "docs" / "integrations"
|
||||
)
|
||||
|
||||
|
||||
@@ -424,14 +421,6 @@ def generate_block_markdown(
|
||||
lines.append("<!-- END MANUAL -->")
|
||||
lines.append("")
|
||||
|
||||
# Optional per-block extras (only include if has content)
|
||||
extras = manual_content.get("extras", "")
|
||||
if extras:
|
||||
lines.append("<!-- MANUAL: extras -->")
|
||||
lines.append(extras)
|
||||
lines.append("<!-- END MANUAL -->")
|
||||
lines.append("")
|
||||
|
||||
lines.append("---")
|
||||
lines.append("")
|
||||
|
||||
@@ -467,52 +456,25 @@ def get_block_file_mapping(blocks: list[BlockDoc]) -> dict[str, list[BlockDoc]]:
|
||||
return dict(file_mapping)
|
||||
|
||||
|
||||
def generate_overview_table(blocks: list[BlockDoc], block_dir_prefix: str = "") -> str:
|
||||
"""Generate the overview table markdown (blocks.md).
|
||||
|
||||
Args:
|
||||
blocks: List of block documentation objects
|
||||
block_dir_prefix: Prefix for block file links (e.g., "block-integrations/")
|
||||
"""
|
||||
def generate_overview_table(blocks: list[BlockDoc]) -> str:
|
||||
"""Generate the overview table markdown (blocks.md)."""
|
||||
lines = []
|
||||
|
||||
# GitBook YAML frontmatter
|
||||
lines.append("---")
|
||||
lines.append("layout:")
|
||||
lines.append(" width: default")
|
||||
lines.append(" title:")
|
||||
lines.append(" visible: true")
|
||||
lines.append(" description:")
|
||||
lines.append(" visible: true")
|
||||
lines.append(" tableOfContents:")
|
||||
lines.append(" visible: false")
|
||||
lines.append(" outline:")
|
||||
lines.append(" visible: true")
|
||||
lines.append(" pagination:")
|
||||
lines.append(" visible: true")
|
||||
lines.append(" metadata:")
|
||||
lines.append(" visible: true")
|
||||
lines.append("---")
|
||||
lines.append("")
|
||||
|
||||
lines.append("# AutoGPT Blocks Overview")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
'AutoGPT uses a modular approach with various "blocks" to handle different tasks. These blocks are the building blocks of AutoGPT workflows, allowing users to create complex automations by combining simple, specialized components.'
|
||||
)
|
||||
lines.append("")
|
||||
lines.append('{% hint style="info" %}')
|
||||
lines.append("**Creating Your Own Blocks**")
|
||||
lines.append("")
|
||||
lines.append("Want to create your own custom blocks? Check out our guides:")
|
||||
lines.append("")
|
||||
lines.append('!!! info "Creating Your Own Blocks"')
|
||||
lines.append(" Want to create your own custom blocks? Check out our guides:")
|
||||
lines.append(" ")
|
||||
lines.append(
|
||||
"* [Build your own Blocks](https://docs.agpt.co/platform/new_blocks/) - Step-by-step tutorial with examples"
|
||||
" - [Build your own Blocks](https://docs.agpt.co/platform/new_blocks/) - Step-by-step tutorial with examples"
|
||||
)
|
||||
lines.append(
|
||||
"* [Block SDK Guide](https://docs.agpt.co/platform/block-sdk-guide/) - Advanced SDK patterns with OAuth, webhooks, and provider configuration"
|
||||
" - [Block SDK Guide](https://docs.agpt.co/platform/block-sdk-guide/) - Advanced SDK patterns with OAuth, webhooks, and provider configuration"
|
||||
)
|
||||
lines.append("{% endhint %}")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
"Below is a comprehensive list of all available blocks, categorized by their primary function. Click on any block name to view its detailed documentation."
|
||||
@@ -575,8 +537,7 @@ def generate_overview_table(blocks: list[BlockDoc], block_dir_prefix: str = "")
|
||||
else "No description"
|
||||
)
|
||||
short_desc = short_desc.replace("\n", " ").replace("|", "\\|")
|
||||
link_path = f"{block_dir_prefix}{file_path}"
|
||||
lines.append(f"| [{block.name}]({link_path}#{anchor}) | {short_desc} |")
|
||||
lines.append(f"| [{block.name}]({file_path}#{anchor}) | {short_desc} |")
|
||||
lines.append("")
|
||||
continue
|
||||
|
||||
@@ -602,55 +563,13 @@ def generate_overview_table(blocks: list[BlockDoc], block_dir_prefix: str = "")
|
||||
)
|
||||
short_desc = short_desc.replace("\n", " ").replace("|", "\\|")
|
||||
|
||||
link_path = f"{block_dir_prefix}{file_path}"
|
||||
lines.append(f"| [{block.name}]({link_path}#{anchor}) | {short_desc} |")
|
||||
lines.append(f"| [{block.name}]({file_path}#{anchor}) | {short_desc} |")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def generate_summary_md(
|
||||
blocks: list[BlockDoc], root_dir: Path, block_dir_prefix: str = ""
|
||||
) -> str:
|
||||
"""Generate SUMMARY.md for GitBook navigation.
|
||||
|
||||
Args:
|
||||
blocks: List of block documentation objects
|
||||
root_dir: The root docs directory (e.g., docs/integrations/)
|
||||
block_dir_prefix: Prefix for block file links (e.g., "block-integrations/")
|
||||
"""
|
||||
lines = []
|
||||
lines.append("# Table of contents")
|
||||
lines.append("")
|
||||
lines.append("* [AutoGPT Blocks Overview](README.md)")
|
||||
lines.append("")
|
||||
|
||||
# Check for guides/ directory at the root level (docs/integrations/guides/)
|
||||
guides_dir = root_dir / "guides"
|
||||
if guides_dir.exists():
|
||||
lines.append("## Guides")
|
||||
lines.append("")
|
||||
for guide_file in sorted(guides_dir.glob("*.md")):
|
||||
# Use just the file name for title (replace hyphens/underscores with spaces)
|
||||
title = file_path_to_title(guide_file.stem.replace("-", "_") + ".md")
|
||||
lines.append(f"* [{title}](guides/{guide_file.name})")
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Block Integrations")
|
||||
lines.append("")
|
||||
|
||||
file_mapping = get_block_file_mapping(blocks)
|
||||
for file_path in sorted(file_mapping.keys()):
|
||||
title = file_path_to_title(file_path)
|
||||
link_path = f"{block_dir_prefix}{file_path}"
|
||||
lines.append(f"* [{title}]({link_path})")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def load_all_blocks_for_docs() -> list[BlockDoc]:
|
||||
"""Load all blocks and extract documentation."""
|
||||
from backend.blocks import load_all_blocks
|
||||
@@ -734,16 +653,6 @@ def write_block_docs(
|
||||
)
|
||||
)
|
||||
|
||||
# Add file-level additional_content section if present
|
||||
file_additional = extract_manual_content(existing_content).get(
|
||||
"additional_content", ""
|
||||
)
|
||||
if file_additional:
|
||||
content_parts.append("<!-- MANUAL: additional_content -->")
|
||||
content_parts.append(file_additional)
|
||||
content_parts.append("<!-- END MANUAL -->")
|
||||
content_parts.append("")
|
||||
|
||||
full_content = file_header + "\n" + "\n".join(content_parts)
|
||||
generated_files[str(file_path)] = full_content
|
||||
|
||||
@@ -752,28 +661,14 @@ def write_block_docs(
|
||||
|
||||
full_path.write_text(full_content)
|
||||
|
||||
# Generate overview file at the parent directory (docs/integrations/)
|
||||
# with links prefixed to point into block-integrations/
|
||||
root_dir = output_dir.parent
|
||||
block_dir_name = output_dir.name # "block-integrations"
|
||||
block_dir_prefix = f"{block_dir_name}/"
|
||||
|
||||
overview_content = generate_overview_table(blocks, block_dir_prefix)
|
||||
overview_path = root_dir / "README.md"
|
||||
# Generate overview file
|
||||
overview_content = generate_overview_table(blocks)
|
||||
overview_path = output_dir / "README.md"
|
||||
generated_files["README.md"] = overview_content
|
||||
overview_path.write_text(overview_content)
|
||||
|
||||
if verbose:
|
||||
print(" Writing README.md (overview) to parent directory")
|
||||
|
||||
# Generate SUMMARY.md for GitBook navigation at the parent directory
|
||||
summary_content = generate_summary_md(blocks, root_dir, block_dir_prefix)
|
||||
summary_path = root_dir / "SUMMARY.md"
|
||||
generated_files["SUMMARY.md"] = summary_content
|
||||
summary_path.write_text(summary_content)
|
||||
|
||||
if verbose:
|
||||
print(" Writing SUMMARY.md (navigation) to parent directory")
|
||||
print(" Writing README.md (overview)")
|
||||
|
||||
return generated_files
|
||||
|
||||
@@ -853,16 +748,6 @@ def check_docs_in_sync(output_dir: Path, blocks: list[BlockDoc]) -> bool:
|
||||
elif block_match.group(1).strip() != expected_block_content.strip():
|
||||
mismatched_blocks.append(block.name)
|
||||
|
||||
# Add file-level additional_content to expected content (matches write_block_docs)
|
||||
file_additional = extract_manual_content(existing_content).get(
|
||||
"additional_content", ""
|
||||
)
|
||||
if file_additional:
|
||||
content_parts.append("<!-- MANUAL: additional_content -->")
|
||||
content_parts.append(file_additional)
|
||||
content_parts.append("<!-- END MANUAL -->")
|
||||
content_parts.append("")
|
||||
|
||||
expected_content = file_header + "\n" + "\n".join(content_parts)
|
||||
|
||||
if existing_content.strip() != expected_content.strip():
|
||||
@@ -872,15 +757,11 @@ def check_docs_in_sync(output_dir: Path, blocks: list[BlockDoc]) -> bool:
|
||||
out_of_sync_details.append((file_path, mismatched_blocks))
|
||||
all_match = False
|
||||
|
||||
# Check overview at the parent directory (docs/integrations/)
|
||||
root_dir = output_dir.parent
|
||||
block_dir_name = output_dir.name # "block-integrations"
|
||||
block_dir_prefix = f"{block_dir_name}/"
|
||||
|
||||
overview_path = root_dir / "README.md"
|
||||
# Check overview
|
||||
overview_path = output_dir / "README.md"
|
||||
if overview_path.exists():
|
||||
existing_overview = overview_path.read_text()
|
||||
expected_overview = generate_overview_table(blocks, block_dir_prefix)
|
||||
expected_overview = generate_overview_table(blocks)
|
||||
if existing_overview.strip() != expected_overview.strip():
|
||||
print("OUT OF SYNC: README.md (overview)")
|
||||
print(" The blocks overview table needs regeneration")
|
||||
@@ -891,21 +772,6 @@ def check_docs_in_sync(output_dir: Path, blocks: list[BlockDoc]) -> bool:
|
||||
out_of_sync_details.append(("README.md", ["overview table"]))
|
||||
all_match = False
|
||||
|
||||
# Check SUMMARY.md at the parent directory
|
||||
summary_path = root_dir / "SUMMARY.md"
|
||||
if summary_path.exists():
|
||||
existing_summary = summary_path.read_text()
|
||||
expected_summary = generate_summary_md(blocks, root_dir, block_dir_prefix)
|
||||
if existing_summary.strip() != expected_summary.strip():
|
||||
print("OUT OF SYNC: SUMMARY.md (navigation)")
|
||||
print(" The GitBook navigation needs regeneration")
|
||||
out_of_sync_details.append(("SUMMARY.md", ["navigation"]))
|
||||
all_match = False
|
||||
else:
|
||||
print("MISSING: SUMMARY.md (navigation)")
|
||||
out_of_sync_details.append(("SUMMARY.md", ["navigation"]))
|
||||
all_match = False
|
||||
|
||||
# Check for unfilled manual sections
|
||||
unfilled_patterns = [
|
||||
"_Add a description of this category of blocks._",
|
||||
|
||||
@@ -29,4 +29,4 @@ NEXT_PUBLIC_CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
NEXT_PUBLIC_TURNSTILE=disabled
|
||||
|
||||
# PR previews
|
||||
NEXT_PUBLIC_PREVIEW_STEALING_DEV=
|
||||
NEXT_PUBLIC_PREVIEW_STEALING_DEV=
|
||||
@@ -175,8 +175,6 @@ While server components and actions are cool and cutting-edge, they introduce a
|
||||
|
||||
- Prefer [React Query](https://tanstack.com/query/latest/docs/framework/react/overview) for server state, colocated near consumers (see [state colocation](https://kentcdodds.com/blog/state-colocation-will-make-your-react-app-faster))
|
||||
- Co-locate UI state inside components/hooks; keep global state minimal
|
||||
- Avoid `useMemo` and `useCallback` unless you have a measured performance issue
|
||||
- Do not abuse `useEffect`; prefer state colocation and derive values directly when possible
|
||||
|
||||
### Styling and components
|
||||
|
||||
@@ -551,48 +549,9 @@ Files:
|
||||
Types:
|
||||
|
||||
- Prefer `interface` for object shapes
|
||||
- Component props should be `interface Props { ... }` (not exported)
|
||||
- Only use specific exported names (e.g., `export interface MyComponentProps`) when the interface needs to be used outside the component
|
||||
- Keep type definitions inline with the component - do not create separate `types.ts` files unless types are shared across multiple files
|
||||
- Component props should be `interface Props { ... }`
|
||||
- Use precise types; avoid `any` and unsafe casts
|
||||
|
||||
**Props naming examples:**
|
||||
|
||||
```tsx
|
||||
// ✅ Good - internal props, not exported
|
||||
interface Props {
|
||||
title: string;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function Modal({ title, onClose }: Props) {
|
||||
// ...
|
||||
}
|
||||
|
||||
// ✅ Good - exported when needed externally
|
||||
export interface ModalProps {
|
||||
title: string;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function Modal({ title, onClose }: ModalProps) {
|
||||
// ...
|
||||
}
|
||||
|
||||
// ❌ Bad - unnecessarily specific name for internal use
|
||||
interface ModalComponentProps {
|
||||
title: string;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
// ❌ Bad - separate types.ts file for single component
|
||||
// types.ts
|
||||
export interface ModalProps { ... }
|
||||
|
||||
// Modal.tsx
|
||||
import type { ModalProps } from './types';
|
||||
```
|
||||
|
||||
Parameters:
|
||||
|
||||
- If more than one parameter is needed, pass a single `Args` object for clarity
|
||||
|
||||
@@ -16,12 +16,6 @@ export default defineConfig({
|
||||
client: "react-query",
|
||||
httpClient: "fetch",
|
||||
indexFiles: false,
|
||||
mock: {
|
||||
type: "msw",
|
||||
baseUrl: "http://localhost:3000/api/proxy",
|
||||
generateEachHttpStatus: true,
|
||||
delay: 0,
|
||||
},
|
||||
override: {
|
||||
mutator: {
|
||||
path: "./mutators/custom-mutator.ts",
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
"types": "tsc --noEmit",
|
||||
"test": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test",
|
||||
"test-ui": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test --ui",
|
||||
"test:unit": "vitest run",
|
||||
"test:unit:watch": "vitest",
|
||||
"test:no-build": "playwright test",
|
||||
"gentests": "playwright codegen http://localhost:3000",
|
||||
"storybook": "storybook dev -p 6006",
|
||||
@@ -120,7 +118,6 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/storybook": "4.1.2",
|
||||
"happy-dom": "20.3.4",
|
||||
"@opentelemetry/instrumentation": "0.209.0",
|
||||
"@playwright/test": "1.56.1",
|
||||
"@storybook/addon-a11y": "9.1.5",
|
||||
@@ -130,8 +127,6 @@
|
||||
"@storybook/nextjs": "9.1.5",
|
||||
"@tanstack/eslint-plugin-query": "5.91.2",
|
||||
"@tanstack/react-query-devtools": "5.90.2",
|
||||
"@testing-library/dom": "10.4.1",
|
||||
"@testing-library/react": "16.3.2",
|
||||
"@types/canvas-confetti": "1.9.0",
|
||||
"@types/lodash": "4.17.20",
|
||||
"@types/negotiator": "0.6.4",
|
||||
@@ -140,7 +135,6 @@
|
||||
"@types/react-dom": "18.3.5",
|
||||
"@types/react-modal": "3.16.3",
|
||||
"@types/react-window": "1.8.8",
|
||||
"@vitejs/plugin-react": "5.1.2",
|
||||
"axe-playwright": "2.2.2",
|
||||
"chromatic": "13.3.3",
|
||||
"concurrently": "9.2.1",
|
||||
@@ -159,9 +153,7 @@
|
||||
"require-in-the-middle": "8.0.1",
|
||||
"storybook": "9.1.5",
|
||||
"tailwindcss": "3.4.17",
|
||||
"typescript": "5.9.3",
|
||||
"vite-tsconfig-paths": "6.0.4",
|
||||
"vitest": "4.0.17"
|
||||
"typescript": "5.9.3"
|
||||
},
|
||||
"msw": {
|
||||
"workerDirectory": [
|
||||
|
||||
1118
autogpt_platform/frontend/pnpm-lock.yaml
generated
1118
autogpt_platform/frontend/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -1,58 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect, useRef } from "react";
|
||||
|
||||
const LOGOUT_REDIRECT_DELAY_MS = 400;
|
||||
|
||||
function wait(ms: number): Promise<void> {
|
||||
return new Promise(function resolveAfterDelay(resolve) {
|
||||
setTimeout(resolve, ms);
|
||||
});
|
||||
}
|
||||
|
||||
export default function LogoutPage() {
|
||||
const { logOut } = useSupabase();
|
||||
const { toast } = useToast();
|
||||
const router = useRouter();
|
||||
const hasStartedRef = useRef(false);
|
||||
|
||||
useEffect(
|
||||
function handleLogoutEffect() {
|
||||
if (hasStartedRef.current) return;
|
||||
hasStartedRef.current = true;
|
||||
|
||||
async function runLogout() {
|
||||
try {
|
||||
await logOut();
|
||||
} catch {
|
||||
toast({
|
||||
title: "Failed to log out. Redirecting to login.",
|
||||
variant: "destructive",
|
||||
});
|
||||
} finally {
|
||||
await wait(LOGOUT_REDIRECT_DELAY_MS);
|
||||
router.replace("/login");
|
||||
}
|
||||
}
|
||||
|
||||
void runLogout();
|
||||
},
|
||||
[logOut, router, toast],
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex min-h-screen items-center justify-center px-4">
|
||||
<div className="flex flex-col items-center justify-center gap-4 py-8">
|
||||
<LoadingSpinner size="large" />
|
||||
<Text variant="body" className="text-center">
|
||||
Logging you out...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -9,7 +9,7 @@ export async function GET(request: Request) {
|
||||
const { searchParams, origin } = new URL(request.url);
|
||||
const code = searchParams.get("code");
|
||||
|
||||
let next = "/";
|
||||
let next = "/marketplace";
|
||||
|
||||
if (code) {
|
||||
const supabase = await getServerSupabase();
|
||||
|
||||
@@ -86,7 +86,6 @@ export function FloatingSafeModeToggle({
|
||||
const {
|
||||
currentHITLSafeMode,
|
||||
showHITLToggle,
|
||||
isHITLStateUndetermined,
|
||||
handleHITLToggle,
|
||||
currentSensitiveActionSafeMode,
|
||||
showSensitiveActionToggle,
|
||||
@@ -99,16 +98,9 @@ export function FloatingSafeModeToggle({
|
||||
return null;
|
||||
}
|
||||
|
||||
const showHITL = showHITLToggle && !isHITLStateUndetermined;
|
||||
const showSensitive = showSensitiveActionToggle;
|
||||
|
||||
if (!showHITL && !showSensitive) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("fixed z-50 flex flex-col gap-2", className)}>
|
||||
{showHITL && (
|
||||
{showHITLToggle && (
|
||||
<SafeModeButton
|
||||
isEnabled={currentHITLSafeMode}
|
||||
label="Human in the loop block approval"
|
||||
@@ -119,7 +111,7 @@ export function FloatingSafeModeToggle({
|
||||
fullWidth={fullWidth}
|
||||
/>
|
||||
)}
|
||||
{showSensitive && (
|
||||
{showSensitiveActionToggle && (
|
||||
<SafeModeButton
|
||||
isEnabled={currentSensitiveActionSafeMode}
|
||||
label="Sensitive actions blocks approval"
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { List } from "@phosphor-icons/react";
|
||||
import React, { useState } from "react";
|
||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||
import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState";
|
||||
import { ChatLoadingState } from "./components/ChatLoadingState/ChatLoadingState";
|
||||
import { SessionsDrawer } from "./components/SessionsDrawer/SessionsDrawer";
|
||||
import { useChat } from "./useChat";
|
||||
|
||||
export interface ChatProps {
|
||||
className?: string;
|
||||
headerTitle?: React.ReactNode;
|
||||
showHeader?: boolean;
|
||||
showSessionInfo?: boolean;
|
||||
showNewChatButton?: boolean;
|
||||
onNewChat?: () => void;
|
||||
headerActions?: React.ReactNode;
|
||||
}
|
||||
|
||||
export function Chat({
|
||||
className,
|
||||
headerTitle = "AutoGPT Copilot",
|
||||
showHeader = true,
|
||||
showSessionInfo = true,
|
||||
showNewChatButton = true,
|
||||
onNewChat,
|
||||
headerActions,
|
||||
}: ChatProps) {
|
||||
const {
|
||||
messages,
|
||||
isLoading,
|
||||
isCreating,
|
||||
error,
|
||||
sessionId,
|
||||
createSession,
|
||||
clearSession,
|
||||
loadSession,
|
||||
} = useChat();
|
||||
|
||||
const [isSessionsDrawerOpen, setIsSessionsDrawerOpen] = useState(false);
|
||||
|
||||
const handleNewChat = () => {
|
||||
clearSession();
|
||||
onNewChat?.();
|
||||
};
|
||||
|
||||
const handleSelectSession = async (sessionId: string) => {
|
||||
try {
|
||||
await loadSession(sessionId);
|
||||
} catch (err) {
|
||||
console.error("Failed to load session:", err);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={cn("flex h-full flex-col", className)}>
|
||||
{/* Header */}
|
||||
{showHeader && (
|
||||
<header className="shrink-0 border-t border-zinc-200 bg-white p-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<button
|
||||
aria-label="View sessions"
|
||||
onClick={() => setIsSessionsDrawerOpen(true)}
|
||||
className="flex size-8 items-center justify-center rounded hover:bg-zinc-100"
|
||||
>
|
||||
<List width="1.25rem" height="1.25rem" />
|
||||
</button>
|
||||
{typeof headerTitle === "string" ? (
|
||||
<Text variant="h2" className="text-lg font-semibold">
|
||||
{headerTitle}
|
||||
</Text>
|
||||
) : (
|
||||
headerTitle
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-3">
|
||||
{showSessionInfo && sessionId && (
|
||||
<>
|
||||
{showNewChatButton && (
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={handleNewChat}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
{headerActions}
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
)}
|
||||
|
||||
{/* Main Content */}
|
||||
<main className="flex min-h-0 flex-1 flex-col overflow-hidden">
|
||||
{/* Loading State - show when explicitly loading/creating OR when we don't have a session yet and no error */}
|
||||
{(isLoading || isCreating || (!sessionId && !error)) && (
|
||||
<ChatLoadingState
|
||||
message={isCreating ? "Creating session..." : "Loading..."}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Error State */}
|
||||
{error && !isLoading && (
|
||||
<ChatErrorState error={error} onRetry={createSession} />
|
||||
)}
|
||||
|
||||
{/* Session Content */}
|
||||
{sessionId && !isLoading && !error && (
|
||||
<ChatContainer
|
||||
sessionId={sessionId}
|
||||
initialMessages={messages}
|
||||
className="flex-1"
|
||||
/>
|
||||
)}
|
||||
</main>
|
||||
|
||||
{/* Sessions Drawer */}
|
||||
<SessionsDrawer
|
||||
isOpen={isSessionsDrawerOpen}
|
||||
onClose={() => setIsSessionsDrawerOpen(false)}
|
||||
onSelectSession={handleSelectSession}
|
||||
currentSessionId={sessionId}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -21,7 +21,7 @@ export function AuthPromptWidget({
|
||||
message,
|
||||
sessionId,
|
||||
agentInfo,
|
||||
returnUrl = "/copilot/chat",
|
||||
returnUrl = "/chat",
|
||||
className,
|
||||
}: AuthPromptWidgetProps) {
|
||||
const router = useRouter();
|
||||
@@ -0,0 +1,88 @@
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useCallback } from "react";
|
||||
import { usePageContext } from "../../usePageContext";
|
||||
import { ChatInput } from "../ChatInput/ChatInput";
|
||||
import { MessageList } from "../MessageList/MessageList";
|
||||
import { QuickActionsWelcome } from "../QuickActionsWelcome/QuickActionsWelcome";
|
||||
import { useChatContainer } from "./useChatContainer";
|
||||
|
||||
export interface ChatContainerProps {
|
||||
sessionId: string | null;
|
||||
initialMessages: SessionDetailResponse["messages"];
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
className,
|
||||
}: ChatContainerProps) {
|
||||
const { messages, streamingChunks, isStreaming, sendMessage } =
|
||||
useChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
});
|
||||
const { capturePageContext } = usePageContext();
|
||||
|
||||
// Wrap sendMessage to automatically capture page context
|
||||
const sendMessageWithContext = useCallback(
|
||||
async (content: string, isUserMessage: boolean = true) => {
|
||||
const context = capturePageContext();
|
||||
await sendMessage(content, isUserMessage, context);
|
||||
},
|
||||
[sendMessage, capturePageContext],
|
||||
);
|
||||
|
||||
const quickActions = [
|
||||
"Find agents for social media management",
|
||||
"Show me agents for content creation",
|
||||
"Help me automate my business",
|
||||
"What can you help me with?",
|
||||
];
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn("flex h-full min-h-0 flex-col", className)}
|
||||
style={{
|
||||
backgroundColor: "#ffffff",
|
||||
backgroundImage:
|
||||
"radial-gradient(#e5e5e5 0.5px, transparent 0.5px), radial-gradient(#e5e5e5 0.5px, #ffffff 0.5px)",
|
||||
backgroundSize: "20px 20px",
|
||||
backgroundPosition: "0 0, 10px 10px",
|
||||
}}
|
||||
>
|
||||
{/* Messages or Welcome Screen */}
|
||||
<div className="flex min-h-0 flex-1 flex-col overflow-hidden pb-24">
|
||||
{messages.length === 0 ? (
|
||||
<QuickActionsWelcome
|
||||
title="Welcome to AutoGPT Copilot"
|
||||
description="Start a conversation to discover and run AI agents."
|
||||
actions={quickActions}
|
||||
onActionClick={sendMessageWithContext}
|
||||
disabled={isStreaming || !sessionId}
|
||||
/>
|
||||
) : (
|
||||
<MessageList
|
||||
messages={messages}
|
||||
streamingChunks={streamingChunks}
|
||||
isStreaming={isStreaming}
|
||||
onSendMessage={sendMessageWithContext}
|
||||
className="flex-1"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Input - Always visible */}
|
||||
<div className="fixed bottom-0 left-0 right-0 z-50 border-t border-zinc-200 bg-white p-4">
|
||||
<ChatInput
|
||||
onSend={sendMessageWithContext}
|
||||
disabled={isStreaming || !sessionId}
|
||||
placeholder={
|
||||
sessionId ? "Type your message..." : "Creating session..."
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
import { toast } from "sonner";
|
||||
import { StreamChunk } from "../../useChatStream";
|
||||
import type { HandlerDependencies } from "./handlers";
|
||||
import type { HandlerDependencies } from "./useChatContainer.handlers";
|
||||
import {
|
||||
handleError,
|
||||
handleLoginNeeded,
|
||||
@@ -9,30 +9,12 @@ import {
|
||||
handleTextEnded,
|
||||
handleToolCallStart,
|
||||
handleToolResponse,
|
||||
isRegionBlockedError,
|
||||
} from "./handlers";
|
||||
} from "./useChatContainer.handlers";
|
||||
|
||||
export function createStreamEventDispatcher(
|
||||
deps: HandlerDependencies,
|
||||
): (chunk: StreamChunk) => void {
|
||||
return function dispatchStreamEvent(chunk: StreamChunk): void {
|
||||
if (
|
||||
chunk.type === "text_chunk" ||
|
||||
chunk.type === "tool_call_start" ||
|
||||
chunk.type === "tool_response" ||
|
||||
chunk.type === "login_needed" ||
|
||||
chunk.type === "need_login" ||
|
||||
chunk.type === "error"
|
||||
) {
|
||||
if (!deps.hasResponseRef.current) {
|
||||
console.info("[ChatStream] First response chunk:", {
|
||||
type: chunk.type,
|
||||
sessionId: deps.sessionId,
|
||||
});
|
||||
}
|
||||
deps.hasResponseRef.current = true;
|
||||
}
|
||||
|
||||
switch (chunk.type) {
|
||||
case "text_chunk":
|
||||
handleTextChunk(chunk, deps);
|
||||
@@ -56,23 +38,15 @@ export function createStreamEventDispatcher(
|
||||
break;
|
||||
|
||||
case "stream_end":
|
||||
console.info("[ChatStream] Stream ended:", {
|
||||
sessionId: deps.sessionId,
|
||||
hasResponse: deps.hasResponseRef.current,
|
||||
chunkCount: deps.streamingChunksRef.current.length,
|
||||
});
|
||||
handleStreamEnd(chunk, deps);
|
||||
break;
|
||||
|
||||
case "error":
|
||||
const isRegionBlocked = isRegionBlockedError(chunk);
|
||||
handleError(chunk, deps);
|
||||
// Show toast at dispatcher level to avoid circular dependencies
|
||||
if (!isRegionBlocked) {
|
||||
toast.error("Chat Error", {
|
||||
description: chunk.message || chunk.content || "An error occurred",
|
||||
});
|
||||
}
|
||||
toast.error("Chat Error", {
|
||||
description: chunk.message || chunk.content || "An error occurred",
|
||||
});
|
||||
break;
|
||||
|
||||
case "usage":
|
||||
@@ -1,33 +1,6 @@
|
||||
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
||||
import type { ToolResult } from "@/types/chat";
|
||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||
|
||||
export function hasSentInitialPrompt(sessionId: string): boolean {
|
||||
try {
|
||||
const sent = JSON.parse(
|
||||
sessionStorage.get(SessionKey.CHAT_SENT_INITIAL_PROMPTS) || "{}",
|
||||
);
|
||||
return sent[sessionId] === true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export function markInitialPromptSent(sessionId: string): void {
|
||||
try {
|
||||
const sent = JSON.parse(
|
||||
sessionStorage.get(SessionKey.CHAT_SENT_INITIAL_PROMPTS) || "{}",
|
||||
);
|
||||
sent[sessionId] = true;
|
||||
sessionStorage.set(
|
||||
SessionKey.CHAT_SENT_INITIAL_PROMPTS,
|
||||
JSON.stringify(sent),
|
||||
);
|
||||
} catch {
|
||||
// Ignore storage errors
|
||||
}
|
||||
}
|
||||
|
||||
export function removePageContext(content: string): string {
|
||||
// Remove "Page URL: ..." pattern at start of line (case insensitive, handles various formats)
|
||||
let cleaned = content.replace(/^\s*Page URL:\s*[^\n\r]*/gim, "");
|
||||
@@ -234,22 +207,12 @@ export function parseToolResponse(
|
||||
if (responseType === "setup_requirements") {
|
||||
return null;
|
||||
}
|
||||
if (responseType === "understanding_updated") {
|
||||
return {
|
||||
type: "tool_response",
|
||||
toolId,
|
||||
toolName,
|
||||
result: (parsedResult || result) as ToolResult,
|
||||
success: true,
|
||||
timestamp: timestamp || new Date(),
|
||||
};
|
||||
}
|
||||
}
|
||||
return {
|
||||
type: "tool_response",
|
||||
toolId,
|
||||
toolName,
|
||||
result: parsedResult ? (parsedResult as ToolResult) : result,
|
||||
result,
|
||||
success: true,
|
||||
timestamp: timestamp || new Date(),
|
||||
};
|
||||
@@ -7,30 +7,15 @@ import {
|
||||
parseToolResponse,
|
||||
} from "./helpers";
|
||||
|
||||
function isToolCallMessage(
|
||||
message: ChatMessageData,
|
||||
): message is Extract<ChatMessageData, { type: "tool_call" }> {
|
||||
return message.type === "tool_call";
|
||||
}
|
||||
|
||||
export interface HandlerDependencies {
|
||||
setHasTextChunks: Dispatch<SetStateAction<boolean>>;
|
||||
setStreamingChunks: Dispatch<SetStateAction<string[]>>;
|
||||
streamingChunksRef: MutableRefObject<string[]>;
|
||||
hasResponseRef: MutableRefObject<boolean>;
|
||||
setMessages: Dispatch<SetStateAction<ChatMessageData[]>>;
|
||||
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
||||
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
||||
sessionId: string;
|
||||
}
|
||||
|
||||
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
||||
if (chunk.code === "MODEL_NOT_AVAILABLE_REGION") return true;
|
||||
const message = chunk.message || chunk.content;
|
||||
if (typeof message !== "string") return false;
|
||||
return message.toLowerCase().includes("not available in your region");
|
||||
}
|
||||
|
||||
export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||
if (!chunk.content) return;
|
||||
deps.setHasTextChunks(true);
|
||||
@@ -45,17 +30,16 @@ export function handleTextEnded(
|
||||
_chunk: StreamChunk,
|
||||
deps: HandlerDependencies,
|
||||
) {
|
||||
console.log("[Text Ended] Saving streamed text as assistant message");
|
||||
const completedText = deps.streamingChunksRef.current.join("");
|
||||
if (completedText.trim()) {
|
||||
deps.setMessages((prev) => {
|
||||
const assistantMessage: ChatMessageData = {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: completedText,
|
||||
timestamp: new Date(),
|
||||
};
|
||||
return [...prev, assistantMessage];
|
||||
});
|
||||
const assistantMessage: ChatMessageData = {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: completedText,
|
||||
timestamp: new Date(),
|
||||
};
|
||||
deps.setMessages((prev) => [...prev, assistantMessage]);
|
||||
}
|
||||
deps.setStreamingChunks([]);
|
||||
deps.streamingChunksRef.current = [];
|
||||
@@ -66,45 +50,30 @@ export function handleToolCallStart(
|
||||
chunk: StreamChunk,
|
||||
deps: HandlerDependencies,
|
||||
) {
|
||||
const toolCallMessage: Extract<ChatMessageData, { type: "tool_call" }> = {
|
||||
const toolCallMessage: ChatMessageData = {
|
||||
type: "tool_call",
|
||||
toolId: chunk.tool_id || `tool-${Date.now()}-${chunk.idx || 0}`,
|
||||
toolName: chunk.tool_name || "Executing",
|
||||
toolName: chunk.tool_name || "Executing...",
|
||||
arguments: chunk.arguments || {},
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
function updateToolCallMessages(prev: ChatMessageData[]) {
|
||||
const existingIndex = prev.findIndex(function findToolCallIndex(msg) {
|
||||
return isToolCallMessage(msg) && msg.toolId === toolCallMessage.toolId;
|
||||
});
|
||||
if (existingIndex === -1) {
|
||||
return [...prev, toolCallMessage];
|
||||
}
|
||||
const nextMessages = [...prev];
|
||||
const existing = nextMessages[existingIndex];
|
||||
if (!isToolCallMessage(existing)) return prev;
|
||||
const nextArguments =
|
||||
toolCallMessage.arguments &&
|
||||
Object.keys(toolCallMessage.arguments).length > 0
|
||||
? toolCallMessage.arguments
|
||||
: existing.arguments;
|
||||
nextMessages[existingIndex] = {
|
||||
...existing,
|
||||
toolName: toolCallMessage.toolName || existing.toolName,
|
||||
arguments: nextArguments,
|
||||
timestamp: toolCallMessage.timestamp,
|
||||
};
|
||||
return nextMessages;
|
||||
}
|
||||
|
||||
deps.setMessages(updateToolCallMessages);
|
||||
deps.setMessages((prev) => [...prev, toolCallMessage]);
|
||||
console.log("[Tool Call Start]", {
|
||||
toolId: toolCallMessage.toolId,
|
||||
toolName: toolCallMessage.toolName,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
}
|
||||
|
||||
export function handleToolResponse(
|
||||
chunk: StreamChunk,
|
||||
deps: HandlerDependencies,
|
||||
) {
|
||||
console.log("[Tool Response] Received:", {
|
||||
toolId: chunk.tool_id,
|
||||
toolName: chunk.tool_name,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
let toolName = chunk.tool_name || "unknown";
|
||||
if (!chunk.tool_name || chunk.tool_name === "unknown") {
|
||||
deps.setMessages((prev) => {
|
||||
@@ -158,15 +127,22 @@ export function handleToolResponse(
|
||||
const toolCallIndex = prev.findIndex(
|
||||
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
||||
);
|
||||
const hasResponse = prev.some(
|
||||
(msg) => msg.type === "tool_response" && msg.toolId === chunk.tool_id,
|
||||
);
|
||||
if (hasResponse) return prev;
|
||||
if (toolCallIndex !== -1) {
|
||||
const newMessages = [...prev];
|
||||
newMessages.splice(toolCallIndex + 1, 0, responseMessage);
|
||||
newMessages[toolCallIndex] = responseMessage;
|
||||
console.log(
|
||||
"[Tool Response] Replaced tool_call with matching tool_id:",
|
||||
chunk.tool_id,
|
||||
"at index:",
|
||||
toolCallIndex,
|
||||
);
|
||||
return newMessages;
|
||||
}
|
||||
console.warn(
|
||||
"[Tool Response] No tool_call found with tool_id:",
|
||||
chunk.tool_id,
|
||||
"appending instead",
|
||||
);
|
||||
return [...prev, responseMessage];
|
||||
});
|
||||
}
|
||||
@@ -191,38 +167,55 @@ export function handleStreamEnd(
|
||||
deps: HandlerDependencies,
|
||||
) {
|
||||
const completedContent = deps.streamingChunksRef.current.join("");
|
||||
if (!completedContent.trim() && !deps.hasResponseRef.current) {
|
||||
deps.setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: "No response received. Please try again.",
|
||||
timestamp: new Date(),
|
||||
},
|
||||
]);
|
||||
}
|
||||
// Only save message if there are uncommitted chunks
|
||||
// (text_ended already saved if there were tool calls)
|
||||
if (completedContent.trim()) {
|
||||
console.log(
|
||||
"[Stream End] Saving remaining streamed text as assistant message",
|
||||
);
|
||||
const assistantMessage: ChatMessageData = {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: completedContent,
|
||||
timestamp: new Date(),
|
||||
};
|
||||
deps.setMessages((prev) => [...prev, assistantMessage]);
|
||||
deps.setMessages((prev) => {
|
||||
const updated = [...prev, assistantMessage];
|
||||
console.log("[Stream End] Final state:", {
|
||||
localMessages: updated.map((m) => ({
|
||||
type: m.type,
|
||||
...(m.type === "message" && {
|
||||
role: m.role,
|
||||
contentLength: m.content.length,
|
||||
}),
|
||||
...(m.type === "tool_call" && {
|
||||
toolId: m.toolId,
|
||||
toolName: m.toolName,
|
||||
}),
|
||||
...(m.type === "tool_response" && {
|
||||
toolId: m.toolId,
|
||||
toolName: m.toolName,
|
||||
success: m.success,
|
||||
}),
|
||||
})),
|
||||
streamingChunks: deps.streamingChunksRef.current,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
return updated;
|
||||
});
|
||||
} else {
|
||||
console.log("[Stream End] No uncommitted chunks, message already saved");
|
||||
}
|
||||
deps.setStreamingChunks([]);
|
||||
deps.streamingChunksRef.current = [];
|
||||
deps.setHasTextChunks(false);
|
||||
deps.setIsStreamingInitiated(false);
|
||||
console.log("[Stream End] Stream complete, messages in local state");
|
||||
}
|
||||
|
||||
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||
const errorMessage = chunk.message || chunk.content || "An error occurred";
|
||||
console.error("Stream error:", errorMessage);
|
||||
if (isRegionBlockedError(chunk)) {
|
||||
deps.setIsRegionBlockedModalOpen(true);
|
||||
}
|
||||
deps.setIsStreamingInitiated(false);
|
||||
deps.setHasTextChunks(false);
|
||||
deps.setStreamingChunks([]);
|
||||
@@ -1,17 +1,14 @@
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useCallback, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { useChatStream } from "../../useChatStream";
|
||||
import { usePageContext } from "../../usePageContext";
|
||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
||||
import {
|
||||
createUserMessage,
|
||||
filterAuthMessages,
|
||||
hasSentInitialPrompt,
|
||||
isToolCallArray,
|
||||
isValidMessage,
|
||||
markInitialPromptSent,
|
||||
parseToolResponse,
|
||||
removePageContext,
|
||||
} from "./helpers";
|
||||
@@ -19,45 +16,20 @@ import {
|
||||
interface Args {
|
||||
sessionId: string | null;
|
||||
initialMessages: SessionDetailResponse["messages"];
|
||||
initialPrompt?: string;
|
||||
}
|
||||
|
||||
export function useChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
initialPrompt,
|
||||
}: Args) {
|
||||
export function useChatContainer({ sessionId, initialMessages }: Args) {
|
||||
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
||||
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
||||
const [hasTextChunks, setHasTextChunks] = useState(false);
|
||||
const [isStreamingInitiated, setIsStreamingInitiated] = useState(false);
|
||||
const [isRegionBlockedModalOpen, setIsRegionBlockedModalOpen] =
|
||||
useState(false);
|
||||
const hasResponseRef = useRef(false);
|
||||
const streamingChunksRef = useRef<string[]>([]);
|
||||
const previousSessionIdRef = useRef<string | null>(null);
|
||||
const {
|
||||
error,
|
||||
sendMessage: sendStreamMessage,
|
||||
stopStreaming,
|
||||
} = useChatStream();
|
||||
const { error, sendMessage: sendStreamMessage } = useChatStream();
|
||||
const isStreaming = isStreamingInitiated || hasTextChunks;
|
||||
|
||||
useEffect(() => {
|
||||
if (sessionId !== previousSessionIdRef.current) {
|
||||
stopStreaming(previousSessionIdRef.current ?? undefined, true);
|
||||
previousSessionIdRef.current = sessionId;
|
||||
setMessages([]);
|
||||
setStreamingChunks([]);
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
setIsStreamingInitiated(false);
|
||||
hasResponseRef.current = false;
|
||||
}
|
||||
}, [sessionId, stopStreaming]);
|
||||
|
||||
const allMessages = useMemo(() => {
|
||||
const processedInitialMessages: ChatMessageData[] = [];
|
||||
// Map to track tool calls by their ID so we can look up tool names for tool responses
|
||||
const toolCallMap = new Map<string, string>();
|
||||
|
||||
for (const msg of initialMessages) {
|
||||
@@ -73,9 +45,13 @@ export function useChatContainer({
|
||||
? new Date(msg.timestamp as string)
|
||||
: undefined;
|
||||
|
||||
// Remove page context from user messages when loading existing sessions
|
||||
if (role === "user") {
|
||||
content = removePageContext(content);
|
||||
if (!content.trim()) continue;
|
||||
// Skip user messages that become empty after removing page context
|
||||
if (!content.trim()) {
|
||||
continue;
|
||||
}
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
role: "user",
|
||||
@@ -85,15 +61,19 @@ export function useChatContainer({
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle assistant messages first (before tool messages) to build tool call map
|
||||
if (role === "assistant") {
|
||||
// Strip <thinking> tags from content
|
||||
content = content
|
||||
.replace(/<thinking>[\s\S]*?<\/thinking>/gi, "")
|
||||
.trim();
|
||||
|
||||
// If assistant has tool calls, create tool_call messages for each
|
||||
if (toolCalls && isToolCallArray(toolCalls) && toolCalls.length > 0) {
|
||||
for (const toolCall of toolCalls) {
|
||||
const toolName = toolCall.function.name;
|
||||
const toolId = toolCall.id;
|
||||
// Store tool name for later lookup
|
||||
toolCallMap.set(toolId, toolName);
|
||||
|
||||
try {
|
||||
@@ -116,6 +96,7 @@ export function useChatContainer({
|
||||
});
|
||||
}
|
||||
}
|
||||
// Only add assistant message if there's content after stripping thinking tags
|
||||
if (content.trim()) {
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
@@ -125,6 +106,7 @@ export function useChatContainer({
|
||||
});
|
||||
}
|
||||
} else if (content.trim()) {
|
||||
// Assistant message without tool calls, but with content
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
@@ -135,6 +117,7 @@ export function useChatContainer({
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle tool messages - look up tool name from tool call map
|
||||
if (role === "tool") {
|
||||
const toolCallId = (msg.tool_call_id as string) || "";
|
||||
const toolName = toolCallMap.get(toolCallId) || "unknown";
|
||||
@@ -150,6 +133,7 @@ export function useChatContainer({
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle other message types (system, etc.)
|
||||
if (content.trim()) {
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
@@ -170,10 +154,9 @@ export function useChatContainer({
|
||||
context?: { url: string; content: string },
|
||||
) {
|
||||
if (!sessionId) {
|
||||
console.error("[useChatContainer] Cannot send message: no session ID");
|
||||
console.error("Cannot send message: no session ID");
|
||||
return;
|
||||
}
|
||||
setIsRegionBlockedModalOpen(false);
|
||||
if (isUserMessage) {
|
||||
const userMessage = createUserMessage(content);
|
||||
setMessages((prev) => [...filterAuthMessages(prev), userMessage]);
|
||||
@@ -184,19 +167,14 @@ export function useChatContainer({
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
setIsStreamingInitiated(true);
|
||||
hasResponseRef.current = false;
|
||||
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
hasResponseRef,
|
||||
setMessages,
|
||||
setIsRegionBlockedModalOpen,
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
});
|
||||
|
||||
try {
|
||||
await sendStreamMessage(
|
||||
sessionId,
|
||||
@@ -206,12 +184,8 @@ export function useChatContainer({
|
||||
context,
|
||||
);
|
||||
} catch (err) {
|
||||
console.error("[useChatContainer] Failed to send message:", err);
|
||||
console.error("Failed to send message:", err);
|
||||
setIsStreamingInitiated(false);
|
||||
|
||||
// Don't show error toast for AbortError (expected during cleanup)
|
||||
if (err instanceof Error && err.name === "AbortError") return;
|
||||
|
||||
const errorMessage =
|
||||
err instanceof Error ? err.message : "Failed to send message";
|
||||
toast.error("Failed to send message", {
|
||||
@@ -222,63 +196,11 @@ export function useChatContainer({
|
||||
[sessionId, sendStreamMessage],
|
||||
);
|
||||
|
||||
const handleStopStreaming = useCallback(() => {
|
||||
stopStreaming();
|
||||
setStreamingChunks([]);
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
setIsStreamingInitiated(false);
|
||||
}, [stopStreaming]);
|
||||
|
||||
const { capturePageContext } = usePageContext();
|
||||
|
||||
// Send initial prompt if provided (for new sessions from homepage)
|
||||
useEffect(
|
||||
function handleInitialPrompt() {
|
||||
if (!initialPrompt || !sessionId) return;
|
||||
if (initialMessages.length > 0) return;
|
||||
if (hasSentInitialPrompt(sessionId)) return;
|
||||
|
||||
markInitialPromptSent(sessionId);
|
||||
const context = capturePageContext();
|
||||
sendMessage(initialPrompt, true, context);
|
||||
},
|
||||
[
|
||||
initialPrompt,
|
||||
sessionId,
|
||||
initialMessages.length,
|
||||
sendMessage,
|
||||
capturePageContext,
|
||||
],
|
||||
);
|
||||
|
||||
async function sendMessageWithContext(
|
||||
content: string,
|
||||
isUserMessage: boolean = true,
|
||||
) {
|
||||
const context = capturePageContext();
|
||||
await sendMessage(content, isUserMessage, context);
|
||||
}
|
||||
|
||||
function handleRegionModalOpenChange(open: boolean) {
|
||||
setIsRegionBlockedModalOpen(open);
|
||||
}
|
||||
|
||||
function handleRegionModalClose() {
|
||||
setIsRegionBlockedModalOpen(false);
|
||||
}
|
||||
|
||||
return {
|
||||
messages: allMessages,
|
||||
streamingChunks,
|
||||
isStreaming,
|
||||
error,
|
||||
isRegionBlockedModalOpen,
|
||||
setIsRegionBlockedModalOpen,
|
||||
sendMessageWithContext,
|
||||
handleRegionModalOpenChange,
|
||||
handleRegionModalClose,
|
||||
sendMessage,
|
||||
stopStreaming: handleStopStreaming,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ArrowUpIcon } from "@phosphor-icons/react";
|
||||
import { useChatInput } from "./useChatInput";
|
||||
|
||||
export interface ChatInputProps {
|
||||
onSend: (message: string) => void;
|
||||
disabled?: boolean;
|
||||
placeholder?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ChatInput({
|
||||
onSend,
|
||||
disabled = false,
|
||||
placeholder = "Type your message...",
|
||||
className,
|
||||
}: ChatInputProps) {
|
||||
const inputId = "chat-input";
|
||||
const { value, setValue, handleKeyDown, handleSend } = useChatInput({
|
||||
onSend,
|
||||
disabled,
|
||||
maxRows: 5,
|
||||
inputId,
|
||||
});
|
||||
|
||||
return (
|
||||
<div className={cn("relative flex-1", className)}>
|
||||
<Input
|
||||
id={inputId}
|
||||
label="Chat message input"
|
||||
hideLabel
|
||||
type="textarea"
|
||||
value={value}
|
||||
onChange={(e) => setValue(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={placeholder}
|
||||
disabled={disabled}
|
||||
rows={1}
|
||||
wrapperClassName="mb-0 relative"
|
||||
className="pr-12"
|
||||
/>
|
||||
<span id="chat-input-hint" className="sr-only">
|
||||
Press Enter to send, Shift+Enter for new line
|
||||
</span>
|
||||
|
||||
<button
|
||||
onClick={handleSend}
|
||||
disabled={disabled || !value.trim()}
|
||||
className={cn(
|
||||
"absolute right-3 top-1/2 flex h-8 w-8 -translate-y-1/2 items-center justify-center rounded-full",
|
||||
"border border-zinc-800 bg-zinc-800 text-white",
|
||||
"hover:border-zinc-900 hover:bg-zinc-900",
|
||||
"disabled:border-zinc-200 disabled:bg-zinc-200 disabled:text-white disabled:opacity-50",
|
||||
"transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-neutral-950",
|
||||
"disabled:pointer-events-none",
|
||||
)}
|
||||
aria-label="Send message"
|
||||
>
|
||||
<ArrowUpIcon className="h-3 w-3" weight="bold" />
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
import { KeyboardEvent, useCallback, useEffect, useState } from "react";
|
||||
|
||||
interface UseChatInputArgs {
|
||||
onSend: (message: string) => void;
|
||||
disabled?: boolean;
|
||||
maxRows?: number;
|
||||
inputId?: string;
|
||||
}
|
||||
|
||||
export function useChatInput({
|
||||
onSend,
|
||||
disabled = false,
|
||||
maxRows = 5,
|
||||
inputId = "chat-input",
|
||||
}: UseChatInputArgs) {
|
||||
const [value, setValue] = useState("");
|
||||
|
||||
useEffect(() => {
|
||||
const textarea = document.getElementById(inputId) as HTMLTextAreaElement;
|
||||
if (!textarea) return;
|
||||
textarea.style.height = "auto";
|
||||
const lineHeight = parseInt(
|
||||
window.getComputedStyle(textarea).lineHeight,
|
||||
10,
|
||||
);
|
||||
const maxHeight = lineHeight * maxRows;
|
||||
const newHeight = Math.min(textarea.scrollHeight, maxHeight);
|
||||
textarea.style.height = `${newHeight}px`;
|
||||
textarea.style.overflowY =
|
||||
textarea.scrollHeight > maxHeight ? "auto" : "hidden";
|
||||
}, [value, maxRows, inputId]);
|
||||
|
||||
const handleSend = useCallback(() => {
|
||||
if (disabled || !value.trim()) return;
|
||||
onSend(value.trim());
|
||||
setValue("");
|
||||
const textarea = document.getElementById(inputId) as HTMLTextAreaElement;
|
||||
if (textarea) {
|
||||
textarea.style.height = "auto";
|
||||
}
|
||||
}, [value, onSend, disabled, inputId]);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(event: KeyboardEvent<HTMLInputElement | HTMLTextAreaElement>) => {
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
handleSend();
|
||||
}
|
||||
// Shift+Enter allows default behavior (new line) - no need to handle explicitly
|
||||
},
|
||||
[handleSend],
|
||||
);
|
||||
|
||||
return {
|
||||
value,
|
||||
setValue,
|
||||
handleKeyDown,
|
||||
handleSend,
|
||||
};
|
||||
}
|
||||
@@ -1,64 +1,48 @@
|
||||
"use client";
|
||||
|
||||
import { useGetV2GetUserProfile } from "@/app/api/__generated__/endpoints/store/store";
|
||||
import Avatar, {
|
||||
AvatarFallback,
|
||||
AvatarImage,
|
||||
} from "@/components/atoms/Avatar/Avatar";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
ArrowsClockwiseIcon,
|
||||
ArrowClockwise,
|
||||
CheckCircleIcon,
|
||||
CheckIcon,
|
||||
CopyIcon,
|
||||
RobotIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useCallback, useState } from "react";
|
||||
import { getToolActionPhrase } from "../../helpers";
|
||||
import { AgentCarouselMessage } from "../AgentCarouselMessage/AgentCarouselMessage";
|
||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
||||
import { AuthPromptWidget } from "../AuthPromptWidget/AuthPromptWidget";
|
||||
import { ChatCredentialsSetup } from "../ChatCredentialsSetup/ChatCredentialsSetup";
|
||||
import { ExecutionStartedMessage } from "../ExecutionStartedMessage/ExecutionStartedMessage";
|
||||
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
||||
import { MessageBubble } from "../MessageBubble/MessageBubble";
|
||||
import { NoResultsMessage } from "../NoResultsMessage/NoResultsMessage";
|
||||
import { ToolCallMessage } from "../ToolCallMessage/ToolCallMessage";
|
||||
import { ToolResponseMessage } from "../ToolResponseMessage/ToolResponseMessage";
|
||||
import { UserChatBubble } from "../UserChatBubble/UserChatBubble";
|
||||
import { useChatMessage, type ChatMessageData } from "./useChatMessage";
|
||||
|
||||
function stripInternalReasoning(content: string): string {
|
||||
const cleaned = content.replace(
|
||||
/<internal_reasoning>[\s\S]*?<\/internal_reasoning>/gi,
|
||||
"",
|
||||
);
|
||||
return cleaned.replace(/\n{3,}/g, "\n\n").trim();
|
||||
}
|
||||
|
||||
function getDisplayContent(message: ChatMessageData, isUser: boolean): string {
|
||||
if (message.type !== "message") return "";
|
||||
if (isUser) return message.content;
|
||||
return stripInternalReasoning(message.content);
|
||||
}
|
||||
|
||||
export interface ChatMessageProps {
|
||||
message: ChatMessageData;
|
||||
messages?: ChatMessageData[];
|
||||
index?: number;
|
||||
isStreaming?: boolean;
|
||||
className?: string;
|
||||
onDismissLogin?: () => void;
|
||||
onDismissCredentials?: () => void;
|
||||
onSendMessage?: (content: string, isUserMessage?: boolean) => void;
|
||||
agentOutput?: ChatMessageData;
|
||||
isFinalMessage?: boolean;
|
||||
}
|
||||
|
||||
export function ChatMessage({
|
||||
message,
|
||||
messages = [],
|
||||
index = -1,
|
||||
isStreaming = false,
|
||||
className,
|
||||
onDismissCredentials,
|
||||
onSendMessage,
|
||||
agentOutput,
|
||||
isFinalMessage = true,
|
||||
}: ChatMessageProps) {
|
||||
const { user } = useSupabase();
|
||||
const router = useRouter();
|
||||
@@ -70,7 +54,14 @@ export function ChatMessage({
|
||||
isLoginNeeded,
|
||||
isCredentialsNeeded,
|
||||
} = useChatMessage(message);
|
||||
const displayContent = getDisplayContent(message, isUser);
|
||||
|
||||
const { data: profile } = useGetV2GetUserProfile({
|
||||
query: {
|
||||
select: (res) => (res.status === 200 ? res.data : null),
|
||||
enabled: isUser && !!user,
|
||||
queryKey: ["/api/store/profile", user?.id],
|
||||
},
|
||||
});
|
||||
|
||||
const handleAllCredentialsComplete = useCallback(
|
||||
function handleAllCredentialsComplete() {
|
||||
@@ -96,25 +87,17 @@ export function ChatMessage({
|
||||
}
|
||||
}
|
||||
|
||||
const handleCopy = useCallback(
|
||||
async function handleCopy() {
|
||||
if (message.type !== "message") return;
|
||||
if (!displayContent) return;
|
||||
const handleCopy = useCallback(async () => {
|
||||
if (message.type !== "message") return;
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(displayContent);
|
||||
setCopied(true);
|
||||
setTimeout(() => setCopied(false), 2000);
|
||||
} catch (error) {
|
||||
console.error("Failed to copy:", error);
|
||||
}
|
||||
},
|
||||
[displayContent, message],
|
||||
);
|
||||
|
||||
function isLongResponse(content: string): boolean {
|
||||
return content.split("\n").length > 5;
|
||||
}
|
||||
try {
|
||||
await navigator.clipboard.writeText(message.content);
|
||||
setCopied(true);
|
||||
setTimeout(() => setCopied(false), 2000);
|
||||
} catch (error) {
|
||||
console.error("Failed to copy:", error);
|
||||
}
|
||||
}, [message]);
|
||||
|
||||
const handleTryAgain = useCallback(() => {
|
||||
if (message.type !== "message" || !onSendMessage) return;
|
||||
@@ -186,45 +169,9 @@ export function ChatMessage({
|
||||
|
||||
// Render tool call messages
|
||||
if (isToolCall && message.type === "tool_call") {
|
||||
// Check if this tool call is currently streaming
|
||||
// A tool call is streaming if:
|
||||
// 1. isStreaming is true
|
||||
// 2. This is the last tool_call message
|
||||
// 3. There's no tool_response for this tool call yet
|
||||
const isToolCallStreaming =
|
||||
isStreaming &&
|
||||
index >= 0 &&
|
||||
(() => {
|
||||
// Find the last tool_call index
|
||||
let lastToolCallIndex = -1;
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
if (messages[i].type === "tool_call") {
|
||||
lastToolCallIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Check if this is the last tool_call and there's no response yet
|
||||
if (index === lastToolCallIndex) {
|
||||
// Check if there's a tool_response for this tool call
|
||||
const hasResponse = messages
|
||||
.slice(index + 1)
|
||||
.some(
|
||||
(msg) =>
|
||||
msg.type === "tool_response" && msg.toolId === message.toolId,
|
||||
);
|
||||
return !hasResponse;
|
||||
}
|
||||
return false;
|
||||
})();
|
||||
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<ToolCallMessage
|
||||
toolId={message.toolId}
|
||||
toolName={message.toolName}
|
||||
arguments={message.arguments}
|
||||
isStreaming={isToolCallStreaming}
|
||||
/>
|
||||
<ToolCallMessage toolName={message.toolName} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -271,11 +218,27 @@ export function ChatMessage({
|
||||
|
||||
// Render tool response messages (but skip agent_output if it's being rendered inside assistant message)
|
||||
if (isToolResponse && message.type === "tool_response") {
|
||||
// Check if this is an agent_output that should be rendered inside assistant message
|
||||
if (message.result) {
|
||||
let parsedResult: Record<string, unknown> | null = null;
|
||||
try {
|
||||
parsedResult =
|
||||
typeof message.result === "string"
|
||||
? JSON.parse(message.result)
|
||||
: (message.result as Record<string, unknown>);
|
||||
} catch {
|
||||
parsedResult = null;
|
||||
}
|
||||
if (parsedResult?.type === "agent_output") {
|
||||
// Skip rendering - this will be rendered inside the assistant message
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<ToolResponseMessage
|
||||
toolId={message.toolId}
|
||||
toolName={message.toolName}
|
||||
toolName={getToolActionPhrase(message.toolName)}
|
||||
result={message.result}
|
||||
/>
|
||||
</div>
|
||||
@@ -293,33 +256,40 @@ export function ChatMessage({
|
||||
)}
|
||||
>
|
||||
<div className="flex w-full max-w-3xl gap-3">
|
||||
{!isUser && (
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-indigo-500">
|
||||
<RobotIcon className="h-4 w-4 text-indigo-50" />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div
|
||||
className={cn(
|
||||
"flex min-w-0 flex-1 flex-col",
|
||||
isUser && "items-end",
|
||||
)}
|
||||
>
|
||||
{isUser ? (
|
||||
<UserChatBubble>
|
||||
<MarkdownContent content={displayContent} />
|
||||
</UserChatBubble>
|
||||
) : (
|
||||
<AIChatBubble>
|
||||
<MarkdownContent content={displayContent} />
|
||||
{agentOutput && agentOutput.type === "tool_response" && (
|
||||
<MessageBubble variant={isUser ? "user" : "assistant"}>
|
||||
<MarkdownContent content={message.content} />
|
||||
{agentOutput &&
|
||||
agentOutput.type === "tool_response" &&
|
||||
!isUser && (
|
||||
<div className="mt-4">
|
||||
<ToolResponseMessage
|
||||
toolId={agentOutput.toolId}
|
||||
toolName={agentOutput.toolName || "Agent Output"}
|
||||
toolName={
|
||||
agentOutput.toolName
|
||||
? getToolActionPhrase(agentOutput.toolName)
|
||||
: "Agent Output"
|
||||
}
|
||||
result={agentOutput.result}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</AIChatBubble>
|
||||
)}
|
||||
</MessageBubble>
|
||||
<div
|
||||
className={cn(
|
||||
"flex gap-0",
|
||||
"mt-1 flex gap-1",
|
||||
isUser ? "justify-end" : "justify-start",
|
||||
)}
|
||||
>
|
||||
@@ -330,40 +300,37 @@ export function ChatMessage({
|
||||
onClick={handleTryAgain}
|
||||
aria-label="Try again"
|
||||
>
|
||||
<ArrowsClockwiseIcon className="size-4 text-zinc-600" />
|
||||
</Button>
|
||||
)}
|
||||
{!isUser && isFinalMessage && isLongResponse(displayContent) && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={handleCopy}
|
||||
aria-label="Copy message"
|
||||
className="p-1"
|
||||
>
|
||||
{copied ? (
|
||||
<CheckIcon className="size-4 text-green-600" />
|
||||
) : (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
className="size-3 text-zinc-600"
|
||||
>
|
||||
<rect width="14" height="14" x="8" y="8" rx="2" ry="2" />
|
||||
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2" />
|
||||
</svg>
|
||||
)}
|
||||
<ArrowClockwise className="size-3 text-neutral-500" />
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={handleCopy}
|
||||
aria-label="Copy message"
|
||||
>
|
||||
{copied ? (
|
||||
<CheckIcon className="size-3 text-green-600" />
|
||||
) : (
|
||||
<CopyIcon className="size-3 text-neutral-500" />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{isUser && (
|
||||
<div className="flex-shrink-0">
|
||||
<Avatar className="h-7 w-7">
|
||||
<AvatarImage
|
||||
src={profile?.avatar_url ?? ""}
|
||||
alt={profile?.username ?? "User"}
|
||||
/>
|
||||
<AvatarFallback className="rounded-lg bg-neutral-200 text-neutral-600">
|
||||
{profile?.username?.charAt(0)?.toUpperCase() || "U"}
|
||||
</AvatarFallback>
|
||||
</Avatar>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
@@ -13,9 +13,10 @@ export function MessageBubble({
|
||||
className,
|
||||
}: MessageBubbleProps) {
|
||||
const userTheme = {
|
||||
bg: "bg-purple-100",
|
||||
border: "border-purple-100",
|
||||
text: "text-slate-900",
|
||||
bg: "bg-slate-900",
|
||||
border: "border-slate-800",
|
||||
gradient: "from-slate-900/30 via-slate-800/20 to-transparent",
|
||||
text: "text-slate-50",
|
||||
};
|
||||
|
||||
const assistantTheme = {
|
||||
@@ -39,7 +40,9 @@ export function MessageBubble({
|
||||
)}
|
||||
>
|
||||
{/* Gradient flare background */}
|
||||
<div className={cn("absolute inset-0 bg-gradient-to-br")} />
|
||||
<div
|
||||
className={cn("absolute inset-0 bg-gradient-to-br", theme.gradient)}
|
||||
/>
|
||||
<div
|
||||
className={cn(
|
||||
"relative z-10 transition-all duration-500 ease-in-out",
|
||||
@@ -0,0 +1,121 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ChatMessage } from "../ChatMessage/ChatMessage";
|
||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||
import { StreamingMessage } from "../StreamingMessage/StreamingMessage";
|
||||
import { ThinkingMessage } from "../ThinkingMessage/ThinkingMessage";
|
||||
import { useMessageList } from "./useMessageList";
|
||||
|
||||
export interface MessageListProps {
|
||||
messages: ChatMessageData[];
|
||||
streamingChunks?: string[];
|
||||
isStreaming?: boolean;
|
||||
className?: string;
|
||||
onStreamComplete?: () => void;
|
||||
onSendMessage?: (content: string) => void;
|
||||
}
|
||||
|
||||
export function MessageList({
|
||||
messages,
|
||||
streamingChunks = [],
|
||||
isStreaming = false,
|
||||
className,
|
||||
onStreamComplete,
|
||||
onSendMessage,
|
||||
}: MessageListProps) {
|
||||
const { messagesEndRef, messagesContainerRef } = useMessageList({
|
||||
messageCount: messages.length,
|
||||
isStreaming,
|
||||
});
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={messagesContainerRef}
|
||||
className={cn(
|
||||
"flex-1 overflow-y-auto",
|
||||
"scrollbar-thin scrollbar-track-transparent scrollbar-thumb-zinc-300",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="mx-auto flex max-w-3xl flex-col py-4">
|
||||
{/* Render all persisted messages */}
|
||||
{messages.map((message, index) => {
|
||||
// Check if current message is an agent_output tool_response
|
||||
// and if previous message is an assistant message
|
||||
let agentOutput: ChatMessageData | undefined;
|
||||
|
||||
if (message.type === "tool_response" && message.result) {
|
||||
let parsedResult: Record<string, unknown> | null = null;
|
||||
try {
|
||||
parsedResult =
|
||||
typeof message.result === "string"
|
||||
? JSON.parse(message.result)
|
||||
: (message.result as Record<string, unknown>);
|
||||
} catch {
|
||||
parsedResult = null;
|
||||
}
|
||||
if (parsedResult?.type === "agent_output") {
|
||||
const prevMessage = messages[index - 1];
|
||||
if (
|
||||
prevMessage &&
|
||||
prevMessage.type === "message" &&
|
||||
prevMessage.role === "assistant"
|
||||
) {
|
||||
// This agent output will be rendered inside the previous assistant message
|
||||
// Skip rendering this message separately
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if next message is an agent_output tool_response to include in current assistant message
|
||||
if (message.type === "message" && message.role === "assistant") {
|
||||
const nextMessage = messages[index + 1];
|
||||
if (
|
||||
nextMessage &&
|
||||
nextMessage.type === "tool_response" &&
|
||||
nextMessage.result
|
||||
) {
|
||||
let parsedResult: Record<string, unknown> | null = null;
|
||||
try {
|
||||
parsedResult =
|
||||
typeof nextMessage.result === "string"
|
||||
? JSON.parse(nextMessage.result)
|
||||
: (nextMessage.result as Record<string, unknown>);
|
||||
} catch {
|
||||
parsedResult = null;
|
||||
}
|
||||
if (parsedResult?.type === "agent_output") {
|
||||
agentOutput = nextMessage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<ChatMessage
|
||||
key={index}
|
||||
message={message}
|
||||
onSendMessage={onSendMessage}
|
||||
agentOutput={agentOutput}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Render thinking message when streaming but no chunks yet */}
|
||||
{isStreaming && streamingChunks.length === 0 && <ThinkingMessage />}
|
||||
|
||||
{/* Render streaming message if active */}
|
||||
{isStreaming && streamingChunks.length > 0 && (
|
||||
<StreamingMessage
|
||||
chunks={streamingChunks}
|
||||
onComplete={onStreamComplete}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Invisible div to scroll to */}
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -81,9 +81,9 @@ export function SessionsDrawer({
|
||||
</Text>
|
||||
</div>
|
||||
) : sessions.length === 0 ? (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<div className="flex items-center justify-center py-8">
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
You don't have previously started chats
|
||||
No sessions found
|
||||
</Text>
|
||||
</div>
|
||||
) : (
|
||||
@@ -1,6 +1,7 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
||||
import { RobotIcon } from "@phosphor-icons/react";
|
||||
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
||||
import { MessageBubble } from "../MessageBubble/MessageBubble";
|
||||
import { useStreamingMessage } from "./useStreamingMessage";
|
||||
|
||||
export interface StreamingMessageProps {
|
||||
@@ -24,10 +25,16 @@ export function StreamingMessage({
|
||||
)}
|
||||
>
|
||||
<div className="flex w-full max-w-3xl gap-3">
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-indigo-600">
|
||||
<RobotIcon className="h-4 w-4 text-indigo-50" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex min-w-0 flex-1 flex-col">
|
||||
<AIChatBubble>
|
||||
<MessageBubble variant="assistant">
|
||||
<MarkdownContent content={displayText} />
|
||||
</AIChatBubble>
|
||||
</MessageBubble>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,70 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { RobotIcon } from "@phosphor-icons/react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { MessageBubble } from "../MessageBubble/MessageBubble";
|
||||
|
||||
export interface ThinkingMessageProps {
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
||||
const [showSlowLoader, setShowSlowLoader] = useState(false);
|
||||
const timerRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (timerRef.current === null) {
|
||||
timerRef.current = setTimeout(() => {
|
||||
setShowSlowLoader(true);
|
||||
}, 8000);
|
||||
}
|
||||
|
||||
return () => {
|
||||
if (timerRef.current) {
|
||||
clearTimeout(timerRef.current);
|
||||
timerRef.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"group relative flex w-full justify-start gap-3 px-4 py-3",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex w-full max-w-3xl gap-3">
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-indigo-500">
|
||||
<RobotIcon className="h-4 w-4 text-indigo-50" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex min-w-0 flex-1 flex-col">
|
||||
<MessageBubble variant="assistant">
|
||||
<div className="transition-all duration-500 ease-in-out">
|
||||
{showSlowLoader ? (
|
||||
<div className="flex flex-col items-center gap-3 py-2">
|
||||
<div className="loader" style={{ flexShrink: 0 }} />
|
||||
<p className="text-sm text-slate-700">
|
||||
Taking a bit longer to think, wait a moment please
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<span
|
||||
className="inline-block bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-clip-text text-transparent"
|
||||
style={{
|
||||
backgroundSize: "200% 100%",
|
||||
animation: "shimmer 2s ease-in-out infinite",
|
||||
}}
|
||||
>
|
||||
Thinking...
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</MessageBubble>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { WrenchIcon } from "@phosphor-icons/react";
|
||||
import { getToolActionPhrase } from "../../helpers";
|
||||
|
||||
export interface ToolCallMessageProps {
|
||||
toolName: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ToolCallMessage({ toolName, className }: ToolCallMessageProps) {
|
||||
return (
|
||||
<div className={cn("flex items-center justify-center gap-2", className)}>
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}...
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import "@/components/contextual/OutputRenderers";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { ToolResult } from "@/types/chat";
|
||||
import { WrenchIcon } from "@phosphor-icons/react";
|
||||
import { getToolActionPhrase } from "../../helpers";
|
||||
|
||||
export interface ToolResponseMessageProps {
|
||||
toolName: string;
|
||||
result?: ToolResult;
|
||||
success?: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ToolResponseMessage({
|
||||
toolName,
|
||||
result,
|
||||
success: _success = true,
|
||||
className,
|
||||
}: ToolResponseMessageProps) {
|
||||
if (!result) {
|
||||
return (
|
||||
<div className={cn("flex items-center justify-center gap-2", className)}>
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}...
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
let parsedResult: Record<string, unknown> | null = null;
|
||||
try {
|
||||
parsedResult =
|
||||
typeof result === "string"
|
||||
? JSON.parse(result)
|
||||
: (result as Record<string, unknown>);
|
||||
} catch {
|
||||
parsedResult = null;
|
||||
}
|
||||
|
||||
if (parsedResult && typeof parsedResult === "object") {
|
||||
const responseType = parsedResult.type as string | undefined;
|
||||
|
||||
if (responseType === "agent_output") {
|
||||
const execution = parsedResult.execution as
|
||||
| {
|
||||
outputs?: Record<string, unknown[]>;
|
||||
}
|
||||
| null
|
||||
| undefined;
|
||||
const outputs = execution?.outputs || {};
|
||||
const message = parsedResult.message as string | undefined;
|
||||
|
||||
return (
|
||||
<div className={cn("space-y-4 px-4 py-2", className)}>
|
||||
<div className="flex items-center gap-2">
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}
|
||||
</Text>
|
||||
</div>
|
||||
{message && (
|
||||
<div className="rounded border p-4">
|
||||
<Text variant="small" className="text-neutral-600">
|
||||
{message}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
{Object.keys(outputs).length > 0 && (
|
||||
<div className="space-y-4">
|
||||
{Object.entries(outputs).map(([outputName, values]) =>
|
||||
values.map((value, index) => {
|
||||
const renderer = globalRegistry.getRenderer(value);
|
||||
if (renderer) {
|
||||
return (
|
||||
<OutputItem
|
||||
key={`${outputName}-${index}`}
|
||||
value={value}
|
||||
renderer={renderer}
|
||||
label={outputName}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div
|
||||
key={`${outputName}-${index}`}
|
||||
className="rounded border p-4"
|
||||
>
|
||||
<Text variant="large-medium" className="mb-2 capitalize">
|
||||
{outputName}
|
||||
</Text>
|
||||
<pre className="overflow-auto text-sm">
|
||||
{JSON.stringify(value, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
);
|
||||
}),
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (responseType === "block_output" && parsedResult.outputs) {
|
||||
const outputs = parsedResult.outputs as Record<string, unknown[]>;
|
||||
|
||||
return (
|
||||
<div className={cn("space-y-4 px-4 py-2", className)}>
|
||||
<div className="flex items-center gap-2">
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="space-y-4">
|
||||
{Object.entries(outputs).map(([outputName, values]) =>
|
||||
values.map((value, index) => {
|
||||
const renderer = globalRegistry.getRenderer(value);
|
||||
if (renderer) {
|
||||
return (
|
||||
<OutputItem
|
||||
key={`${outputName}-${index}`}
|
||||
value={value}
|
||||
renderer={renderer}
|
||||
label={outputName}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div
|
||||
key={`${outputName}-${index}`}
|
||||
className="rounded border p-4"
|
||||
>
|
||||
<Text variant="large-medium" className="mb-2 capitalize">
|
||||
{outputName}
|
||||
</Text>
|
||||
<pre className="overflow-auto text-sm">
|
||||
{JSON.stringify(value, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
);
|
||||
}),
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Handle other response types with a message field (e.g., understanding_updated)
|
||||
if (parsedResult.message && typeof parsedResult.message === "string") {
|
||||
// Format tool name from snake_case to Title Case
|
||||
const formattedToolName = toolName
|
||||
.split("_")
|
||||
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join(" ");
|
||||
|
||||
// Clean up message - remove incomplete user_name references
|
||||
let cleanedMessage = parsedResult.message;
|
||||
// Remove "Updated understanding with: user_name" pattern if user_name is just a placeholder
|
||||
cleanedMessage = cleanedMessage.replace(
|
||||
/Updated understanding with:\s*user_name\.?\s*/gi,
|
||||
"",
|
||||
);
|
||||
// Remove standalone user_name references
|
||||
cleanedMessage = cleanedMessage.replace(/\buser_name\b\.?\s*/gi, "");
|
||||
cleanedMessage = cleanedMessage.trim();
|
||||
|
||||
// Only show message if it has content after cleaning
|
||||
if (!cleanedMessage) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center gap-2 px-4 py-2",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{formattedToolName}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("space-y-2 px-4 py-2", className)}>
|
||||
<div className="flex items-center justify-center gap-2">
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{formattedToolName}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="rounded border p-4">
|
||||
<Text variant="small" className="text-neutral-600">
|
||||
{cleanedMessage}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const renderer = globalRegistry.getRenderer(result);
|
||||
if (renderer) {
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<div className="mb-2 flex items-center gap-2">
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}
|
||||
</Text>
|
||||
</div>
|
||||
<OutputItem value={result} renderer={renderer} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("flex items-center justify-center gap-2", className)}>
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}...
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
/**
|
||||
* Maps internal tool names to user-friendly display names with emojis.
|
||||
* @deprecated Use getToolActionPhrase or getToolCompletionPhrase for status messages
|
||||
*
|
||||
* @param toolName - The internal tool name from the backend
|
||||
* @returns A user-friendly display name with an emoji prefix
|
||||
*/
|
||||
export function getToolDisplayName(toolName: string): string {
|
||||
const toolDisplayNames: Record<string, string> = {
|
||||
find_agent: "🔍 Search Marketplace",
|
||||
get_agent_details: "📋 Get Agent Details",
|
||||
check_credentials: "🔑 Check Credentials",
|
||||
setup_agent: "⚙️ Setup Agent",
|
||||
run_agent: "▶️ Run Agent",
|
||||
get_required_setup_info: "📝 Get Setup Requirements",
|
||||
};
|
||||
return toolDisplayNames[toolName] || toolName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps internal tool names to human-friendly action phrases (present continuous).
|
||||
* Used for tool call messages to indicate what action is currently happening.
|
||||
*
|
||||
* @param toolName - The internal tool name from the backend
|
||||
* @returns A human-friendly action phrase in present continuous tense
|
||||
*/
|
||||
export function getToolActionPhrase(toolName: string): string {
|
||||
const toolActionPhrases: Record<string, string> = {
|
||||
find_agent: "Looking for agents in the marketplace",
|
||||
agent_carousel: "Looking for agents in the marketplace",
|
||||
get_agent_details: "Learning about the agent",
|
||||
check_credentials: "Checking your credentials",
|
||||
setup_agent: "Setting up the agent",
|
||||
execution_started: "Running the agent",
|
||||
run_agent: "Running the agent",
|
||||
get_required_setup_info: "Getting setup requirements",
|
||||
schedule_agent: "Scheduling the agent to run",
|
||||
};
|
||||
|
||||
// Return mapped phrase or generate human-friendly fallback
|
||||
return toolActionPhrases[toolName] || toolName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps internal tool names to human-friendly completion phrases (past tense).
|
||||
* Used for tool response messages to indicate what action was completed.
|
||||
*
|
||||
* @param toolName - The internal tool name from the backend
|
||||
* @returns A human-friendly completion phrase in past tense
|
||||
*/
|
||||
export function getToolCompletionPhrase(toolName: string): string {
|
||||
const toolCompletionPhrases: Record<string, string> = {
|
||||
find_agent: "Finished searching the marketplace",
|
||||
get_agent_details: "Got agent details",
|
||||
check_credentials: "Checked credentials",
|
||||
setup_agent: "Agent setup complete",
|
||||
run_agent: "Agent execution started",
|
||||
get_required_setup_info: "Got setup requirements",
|
||||
};
|
||||
|
||||
// Return mapped phrase or generate human-friendly fallback
|
||||
return (
|
||||
toolCompletionPhrases[toolName] ||
|
||||
`Finished ${toolName.replace(/_/g, " ").replace("...", "")}`
|
||||
);
|
||||
}
|
||||
@@ -1,20 +1,17 @@
|
||||
"use client";
|
||||
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { useChatSession } from "./useChatSession";
|
||||
import { useChatStream } from "./useChatStream";
|
||||
|
||||
interface UseChatArgs {
|
||||
urlSessionId?: string | null;
|
||||
}
|
||||
|
||||
export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
||||
export function useChat() {
|
||||
const hasCreatedSessionRef = useRef(false);
|
||||
const hasClaimedSessionRef = useRef(false);
|
||||
const { user } = useSupabase();
|
||||
const { sendMessage: sendStreamMessage } = useChatStream();
|
||||
const [showLoader, setShowLoader] = useState(false);
|
||||
|
||||
const {
|
||||
session,
|
||||
sessionId: sessionIdFromHook,
|
||||
@@ -22,16 +19,27 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
||||
isLoading,
|
||||
isCreating,
|
||||
error,
|
||||
isSessionNotFound,
|
||||
createSession,
|
||||
claimSession,
|
||||
clearSession: clearSessionBase,
|
||||
loadSession,
|
||||
} = useChatSession({
|
||||
urlSessionId,
|
||||
urlSessionId: null,
|
||||
autoCreate: false,
|
||||
});
|
||||
|
||||
useEffect(
|
||||
function autoCreateSession() {
|
||||
if (!hasCreatedSessionRef.current && !isCreating && !sessionIdFromHook) {
|
||||
hasCreatedSessionRef.current = true;
|
||||
createSession().catch((_err) => {
|
||||
hasCreatedSessionRef.current = false;
|
||||
});
|
||||
}
|
||||
},
|
||||
[isCreating, sessionIdFromHook, createSession],
|
||||
);
|
||||
|
||||
useEffect(
|
||||
function autoClaimSession() {
|
||||
if (
|
||||
@@ -67,17 +75,6 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
||||
],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoading || isCreating) {
|
||||
const timer = setTimeout(() => {
|
||||
setShowLoader(true);
|
||||
}, 300);
|
||||
return () => clearTimeout(timer);
|
||||
} else {
|
||||
setShowLoader(false);
|
||||
}
|
||||
}, [isLoading, isCreating]);
|
||||
|
||||
useEffect(function monitorNetworkStatus() {
|
||||
function handleOnline() {
|
||||
toast.success("Connection restored", {
|
||||
@@ -102,6 +99,7 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
||||
|
||||
function clearSession() {
|
||||
clearSessionBase();
|
||||
hasCreatedSessionRef.current = false;
|
||||
hasClaimedSessionRef.current = false;
|
||||
}
|
||||
|
||||
@@ -111,11 +109,9 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
||||
isLoading,
|
||||
isCreating,
|
||||
error,
|
||||
isSessionNotFound,
|
||||
createSession,
|
||||
clearSession,
|
||||
loadSession,
|
||||
sessionId: sessionIdFromHook,
|
||||
showLoader,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
import {
|
||||
getGetV2GetSessionQueryKey,
|
||||
getGetV2GetSessionQueryOptions,
|
||||
postV2CreateSession,
|
||||
useGetV2GetSession,
|
||||
usePatchV2SessionAssignUser,
|
||||
usePostV2CreateSession,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { isValidUUID } from "@/lib/utils";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
interface UseChatSessionArgs {
|
||||
urlSessionId?: string | null;
|
||||
autoCreate?: boolean;
|
||||
}
|
||||
|
||||
export function useChatSession({
|
||||
urlSessionId,
|
||||
autoCreate = false,
|
||||
}: UseChatSessionArgs = {}) {
|
||||
const queryClient = useQueryClient();
|
||||
const [sessionId, setSessionId] = useState<string | null>(null);
|
||||
const [error, setError] = useState<Error | null>(null);
|
||||
const justCreatedSessionIdRef = useRef<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (urlSessionId) {
|
||||
if (!isValidUUID(urlSessionId)) {
|
||||
console.error("Invalid session ID format:", urlSessionId);
|
||||
toast.error("Invalid session ID", {
|
||||
description:
|
||||
"The session ID in the URL is not valid. Starting a new session...",
|
||||
});
|
||||
setSessionId(null);
|
||||
storage.clean(Key.CHAT_SESSION_ID);
|
||||
return;
|
||||
}
|
||||
setSessionId(urlSessionId);
|
||||
storage.set(Key.CHAT_SESSION_ID, urlSessionId);
|
||||
} else {
|
||||
const storedSessionId = storage.get(Key.CHAT_SESSION_ID);
|
||||
if (storedSessionId) {
|
||||
if (!isValidUUID(storedSessionId)) {
|
||||
console.error("Invalid stored session ID:", storedSessionId);
|
||||
storage.clean(Key.CHAT_SESSION_ID);
|
||||
setSessionId(null);
|
||||
} else {
|
||||
setSessionId(storedSessionId);
|
||||
}
|
||||
} else if (autoCreate) {
|
||||
setSessionId(null);
|
||||
}
|
||||
}
|
||||
}, [urlSessionId, autoCreate]);
|
||||
|
||||
const {
|
||||
mutateAsync: createSessionMutation,
|
||||
isPending: isCreating,
|
||||
error: createError,
|
||||
} = usePostV2CreateSession();
|
||||
|
||||
const {
|
||||
data: sessionData,
|
||||
isLoading: isLoadingSession,
|
||||
error: loadError,
|
||||
refetch,
|
||||
} = useGetV2GetSession(sessionId || "", {
|
||||
query: {
|
||||
enabled: !!sessionId,
|
||||
select: okData,
|
||||
staleTime: Infinity, // Never mark as stale
|
||||
refetchOnMount: false, // Don't refetch on component mount
|
||||
refetchOnWindowFocus: false, // Don't refetch when window regains focus
|
||||
refetchOnReconnect: false, // Don't refetch when network reconnects
|
||||
retry: 1,
|
||||
},
|
||||
});
|
||||
|
||||
const { mutateAsync: claimSessionMutation } = usePatchV2SessionAssignUser();
|
||||
|
||||
const session = useMemo(() => {
|
||||
if (sessionData) return sessionData;
|
||||
|
||||
if (sessionId && justCreatedSessionIdRef.current === sessionId) {
|
||||
return {
|
||||
id: sessionId,
|
||||
user_id: null,
|
||||
messages: [],
|
||||
created_at: new Date().toISOString(),
|
||||
updated_at: new Date().toISOString(),
|
||||
} as SessionDetailResponse;
|
||||
}
|
||||
return null;
|
||||
}, [sessionData, sessionId]);
|
||||
|
||||
const messages = session?.messages || [];
|
||||
const isLoading = isCreating || isLoadingSession;
|
||||
|
||||
useEffect(() => {
|
||||
if (createError) {
|
||||
setError(
|
||||
createError instanceof Error
|
||||
? createError
|
||||
: new Error("Failed to create session"),
|
||||
);
|
||||
} else if (loadError) {
|
||||
setError(
|
||||
loadError instanceof Error
|
||||
? loadError
|
||||
: new Error("Failed to load session"),
|
||||
);
|
||||
} else {
|
||||
setError(null);
|
||||
}
|
||||
}, [createError, loadError]);
|
||||
|
||||
const createSession = useCallback(
|
||||
async function createSession() {
|
||||
try {
|
||||
setError(null);
|
||||
const response = await postV2CreateSession({
|
||||
body: JSON.stringify({}),
|
||||
});
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to create session");
|
||||
}
|
||||
const newSessionId = response.data.id;
|
||||
setSessionId(newSessionId);
|
||||
storage.set(Key.CHAT_SESSION_ID, newSessionId);
|
||||
justCreatedSessionIdRef.current = newSessionId;
|
||||
setTimeout(() => {
|
||||
if (justCreatedSessionIdRef.current === newSessionId) {
|
||||
justCreatedSessionIdRef.current = null;
|
||||
}
|
||||
}, 10000);
|
||||
return newSessionId;
|
||||
} catch (err) {
|
||||
const error =
|
||||
err instanceof Error ? err : new Error("Failed to create session");
|
||||
setError(error);
|
||||
toast.error("Failed to create chat session", {
|
||||
description: error.message,
|
||||
});
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[createSessionMutation],
|
||||
);
|
||||
|
||||
const loadSession = useCallback(
|
||||
async function loadSession(id: string) {
|
||||
try {
|
||||
setError(null);
|
||||
// Invalidate the query cache for this session to force a fresh fetch
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(id),
|
||||
});
|
||||
// Set sessionId after invalidation to ensure the hook refetches
|
||||
setSessionId(id);
|
||||
storage.set(Key.CHAT_SESSION_ID, id);
|
||||
// Force fetch with fresh data (bypass cache)
|
||||
const queryOptions = getGetV2GetSessionQueryOptions(id, {
|
||||
query: {
|
||||
staleTime: 0, // Force fresh fetch
|
||||
retry: 1,
|
||||
},
|
||||
});
|
||||
const result = await queryClient.fetchQuery(queryOptions);
|
||||
if (!result || ("status" in result && result.status !== 200)) {
|
||||
console.warn("Session not found on server, clearing local state");
|
||||
storage.clean(Key.CHAT_SESSION_ID);
|
||||
setSessionId(null);
|
||||
throw new Error("Session not found");
|
||||
}
|
||||
} catch (err) {
|
||||
const error =
|
||||
err instanceof Error ? err : new Error("Failed to load session");
|
||||
setError(error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[queryClient],
|
||||
);
|
||||
|
||||
const refreshSession = useCallback(
|
||||
async function refreshSession() {
|
||||
if (!sessionId) {
|
||||
console.log("[refreshSession] Skipping - no session ID");
|
||||
return;
|
||||
}
|
||||
try {
|
||||
setError(null);
|
||||
await refetch();
|
||||
} catch (err) {
|
||||
const error =
|
||||
err instanceof Error ? err : new Error("Failed to refresh session");
|
||||
setError(error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[sessionId, refetch],
|
||||
);
|
||||
|
||||
const claimSession = useCallback(
|
||||
async function claimSession(id: string) {
|
||||
try {
|
||||
setError(null);
|
||||
await claimSessionMutation({ sessionId: id });
|
||||
if (justCreatedSessionIdRef.current === id) {
|
||||
justCreatedSessionIdRef.current = null;
|
||||
}
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(id),
|
||||
});
|
||||
await refetch();
|
||||
toast.success("Session claimed successfully", {
|
||||
description: "Your chat history has been saved to your account",
|
||||
});
|
||||
} catch (err: unknown) {
|
||||
const error =
|
||||
err instanceof Error ? err : new Error("Failed to claim session");
|
||||
const is404 =
|
||||
(typeof err === "object" &&
|
||||
err !== null &&
|
||||
"status" in err &&
|
||||
err.status === 404) ||
|
||||
(typeof err === "object" &&
|
||||
err !== null &&
|
||||
"response" in err &&
|
||||
typeof err.response === "object" &&
|
||||
err.response !== null &&
|
||||
"status" in err.response &&
|
||||
err.response.status === 404);
|
||||
if (!is404) {
|
||||
setError(error);
|
||||
toast.error("Failed to claim session", {
|
||||
description: error.message || "Unable to claim session",
|
||||
});
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[claimSessionMutation, queryClient, refetch],
|
||||
);
|
||||
|
||||
const clearSession = useCallback(function clearSession() {
|
||||
setSessionId(null);
|
||||
setError(null);
|
||||
storage.clean(Key.CHAT_SESSION_ID);
|
||||
justCreatedSessionIdRef.current = null;
|
||||
}, []);
|
||||
|
||||
return {
|
||||
session,
|
||||
sessionId,
|
||||
messages,
|
||||
isLoading,
|
||||
isCreating,
|
||||
error,
|
||||
createSession,
|
||||
loadSession,
|
||||
refreshSession,
|
||||
claimSession,
|
||||
clearSession,
|
||||
};
|
||||
}
|
||||
@@ -21,8 +21,6 @@ export interface StreamChunk {
|
||||
timestamp?: string;
|
||||
content?: string;
|
||||
message?: string;
|
||||
code?: string;
|
||||
details?: Record<string, unknown>;
|
||||
tool_id?: string;
|
||||
tool_name?: string;
|
||||
arguments?: ToolArguments;
|
||||
@@ -140,18 +138,8 @@ function normalizeStreamChunk(
|
||||
return { type: "stream_end" };
|
||||
case "start":
|
||||
case "text-start":
|
||||
return null;
|
||||
case "tool-input-start":
|
||||
const toolInputStart = chunk as Extract<
|
||||
VercelStreamChunk,
|
||||
{ type: "tool-input-start" }
|
||||
>;
|
||||
return {
|
||||
type: "tool_call_start",
|
||||
tool_id: toolInputStart.toolCallId,
|
||||
tool_name: toolInputStart.toolName,
|
||||
arguments: {},
|
||||
};
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,103 +149,22 @@ export function useChatStream() {
|
||||
const retryCountRef = useRef<number>(0);
|
||||
const retryTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const currentSessionIdRef = useRef<string | null>(null);
|
||||
const requestStartTimeRef = useRef<number | null>(null);
|
||||
|
||||
const stopStreaming = useCallback(
|
||||
(sessionId?: string, force: boolean = false) => {
|
||||
console.log("[useChatStream] stopStreaming called", {
|
||||
hasAbortController: !!abortControllerRef.current,
|
||||
isAborted: abortControllerRef.current?.signal.aborted,
|
||||
currentSessionId: currentSessionIdRef.current,
|
||||
requestedSessionId: sessionId,
|
||||
requestStartTime: requestStartTimeRef.current,
|
||||
timeSinceStart: requestStartTimeRef.current
|
||||
? Date.now() - requestStartTimeRef.current
|
||||
: null,
|
||||
force,
|
||||
stack: new Error().stack,
|
||||
});
|
||||
|
||||
if (
|
||||
sessionId &&
|
||||
currentSessionIdRef.current &&
|
||||
currentSessionIdRef.current !== sessionId
|
||||
) {
|
||||
console.log(
|
||||
"[useChatStream] Session changed, aborting previous stream",
|
||||
{
|
||||
oldSessionId: currentSessionIdRef.current,
|
||||
newSessionId: sessionId,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
const controller = abortControllerRef.current;
|
||||
if (controller) {
|
||||
const timeSinceStart = requestStartTimeRef.current
|
||||
? Date.now() - requestStartTimeRef.current
|
||||
: null;
|
||||
|
||||
if (!force && timeSinceStart !== null && timeSinceStart < 100) {
|
||||
console.log(
|
||||
"[useChatStream] Request just started (<100ms), skipping abort to prevent race condition",
|
||||
{
|
||||
timeSinceStart,
|
||||
},
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const signal = controller.signal;
|
||||
|
||||
if (
|
||||
signal &&
|
||||
typeof signal.aborted === "boolean" &&
|
||||
!signal.aborted
|
||||
) {
|
||||
console.log("[useChatStream] Aborting stream");
|
||||
controller.abort();
|
||||
} else {
|
||||
console.log(
|
||||
"[useChatStream] Stream already aborted or signal invalid",
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
console.log(
|
||||
"[useChatStream] AbortError caught (expected during cleanup)",
|
||||
);
|
||||
} else {
|
||||
console.warn("[useChatStream] Error aborting stream:", error);
|
||||
}
|
||||
} finally {
|
||||
abortControllerRef.current = null;
|
||||
requestStartTimeRef.current = null;
|
||||
}
|
||||
}
|
||||
if (retryTimeoutRef.current) {
|
||||
clearTimeout(retryTimeoutRef.current);
|
||||
retryTimeoutRef.current = null;
|
||||
}
|
||||
setIsStreaming(false);
|
||||
},
|
||||
[],
|
||||
);
|
||||
const stopStreaming = useCallback(() => {
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
if (retryTimeoutRef.current) {
|
||||
clearTimeout(retryTimeoutRef.current);
|
||||
retryTimeoutRef.current = null;
|
||||
}
|
||||
setIsStreaming(false);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
console.log("[useChatStream] Component mounted");
|
||||
return () => {
|
||||
const sessionIdAtUnmount = currentSessionIdRef.current;
|
||||
console.log(
|
||||
"[useChatStream] Component unmounting, calling stopStreaming",
|
||||
{
|
||||
sessionIdAtUnmount,
|
||||
},
|
||||
);
|
||||
stopStreaming(undefined, false);
|
||||
currentSessionIdRef.current = null;
|
||||
stopStreaming();
|
||||
};
|
||||
}, [stopStreaming]);
|
||||
|
||||
@@ -270,32 +177,12 @@ export function useChatStream() {
|
||||
context?: { url: string; content: string },
|
||||
isRetry: boolean = false,
|
||||
) => {
|
||||
console.log("[useChatStream] sendMessage called", {
|
||||
sessionId,
|
||||
message: message.substring(0, 50),
|
||||
isUserMessage,
|
||||
isRetry,
|
||||
stack: new Error().stack,
|
||||
});
|
||||
|
||||
const previousSessionId = currentSessionIdRef.current;
|
||||
stopStreaming(sessionId, true);
|
||||
currentSessionIdRef.current = sessionId;
|
||||
stopStreaming();
|
||||
|
||||
const abortController = new AbortController();
|
||||
abortControllerRef.current = abortController;
|
||||
requestStartTimeRef.current = Date.now();
|
||||
console.log("[useChatStream] Created new AbortController", {
|
||||
sessionId,
|
||||
previousSessionId,
|
||||
requestStartTime: requestStartTimeRef.current,
|
||||
});
|
||||
|
||||
if (abortController.signal.aborted) {
|
||||
console.warn(
|
||||
"[useChatStream] AbortController was aborted before request started",
|
||||
);
|
||||
requestStartTimeRef.current = null;
|
||||
return Promise.reject(new Error("Request aborted"));
|
||||
}
|
||||
|
||||
@@ -323,34 +210,18 @@ export function useChatStream() {
|
||||
signal: abortController.signal,
|
||||
});
|
||||
|
||||
console.info("[useChatStream] Stream response", {
|
||||
sessionId,
|
||||
status: response.status,
|
||||
ok: response.ok,
|
||||
contentType: response.headers.get("content-type"),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.warn("[useChatStream] Stream response error", {
|
||||
sessionId,
|
||||
status: response.status,
|
||||
errorText,
|
||||
});
|
||||
throw new Error(errorText || `HTTP ${response.status}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
console.warn("[useChatStream] Response body is null", { sessionId });
|
||||
throw new Error("Response body is null");
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
let receivedChunkCount = 0;
|
||||
let firstChunkAt: number | null = null;
|
||||
let loggedLineCount = 0;
|
||||
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
let didDispatchStreamEnd = false;
|
||||
@@ -374,13 +245,6 @@ export function useChatStream() {
|
||||
|
||||
if (done) {
|
||||
cleanup();
|
||||
console.info("[useChatStream] Stream closed", {
|
||||
sessionId,
|
||||
receivedChunkCount,
|
||||
timeSinceStart: requestStartTimeRef.current
|
||||
? Date.now() - requestStartTimeRef.current
|
||||
: null,
|
||||
});
|
||||
dispatchStreamEnd();
|
||||
retryCountRef.current = 0;
|
||||
stopStreaming();
|
||||
@@ -395,23 +259,8 @@ export function useChatStream() {
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ")) {
|
||||
const data = line.slice(6);
|
||||
if (loggedLineCount < 3) {
|
||||
console.info("[useChatStream] Raw stream line", {
|
||||
sessionId,
|
||||
data:
|
||||
data.length > 300 ? `${data.slice(0, 300)}...` : data,
|
||||
});
|
||||
loggedLineCount += 1;
|
||||
}
|
||||
if (data === "[DONE]") {
|
||||
cleanup();
|
||||
console.info("[useChatStream] Stream done marker", {
|
||||
sessionId,
|
||||
receivedChunkCount,
|
||||
timeSinceStart: requestStartTimeRef.current
|
||||
? Date.now() - requestStartTimeRef.current
|
||||
: null,
|
||||
});
|
||||
dispatchStreamEnd();
|
||||
retryCountRef.current = 0;
|
||||
stopStreaming();
|
||||
@@ -428,18 +277,6 @@ export function useChatStream() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!firstChunkAt) {
|
||||
firstChunkAt = Date.now();
|
||||
console.info("[useChatStream] First stream chunk", {
|
||||
sessionId,
|
||||
chunkType: chunk.type,
|
||||
timeSinceStart: requestStartTimeRef.current
|
||||
? firstChunkAt - requestStartTimeRef.current
|
||||
: null,
|
||||
});
|
||||
}
|
||||
receivedChunkCount += 1;
|
||||
|
||||
// Call the chunk handler
|
||||
onChunk(chunk);
|
||||
|
||||
@@ -447,13 +284,6 @@ export function useChatStream() {
|
||||
if (chunk.type === "stream_end") {
|
||||
didDispatchStreamEnd = true;
|
||||
cleanup();
|
||||
console.info("[useChatStream] Stream end chunk", {
|
||||
sessionId,
|
||||
receivedChunkCount,
|
||||
timeSinceStart: requestStartTimeRef.current
|
||||
? Date.now() - requestStartTimeRef.current
|
||||
: null,
|
||||
});
|
||||
retryCountRef.current = 0;
|
||||
stopStreaming();
|
||||
resolve();
|
||||
@@ -477,9 +307,6 @@ export function useChatStream() {
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
cleanup();
|
||||
dispatchStreamEnd();
|
||||
stopStreaming();
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -525,10 +352,6 @@ export function useChatStream() {
|
||||
readStream();
|
||||
});
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
setIsStreaming(false);
|
||||
return Promise.resolve();
|
||||
}
|
||||
const streamError =
|
||||
err instanceof Error ? err : new Error("Failed to start stream");
|
||||
setError(streamError);
|
||||
27
autogpt_platform/frontend/src/app/(platform)/chat/page.tsx
Normal file
27
autogpt_platform/frontend/src/app/(platform)/chat/page.tsx
Normal file
@@ -0,0 +1,27 @@
|
||||
"use client";
|
||||
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
import { Chat } from "./components/Chat/Chat";
|
||||
|
||||
export default function ChatPage() {
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const router = useRouter();
|
||||
|
||||
useEffect(() => {
|
||||
if (isChatEnabled === false) {
|
||||
router.push("/marketplace");
|
||||
}
|
||||
}, [isChatEnabled, router]);
|
||||
|
||||
if (isChatEnabled === null || isChatEnabled === false) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
<Chat className="flex-1" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { createContext, useContext, useRef, type ReactNode } from "react";
|
||||
|
||||
interface NewChatContextValue {
|
||||
onNewChatClick: () => void;
|
||||
setOnNewChatClick: (handler?: () => void) => void;
|
||||
performNewChat?: () => void;
|
||||
setPerformNewChat: (handler?: () => void) => void;
|
||||
}
|
||||
|
||||
const NewChatContext = createContext<NewChatContextValue | null>(null);
|
||||
|
||||
export function NewChatProvider({ children }: { children: ReactNode }) {
|
||||
const onNewChatRef = useRef<(() => void) | undefined>();
|
||||
const performNewChatRef = useRef<(() => void) | undefined>();
|
||||
const contextValueRef = useRef<NewChatContextValue>({
|
||||
onNewChatClick() {
|
||||
onNewChatRef.current?.();
|
||||
},
|
||||
setOnNewChatClick(handler?: () => void) {
|
||||
onNewChatRef.current = handler;
|
||||
},
|
||||
performNewChat() {
|
||||
performNewChatRef.current?.();
|
||||
},
|
||||
setPerformNewChat(handler?: () => void) {
|
||||
performNewChatRef.current = handler;
|
||||
},
|
||||
});
|
||||
|
||||
return (
|
||||
<NewChatContext.Provider value={contextValueRef.current}>
|
||||
{children}
|
||||
</NewChatContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useNewChat() {
|
||||
return useContext(NewChatContext);
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
||||
import type { ReactNode } from "react";
|
||||
import { useEffect } from "react";
|
||||
import { useNewChat } from "../../NewChatContext";
|
||||
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
|
||||
import { LoadingState } from "./components/LoadingState/LoadingState";
|
||||
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
||||
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
||||
import { useCopilotShell } from "./useCopilotShell";
|
||||
|
||||
interface Props {
|
||||
children: ReactNode;
|
||||
}
|
||||
|
||||
export function CopilotShell({ children }: Props) {
|
||||
const {
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
isLoading,
|
||||
isLoggedIn,
|
||||
hasActiveSession,
|
||||
sessions,
|
||||
currentSessionId,
|
||||
handleSelectSession,
|
||||
handleOpenDrawer,
|
||||
handleCloseDrawer,
|
||||
handleDrawerOpenChange,
|
||||
handleNewChat,
|
||||
hasNextPage,
|
||||
isFetchingNextPage,
|
||||
fetchNextPage,
|
||||
isReadyToShowContent,
|
||||
} = useCopilotShell();
|
||||
|
||||
const newChatContext = useNewChat();
|
||||
const handleNewChatClickWrapper =
|
||||
newChatContext?.onNewChatClick || handleNewChat;
|
||||
|
||||
useEffect(
|
||||
function registerNewChatHandler() {
|
||||
if (!newChatContext) return;
|
||||
newChatContext.setPerformNewChat(handleNewChat);
|
||||
return function cleanup() {
|
||||
newChatContext.setPerformNewChat(undefined);
|
||||
};
|
||||
},
|
||||
[newChatContext, handleNewChat],
|
||||
);
|
||||
|
||||
if (!isLoggedIn) {
|
||||
return (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<ChatLoader />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex overflow-hidden bg-[#EFEFF0]"
|
||||
style={{ height: `calc(100vh - ${NAVBAR_HEIGHT_PX}px)` }}
|
||||
>
|
||||
{!isMobile && (
|
||||
<DesktopSidebar
|
||||
sessions={sessions}
|
||||
currentSessionId={currentSessionId}
|
||||
isLoading={isLoading}
|
||||
hasNextPage={hasNextPage}
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
onSelectSession={handleSelectSession}
|
||||
onFetchNextPage={fetchNextPage}
|
||||
onNewChat={handleNewChatClickWrapper}
|
||||
hasActiveSession={Boolean(hasActiveSession)}
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className="relative flex min-h-0 flex-1 flex-col">
|
||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||
<div className="flex min-h-0 flex-1 flex-col">
|
||||
{isReadyToShowContent ? children : <LoadingState />}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{isMobile && (
|
||||
<MobileDrawer
|
||||
isOpen={isDrawerOpen}
|
||||
sessions={sessions}
|
||||
currentSessionId={currentSessionId}
|
||||
isLoading={isLoading}
|
||||
hasNextPage={hasNextPage}
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
onSelectSession={handleSelectSession}
|
||||
onFetchNextPage={fetchNextPage}
|
||||
onNewChat={handleNewChatClickWrapper}
|
||||
onClose={handleCloseDrawer}
|
||||
onOpenChange={handleDrawerOpenChange}
|
||||
hasActiveSession={Boolean(hasActiveSession)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Plus } from "@phosphor-icons/react";
|
||||
import { SessionsList } from "../SessionsList/SessionsList";
|
||||
|
||||
interface Props {
|
||||
sessions: SessionSummaryResponse[];
|
||||
currentSessionId: string | null;
|
||||
isLoading: boolean;
|
||||
hasNextPage: boolean;
|
||||
isFetchingNextPage: boolean;
|
||||
onSelectSession: (sessionId: string) => void;
|
||||
onFetchNextPage: () => void;
|
||||
onNewChat: () => void;
|
||||
hasActiveSession: boolean;
|
||||
}
|
||||
|
||||
export function DesktopSidebar({
|
||||
sessions,
|
||||
currentSessionId,
|
||||
isLoading,
|
||||
hasNextPage,
|
||||
isFetchingNextPage,
|
||||
onSelectSession,
|
||||
onFetchNextPage,
|
||||
onNewChat,
|
||||
hasActiveSession,
|
||||
}: Props) {
|
||||
return (
|
||||
<aside className="flex h-full w-80 flex-col border-r border-zinc-100 bg-zinc-50">
|
||||
<div className="shrink-0 px-6 py-4">
|
||||
<Text variant="h3" size="body-medium">
|
||||
Your chats
|
||||
</Text>
|
||||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
|
||||
scrollbarStyles,
|
||||
)}
|
||||
>
|
||||
<SessionsList
|
||||
sessions={sessions}
|
||||
currentSessionId={currentSessionId}
|
||||
isLoading={isLoading}
|
||||
hasNextPage={hasNextPage}
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
onSelectSession={onSelectSession}
|
||||
onFetchNextPage={onFetchNextPage}
|
||||
/>
|
||||
</div>
|
||||
{hasActiveSession && (
|
||||
<div className="shrink-0 bg-zinc-50 p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={onNewChat}
|
||||
className="w-full"
|
||||
leftIcon={<Plus width="1rem" height="1rem" />}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</aside>
|
||||
);
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||
|
||||
export function LoadingState() {
|
||||
return (
|
||||
<div className="flex flex-1 items-center justify-center">
|
||||
<div className="flex flex-col items-center gap-4">
|
||||
<ChatLoader />
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
Loading your chats...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { PlusIcon, X } from "@phosphor-icons/react";
|
||||
import { Drawer } from "vaul";
|
||||
import { SessionsList } from "../SessionsList/SessionsList";
|
||||
|
||||
interface Props {
|
||||
isOpen: boolean;
|
||||
sessions: SessionSummaryResponse[];
|
||||
currentSessionId: string | null;
|
||||
isLoading: boolean;
|
||||
hasNextPage: boolean;
|
||||
isFetchingNextPage: boolean;
|
||||
onSelectSession: (sessionId: string) => void;
|
||||
onFetchNextPage: () => void;
|
||||
onNewChat: () => void;
|
||||
onClose: () => void;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
hasActiveSession: boolean;
|
||||
}
|
||||
|
||||
export function MobileDrawer({
|
||||
isOpen,
|
||||
sessions,
|
||||
currentSessionId,
|
||||
isLoading,
|
||||
hasNextPage,
|
||||
isFetchingNextPage,
|
||||
onSelectSession,
|
||||
onFetchNextPage,
|
||||
onNewChat,
|
||||
onClose,
|
||||
onOpenChange,
|
||||
hasActiveSession,
|
||||
}: Props) {
|
||||
return (
|
||||
<Drawer.Root open={isOpen} onOpenChange={onOpenChange} direction="left">
|
||||
<Drawer.Portal>
|
||||
<Drawer.Overlay className="fixed inset-0 z-[60] bg-black/10 backdrop-blur-sm" />
|
||||
<Drawer.Content className="fixed left-0 top-0 z-[70] flex h-full w-80 flex-col border-r border-zinc-200 bg-zinc-50">
|
||||
<div className="shrink-0 border-b border-zinc-200 p-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<Drawer.Title className="text-lg font-semibold text-zinc-800">
|
||||
Your chats
|
||||
</Drawer.Title>
|
||||
<Button
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label="Close sessions"
|
||||
onClick={onClose}
|
||||
>
|
||||
<X width="1.25rem" height="1.25rem" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
|
||||
scrollbarStyles,
|
||||
)}
|
||||
>
|
||||
<SessionsList
|
||||
sessions={sessions}
|
||||
currentSessionId={currentSessionId}
|
||||
isLoading={isLoading}
|
||||
hasNextPage={hasNextPage}
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
onSelectSession={onSelectSession}
|
||||
onFetchNextPage={onFetchNextPage}
|
||||
/>
|
||||
</div>
|
||||
{hasActiveSession && (
|
||||
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={onNewChat}
|
||||
className="w-full"
|
||||
leftIcon={<PlusIcon width="1rem" height="1rem" />}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</Drawer.Content>
|
||||
</Drawer.Portal>
|
||||
</Drawer.Root>
|
||||
);
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
import { useState } from "react";
|
||||
|
||||
export function useMobileDrawer() {
|
||||
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
||||
|
||||
function handleOpenDrawer() {
|
||||
setIsDrawerOpen(true);
|
||||
}
|
||||
|
||||
function handleCloseDrawer() {
|
||||
setIsDrawerOpen(false);
|
||||
}
|
||||
|
||||
function handleDrawerOpenChange(open: boolean) {
|
||||
setIsDrawerOpen(open);
|
||||
}
|
||||
|
||||
return {
|
||||
isDrawerOpen,
|
||||
handleOpenDrawer,
|
||||
handleCloseDrawer,
|
||||
handleDrawerOpenChange,
|
||||
};
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
||||
import { ListIcon } from "@phosphor-icons/react";
|
||||
|
||||
interface Props {
|
||||
onOpenDrawer: () => void;
|
||||
}
|
||||
|
||||
export function MobileHeader({ onOpenDrawer }: Props) {
|
||||
return (
|
||||
<Button
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label="Open sessions"
|
||||
onClick={onOpenDrawer}
|
||||
className="fixed z-50 bg-white shadow-md"
|
||||
style={{ left: "1rem", top: `${NAVBAR_HEIGHT_PX + 20}px` }}
|
||||
>
|
||||
<ListIcon width="1.25rem" height="1.25rem" />
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { InfiniteList } from "@/components/molecules/InfiniteList/InfiniteList";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { getSessionTitle } from "../../helpers";
|
||||
|
||||
interface Props {
|
||||
sessions: SessionSummaryResponse[];
|
||||
currentSessionId: string | null;
|
||||
isLoading: boolean;
|
||||
hasNextPage: boolean;
|
||||
isFetchingNextPage: boolean;
|
||||
onSelectSession: (sessionId: string) => void;
|
||||
onFetchNextPage: () => void;
|
||||
}
|
||||
|
||||
export function SessionsList({
|
||||
sessions,
|
||||
currentSessionId,
|
||||
isLoading,
|
||||
hasNextPage,
|
||||
isFetchingNextPage,
|
||||
onSelectSession,
|
||||
onFetchNextPage,
|
||||
}: Props) {
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="space-y-1">
|
||||
{Array.from({ length: 5 }).map((_, i) => (
|
||||
<div key={i} className="rounded-lg px-3 py-2.5">
|
||||
<Skeleton className="h-5 w-full" />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (sessions.length === 0) {
|
||||
return (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
You don't have previous chats
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<InfiniteList
|
||||
items={sessions}
|
||||
hasMore={hasNextPage}
|
||||
isFetchingMore={isFetchingNextPage}
|
||||
onEndReached={onFetchNextPage}
|
||||
className="space-y-1"
|
||||
renderItem={(session) => {
|
||||
const isActive = session.id === currentSessionId;
|
||||
return (
|
||||
<button
|
||||
onClick={() => onSelectSession(session.id)}
|
||||
className={cn(
|
||||
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
|
||||
isActive ? "bg-zinc-100" : "hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"font-normal",
|
||||
isActive ? "text-zinc-600" : "text-zinc-800",
|
||||
)}
|
||||
>
|
||||
{getSessionTitle(session)}
|
||||
</Text>
|
||||
</button>
|
||||
);
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
|
||||
const PAGE_SIZE = 50;
|
||||
|
||||
export interface UseSessionsPaginationArgs {
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
||||
const [offset, setOffset] = useState(0);
|
||||
const [accumulatedSessions, setAccumulatedSessions] = useState<
|
||||
SessionSummaryResponse[]
|
||||
>([]);
|
||||
const [totalCount, setTotalCount] = useState<number | null>(null);
|
||||
|
||||
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
|
||||
{ limit: PAGE_SIZE, offset },
|
||||
{
|
||||
query: {
|
||||
enabled: enabled && offset >= 0,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const responseData = okData(data);
|
||||
if (responseData) {
|
||||
const newSessions = responseData.sessions;
|
||||
const total = responseData.total;
|
||||
setTotalCount(total);
|
||||
|
||||
if (offset === 0) {
|
||||
setAccumulatedSessions(newSessions);
|
||||
} else {
|
||||
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
|
||||
}
|
||||
} else if (!enabled) {
|
||||
setAccumulatedSessions([]);
|
||||
setTotalCount(null);
|
||||
}
|
||||
}, [data, offset, enabled]);
|
||||
|
||||
const hasNextPage = useMemo(() => {
|
||||
if (totalCount === null) return false;
|
||||
return accumulatedSessions.length < totalCount;
|
||||
}, [accumulatedSessions.length, totalCount]);
|
||||
|
||||
const areAllSessionsLoaded = useMemo(() => {
|
||||
if (totalCount === null) return false;
|
||||
return (
|
||||
accumulatedSessions.length >= totalCount && !isFetching && !isLoading
|
||||
);
|
||||
}, [accumulatedSessions.length, totalCount, isFetching, isLoading]);
|
||||
|
||||
useEffect(() => {
|
||||
if (
|
||||
hasNextPage &&
|
||||
!isFetching &&
|
||||
!isLoading &&
|
||||
!isError &&
|
||||
totalCount !== null
|
||||
) {
|
||||
setOffset((prev) => prev + PAGE_SIZE);
|
||||
}
|
||||
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
|
||||
|
||||
function fetchNextPage() {
|
||||
if (hasNextPage && !isFetching) {
|
||||
setOffset((prev) => prev + PAGE_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
function reset() {
|
||||
setOffset(0);
|
||||
setAccumulatedSessions([]);
|
||||
setTotalCount(null);
|
||||
}
|
||||
|
||||
return {
|
||||
sessions: accumulatedSessions,
|
||||
isLoading,
|
||||
isFetching,
|
||||
hasNextPage,
|
||||
areAllSessionsLoaded,
|
||||
totalCount,
|
||||
fetchNextPage,
|
||||
reset,
|
||||
};
|
||||
}
|
||||
@@ -1,165 +0,0 @@
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { format, formatDistanceToNow, isToday } from "date-fns";
|
||||
|
||||
export function convertSessionDetailToSummary(
|
||||
session: SessionDetailResponse,
|
||||
): SessionSummaryResponse {
|
||||
return {
|
||||
id: session.id,
|
||||
created_at: session.created_at,
|
||||
updated_at: session.updated_at,
|
||||
title: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
export function filterVisibleSessions(
|
||||
sessions: SessionSummaryResponse[],
|
||||
): SessionSummaryResponse[] {
|
||||
return sessions.filter(
|
||||
(session) => session.updated_at !== session.created_at,
|
||||
);
|
||||
}
|
||||
|
||||
export function getSessionTitle(session: SessionSummaryResponse): string {
|
||||
if (session.title) return session.title;
|
||||
const isNewSession = session.updated_at === session.created_at;
|
||||
if (isNewSession) {
|
||||
const createdDate = new Date(session.created_at);
|
||||
if (isToday(createdDate)) {
|
||||
return "Today";
|
||||
}
|
||||
return format(createdDate, "MMM d, yyyy");
|
||||
}
|
||||
return "Untitled Chat";
|
||||
}
|
||||
|
||||
export function getSessionUpdatedLabel(
|
||||
session: SessionSummaryResponse,
|
||||
): string {
|
||||
if (!session.updated_at) return "";
|
||||
return formatDistanceToNow(new Date(session.updated_at), { addSuffix: true });
|
||||
}
|
||||
|
||||
export function mergeCurrentSessionIntoList(
|
||||
accumulatedSessions: SessionSummaryResponse[],
|
||||
currentSessionId: string | null,
|
||||
currentSessionData: SessionDetailResponse | null | undefined,
|
||||
): SessionSummaryResponse[] {
|
||||
const filteredSessions: SessionSummaryResponse[] = [];
|
||||
|
||||
if (accumulatedSessions.length > 0) {
|
||||
const visibleSessions = filterVisibleSessions(accumulatedSessions);
|
||||
|
||||
if (currentSessionId) {
|
||||
const currentInAll = accumulatedSessions.find(
|
||||
(s) => s.id === currentSessionId,
|
||||
);
|
||||
if (currentInAll) {
|
||||
const isInVisible = visibleSessions.some(
|
||||
(s) => s.id === currentSessionId,
|
||||
);
|
||||
if (!isInVisible) {
|
||||
filteredSessions.push(currentInAll);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
filteredSessions.push(...visibleSessions);
|
||||
}
|
||||
|
||||
if (currentSessionId && currentSessionData) {
|
||||
const isCurrentInList = filteredSessions.some(
|
||||
(s) => s.id === currentSessionId,
|
||||
);
|
||||
if (!isCurrentInList) {
|
||||
const summarySession = convertSessionDetailToSummary(currentSessionData);
|
||||
filteredSessions.unshift(summarySession);
|
||||
}
|
||||
}
|
||||
|
||||
return filteredSessions;
|
||||
}
|
||||
|
||||
export function getCurrentSessionId(
|
||||
searchParams: URLSearchParams,
|
||||
): string | null {
|
||||
return searchParams.get("sessionId");
|
||||
}
|
||||
|
||||
export function shouldAutoSelectSession(
|
||||
areAllSessionsLoaded: boolean,
|
||||
hasAutoSelectedSession: boolean,
|
||||
paramSessionId: string | null,
|
||||
visibleSessions: SessionSummaryResponse[],
|
||||
accumulatedSessions: SessionSummaryResponse[],
|
||||
isLoading: boolean,
|
||||
totalCount: number | null,
|
||||
): {
|
||||
shouldSelect: boolean;
|
||||
sessionIdToSelect: string | null;
|
||||
shouldCreate: boolean;
|
||||
} {
|
||||
if (!areAllSessionsLoaded || hasAutoSelectedSession) {
|
||||
return {
|
||||
shouldSelect: false,
|
||||
sessionIdToSelect: null,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (paramSessionId) {
|
||||
return {
|
||||
shouldSelect: false,
|
||||
sessionIdToSelect: null,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (visibleSessions.length > 0) {
|
||||
return {
|
||||
shouldSelect: true,
|
||||
sessionIdToSelect: visibleSessions[0].id,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (accumulatedSessions.length === 0 && !isLoading && totalCount === 0) {
|
||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: true };
|
||||
}
|
||||
|
||||
if (totalCount === 0) {
|
||||
return {
|
||||
shouldSelect: false,
|
||||
sessionIdToSelect: null,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: false };
|
||||
}
|
||||
|
||||
export function checkReadyToShowContent(
|
||||
areAllSessionsLoaded: boolean,
|
||||
paramSessionId: string | null,
|
||||
accumulatedSessions: SessionSummaryResponse[],
|
||||
isCurrentSessionLoading: boolean,
|
||||
currentSessionData: SessionDetailResponse | null | undefined,
|
||||
hasAutoSelectedSession: boolean,
|
||||
): boolean {
|
||||
if (!areAllSessionsLoaded) return false;
|
||||
|
||||
if (paramSessionId) {
|
||||
const sessionFound = accumulatedSessions.some(
|
||||
(s) => s.id === paramSessionId,
|
||||
);
|
||||
return (
|
||||
sessionFound ||
|
||||
(!isCurrentSessionLoading &&
|
||||
currentSessionData !== undefined &&
|
||||
currentSessionData !== null)
|
||||
);
|
||||
}
|
||||
|
||||
return hasAutoSelectedSession;
|
||||
}
|
||||
@@ -1,172 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
getGetV2ListSessionsQueryKey,
|
||||
useGetV2GetSession,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { usePathname, useRouter, useSearchParams } from "next/navigation";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
||||
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
|
||||
import {
|
||||
checkReadyToShowContent,
|
||||
filterVisibleSessions,
|
||||
getCurrentSessionId,
|
||||
mergeCurrentSessionIntoList,
|
||||
} from "./helpers";
|
||||
|
||||
export function useCopilotShell() {
|
||||
const router = useRouter();
|
||||
const pathname = usePathname();
|
||||
const searchParams = useSearchParams();
|
||||
const queryClient = useQueryClient();
|
||||
const breakpoint = useBreakpoint();
|
||||
const { isLoggedIn } = useSupabase();
|
||||
const isMobile =
|
||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||
|
||||
const isOnHomepage = pathname === "/copilot";
|
||||
const paramSessionId = searchParams.get("sessionId");
|
||||
|
||||
const {
|
||||
isDrawerOpen,
|
||||
handleOpenDrawer,
|
||||
handleCloseDrawer,
|
||||
handleDrawerOpenChange,
|
||||
} = useMobileDrawer();
|
||||
|
||||
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
|
||||
|
||||
const {
|
||||
sessions: accumulatedSessions,
|
||||
isLoading: isSessionsLoading,
|
||||
isFetching: isSessionsFetching,
|
||||
hasNextPage,
|
||||
areAllSessionsLoaded,
|
||||
fetchNextPage,
|
||||
reset: resetPagination,
|
||||
} = useSessionsPagination({
|
||||
enabled: paginationEnabled,
|
||||
});
|
||||
|
||||
const currentSessionId = getCurrentSessionId(searchParams);
|
||||
|
||||
const { data: currentSessionData, isLoading: isCurrentSessionLoading } =
|
||||
useGetV2GetSession(currentSessionId || "", {
|
||||
query: {
|
||||
enabled: !!currentSessionId,
|
||||
select: okData,
|
||||
},
|
||||
});
|
||||
|
||||
const [hasAutoSelectedSession, setHasAutoSelectedSession] = useState(false);
|
||||
const hasAutoSelectedRef = useRef(false);
|
||||
|
||||
// Mark as auto-selected when sessionId is in URL
|
||||
useEffect(() => {
|
||||
if (paramSessionId && !hasAutoSelectedRef.current) {
|
||||
hasAutoSelectedRef.current = true;
|
||||
setHasAutoSelectedSession(true);
|
||||
}
|
||||
}, [paramSessionId]);
|
||||
|
||||
// On homepage without sessionId, mark as ready immediately
|
||||
useEffect(() => {
|
||||
if (isOnHomepage && !paramSessionId && !hasAutoSelectedRef.current) {
|
||||
hasAutoSelectedRef.current = true;
|
||||
setHasAutoSelectedSession(true);
|
||||
}
|
||||
}, [isOnHomepage, paramSessionId]);
|
||||
|
||||
// Invalidate sessions list when navigating to homepage (to show newly created sessions)
|
||||
useEffect(() => {
|
||||
if (isOnHomepage && !paramSessionId) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
}
|
||||
}, [isOnHomepage, paramSessionId, queryClient]);
|
||||
|
||||
// Reset pagination when query becomes disabled
|
||||
const prevPaginationEnabledRef = useRef(paginationEnabled);
|
||||
useEffect(() => {
|
||||
if (prevPaginationEnabledRef.current && !paginationEnabled) {
|
||||
resetPagination();
|
||||
resetAutoSelect();
|
||||
}
|
||||
prevPaginationEnabledRef.current = paginationEnabled;
|
||||
}, [paginationEnabled, resetPagination]);
|
||||
|
||||
const sessions = mergeCurrentSessionIntoList(
|
||||
accumulatedSessions,
|
||||
currentSessionId,
|
||||
currentSessionData,
|
||||
);
|
||||
|
||||
const visibleSessions = filterVisibleSessions(sessions);
|
||||
|
||||
const sidebarSelectedSessionId =
|
||||
isOnHomepage && !paramSessionId ? null : currentSessionId;
|
||||
|
||||
const isReadyToShowContent = isOnHomepage
|
||||
? true
|
||||
: checkReadyToShowContent(
|
||||
areAllSessionsLoaded,
|
||||
paramSessionId,
|
||||
accumulatedSessions,
|
||||
isCurrentSessionLoading,
|
||||
currentSessionData,
|
||||
hasAutoSelectedSession,
|
||||
);
|
||||
|
||||
function handleSelectSession(sessionId: string) {
|
||||
// Navigate using replaceState to avoid full page reload
|
||||
window.history.replaceState(null, "", `/copilot?sessionId=${sessionId}`);
|
||||
// Force a re-render by updating the URL through router
|
||||
router.replace(`/copilot?sessionId=${sessionId}`);
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function handleNewChat() {
|
||||
resetAutoSelect();
|
||||
resetPagination();
|
||||
// Invalidate and refetch sessions list to ensure newly created sessions appear
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
window.history.replaceState(null, "", "/copilot");
|
||||
router.replace("/copilot");
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function resetAutoSelect() {
|
||||
hasAutoSelectedRef.current = false;
|
||||
setHasAutoSelectedSession(false);
|
||||
}
|
||||
|
||||
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
|
||||
|
||||
return {
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
isLoggedIn,
|
||||
hasActiveSession:
|
||||
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
|
||||
isLoading,
|
||||
sessions: visibleSessions,
|
||||
currentSessionId: sidebarSelectedSessionId,
|
||||
handleSelectSession,
|
||||
handleOpenDrawer,
|
||||
handleCloseDrawer,
|
||||
handleDrawerOpenChange,
|
||||
handleNewChat,
|
||||
hasNextPage,
|
||||
isFetchingNextPage: isSessionsFetching,
|
||||
fetchNextPage,
|
||||
isReadyToShowContent,
|
||||
};
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
import type { User } from "@supabase/supabase-js";
|
||||
|
||||
export type PageState =
|
||||
| { type: "welcome" }
|
||||
| { type: "newChat" }
|
||||
| { type: "creating"; prompt: string }
|
||||
| { type: "chat"; sessionId: string; initialPrompt?: string };
|
||||
|
||||
export function getInitialPromptFromState(
|
||||
pageState: PageState,
|
||||
storedInitialPrompt: string | undefined,
|
||||
) {
|
||||
if (storedInitialPrompt) return storedInitialPrompt;
|
||||
if (pageState.type === "creating") return pageState.prompt;
|
||||
if (pageState.type === "chat") return pageState.initialPrompt;
|
||||
}
|
||||
|
||||
export function shouldResetToWelcome(pageState: PageState) {
|
||||
return (
|
||||
pageState.type !== "newChat" &&
|
||||
pageState.type !== "creating" &&
|
||||
pageState.type !== "welcome"
|
||||
);
|
||||
}
|
||||
|
||||
export function getGreetingName(user?: User | null): string {
|
||||
if (!user) return "there";
|
||||
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
||||
const fullName = metadata?.full_name;
|
||||
const name = metadata?.name;
|
||||
if (typeof fullName === "string" && fullName.trim()) {
|
||||
return fullName.split(" ")[0];
|
||||
}
|
||||
if (typeof name === "string" && name.trim()) {
|
||||
return name.split(" ")[0];
|
||||
}
|
||||
if (user.email) {
|
||||
return user.email.split("@")[0];
|
||||
}
|
||||
return "there";
|
||||
}
|
||||
|
||||
export function buildCopilotChatUrl(prompt: string): string {
|
||||
const trimmed = prompt.trim();
|
||||
if (!trimmed) return "/copilot/chat";
|
||||
const encoded = encodeURIComponent(trimmed);
|
||||
return `/copilot/chat?prompt=${encoded}`;
|
||||
}
|
||||
|
||||
export function getQuickActions(): string[] {
|
||||
return [
|
||||
"Show me what I can automate",
|
||||
"Design a custom workflow",
|
||||
"Help me with content creation",
|
||||
];
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
import type { ReactNode } from "react";
|
||||
import { NewChatProvider } from "./NewChatContext";
|
||||
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
|
||||
|
||||
export default function CopilotLayout({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<NewChatProvider>
|
||||
<CopilotShell>{children}</CopilotShell>
|
||||
</NewChatProvider>
|
||||
);
|
||||
}
|
||||
@@ -1,169 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useCopilotPage } from "./useCopilotPage";
|
||||
|
||||
export default function CopilotPage() {
|
||||
const { state, handlers } = useCopilotPage();
|
||||
const {
|
||||
greetingName,
|
||||
quickActions,
|
||||
isLoading,
|
||||
pageState,
|
||||
isNewChatModalOpen,
|
||||
isReady,
|
||||
} = state;
|
||||
const {
|
||||
handleQuickAction,
|
||||
startChatWithPrompt,
|
||||
handleSessionNotFound,
|
||||
handleStreamingChange,
|
||||
handleCancelNewChat,
|
||||
proceedWithNewChat,
|
||||
handleNewChatModalOpen,
|
||||
} = handlers;
|
||||
|
||||
if (!isReady) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Show Chat when we have an active session
|
||||
if (pageState.type === "chat") {
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
<Chat
|
||||
key={pageState.sessionId ?? "welcome"}
|
||||
className="flex-1"
|
||||
urlSessionId={pageState.sessionId}
|
||||
initialPrompt={pageState.initialPrompt}
|
||||
onSessionNotFound={handleSessionNotFound}
|
||||
onStreamingChange={handleStreamingChange}
|
||||
/>
|
||||
<Dialog
|
||||
title="Interrupt current chat?"
|
||||
styling={{ maxWidth: 300, width: "100%" }}
|
||||
controlled={{
|
||||
isOpen: isNewChatModalOpen,
|
||||
set: handleNewChatModalOpen,
|
||||
}}
|
||||
onClose={handleCancelNewChat}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="flex flex-col gap-4">
|
||||
<Text variant="body">
|
||||
The current chat response will be interrupted. Are you sure you
|
||||
want to start a new chat?
|
||||
</Text>
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={handleCancelNewChat}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="primary"
|
||||
onClick={proceedWithNewChat}
|
||||
>
|
||||
Start new chat
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (pageState.type === "newChat") {
|
||||
return (
|
||||
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
||||
<div className="flex flex-col items-center gap-4">
|
||||
<ChatLoader />
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
Loading your chats...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Show loading state while creating session and sending first message
|
||||
if (pageState.type === "creating") {
|
||||
return (
|
||||
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
||||
<div className="flex flex-col items-center gap-4">
|
||||
<ChatLoader />
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
Loading your chats...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Show Welcome screen
|
||||
return (
|
||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
||||
<div className="w-full text-center">
|
||||
{isLoading ? (
|
||||
<div className="mx-auto max-w-2xl">
|
||||
<Skeleton className="mx-auto mb-3 h-8 w-64" />
|
||||
<Skeleton className="mx-auto mb-8 h-6 w-80" />
|
||||
<div className="mb-8">
|
||||
<Skeleton className="mx-auto h-14 w-full rounded-lg" />
|
||||
</div>
|
||||
<div className="flex flex-wrap items-center justify-center gap-3">
|
||||
{Array.from({ length: 4 }).map((_, i) => (
|
||||
<Skeleton key={i} className="h-9 w-48 rounded-md" />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="mx-auto max-w-2xl">
|
||||
<Text
|
||||
variant="h3"
|
||||
className="mb-3 !text-[1.375rem] text-zinc-700"
|
||||
>
|
||||
Hey, <span className="text-violet-600">{greetingName}</span>
|
||||
</Text>
|
||||
<Text variant="h3" className="mb-8 !font-normal">
|
||||
What do you want to automate?
|
||||
</Text>
|
||||
|
||||
<div className="mb-6">
|
||||
<ChatInput
|
||||
onSend={startChatWithPrompt}
|
||||
placeholder='You can search or just ask - e.g. "create a blog post outline"'
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-nowrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||
{quickActions.map((action) => (
|
||||
<Button
|
||||
key={action}
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => handleQuickAction(action)}
|
||||
className="h-auto shrink-0 border-zinc-600 !px-4 !py-2 text-[1rem] text-zinc-600"
|
||||
>
|
||||
{action}
|
||||
</Button>
|
||||
))}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,266 +0,0 @@
|
||||
import { postV2CreateSession } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import {
|
||||
Flag,
|
||||
type FlagValues,
|
||||
useGetFlag,
|
||||
} from "@/services/feature-flags/use-get-flag";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect, useReducer } from "react";
|
||||
import { useNewChat } from "./NewChatContext";
|
||||
import { getGreetingName, getQuickActions, type PageState } from "./helpers";
|
||||
import { useCopilotURLState } from "./useCopilotURLState";
|
||||
|
||||
type CopilotState = {
|
||||
pageState: PageState;
|
||||
isStreaming: boolean;
|
||||
isNewChatModalOpen: boolean;
|
||||
initialPrompts: Record<string, string>;
|
||||
previousSessionId: string | null;
|
||||
};
|
||||
|
||||
type CopilotAction =
|
||||
| { type: "setPageState"; pageState: PageState }
|
||||
| { type: "setStreaming"; isStreaming: boolean }
|
||||
| { type: "setNewChatModalOpen"; isOpen: boolean }
|
||||
| { type: "setInitialPrompt"; sessionId: string; prompt: string }
|
||||
| { type: "setPreviousSessionId"; sessionId: string | null };
|
||||
|
||||
function isSamePageState(next: PageState, current: PageState) {
|
||||
if (next.type !== current.type) return false;
|
||||
if (next.type === "creating" && current.type === "creating") {
|
||||
return next.prompt === current.prompt;
|
||||
}
|
||||
if (next.type === "chat" && current.type === "chat") {
|
||||
return (
|
||||
next.sessionId === current.sessionId &&
|
||||
next.initialPrompt === current.initialPrompt
|
||||
);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
function copilotReducer(
|
||||
state: CopilotState,
|
||||
action: CopilotAction,
|
||||
): CopilotState {
|
||||
if (action.type === "setPageState") {
|
||||
if (isSamePageState(action.pageState, state.pageState)) return state;
|
||||
return { ...state, pageState: action.pageState };
|
||||
}
|
||||
if (action.type === "setStreaming") {
|
||||
if (action.isStreaming === state.isStreaming) return state;
|
||||
return { ...state, isStreaming: action.isStreaming };
|
||||
}
|
||||
if (action.type === "setNewChatModalOpen") {
|
||||
if (action.isOpen === state.isNewChatModalOpen) return state;
|
||||
return { ...state, isNewChatModalOpen: action.isOpen };
|
||||
}
|
||||
if (action.type === "setInitialPrompt") {
|
||||
if (state.initialPrompts[action.sessionId] === action.prompt) return state;
|
||||
return {
|
||||
...state,
|
||||
initialPrompts: {
|
||||
...state.initialPrompts,
|
||||
[action.sessionId]: action.prompt,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (action.type === "setPreviousSessionId") {
|
||||
if (state.previousSessionId === action.sessionId) return state;
|
||||
return { ...state, previousSessionId: action.sessionId };
|
||||
}
|
||||
return state;
|
||||
}
|
||||
|
||||
export function useCopilotPage() {
|
||||
const router = useRouter();
|
||||
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
||||
const { toast } = useToast();
|
||||
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const flags = useFlags<FlagValues>();
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
||||
const isFlagReady =
|
||||
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
||||
|
||||
const [state, dispatch] = useReducer(copilotReducer, {
|
||||
pageState: { type: "welcome" },
|
||||
isStreaming: false,
|
||||
isNewChatModalOpen: false,
|
||||
initialPrompts: {},
|
||||
previousSessionId: null,
|
||||
});
|
||||
|
||||
const newChatContext = useNewChat();
|
||||
const greetingName = getGreetingName(user);
|
||||
const quickActions = getQuickActions();
|
||||
|
||||
function setPageState(pageState: PageState) {
|
||||
dispatch({ type: "setPageState", pageState });
|
||||
}
|
||||
|
||||
function setInitialPrompt(sessionId: string, prompt: string) {
|
||||
dispatch({ type: "setInitialPrompt", sessionId, prompt });
|
||||
}
|
||||
|
||||
function setPreviousSessionId(sessionId: string | null) {
|
||||
dispatch({ type: "setPreviousSessionId", sessionId });
|
||||
}
|
||||
|
||||
const { setUrlSessionId } = useCopilotURLState({
|
||||
pageState: state.pageState,
|
||||
initialPrompts: state.initialPrompts,
|
||||
previousSessionId: state.previousSessionId,
|
||||
setPageState,
|
||||
setInitialPrompt,
|
||||
setPreviousSessionId,
|
||||
});
|
||||
|
||||
useEffect(
|
||||
function registerNewChatHandler() {
|
||||
if (!newChatContext) return;
|
||||
newChatContext.setOnNewChatClick(handleNewChatClick);
|
||||
return function cleanup() {
|
||||
newChatContext.setOnNewChatClick(undefined);
|
||||
};
|
||||
},
|
||||
[newChatContext, handleNewChatClick],
|
||||
);
|
||||
|
||||
useEffect(
|
||||
function transitionNewChatToWelcome() {
|
||||
if (state.pageState.type === "newChat") {
|
||||
function setWelcomeState() {
|
||||
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
|
||||
}
|
||||
|
||||
const timer = setTimeout(setWelcomeState, 300);
|
||||
|
||||
return function cleanup() {
|
||||
clearTimeout(timer);
|
||||
};
|
||||
}
|
||||
},
|
||||
[state.pageState.type],
|
||||
);
|
||||
|
||||
useEffect(
|
||||
function ensureAccess() {
|
||||
if (!isFlagReady) return;
|
||||
if (isChatEnabled === false) {
|
||||
router.replace(homepageRoute);
|
||||
}
|
||||
},
|
||||
[homepageRoute, isChatEnabled, isFlagReady, router],
|
||||
);
|
||||
|
||||
async function startChatWithPrompt(prompt: string) {
|
||||
if (!prompt?.trim()) return;
|
||||
if (state.pageState.type === "creating") return;
|
||||
|
||||
const trimmedPrompt = prompt.trim();
|
||||
dispatch({
|
||||
type: "setPageState",
|
||||
pageState: { type: "creating", prompt: trimmedPrompt },
|
||||
});
|
||||
|
||||
try {
|
||||
const sessionResponse = await postV2CreateSession({
|
||||
body: JSON.stringify({}),
|
||||
});
|
||||
|
||||
if (sessionResponse.status !== 200 || !sessionResponse.data?.id) {
|
||||
throw new Error("Failed to create session");
|
||||
}
|
||||
|
||||
const sessionId = sessionResponse.data.id;
|
||||
|
||||
dispatch({
|
||||
type: "setInitialPrompt",
|
||||
sessionId,
|
||||
prompt: trimmedPrompt,
|
||||
});
|
||||
|
||||
await setUrlSessionId(sessionId, { shallow: false });
|
||||
dispatch({
|
||||
type: "setPageState",
|
||||
pageState: { type: "chat", sessionId, initialPrompt: trimmedPrompt },
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("[CopilotPage] Failed to start chat:", error);
|
||||
toast({ title: "Failed to start chat", variant: "destructive" });
|
||||
Sentry.captureException(error);
|
||||
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
|
||||
}
|
||||
}
|
||||
|
||||
function handleQuickAction(action: string) {
|
||||
startChatWithPrompt(action);
|
||||
}
|
||||
|
||||
function handleSessionNotFound() {
|
||||
router.replace("/copilot");
|
||||
}
|
||||
|
||||
function handleStreamingChange(isStreamingValue: boolean) {
|
||||
dispatch({ type: "setStreaming", isStreaming: isStreamingValue });
|
||||
}
|
||||
|
||||
async function proceedWithNewChat() {
|
||||
dispatch({ type: "setNewChatModalOpen", isOpen: false });
|
||||
if (newChatContext?.performNewChat) {
|
||||
newChatContext.performNewChat();
|
||||
return;
|
||||
}
|
||||
try {
|
||||
await setUrlSessionId(null, { shallow: false });
|
||||
} catch (error) {
|
||||
console.error("[CopilotPage] Failed to clear session:", error);
|
||||
}
|
||||
router.replace("/copilot");
|
||||
}
|
||||
|
||||
function handleCancelNewChat() {
|
||||
dispatch({ type: "setNewChatModalOpen", isOpen: false });
|
||||
}
|
||||
|
||||
function handleNewChatModalOpen(isOpen: boolean) {
|
||||
dispatch({ type: "setNewChatModalOpen", isOpen });
|
||||
}
|
||||
|
||||
function handleNewChatClick() {
|
||||
if (state.isStreaming) {
|
||||
dispatch({ type: "setNewChatModalOpen", isOpen: true });
|
||||
} else {
|
||||
proceedWithNewChat();
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
state: {
|
||||
greetingName,
|
||||
quickActions,
|
||||
isLoading: isUserLoading,
|
||||
pageState: state.pageState,
|
||||
isNewChatModalOpen: state.isNewChatModalOpen,
|
||||
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
|
||||
},
|
||||
handlers: {
|
||||
handleQuickAction,
|
||||
startChatWithPrompt,
|
||||
handleSessionNotFound,
|
||||
handleStreamingChange,
|
||||
handleCancelNewChat,
|
||||
proceedWithNewChat,
|
||||
handleNewChatModalOpen,
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
import { parseAsString, useQueryState } from "nuqs";
|
||||
import { useLayoutEffect } from "react";
|
||||
import {
|
||||
getInitialPromptFromState,
|
||||
type PageState,
|
||||
shouldResetToWelcome,
|
||||
} from "./helpers";
|
||||
|
||||
interface UseCopilotUrlStateArgs {
|
||||
pageState: PageState;
|
||||
initialPrompts: Record<string, string>;
|
||||
previousSessionId: string | null;
|
||||
setPageState: (pageState: PageState) => void;
|
||||
setInitialPrompt: (sessionId: string, prompt: string) => void;
|
||||
setPreviousSessionId: (sessionId: string | null) => void;
|
||||
}
|
||||
|
||||
export function useCopilotURLState({
|
||||
pageState,
|
||||
initialPrompts,
|
||||
previousSessionId,
|
||||
setPageState,
|
||||
setInitialPrompt,
|
||||
setPreviousSessionId,
|
||||
}: UseCopilotUrlStateArgs) {
|
||||
const [urlSessionId, setUrlSessionId] = useQueryState(
|
||||
"sessionId",
|
||||
parseAsString,
|
||||
);
|
||||
|
||||
function syncSessionFromUrl() {
|
||||
if (urlSessionId) {
|
||||
if (pageState.type === "chat" && pageState.sessionId === urlSessionId) {
|
||||
setPreviousSessionId(urlSessionId);
|
||||
return;
|
||||
}
|
||||
|
||||
const storedInitialPrompt = initialPrompts[urlSessionId];
|
||||
const currentInitialPrompt = getInitialPromptFromState(
|
||||
pageState,
|
||||
storedInitialPrompt,
|
||||
);
|
||||
|
||||
if (currentInitialPrompt) {
|
||||
setInitialPrompt(urlSessionId, currentInitialPrompt);
|
||||
}
|
||||
|
||||
setPageState({
|
||||
type: "chat",
|
||||
sessionId: urlSessionId,
|
||||
initialPrompt: currentInitialPrompt,
|
||||
});
|
||||
setPreviousSessionId(urlSessionId);
|
||||
return;
|
||||
}
|
||||
|
||||
const wasInChat = previousSessionId !== null && pageState.type === "chat";
|
||||
setPreviousSessionId(null);
|
||||
if (wasInChat) {
|
||||
setPageState({ type: "newChat" });
|
||||
return;
|
||||
}
|
||||
|
||||
if (shouldResetToWelcome(pageState)) {
|
||||
setPageState({ type: "welcome" });
|
||||
}
|
||||
}
|
||||
|
||||
useLayoutEffect(syncSessionFromUrl, [
|
||||
urlSessionId,
|
||||
pageState.type,
|
||||
previousSessionId,
|
||||
initialPrompts,
|
||||
]);
|
||||
|
||||
return {
|
||||
urlSessionId,
|
||||
setUrlSessionId,
|
||||
};
|
||||
}
|
||||
@@ -1,8 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { Suspense } from "react";
|
||||
import { getErrorDetails } from "./helpers";
|
||||
@@ -11,8 +9,6 @@ function ErrorPageContent() {
|
||||
const searchParams = useSearchParams();
|
||||
const errorMessage = searchParams.get("message");
|
||||
const errorDetails = getErrorDetails(errorMessage);
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
function handleRetry() {
|
||||
// Auth-related errors should redirect to login
|
||||
@@ -29,8 +25,8 @@ function ErrorPageContent() {
|
||||
window.location.reload();
|
||||
}, 2000);
|
||||
} else {
|
||||
// For server/network errors, go to home
|
||||
window.location.href = homepageRoute;
|
||||
// For server/network errors, go to marketplace
|
||||
window.location.href = "/marketplace";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,10 @@ import {
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { ScheduleAgentModal } from "../ScheduleAgentModal/ScheduleAgentModal";
|
||||
import {
|
||||
AIAgentSafetyPopup,
|
||||
useAIAgentSafetyPopup,
|
||||
} from "./components/AIAgentSafetyPopup/AIAgentSafetyPopup";
|
||||
import { ModalHeader } from "./components/ModalHeader/ModalHeader";
|
||||
import { ModalRunSection } from "./components/ModalRunSection/ModalRunSection";
|
||||
import { RunActions } from "./components/RunActions/RunActions";
|
||||
@@ -83,8 +87,18 @@ export function RunAgentModal({
|
||||
|
||||
const [isScheduleModalOpen, setIsScheduleModalOpen] = useState(false);
|
||||
const [hasOverflow, setHasOverflow] = useState(false);
|
||||
const [isSafetyPopupOpen, setIsSafetyPopupOpen] = useState(false);
|
||||
const [pendingRunAction, setPendingRunAction] = useState<(() => void) | null>(
|
||||
null,
|
||||
);
|
||||
const contentRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const { shouldShowPopup, dismissPopup } = useAIAgentSafetyPopup(
|
||||
agent.id,
|
||||
agent.has_sensitive_action,
|
||||
agent.has_human_in_the_loop,
|
||||
);
|
||||
|
||||
const hasAnySetupFields =
|
||||
Object.keys(agentInputFields || {}).length > 0 ||
|
||||
Object.keys(agentCredentialsInputFields || {}).length > 0;
|
||||
@@ -165,6 +179,24 @@ export function RunAgentModal({
|
||||
onScheduleCreated?.(schedule);
|
||||
}
|
||||
|
||||
function handleRunWithSafetyCheck() {
|
||||
if (shouldShowPopup) {
|
||||
setPendingRunAction(() => handleRun);
|
||||
setIsSafetyPopupOpen(true);
|
||||
} else {
|
||||
handleRun();
|
||||
}
|
||||
}
|
||||
|
||||
function handleSafetyPopupAcknowledge() {
|
||||
setIsSafetyPopupOpen(false);
|
||||
dismissPopup();
|
||||
if (pendingRunAction) {
|
||||
pendingRunAction();
|
||||
setPendingRunAction(null);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Dialog
|
||||
@@ -180,7 +212,7 @@ export function RunAgentModal({
|
||||
|
||||
{/* Content */}
|
||||
{hasAnySetupFields ? (
|
||||
<div className="mt-4 pb-10">
|
||||
<div className="mt-10 pb-32">
|
||||
<RunAgentModalContextProvider
|
||||
value={{
|
||||
agent,
|
||||
@@ -248,7 +280,7 @@ export function RunAgentModal({
|
||||
)}
|
||||
<RunActions
|
||||
defaultRunType={defaultRunType}
|
||||
onRun={handleRun}
|
||||
onRun={handleRunWithSafetyCheck}
|
||||
isExecuting={isExecuting}
|
||||
isSettingUpTrigger={isSettingUpTrigger}
|
||||
isRunReady={allRequiredInputsAreSet}
|
||||
@@ -266,6 +298,12 @@ export function RunAgentModal({
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
|
||||
<AIAgentSafetyPopup
|
||||
agentId={agent.id}
|
||||
isOpen={isSafetyPopupOpen}
|
||||
onAcknowledge={handleSafetyPopupAcknowledge}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user