mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-23 05:57:58 -05:00
Compare commits
47 Commits
abhi/marke
...
feat/sensi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e95a9273e6 | ||
|
|
abd4a2b0b3 | ||
|
|
3a96455db5 | ||
|
|
4c68b87e00 | ||
|
|
f4296f8764 | ||
|
|
c0547bb1ed | ||
|
|
a396155d63 | ||
|
|
4c47b4df76 | ||
|
|
c46e563e6a | ||
|
|
a80f452ffe | ||
|
|
fd970c800c | ||
|
|
5fc1ec0ece | ||
|
|
9be3ec58ae | ||
|
|
e6ca904326 | ||
|
|
dbb56fa7aa | ||
|
|
0111820f61 | ||
|
|
1a1b1aa26d | ||
|
|
614ed8cf82 | ||
|
|
edd4c96aa6 | ||
|
|
cd231e2d69 | ||
|
|
399c472623 | ||
|
|
554e2beddf | ||
|
|
29fdda3fa8 | ||
|
|
67e6a8841c | ||
|
|
aea97db485 | ||
|
|
71a6969bbd | ||
|
|
e4c3f9995b | ||
|
|
3b58684abc | ||
|
|
e8d44a62fd | ||
|
|
be024da2a8 | ||
|
|
0df917e243 | ||
|
|
8688805a8c | ||
|
|
9bdda7dab0 | ||
|
|
7d377aabaa | ||
|
|
dfd7c64068 | ||
|
|
02089bc047 | ||
|
|
bed7b356bb | ||
|
|
4efc0ff502 | ||
|
|
4ad0528257 | ||
|
|
2f440ee80a | ||
|
|
2a55923ec0 | ||
|
|
ad50f57a2b | ||
|
|
aebd961ef5 | ||
|
|
bcccaa16cc | ||
|
|
d5ddc41b18 | ||
|
|
95eab5b7eb | ||
|
|
832d6e1696 |
38
.github/workflows/platform-frontend-ci.yml
vendored
38
.github/workflows/platform-frontend-ci.yml
vendored
@@ -128,7 +128,7 @@ jobs:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
exitOnceUploaded: true
|
||||
|
||||
e2e_test:
|
||||
test:
|
||||
runs-on: big-boi
|
||||
needs: setup
|
||||
strategy:
|
||||
@@ -258,39 +258,3 @@ jobs:
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.yml logs
|
||||
|
||||
integration_test:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Generate API client
|
||||
run: pnpm generate:api
|
||||
|
||||
- name: Run Integration Tests
|
||||
run: pnpm test:unit
|
||||
|
||||
26
AGENTS.md
26
AGENTS.md
@@ -16,32 +16,6 @@ See `docs/content/platform/getting-started.md` for setup instructions.
|
||||
- Format Python code with `poetry run format`.
|
||||
- Format frontend code using `pnpm format`.
|
||||
|
||||
|
||||
## Frontend guidelines:
|
||||
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Add `usePageName.ts` hook for logic
|
||||
- Put sub-components in local `components/` folder
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
||||
- Use function declarations for components, arrow functions only for callbacks
|
||||
- No barrel files or `index.ts` re-exports
|
||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
||||
- Avoid comments at all times unless the code is very complex
|
||||
|
||||
## Testing
|
||||
|
||||
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
|
||||
|
||||
@@ -201,7 +201,7 @@ If you get any pushback or hit complex block conditions check the new_blocks gui
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
### Frontend guidelines:
|
||||
**Frontend feature development:**
|
||||
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
@@ -217,14 +217,6 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
||||
- Use function declarations for components, arrow functions only for callbacks
|
||||
- No barrel files or `index.ts` re-exports
|
||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
||||
- Avoid comments at all times unless the code is very complex
|
||||
|
||||
### Security Implementation
|
||||
|
||||
|
||||
@@ -290,11 +290,6 @@ async def _cache_session(session: ChatSession) -> None:
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def cache_chat_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session without persisting to the database."""
|
||||
await _cache_session(session)
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
|
||||
@@ -172,12 +172,12 @@ async def get_session(
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, or None if not found.
|
||||
SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.
|
||||
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
logger.info(
|
||||
@@ -222,8 +222,6 @@ async def stream_chat_post(
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
@@ -232,26 +230,7 @@ async def stream_chat_post(
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
@@ -296,8 +275,6 @@ async def stream_chat_get(
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
message,
|
||||
@@ -305,26 +282,7 @@ async def stream_chat_get(
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@@ -1,20 +1,12 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from asyncio import CancelledError
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from langfuse import get_client, propagate_attributes
|
||||
from langfuse.openai import openai # type: ignore
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
APIStatusError,
|
||||
PermissionDeniedError,
|
||||
RateLimitError,
|
||||
)
|
||||
from openai import APIConnectionError, APIError, APIStatusError, RateLimitError
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||
|
||||
from backend.data.understanding import (
|
||||
@@ -29,7 +21,6 @@ from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
Usage,
|
||||
cache_chat_session,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
@@ -305,10 +296,6 @@ async def stream_chat_completion(
|
||||
content="",
|
||||
)
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
has_saved_assistant_message = False
|
||||
has_appended_streaming_message = False
|
||||
last_cache_time = 0.0
|
||||
last_cache_content_len = 0
|
||||
|
||||
# Wrap main logic in try/finally to ensure Langfuse observations are always ended
|
||||
has_yielded_end = False
|
||||
@@ -345,23 +332,6 @@ async def stream_chat_completion(
|
||||
assert assistant_response.content is not None
|
||||
assistant_response.content += delta
|
||||
has_received_text = True
|
||||
if not has_appended_streaming_message:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_streaming_message = True
|
||||
current_time = time.monotonic()
|
||||
content_len = len(assistant_response.content)
|
||||
if (
|
||||
current_time - last_cache_time >= 1.0
|
||||
and content_len > last_cache_content_len
|
||||
):
|
||||
try:
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to cache partial session {session.session_id}: {e}"
|
||||
)
|
||||
last_cache_time = current_time
|
||||
last_cache_content_len = content_len
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamTextEnd):
|
||||
# Emit text-end after text completes
|
||||
@@ -420,42 +390,10 @@ async def stream_chat_completion(
|
||||
if has_received_text and not text_streaming_ended:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_streaming_ended = True
|
||||
|
||||
# Save assistant message before yielding finish to ensure it's persisted
|
||||
# even if client disconnects immediately after receiving StreamFinish
|
||||
if not has_saved_assistant_message:
|
||||
messages_to_save_early: list[ChatMessage] = []
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = (
|
||||
accumulated_tool_calls
|
||||
)
|
||||
if not has_appended_streaming_message and (
|
||||
assistant_response.content
|
||||
or assistant_response.tool_calls
|
||||
):
|
||||
messages_to_save_early.append(assistant_response)
|
||||
messages_to_save_early.extend(tool_response_messages)
|
||||
|
||||
if messages_to_save_early:
|
||||
session.messages.extend(messages_to_save_early)
|
||||
logger.info(
|
||||
f"Saving assistant message before StreamFinish: "
|
||||
f"content_len={len(assistant_response.content or '')}, "
|
||||
f"tool_calls={len(assistant_response.tool_calls or [])}, "
|
||||
f"tool_responses={len(tool_response_messages)}"
|
||||
)
|
||||
if (
|
||||
messages_to_save_early
|
||||
or has_appended_streaming_message
|
||||
):
|
||||
await upsert_chat_session(session)
|
||||
has_saved_assistant_message = True
|
||||
|
||||
has_yielded_end = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamError):
|
||||
has_yielded_error = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamUsage):
|
||||
session.usage.append(
|
||||
Usage(
|
||||
@@ -475,27 +413,6 @@ async def stream_chat_completion(
|
||||
langfuse.update_current_trace(output=str(tool_response_messages))
|
||||
langfuse.update_current_span(output=str(tool_response_messages))
|
||||
|
||||
except CancelledError:
|
||||
if not has_saved_assistant_message:
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if assistant_response.content:
|
||||
assistant_response.content = (
|
||||
f"{assistant_response.content}\n\n[interrupted]"
|
||||
)
|
||||
else:
|
||||
assistant_response.content = "[interrupted]"
|
||||
if not has_appended_streaming_message:
|
||||
session.messages.append(assistant_response)
|
||||
if tool_response_messages:
|
||||
session.messages.extend(tool_response_messages)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to save interrupted session {session.session_id}: {e}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error during stream: {e!s}", exc_info=True)
|
||||
|
||||
@@ -517,19 +434,14 @@ async def stream_chat_completion(
|
||||
# Add assistant message if it has content or tool calls
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if not has_appended_streaming_message and (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
):
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_save.append(assistant_response)
|
||||
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
|
||||
if not has_saved_assistant_message:
|
||||
if messages_to_save:
|
||||
session.messages.extend(messages_to_save)
|
||||
if messages_to_save or has_appended_streaming_message:
|
||||
await upsert_chat_session(session)
|
||||
session.messages.extend(messages_to_save)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
if not has_yielded_error:
|
||||
error_message = str(e)
|
||||
@@ -560,49 +472,38 @@ async def stream_chat_completion(
|
||||
return # Exit after retry to avoid double-saving in finally block
|
||||
|
||||
# Normal completion path - save session and handle tool call continuation
|
||||
# Only save if we haven't already saved when StreamFinish was received
|
||||
if not has_saved_assistant_message:
|
||||
logger.info(
|
||||
f"Normal completion path: session={session.session_id}, "
|
||||
f"current message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
# Build the messages list in the correct order
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
|
||||
# Add assistant message with tool_calls if any
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
logger.info(
|
||||
f"Normal completion path: session={session.session_id}, "
|
||||
f"current message_count={len(session.messages)}"
|
||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||
)
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_save.append(assistant_response)
|
||||
logger.info(
|
||||
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
|
||||
)
|
||||
|
||||
# Build the messages list in the correct order
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
logger.info(
|
||||
f"Saving {len(tool_response_messages)} tool response messages, "
|
||||
f"total_to_save={len(messages_to_save)}"
|
||||
)
|
||||
|
||||
# Add assistant message with tool_calls if any
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
logger.info(
|
||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||
)
|
||||
if not has_appended_streaming_message and (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
):
|
||||
messages_to_save.append(assistant_response)
|
||||
logger.info(
|
||||
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
|
||||
)
|
||||
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
logger.info(
|
||||
f"Saving {len(tool_response_messages)} tool response messages, "
|
||||
f"total_to_save={len(messages_to_save)}"
|
||||
)
|
||||
|
||||
if messages_to_save:
|
||||
session.messages.extend(messages_to_save)
|
||||
logger.info(
|
||||
f"Extended session messages, new message_count={len(session.messages)}"
|
||||
)
|
||||
if messages_to_save or has_appended_streaming_message:
|
||||
await upsert_chat_session(session)
|
||||
else:
|
||||
logger.info(
|
||||
"Assistant message already saved when StreamFinish was received, "
|
||||
"skipping duplicate save"
|
||||
)
|
||||
session.messages.extend(messages_to_save)
|
||||
logger.info(
|
||||
f"Extended session messages, new message_count={len(session.messages)}"
|
||||
)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
# If we did a tool call, stream the chat completion again to get the next response
|
||||
if has_done_tool_call:
|
||||
@@ -644,12 +545,6 @@ def _is_retryable_error(error: Exception) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _is_region_blocked_error(error: Exception) -> bool:
|
||||
if isinstance(error, PermissionDeniedError):
|
||||
return "not available in your region" in str(error).lower()
|
||||
return "not available in your region" in str(error).lower()
|
||||
|
||||
|
||||
async def _stream_chat_chunks(
|
||||
session: ChatSession,
|
||||
tools: list[ChatCompletionToolParam],
|
||||
@@ -842,18 +737,7 @@ async def _stream_chat_chunks(
|
||||
f"Error in stream (not retrying): {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
error_code = None
|
||||
error_text = str(e)
|
||||
if _is_region_blocked_error(e):
|
||||
error_code = "MODEL_NOT_AVAILABLE_REGION"
|
||||
error_text = (
|
||||
"This model is not available in your region. "
|
||||
"Please connect via VPN and try again."
|
||||
)
|
||||
error_response = StreamError(
|
||||
errorText=error_text,
|
||||
code=error_code,
|
||||
)
|
||||
error_response = StreamError(errorText=str(e))
|
||||
yield error_response
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
@@ -107,6 +107,13 @@ class ReviewItem(BaseModel):
|
||||
reviewed_data: SafeJsonData | None = Field(
|
||||
None, description="Optional edited data (ignored if approved=False)"
|
||||
)
|
||||
auto_approve_future: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true and this review is approved, future executions of this same "
|
||||
"block (node) will be automatically approved. This only affects approved reviews."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("reviewed_data")
|
||||
@classmethod
|
||||
@@ -174,6 +181,9 @@ class ReviewRequest(BaseModel):
|
||||
This request must include ALL pending reviews for a graph execution.
|
||||
Each review will be either approved (with optional data modifications)
|
||||
or rejected (data ignored). The execution will resume only after ALL reviews are processed.
|
||||
|
||||
Each review item can individually specify whether to auto-approve future executions
|
||||
of the same block via the `auto_approve_future` field on ReviewItem.
|
||||
"""
|
||||
|
||||
reviews: List[ReviewItem] = Field(
|
||||
|
||||
@@ -8,6 +8,12 @@ from prisma.enums import ReviewStatus
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.api.rest_api import handle_internal_http_error
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
ExecutionStatus,
|
||||
NodeExecutionResult,
|
||||
)
|
||||
from backend.data.graph import GraphSettings
|
||||
|
||||
from .model import PendingHumanReviewModel
|
||||
from .routes import router
|
||||
@@ -15,20 +21,24 @@ from .routes import router
|
||||
# Using a fixed timestamp for reproducible tests
|
||||
FIXED_NOW = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router, prefix="/api/review")
|
||||
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create FastAPI app for testing"""
|
||||
test_app = fastapi.FastAPI()
|
||||
test_app.include_router(router, prefix="/api/review")
|
||||
test_app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
return test_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
@pytest.fixture
|
||||
def client(app, mock_jwt_user):
|
||||
"""Create test client with auth overrides"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
with fastapi.testclient.TestClient(app) as test_client:
|
||||
yield test_client
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@@ -55,6 +65,7 @@ def sample_pending_review(test_user_id: str) -> PendingHumanReviewModel:
|
||||
|
||||
|
||||
def test_get_pending_reviews_empty(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
@@ -73,6 +84,7 @@ def test_get_pending_reviews_empty(
|
||||
|
||||
|
||||
def test_get_pending_reviews_with_data(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
snapshot: Snapshot,
|
||||
@@ -95,6 +107,7 @@ def test_get_pending_reviews_with_data(
|
||||
|
||||
|
||||
def test_get_pending_reviews_for_execution_success(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
snapshot: Snapshot,
|
||||
@@ -123,6 +136,7 @@ def test_get_pending_reviews_for_execution_success(
|
||||
|
||||
|
||||
def test_get_pending_reviews_for_execution_not_available(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Test access denied when user doesn't own the execution"""
|
||||
@@ -138,6 +152,7 @@ def test_get_pending_reviews_for_execution_not_available(
|
||||
|
||||
|
||||
def test_process_review_action_approve_success(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
@@ -145,6 +160,12 @@ def test_process_review_action_approve_success(
|
||||
"""Test successful review approval"""
|
||||
# Mock the route functions
|
||||
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = sample_pending_review
|
||||
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
@@ -173,6 +194,14 @@ def test_process_review_action_approve_success(
|
||||
)
|
||||
mock_process_all_reviews.return_value = {"test_node_123": approved_review}
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
@@ -202,6 +231,7 @@ def test_process_review_action_approve_success(
|
||||
|
||||
|
||||
def test_process_review_action_reject_success(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
@@ -209,6 +239,20 @@ def test_process_review_action_reject_success(
|
||||
"""Test successful review rejection"""
|
||||
# Mock the route functions
|
||||
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = sample_pending_review
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
@@ -262,6 +306,7 @@ def test_process_review_action_reject_success(
|
||||
|
||||
|
||||
def test_process_review_action_mixed_success(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
@@ -288,6 +333,12 @@ def test_process_review_action_mixed_success(
|
||||
|
||||
# Mock the route functions
|
||||
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = sample_pending_review
|
||||
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
@@ -337,6 +388,14 @@ def test_process_review_action_mixed_success(
|
||||
"test_node_456": rejected_review,
|
||||
}
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
@@ -369,6 +428,7 @@ def test_process_review_action_mixed_success(
|
||||
|
||||
|
||||
def test_process_review_action_empty_request(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
@@ -386,10 +446,45 @@ def test_process_review_action_empty_request(
|
||||
|
||||
|
||||
def test_process_review_action_review_not_found(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error when review is not found"""
|
||||
# Create a review with the nonexistent_node ID so the route can find the graph_exec_id
|
||||
nonexistent_review = PendingHumanReviewModel(
|
||||
node_exec_id="nonexistent_node",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
payload={"data": "test"},
|
||||
instructions="Review",
|
||||
editable=True,
|
||||
status=ReviewStatus.WAITING,
|
||||
review_message=None,
|
||||
was_edited=None,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=None,
|
||||
reviewed_at=None,
|
||||
)
|
||||
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = nonexistent_review
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
# Mock the functions that extract graph execution ID from the request
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
@@ -422,11 +517,26 @@ def test_process_review_action_review_not_found(
|
||||
|
||||
|
||||
def test_process_review_action_partial_failure(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test handling of partial failures in review processing"""
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = sample_pending_review
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
# Mock the route functions
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
@@ -456,16 +566,50 @@ def test_process_review_action_partial_failure(
|
||||
|
||||
|
||||
def test_process_review_action_invalid_node_exec_id(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test failure when trying to process review with invalid node execution ID"""
|
||||
# Create a review with the invalid-node-format ID so the route can find the graph_exec_id
|
||||
invalid_review = PendingHumanReviewModel(
|
||||
node_exec_id="invalid-node-format",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
payload={"data": "test"},
|
||||
instructions="Review",
|
||||
editable=True,
|
||||
status=ReviewStatus.WAITING,
|
||||
review_message=None,
|
||||
was_edited=None,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=None,
|
||||
reviewed_at=None,
|
||||
)
|
||||
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = invalid_review
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
# Mock the route functions
|
||||
mock_get_reviews_for_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_reviews_for_execution"
|
||||
)
|
||||
mock_get_reviews_for_execution.return_value = [sample_pending_review]
|
||||
mock_get_reviews_for_execution.return_value = [invalid_review]
|
||||
|
||||
# Mock validation failure - this should return 400, not 500
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
@@ -490,3 +634,595 @@ def test_process_review_action_invalid_node_exec_id(
|
||||
# Should be a 400 Bad Request, not 500 Internal Server Error
|
||||
assert response.status_code == 400
|
||||
assert "Invalid node execution ID format" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_process_review_action_auto_approve_creates_auto_approval_records(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test that auto_approve_future_actions flag creates auto-approval records"""
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = sample_pending_review
|
||||
|
||||
# Mock process_all_reviews
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
approved_review = PendingHumanReviewModel(
|
||||
node_exec_id="test_node_123",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
payload={"data": "test payload"},
|
||||
instructions="Please review",
|
||||
editable=True,
|
||||
status=ReviewStatus.APPROVED,
|
||||
review_message="Approved",
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=FIXED_NOW,
|
||||
reviewed_at=FIXED_NOW,
|
||||
)
|
||||
mock_process_all_reviews.return_value = {"test_node_123": approved_review}
|
||||
|
||||
# Mock get_node_execution to return node_id
|
||||
mock_get_node_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_node_execution"
|
||||
)
|
||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||
mock_node_exec.node_id = "test_node_def_456"
|
||||
mock_get_node_execution.return_value = mock_node_exec
|
||||
|
||||
# Mock create_auto_approval_record
|
||||
mock_create_auto_approval = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.create_auto_approval_record"
|
||||
)
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
# Mock has_pending_reviews_for_graph_exec
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
# Mock get_graph_settings to return custom settings
|
||||
mock_get_settings = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_settings"
|
||||
)
|
||||
mock_get_settings.return_value = GraphSettings(
|
||||
human_in_the_loop_safe_mode=True,
|
||||
sensitive_action_safe_mode=True,
|
||||
)
|
||||
|
||||
# Mock get_user_by_id to prevent database access
|
||||
mock_get_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_user_by_id"
|
||||
)
|
||||
mock_user = mocker.Mock()
|
||||
mock_user.timezone = "UTC"
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock add_graph_execution
|
||||
mock_add_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.add_graph_execution"
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"reviews": [
|
||||
{
|
||||
"node_exec_id": "test_node_123",
|
||||
"approved": True,
|
||||
"message": "Approved",
|
||||
"auto_approve_future": True,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
response = client.post("/api/review/action", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify process_all_reviews_for_execution was called (without auto_approve param)
|
||||
mock_process_all_reviews.assert_called_once()
|
||||
|
||||
# Verify create_auto_approval_record was called for the approved review
|
||||
mock_create_auto_approval.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
node_id="test_node_def_456",
|
||||
payload={"data": "test payload"},
|
||||
)
|
||||
|
||||
# Verify get_graph_settings was called with correct parameters
|
||||
mock_get_settings.assert_called_once_with(
|
||||
user_id=test_user_id, graph_id="test_graph_789"
|
||||
)
|
||||
|
||||
# Verify add_graph_execution was called with proper ExecutionContext
|
||||
mock_add_execution.assert_called_once()
|
||||
call_kwargs = mock_add_execution.call_args.kwargs
|
||||
execution_context = call_kwargs["execution_context"]
|
||||
|
||||
assert isinstance(execution_context, ExecutionContext)
|
||||
assert execution_context.human_in_the_loop_safe_mode is True
|
||||
assert execution_context.sensitive_action_safe_mode is True
|
||||
|
||||
|
||||
def test_process_review_action_without_auto_approve_still_loads_settings(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test that execution context is created with settings even without auto-approve"""
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = sample_pending_review
|
||||
|
||||
# Mock process_all_reviews
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
approved_review = PendingHumanReviewModel(
|
||||
node_exec_id="test_node_123",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
payload={"data": "test payload"},
|
||||
instructions="Please review",
|
||||
editable=True,
|
||||
status=ReviewStatus.APPROVED,
|
||||
review_message="Approved",
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=FIXED_NOW,
|
||||
reviewed_at=FIXED_NOW,
|
||||
)
|
||||
mock_process_all_reviews.return_value = {"test_node_123": approved_review}
|
||||
|
||||
# Mock create_auto_approval_record - should NOT be called when auto_approve is False
|
||||
mock_create_auto_approval = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.create_auto_approval_record"
|
||||
)
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
# Mock has_pending_reviews_for_graph_exec
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
# Mock get_graph_settings with sensitive_action_safe_mode enabled
|
||||
mock_get_settings = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_settings"
|
||||
)
|
||||
mock_get_settings.return_value = GraphSettings(
|
||||
human_in_the_loop_safe_mode=False,
|
||||
sensitive_action_safe_mode=True,
|
||||
)
|
||||
|
||||
# Mock get_user_by_id to prevent database access
|
||||
mock_get_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_user_by_id"
|
||||
)
|
||||
mock_user = mocker.Mock()
|
||||
mock_user.timezone = "UTC"
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock add_graph_execution
|
||||
mock_add_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.add_graph_execution"
|
||||
)
|
||||
|
||||
# Request WITHOUT auto_approve_future (defaults to False)
|
||||
request_data = {
|
||||
"reviews": [
|
||||
{
|
||||
"node_exec_id": "test_node_123",
|
||||
"approved": True,
|
||||
"message": "Approved",
|
||||
# auto_approve_future defaults to False
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
response = client.post("/api/review/action", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify process_all_reviews_for_execution was called
|
||||
mock_process_all_reviews.assert_called_once()
|
||||
|
||||
# Verify create_auto_approval_record was NOT called (auto_approve_future=False)
|
||||
mock_create_auto_approval.assert_not_called()
|
||||
|
||||
# Verify settings were loaded
|
||||
mock_get_settings.assert_called_once()
|
||||
|
||||
# Verify ExecutionContext has proper settings
|
||||
mock_add_execution.assert_called_once()
|
||||
call_kwargs = mock_add_execution.call_args.kwargs
|
||||
execution_context = call_kwargs["execution_context"]
|
||||
|
||||
assert isinstance(execution_context, ExecutionContext)
|
||||
assert execution_context.human_in_the_loop_safe_mode is False
|
||||
assert execution_context.sensitive_action_safe_mode is True
|
||||
|
||||
|
||||
def test_process_review_action_auto_approve_only_applies_to_approved_reviews(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test that auto_approve record is created only for approved reviews"""
|
||||
# Create two reviews - one approved, one rejected
|
||||
approved_review = PendingHumanReviewModel(
|
||||
node_exec_id="node_exec_approved",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
payload={"data": "approved"},
|
||||
instructions="Review",
|
||||
editable=True,
|
||||
status=ReviewStatus.APPROVED,
|
||||
review_message=None,
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=FIXED_NOW,
|
||||
reviewed_at=FIXED_NOW,
|
||||
)
|
||||
rejected_review = PendingHumanReviewModel(
|
||||
node_exec_id="node_exec_rejected",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
payload={"data": "rejected"},
|
||||
instructions="Review",
|
||||
editable=True,
|
||||
status=ReviewStatus.REJECTED,
|
||||
review_message="Rejected",
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=FIXED_NOW,
|
||||
reviewed_at=FIXED_NOW,
|
||||
)
|
||||
|
||||
# Mock get_pending_review_by_node_exec_id (called to find the graph_exec_id)
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
mock_get_reviews_for_user.return_value = approved_review
|
||||
|
||||
# Mock process_all_reviews
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
mock_process_all_reviews.return_value = {
|
||||
"node_exec_approved": approved_review,
|
||||
"node_exec_rejected": rejected_review,
|
||||
}
|
||||
|
||||
# Mock get_node_execution to return node_id (only called for approved review)
|
||||
mock_get_node_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_node_execution"
|
||||
)
|
||||
mock_node_exec = mocker.Mock(spec=NodeExecutionResult)
|
||||
mock_node_exec.node_id = "test_node_def_approved"
|
||||
mock_get_node_execution.return_value = mock_node_exec
|
||||
|
||||
# Mock create_auto_approval_record
|
||||
mock_create_auto_approval = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.create_auto_approval_record"
|
||||
)
|
||||
|
||||
# Mock get_graph_execution_meta to return execution in REVIEW status
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
# Mock has_pending_reviews_for_graph_exec
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
# Mock get_graph_settings
|
||||
mock_get_settings = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_settings"
|
||||
)
|
||||
mock_get_settings.return_value = GraphSettings()
|
||||
|
||||
# Mock get_user_by_id to prevent database access
|
||||
mock_get_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_user_by_id"
|
||||
)
|
||||
mock_user = mocker.Mock()
|
||||
mock_user.timezone = "UTC"
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock add_graph_execution
|
||||
mock_add_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.add_graph_execution"
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"reviews": [
|
||||
{
|
||||
"node_exec_id": "node_exec_approved",
|
||||
"approved": True,
|
||||
"auto_approve_future": True,
|
||||
},
|
||||
{
|
||||
"node_exec_id": "node_exec_rejected",
|
||||
"approved": False,
|
||||
"auto_approve_future": True, # Should be ignored since rejected
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
response = client.post("/api/review/action", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify process_all_reviews_for_execution was called
|
||||
mock_process_all_reviews.assert_called_once()
|
||||
|
||||
# Verify create_auto_approval_record was called ONLY for the approved review
|
||||
# (not for the rejected one)
|
||||
mock_create_auto_approval.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
node_id="test_node_def_approved",
|
||||
payload={"data": "approved"},
|
||||
)
|
||||
|
||||
# Verify get_node_execution was called only for approved review
|
||||
mock_get_node_execution.assert_called_once_with("node_exec_approved")
|
||||
|
||||
# Verify ExecutionContext was created (auto-approval is now DB-based)
|
||||
call_kwargs = mock_add_execution.call_args.kwargs
|
||||
execution_context = call_kwargs["execution_context"]
|
||||
assert isinstance(execution_context, ExecutionContext)
|
||||
|
||||
|
||||
def test_process_review_action_per_review_auto_approve_granularity(
|
||||
client: fastapi.testclient.TestClient,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
sample_pending_review: PendingHumanReviewModel,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test that auto-approval can be set per-review (granular control)"""
|
||||
# Mock get_pending_review_by_node_exec_id - return different reviews based on node_exec_id
|
||||
mock_get_reviews_for_user = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_pending_review_by_node_exec_id"
|
||||
)
|
||||
|
||||
# Create a mapping of node_exec_id to review
|
||||
review_map = {
|
||||
"node_1_auto": PendingHumanReviewModel(
|
||||
node_exec_id="node_1_auto",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
payload={"data": "node1"},
|
||||
instructions="Review 1",
|
||||
editable=True,
|
||||
status=ReviewStatus.WAITING,
|
||||
review_message=None,
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
),
|
||||
"node_2_manual": PendingHumanReviewModel(
|
||||
node_exec_id="node_2_manual",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
payload={"data": "node2"},
|
||||
instructions="Review 2",
|
||||
editable=True,
|
||||
status=ReviewStatus.WAITING,
|
||||
review_message=None,
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
),
|
||||
"node_3_auto": PendingHumanReviewModel(
|
||||
node_exec_id="node_3_auto",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
payload={"data": "node3"},
|
||||
instructions="Review 3",
|
||||
editable=True,
|
||||
status=ReviewStatus.WAITING,
|
||||
review_message=None,
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
),
|
||||
}
|
||||
|
||||
# Use side_effect to return different reviews based on node_exec_id parameter
|
||||
def mock_get_review_by_id(node_exec_id: str, _user_id: str):
|
||||
return review_map.get(node_exec_id)
|
||||
|
||||
mock_get_reviews_for_user.side_effect = mock_get_review_by_id
|
||||
|
||||
# Mock process_all_reviews - return 3 approved reviews
|
||||
mock_process_all_reviews = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.process_all_reviews_for_execution"
|
||||
)
|
||||
mock_process_all_reviews.return_value = {
|
||||
"node_1_auto": PendingHumanReviewModel(
|
||||
node_exec_id="node_1_auto",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
payload={"data": "node1"},
|
||||
instructions="Review 1",
|
||||
editable=True,
|
||||
status=ReviewStatus.APPROVED,
|
||||
review_message=None,
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=FIXED_NOW,
|
||||
reviewed_at=FIXED_NOW,
|
||||
),
|
||||
"node_2_manual": PendingHumanReviewModel(
|
||||
node_exec_id="node_2_manual",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
payload={"data": "node2"},
|
||||
instructions="Review 2",
|
||||
editable=True,
|
||||
status=ReviewStatus.APPROVED,
|
||||
review_message=None,
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=FIXED_NOW,
|
||||
reviewed_at=FIXED_NOW,
|
||||
),
|
||||
"node_3_auto": PendingHumanReviewModel(
|
||||
node_exec_id="node_3_auto",
|
||||
user_id=test_user_id,
|
||||
graph_exec_id="test_graph_exec",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
payload={"data": "node3"},
|
||||
instructions="Review 3",
|
||||
editable=True,
|
||||
status=ReviewStatus.APPROVED,
|
||||
review_message=None,
|
||||
was_edited=False,
|
||||
processed=False,
|
||||
created_at=FIXED_NOW,
|
||||
updated_at=FIXED_NOW,
|
||||
reviewed_at=FIXED_NOW,
|
||||
),
|
||||
}
|
||||
|
||||
# Mock get_node_execution
|
||||
mock_get_node_execution = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_node_execution"
|
||||
)
|
||||
|
||||
def mock_get_node(node_exec_id: str):
|
||||
mock_node = mocker.Mock(spec=NodeExecutionResult)
|
||||
mock_node.node_id = f"node_def_{node_exec_id}"
|
||||
return mock_node
|
||||
|
||||
mock_get_node_execution.side_effect = mock_get_node
|
||||
|
||||
# Mock create_auto_approval_record
|
||||
mock_create_auto_approval = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.create_auto_approval_record"
|
||||
)
|
||||
|
||||
# Mock get_graph_execution_meta
|
||||
mock_get_graph_exec = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_execution_meta"
|
||||
)
|
||||
mock_graph_exec_meta = mocker.Mock()
|
||||
mock_graph_exec_meta.status = ExecutionStatus.REVIEW
|
||||
mock_get_graph_exec.return_value = mock_graph_exec_meta
|
||||
|
||||
# Mock has_pending_reviews_for_graph_exec
|
||||
mock_has_pending = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec"
|
||||
)
|
||||
mock_has_pending.return_value = False
|
||||
|
||||
# Mock settings and execution
|
||||
mock_get_settings = mocker.patch(
|
||||
"backend.api.features.executions.review.routes.get_graph_settings"
|
||||
)
|
||||
mock_get_settings.return_value = GraphSettings(
|
||||
human_in_the_loop_safe_mode=False, sensitive_action_safe_mode=False
|
||||
)
|
||||
|
||||
mocker.patch("backend.api.features.executions.review.routes.add_graph_execution")
|
||||
mocker.patch("backend.api.features.executions.review.routes.get_user_by_id")
|
||||
|
||||
# Request with granular auto-approval:
|
||||
# - node_1_auto: auto_approve_future=True
|
||||
# - node_2_manual: auto_approve_future=False (explicit)
|
||||
# - node_3_auto: auto_approve_future=True
|
||||
request_data = {
|
||||
"reviews": [
|
||||
{
|
||||
"node_exec_id": "node_1_auto",
|
||||
"approved": True,
|
||||
"auto_approve_future": True,
|
||||
},
|
||||
{
|
||||
"node_exec_id": "node_2_manual",
|
||||
"approved": True,
|
||||
"auto_approve_future": False, # Don't auto-approve this one
|
||||
},
|
||||
{
|
||||
"node_exec_id": "node_3_auto",
|
||||
"approved": True,
|
||||
"auto_approve_future": True,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
response = client.post("/api/review/action", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify create_auto_approval_record was called ONLY for reviews with auto_approve_future=True
|
||||
assert mock_create_auto_approval.call_count == 2
|
||||
|
||||
# Check that it was called for node_1 and node_3, but NOT node_2
|
||||
call_args_list = [call.kwargs for call in mock_create_auto_approval.call_args_list]
|
||||
node_ids_with_auto_approval = [args["node_id"] for args in call_args_list]
|
||||
|
||||
assert "node_def_node_1_auto" in node_ids_with_auto_approval
|
||||
assert "node_def_node_3_auto" in node_ids_with_auto_approval
|
||||
assert "node_def_node_2_manual" not in node_ids_with_auto_approval
|
||||
|
||||
@@ -5,13 +5,23 @@ import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, HTTPException, Query, Security, status
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.data.execution import get_graph_execution_meta
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
ExecutionStatus,
|
||||
get_graph_execution_meta,
|
||||
get_node_execution,
|
||||
)
|
||||
from backend.data.graph import get_graph_settings
|
||||
from backend.data.human_review import (
|
||||
create_auto_approval_record,
|
||||
get_pending_review_by_node_exec_id,
|
||||
get_pending_reviews_for_execution,
|
||||
get_pending_reviews_for_user,
|
||||
has_pending_reviews_for_graph_exec,
|
||||
process_all_reviews_for_execution,
|
||||
)
|
||||
from backend.data.model import USER_TIMEZONE_NOT_SET
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor.utils import add_graph_execution
|
||||
|
||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||
@@ -127,17 +137,80 @@ async def process_review_action(
|
||||
detail="At least one review must be provided",
|
||||
)
|
||||
|
||||
# Build review decisions map
|
||||
# Get graph execution ID by looking up all requested reviews
|
||||
# Use direct lookup to avoid pagination issues (can't miss reviews beyond first page)
|
||||
# Also validate that all reviews belong to the same execution
|
||||
matching_review = None
|
||||
graph_exec_ids: set[str] = set()
|
||||
|
||||
for node_exec_id in all_request_node_ids:
|
||||
review = await get_pending_review_by_node_exec_id(node_exec_id, user_id)
|
||||
if not review:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No pending review found for node execution {node_exec_id}",
|
||||
)
|
||||
if matching_review is None:
|
||||
matching_review = review
|
||||
graph_exec_ids.add(review.graph_exec_id)
|
||||
|
||||
# Ensure all reviews belong to the same execution
|
||||
if len(graph_exec_ids) > 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="All reviews in a single request must belong to the same execution.",
|
||||
)
|
||||
|
||||
# Safety check (matching_review should never be None here due to validation above)
|
||||
if matching_review is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal error: No matching review found despite validation",
|
||||
)
|
||||
|
||||
graph_exec_id = matching_review.graph_exec_id
|
||||
|
||||
# Validate execution status before processing reviews
|
||||
graph_exec_meta = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
|
||||
if not graph_exec_meta:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
# Only allow processing reviews if execution is paused for review
|
||||
# or incomplete (partial execution with some reviews already processed)
|
||||
if graph_exec_meta.status not in (
|
||||
ExecutionStatus.REVIEW,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
|
||||
f"Reviews can only be processed when execution is paused (REVIEW status). "
|
||||
f"Current status: {graph_exec_meta.status}",
|
||||
)
|
||||
|
||||
# Build review decisions map and track which reviews requested auto-approval
|
||||
# Auto-approved reviews use original data (no modifications allowed)
|
||||
review_decisions = {}
|
||||
auto_approve_requests = {} # Map node_exec_id -> auto_approve_future flag
|
||||
|
||||
for review in request.reviews:
|
||||
review_status = (
|
||||
ReviewStatus.APPROVED if review.approved else ReviewStatus.REJECTED
|
||||
)
|
||||
# If this review requested auto-approval, don't allow data modifications
|
||||
reviewed_data = None if review.auto_approve_future else review.reviewed_data
|
||||
review_decisions[review.node_exec_id] = (
|
||||
review_status,
|
||||
review.reviewed_data,
|
||||
reviewed_data,
|
||||
review.message,
|
||||
)
|
||||
auto_approve_requests[review.node_exec_id] = review.auto_approve_future
|
||||
|
||||
# Process all reviews
|
||||
updated_reviews = await process_all_reviews_for_execution(
|
||||
@@ -145,6 +218,32 @@ async def process_review_action(
|
||||
review_decisions=review_decisions,
|
||||
)
|
||||
|
||||
# Create auto-approval records for approved reviews that requested it
|
||||
# Note: Processing sequentially to avoid event loop issues in tests
|
||||
for node_exec_id, review_result in updated_reviews.items():
|
||||
# Only create auto-approval if:
|
||||
# 1. This review was approved
|
||||
# 2. The review requested auto-approval
|
||||
if review_result.status == ReviewStatus.APPROVED and auto_approve_requests.get(
|
||||
node_exec_id, False
|
||||
):
|
||||
try:
|
||||
node_exec = await get_node_execution(node_exec_id)
|
||||
if node_exec:
|
||||
await create_auto_approval_record(
|
||||
user_id=user_id,
|
||||
graph_exec_id=review_result.graph_exec_id,
|
||||
graph_id=review_result.graph_id,
|
||||
graph_version=review_result.graph_version,
|
||||
node_id=node_exec.node_id,
|
||||
payload=review_result.payload,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create auto-approval record for {node_exec_id}",
|
||||
exc_info=e,
|
||||
)
|
||||
|
||||
# Count results
|
||||
approved_count = sum(
|
||||
1
|
||||
@@ -157,22 +256,37 @@ async def process_review_action(
|
||||
if review.status == ReviewStatus.REJECTED
|
||||
)
|
||||
|
||||
# Resume execution if we processed some reviews
|
||||
# Resume execution only if ALL pending reviews for this execution have been processed
|
||||
if updated_reviews:
|
||||
# Get graph execution ID from any processed review
|
||||
first_review = next(iter(updated_reviews.values()))
|
||||
graph_exec_id = first_review.graph_exec_id
|
||||
|
||||
# Check if any pending reviews remain for this execution
|
||||
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
||||
|
||||
if not still_has_pending:
|
||||
# Resume execution
|
||||
# Get the graph_id from any processed review
|
||||
first_review = next(iter(updated_reviews.values()))
|
||||
|
||||
try:
|
||||
# Fetch user and settings to build complete execution context
|
||||
user = await get_user_by_id(user_id)
|
||||
settings = await get_graph_settings(
|
||||
user_id=user_id, graph_id=first_review.graph_id
|
||||
)
|
||||
|
||||
# Preserve user's timezone preference when resuming execution
|
||||
user_timezone = (
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||
user_timezone=user_timezone,
|
||||
)
|
||||
|
||||
await add_graph_execution(
|
||||
graph_id=first_review.graph_id,
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
logger.info(f"Resumed execution {graph_exec_id}")
|
||||
except Exception as e:
|
||||
|
||||
@@ -6,6 +6,7 @@ Handles generation and storage of OpenAI embeddings for all content types
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
@@ -21,6 +22,11 @@ from backend.util.json import dumps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Context variable to track errors logged in the current task/operation
|
||||
# This prevents spamming the same error multiple times when processing batches
|
||||
_logged_errors: contextvars.ContextVar[set[str]] = contextvars.ContextVar(
|
||||
"_logged_errors"
|
||||
)
|
||||
|
||||
# OpenAI embedding model configuration
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
@@ -31,6 +37,42 @@ EMBEDDING_DIM = 1536
|
||||
EMBEDDING_MAX_TOKENS = 8191
|
||||
|
||||
|
||||
def log_once_per_task(error_key: str, log_fn, message: str, **kwargs) -> bool:
|
||||
"""
|
||||
Log an error/warning only once per task/operation to avoid log spam.
|
||||
|
||||
Uses contextvars to track what has been logged in the current async context.
|
||||
Useful when processing batches where the same error might occur for many items.
|
||||
|
||||
Args:
|
||||
error_key: Unique identifier for this error type
|
||||
log_fn: Logger function to call (e.g., logger.error, logger.warning)
|
||||
message: Message to log
|
||||
**kwargs: Additional arguments to pass to log_fn
|
||||
|
||||
Returns:
|
||||
True if the message was logged, False if it was suppressed (already logged)
|
||||
|
||||
Example:
|
||||
log_once_per_task("missing_api_key", logger.error, "API key not set")
|
||||
"""
|
||||
# Get current logged errors, or create a new set if this is the first call in this context
|
||||
logged = _logged_errors.get(None)
|
||||
if logged is None:
|
||||
logged = set()
|
||||
_logged_errors.set(logged)
|
||||
|
||||
if error_key in logged:
|
||||
return False
|
||||
|
||||
# Log the message with a note that it will only appear once
|
||||
log_fn(f"{message} (This message will only be shown once per task.)", **kwargs)
|
||||
|
||||
# Mark as logged
|
||||
logged.add(error_key)
|
||||
return True
|
||||
|
||||
|
||||
def build_searchable_text(
|
||||
name: str,
|
||||
description: str,
|
||||
@@ -73,7 +115,11 @@ async def generate_embedding(text: str) -> list[float] | None:
|
||||
try:
|
||||
client = get_openai_client()
|
||||
if not client:
|
||||
logger.error("openai_internal_api_key not set, cannot generate embedding")
|
||||
log_once_per_task(
|
||||
"openai_api_key_missing",
|
||||
logger.error,
|
||||
"openai_internal_api_key not set, cannot generate embeddings",
|
||||
)
|
||||
return None
|
||||
|
||||
# Truncate text to token limit using tiktoken
|
||||
@@ -290,7 +336,12 @@ async def ensure_embedding(
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(f"Could not generate embedding for version {version_id}")
|
||||
log_once_per_task(
|
||||
"embedding_generation_failed",
|
||||
logger.warning,
|
||||
"Could not generate embeddings (missing API key or service unavailable). "
|
||||
"Embedding generation is disabled for this task.",
|
||||
)
|
||||
return False
|
||||
|
||||
# Store the embedding with metadata using new function
|
||||
@@ -609,8 +660,11 @@ async def ensure_content_embedding(
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(
|
||||
f"Could not generate embedding for {content_type}:{content_id}"
|
||||
log_once_per_task(
|
||||
"embedding_generation_failed",
|
||||
logger.warning,
|
||||
"Could not generate embeddings (missing API key or service unavailable). "
|
||||
"Embedding generation is disabled for this task.",
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@@ -116,6 +116,7 @@ class PrintToConsoleBlock(Block):
|
||||
input_schema=PrintToConsoleBlock.Input,
|
||||
output_schema=PrintToConsoleBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
is_sensitive_action=True,
|
||||
test_output=[
|
||||
("output", "Hello, World!"),
|
||||
("status", "printed"),
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any, Optional
|
||||
from prisma.enums import ReviewStatus
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.execution import ExecutionContext, ExecutionStatus
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.human_review import ReviewResult
|
||||
from backend.executor.manager import async_update_node_execution_status
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
@@ -28,6 +28,11 @@ class ReviewDecision(BaseModel):
|
||||
class HITLReviewHelper:
|
||||
"""Helper class for Human-In-The-Loop review operations."""
|
||||
|
||||
@staticmethod
|
||||
async def check_approval(**kwargs) -> Optional[ReviewResult]:
|
||||
"""Check if there's an existing approval for this node execution."""
|
||||
return await get_database_manager_async_client().check_approval(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
|
||||
"""Create or retrieve a human review from the database."""
|
||||
@@ -55,11 +60,11 @@ class HITLReviewHelper:
|
||||
async def _handle_review_request(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
) -> Optional[ReviewResult]:
|
||||
@@ -69,11 +74,11 @@ class HITLReviewHelper:
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_id: ID of the node in the graph definition
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
graph_version: Version of the graph
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
|
||||
@@ -83,15 +88,41 @@ class HITLReviewHelper:
|
||||
Raises:
|
||||
Exception: If review creation or status update fails
|
||||
"""
|
||||
# Skip review if safe mode is disabled - return auto-approved result
|
||||
if not execution_context.human_in_the_loop_safe_mode:
|
||||
# Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode)
|
||||
# are handled by the caller:
|
||||
# - HITL blocks check human_in_the_loop_safe_mode in their run() method
|
||||
# - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review()
|
||||
# This function only handles checking for existing approvals.
|
||||
|
||||
# Check if this node has already been approved (normal or auto-approval)
|
||||
if approval_result := await HITLReviewHelper.check_approval(
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
node_id=node_id,
|
||||
user_id=user_id,
|
||||
input_data=input_data,
|
||||
):
|
||||
logger.info(
|
||||
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
|
||||
f"Block {block_name} skipping review for node {node_exec_id} - "
|
||||
f"found existing approval"
|
||||
)
|
||||
# Return a new ReviewResult with the current node_exec_id but approved status
|
||||
# For auto-approvals, always use current input_data
|
||||
# For normal approvals, use approval_result.data unless it's None
|
||||
is_auto_approval = approval_result.node_exec_id != node_exec_id
|
||||
approved_data = (
|
||||
input_data
|
||||
if is_auto_approval
|
||||
else (
|
||||
approval_result.data
|
||||
if approval_result.data is not None
|
||||
else input_data
|
||||
)
|
||||
)
|
||||
return ReviewResult(
|
||||
data=input_data,
|
||||
data=approved_data,
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="Auto-approved (safe mode disabled)",
|
||||
message=approval_result.message,
|
||||
processed=True,
|
||||
node_exec_id=node_exec_id,
|
||||
)
|
||||
@@ -129,11 +160,11 @@ class HITLReviewHelper:
|
||||
async def handle_review_decision(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
) -> Optional[ReviewDecision]:
|
||||
@@ -143,11 +174,11 @@ class HITLReviewHelper:
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_id: ID of the node in the graph definition
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
graph_version: Version of the graph
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
|
||||
@@ -158,11 +189,11 @@ class HITLReviewHelper:
|
||||
review_result = await HITLReviewHelper._handle_review_request(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=block_name,
|
||||
editable=editable,
|
||||
)
|
||||
|
||||
@@ -97,6 +97,7 @@ class HumanInTheLoopBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
@@ -115,11 +116,11 @@ class HumanInTheLoopBlock(Block):
|
||||
decision = await self.handle_review_decision(
|
||||
input_data=input_data.data,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=self.name,
|
||||
editable=input_data.editable,
|
||||
)
|
||||
|
||||
@@ -441,6 +441,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
static_output: bool = False,
|
||||
block_type: BlockType = BlockType.STANDARD,
|
||||
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
|
||||
is_sensitive_action: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the block with the given schema.
|
||||
@@ -473,8 +474,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.static_output = static_output
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.is_sensitive_action = is_sensitive_action
|
||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||
self.is_sensitive_action: bool = False
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
@@ -622,6 +623,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
input_data: BlockInput,
|
||||
*,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
@@ -648,11 +650,11 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
decision = await HITLReviewHelper.handle_review_decision(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
execution_context=execution_context,
|
||||
block_name=self.name,
|
||||
editable=True,
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ from backend.api.features.executions.review.model import (
|
||||
PendingHumanReviewModel,
|
||||
SafeJsonData,
|
||||
)
|
||||
from backend.data.execution import get_graph_execution_meta
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -32,6 +33,125 @@ class ReviewResult(BaseModel):
|
||||
node_exec_id: str
|
||||
|
||||
|
||||
def get_auto_approve_key(graph_exec_id: str, node_id: str) -> str:
|
||||
"""Generate the special nodeExecId key for auto-approval records."""
|
||||
return f"auto_approve_{graph_exec_id}_{node_id}"
|
||||
|
||||
|
||||
async def check_approval(
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
input_data: SafeJsonData | None = None,
|
||||
) -> Optional[ReviewResult]:
|
||||
"""
|
||||
Check if there's an existing approval for this node execution.
|
||||
|
||||
Checks both:
|
||||
1. Normal approval by node_exec_id (previous run of the same node execution)
|
||||
2. Auto-approval by special key pattern "auto_approve_{graph_exec_id}_{node_id}"
|
||||
|
||||
Args:
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
node_id: ID of the node definition (not execution)
|
||||
user_id: ID of the user (for data isolation)
|
||||
input_data: Current input data (used for auto-approvals to avoid stale data)
|
||||
|
||||
Returns:
|
||||
ReviewResult if approval found (either normal or auto), None otherwise
|
||||
"""
|
||||
auto_approve_key = get_auto_approve_key(graph_exec_id, node_id)
|
||||
|
||||
# Check for either normal approval or auto-approval in a single query
|
||||
existing_review = await PendingHumanReview.prisma().find_first(
|
||||
where={
|
||||
"OR": [
|
||||
{"nodeExecId": node_exec_id},
|
||||
{"nodeExecId": auto_approve_key},
|
||||
],
|
||||
"status": ReviewStatus.APPROVED,
|
||||
"userId": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
if existing_review:
|
||||
is_auto_approval = existing_review.nodeExecId == auto_approve_key
|
||||
logger.info(
|
||||
f"Found {'auto-' if is_auto_approval else ''}approval for node {node_id} "
|
||||
f"(exec: {node_exec_id}) in execution {graph_exec_id}"
|
||||
)
|
||||
# For auto-approvals, use current input_data to avoid replaying stale payload
|
||||
# For normal approvals, use the stored payload (which may have been edited)
|
||||
return ReviewResult(
|
||||
data=(
|
||||
input_data
|
||||
if is_auto_approval and input_data is not None
|
||||
else existing_review.payload
|
||||
),
|
||||
status=ReviewStatus.APPROVED,
|
||||
message=(
|
||||
"Auto-approved (user approved all future actions for this node)"
|
||||
if is_auto_approval
|
||||
else existing_review.reviewMessage or ""
|
||||
),
|
||||
processed=True,
|
||||
node_exec_id=existing_review.nodeExecId,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def create_auto_approval_record(
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_id: str,
|
||||
payload: SafeJsonData,
|
||||
) -> None:
|
||||
"""
|
||||
Create an auto-approval record for a node in this execution.
|
||||
|
||||
This is stored as a PendingHumanReview with a special nodeExecId pattern
|
||||
and status=APPROVED, so future executions of the same node can skip review.
|
||||
|
||||
Raises:
|
||||
ValueError: If the graph execution doesn't belong to the user
|
||||
"""
|
||||
# Validate that the graph execution belongs to this user (defense in depth)
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec:
|
||||
raise ValueError(
|
||||
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
|
||||
)
|
||||
|
||||
auto_approve_key = get_auto_approve_key(graph_exec_id, node_id)
|
||||
|
||||
await PendingHumanReview.prisma().upsert(
|
||||
where={"nodeExecId": auto_approve_key},
|
||||
data={
|
||||
"create": {
|
||||
"nodeExecId": auto_approve_key,
|
||||
"userId": user_id,
|
||||
"graphExecId": graph_exec_id,
|
||||
"graphId": graph_id,
|
||||
"graphVersion": graph_version,
|
||||
"payload": SafeJson(payload),
|
||||
"instructions": "Auto-approval record",
|
||||
"editable": False,
|
||||
"status": ReviewStatus.APPROVED,
|
||||
"processed": True,
|
||||
"reviewedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
"update": {}, # Already exists, no update needed
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_or_create_human_review(
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
@@ -108,6 +228,33 @@ async def get_or_create_human_review(
|
||||
)
|
||||
|
||||
|
||||
async def get_pending_review_by_node_exec_id(
|
||||
node_exec_id: str, user_id: str
|
||||
) -> Optional["PendingHumanReviewModel"]:
|
||||
"""
|
||||
Get a pending review by its node execution ID.
|
||||
|
||||
Args:
|
||||
node_exec_id: The node execution ID to look up
|
||||
user_id: User ID for authorization (only returns if review belongs to this user)
|
||||
|
||||
Returns:
|
||||
The pending review if found and belongs to user, None otherwise
|
||||
"""
|
||||
review = await PendingHumanReview.prisma().find_first(
|
||||
where={
|
||||
"nodeExecId": node_exec_id,
|
||||
"userId": user_id,
|
||||
"status": ReviewStatus.WAITING,
|
||||
}
|
||||
)
|
||||
|
||||
if not review:
|
||||
return None
|
||||
|
||||
return PendingHumanReviewModel.from_db(review)
|
||||
|
||||
|
||||
async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool:
|
||||
"""
|
||||
Check if a graph execution has any pending reviews.
|
||||
@@ -256,3 +403,44 @@ async def update_review_processed_status(node_exec_id: str, processed: bool) ->
|
||||
await PendingHumanReview.prisma().update(
|
||||
where={"nodeExecId": node_exec_id}, data={"processed": processed}
|
||||
)
|
||||
|
||||
|
||||
async def cancel_pending_reviews_for_execution(graph_exec_id: str, user_id: str) -> int:
|
||||
"""
|
||||
Cancel all pending reviews for a graph execution (e.g., when execution is stopped).
|
||||
|
||||
Marks all WAITING reviews as REJECTED with a message indicating the execution was stopped.
|
||||
|
||||
Args:
|
||||
graph_exec_id: The graph execution ID
|
||||
user_id: User ID who owns the execution (for security validation)
|
||||
|
||||
Returns:
|
||||
Number of reviews cancelled
|
||||
|
||||
Raises:
|
||||
ValueError: If the graph execution doesn't belong to the user
|
||||
"""
|
||||
# Validate user ownership before cancelling reviews
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec:
|
||||
raise ValueError(
|
||||
f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}"
|
||||
)
|
||||
|
||||
result = await PendingHumanReview.prisma().update_many(
|
||||
where={
|
||||
"graphExecId": graph_exec_id,
|
||||
"userId": user_id,
|
||||
"status": ReviewStatus.WAITING,
|
||||
},
|
||||
data={
|
||||
"status": ReviewStatus.REJECTED,
|
||||
"reviewMessage": "Execution was stopped by user",
|
||||
"processed": True,
|
||||
"reviewedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -46,8 +46,8 @@ async def test_get_or_create_human_review_new(
|
||||
sample_db_review.status = ReviewStatus.WAITING
|
||||
sample_db_review.processed = False
|
||||
|
||||
mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||
mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||
mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||
mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||
|
||||
result = await get_or_create_human_review(
|
||||
user_id="test-user-123",
|
||||
@@ -75,8 +75,8 @@ async def test_get_or_create_human_review_approved(
|
||||
sample_db_review.processed = False
|
||||
sample_db_review.reviewMessage = "Looks good"
|
||||
|
||||
mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||
mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||
mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma")
|
||||
mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review)
|
||||
|
||||
result = await get_or_create_human_review(
|
||||
user_id="test-user-123",
|
||||
|
||||
@@ -50,6 +50,8 @@ from backend.data.graph import (
|
||||
validate_graph_execution_permissions,
|
||||
)
|
||||
from backend.data.human_review import (
|
||||
cancel_pending_reviews_for_execution,
|
||||
check_approval,
|
||||
get_or_create_human_review,
|
||||
has_pending_reviews_for_graph_exec,
|
||||
update_review_processed_status,
|
||||
@@ -190,6 +192,8 @@ class DatabaseManager(AppService):
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
|
||||
# Human In The Loop
|
||||
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
||||
check_approval = _(check_approval)
|
||||
get_or_create_human_review = _(get_or_create_human_review)
|
||||
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
||||
update_review_processed_status = _(update_review_processed_status)
|
||||
@@ -313,6 +317,8 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
set_execution_kv_data = d.set_execution_kv_data
|
||||
|
||||
# Human In The Loop
|
||||
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
||||
check_approval = d.check_approval
|
||||
get_or_create_human_review = d.get_or_create_human_review
|
||||
update_review_processed_status = d.update_review_processed_status
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import human_review as human_review_db
|
||||
from backend.data import onboarding as onboarding_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data.block import (
|
||||
@@ -749,9 +750,27 @@ async def stop_graph_execution(
|
||||
if graph_exec.status in [
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.REVIEW,
|
||||
]:
|
||||
# If the graph is still on the queue, we can prevent them from being executed
|
||||
# by setting the status to TERMINATED.
|
||||
# If the graph is queued/incomplete/paused for review, terminate immediately
|
||||
# No need to wait for executor since it's not actively running
|
||||
|
||||
# If graph is in REVIEW status, clean up pending reviews before terminating
|
||||
if graph_exec.status == ExecutionStatus.REVIEW:
|
||||
# Use human_review_db if Prisma connected, else database manager
|
||||
review_db = (
|
||||
human_review_db
|
||||
if prisma.is_connected()
|
||||
else get_database_manager_async_client()
|
||||
)
|
||||
# Mark all pending reviews as rejected/cancelled
|
||||
cancelled_count = await review_db.cancel_pending_reviews_for_execution(
|
||||
graph_exec_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"Cancelled {cancelled_count} pending review(s) for stopped execution {graph_exec_id}"
|
||||
)
|
||||
|
||||
graph_exec.status = ExecutionStatus.TERMINATED
|
||||
|
||||
await asyncio.gather(
|
||||
@@ -887,9 +906,28 @@ async def add_graph_execution(
|
||||
nodes_to_skip=nodes_to_skip,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
||||
logger.info(f"Queueing execution {graph_exec.id}")
|
||||
|
||||
# Update execution status to QUEUED BEFORE publishing to prevent race condition
|
||||
# where two concurrent requests could both publish the same execution
|
||||
updated_exec = await edb.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.QUEUED,
|
||||
)
|
||||
|
||||
# Verify the status update succeeded (prevents duplicate queueing in race conditions)
|
||||
# If another request already updated the status, this execution will not be QUEUED
|
||||
if not updated_exec or updated_exec.status != ExecutionStatus.QUEUED:
|
||||
logger.warning(
|
||||
f"Skipping queue publish for execution {graph_exec.id} - "
|
||||
f"status update failed or execution already queued by another request"
|
||||
)
|
||||
return graph_exec
|
||||
|
||||
graph_exec.status = ExecutionStatus.QUEUED
|
||||
|
||||
# Publish to execution queue for executor to pick up
|
||||
# This happens AFTER status update to ensure only one request publishes
|
||||
exec_queue = await get_async_execution_queue()
|
||||
await exec_queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
@@ -897,13 +935,6 @@ async def add_graph_execution(
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
||||
|
||||
# Update execution status to QUEUED
|
||||
graph_exec.status = ExecutionStatus.QUEUED
|
||||
await edb.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=graph_exec.status,
|
||||
)
|
||||
except BaseException as e:
|
||||
err = str(e) or type(e).__name__
|
||||
if not graph_exec:
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from backend.data.dynamic_fields import merge_execution_input, parse_execution_output
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.util.mock import MockObject
|
||||
|
||||
|
||||
@@ -346,6 +347,7 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Mock the queue and event bus
|
||||
@@ -611,6 +613,7 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = []
|
||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||
|
||||
# Track what's passed to to_graph_execution_entry
|
||||
captured_kwargs = {}
|
||||
@@ -670,3 +673,232 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
# Verify nodes_to_skip was passed to to_graph_execution_entry
|
||||
assert "nodes_to_skip" in captured_kwargs
|
||||
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_graph_execution_in_review_status_cancels_pending_reviews(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""Test that stopping an execution in REVIEW status cancels pending reviews."""
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
from backend.executor.utils import stop_graph_execution
|
||||
|
||||
user_id = "test-user"
|
||||
graph_exec_id = "test-exec-123"
|
||||
|
||||
# Mock graph execution in REVIEW status
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||
mock_graph_exec.id = graph_exec_id
|
||||
mock_graph_exec.status = ExecutionStatus.REVIEW
|
||||
|
||||
# Mock dependencies
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_queue_client = mocker.AsyncMock()
|
||||
mock_get_queue.return_value = mock_queue_client
|
||||
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_prisma.is_connected.return_value = True
|
||||
|
||||
mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db")
|
||||
mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock(
|
||||
return_value=2 # 2 reviews cancelled
|
||||
)
|
||||
|
||||
mock_execution_db = mocker.patch("backend.executor.utils.execution_db")
|
||||
mock_execution_db.get_graph_execution_meta = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_execution_db.update_graph_execution_stats = mocker.AsyncMock()
|
||||
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_event_bus = mocker.MagicMock()
|
||||
mock_event_bus.publish = mocker.AsyncMock()
|
||||
mock_get_event_bus.return_value = mock_event_bus
|
||||
|
||||
mock_get_child_executions = mocker.patch(
|
||||
"backend.executor.utils._get_child_executions"
|
||||
)
|
||||
mock_get_child_executions.return_value = [] # No children
|
||||
|
||||
# Call stop_graph_execution with timeout to allow status check
|
||||
await stop_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
wait_timeout=1.0, # Wait to allow status check
|
||||
cascade=True,
|
||||
)
|
||||
|
||||
# Verify pending reviews were cancelled
|
||||
mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with(
|
||||
graph_exec_id, user_id
|
||||
)
|
||||
|
||||
# Verify execution status was updated to TERMINATED
|
||||
mock_execution_db.update_graph_execution_stats.assert_called_once()
|
||||
call_kwargs = mock_execution_db.update_graph_execution_stats.call_args[1]
|
||||
assert call_kwargs["graph_exec_id"] == graph_exec_id
|
||||
assert call_kwargs["status"] == ExecutionStatus.TERMINATED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_graph_execution_with_database_manager_when_prisma_disconnected(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""Test that stop uses database manager when Prisma is not connected."""
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
from backend.executor.utils import stop_graph_execution
|
||||
|
||||
user_id = "test-user"
|
||||
graph_exec_id = "test-exec-456"
|
||||
|
||||
# Mock graph execution in REVIEW status
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||
mock_graph_exec.id = graph_exec_id
|
||||
mock_graph_exec.status = ExecutionStatus.REVIEW
|
||||
|
||||
# Mock dependencies
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_queue_client = mocker.AsyncMock()
|
||||
mock_get_queue.return_value = mock_queue_client
|
||||
|
||||
# Prisma is NOT connected
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_prisma.is_connected.return_value = False
|
||||
|
||||
# Mock database manager client
|
||||
mock_get_db_manager = mocker.patch(
|
||||
"backend.executor.utils.get_database_manager_async_client"
|
||||
)
|
||||
mock_db_manager = mocker.AsyncMock()
|
||||
mock_db_manager.get_graph_execution_meta = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_db_manager.cancel_pending_reviews_for_execution = mocker.AsyncMock(
|
||||
return_value=3 # 3 reviews cancelled
|
||||
)
|
||||
mock_db_manager.update_graph_execution_stats = mocker.AsyncMock()
|
||||
mock_get_db_manager.return_value = mock_db_manager
|
||||
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_event_bus = mocker.MagicMock()
|
||||
mock_event_bus.publish = mocker.AsyncMock()
|
||||
mock_get_event_bus.return_value = mock_event_bus
|
||||
|
||||
mock_get_child_executions = mocker.patch(
|
||||
"backend.executor.utils._get_child_executions"
|
||||
)
|
||||
mock_get_child_executions.return_value = [] # No children
|
||||
|
||||
# Call stop_graph_execution with timeout
|
||||
await stop_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
wait_timeout=1.0,
|
||||
cascade=True,
|
||||
)
|
||||
|
||||
# Verify database manager was used for cancel_pending_reviews
|
||||
mock_db_manager.cancel_pending_reviews_for_execution.assert_called_once_with(
|
||||
graph_exec_id, user_id
|
||||
)
|
||||
|
||||
# Verify execution status was updated via database manager
|
||||
mock_db_manager.update_graph_execution_stats.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_graph_execution_cascades_to_child_with_reviews(
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
"""Test that stopping parent execution cascades to children and cancels their reviews."""
|
||||
from backend.data.execution import ExecutionStatus, GraphExecutionMeta
|
||||
from backend.executor.utils import stop_graph_execution
|
||||
|
||||
user_id = "test-user"
|
||||
parent_exec_id = "parent-exec"
|
||||
child_exec_id = "child-exec"
|
||||
|
||||
# Mock parent execution in RUNNING status
|
||||
mock_parent_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||
mock_parent_exec.id = parent_exec_id
|
||||
mock_parent_exec.status = ExecutionStatus.RUNNING
|
||||
|
||||
# Mock child execution in REVIEW status
|
||||
mock_child_exec = mocker.MagicMock(spec=GraphExecutionMeta)
|
||||
mock_child_exec.id = child_exec_id
|
||||
mock_child_exec.status = ExecutionStatus.REVIEW
|
||||
|
||||
# Mock dependencies
|
||||
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
|
||||
mock_queue_client = mocker.AsyncMock()
|
||||
mock_get_queue.return_value = mock_queue_client
|
||||
|
||||
mock_prisma = mocker.patch("backend.executor.utils.prisma")
|
||||
mock_prisma.is_connected.return_value = True
|
||||
|
||||
mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db")
|
||||
mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock(
|
||||
return_value=1 # 1 child review cancelled
|
||||
)
|
||||
|
||||
# Mock execution_db to return different status based on which execution is queried
|
||||
mock_execution_db = mocker.patch("backend.executor.utils.execution_db")
|
||||
|
||||
# Track call count to simulate status transition
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def get_exec_meta_side_effect(execution_id, user_id):
|
||||
call_count["count"] += 1
|
||||
if execution_id == parent_exec_id:
|
||||
# After a few calls (child processing happens), transition parent to TERMINATED
|
||||
# This simulates the executor service processing the stop request
|
||||
if call_count["count"] > 3:
|
||||
mock_parent_exec.status = ExecutionStatus.TERMINATED
|
||||
return mock_parent_exec
|
||||
elif execution_id == child_exec_id:
|
||||
return mock_child_exec
|
||||
return None
|
||||
|
||||
mock_execution_db.get_graph_execution_meta = mocker.AsyncMock(
|
||||
side_effect=get_exec_meta_side_effect
|
||||
)
|
||||
mock_execution_db.update_graph_execution_stats = mocker.AsyncMock()
|
||||
|
||||
mock_get_event_bus = mocker.patch(
|
||||
"backend.executor.utils.get_async_execution_event_bus"
|
||||
)
|
||||
mock_event_bus = mocker.MagicMock()
|
||||
mock_event_bus.publish = mocker.AsyncMock()
|
||||
mock_get_event_bus.return_value = mock_event_bus
|
||||
|
||||
# Mock _get_child_executions to return the child
|
||||
mock_get_child_executions = mocker.patch(
|
||||
"backend.executor.utils._get_child_executions"
|
||||
)
|
||||
|
||||
def get_children_side_effect(parent_id):
|
||||
if parent_id == parent_exec_id:
|
||||
return [mock_child_exec]
|
||||
return []
|
||||
|
||||
mock_get_child_executions.side_effect = get_children_side_effect
|
||||
|
||||
# Call stop_graph_execution on parent with cascade=True
|
||||
await stop_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_exec_id=parent_exec_id,
|
||||
wait_timeout=1.0,
|
||||
cascade=True,
|
||||
)
|
||||
|
||||
# Verify child reviews were cancelled
|
||||
mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with(
|
||||
child_exec_id, user_id
|
||||
)
|
||||
|
||||
# Verify both parent and child status updates
|
||||
assert mock_execution_db.update_graph_execution_stats.call_count >= 1
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
-- Remove NodeExecution foreign key from PendingHumanReview
|
||||
-- The nodeExecId column remains as the primary key, but we remove the FK constraint
|
||||
-- to AgentNodeExecution since PendingHumanReview records can persist after node
|
||||
-- execution records are deleted.
|
||||
|
||||
-- Drop foreign key constraint that linked PendingHumanReview.nodeExecId to AgentNodeExecution.id
|
||||
ALTER TABLE "PendingHumanReview" DROP CONSTRAINT IF EXISTS "PendingHumanReview_nodeExecId_fkey";
|
||||
@@ -517,8 +517,6 @@ model AgentNodeExecution {
|
||||
|
||||
stats Json?
|
||||
|
||||
PendingHumanReview PendingHumanReview?
|
||||
|
||||
@@index([agentGraphExecutionId, agentNodeId, executionStatus])
|
||||
@@index([agentNodeId, executionStatus])
|
||||
@@index([addedTime, queuedTime])
|
||||
@@ -567,6 +565,7 @@ enum ReviewStatus {
|
||||
}
|
||||
|
||||
// Pending human reviews for Human-in-the-loop blocks
|
||||
// Also stores auto-approval records with special nodeExecId patterns (e.g., "auto_approve_{graph_exec_id}_{node_id}")
|
||||
model PendingHumanReview {
|
||||
nodeExecId String @id
|
||||
userId String
|
||||
@@ -585,7 +584,6 @@ model PendingHumanReview {
|
||||
reviewedAt DateTime?
|
||||
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
NodeExecution AgentNodeExecution @relation(fields: [nodeExecId], references: [id], onDelete: Cascade)
|
||||
GraphExecution AgentGraphExecution @relation(fields: [graphExecId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@unique([nodeExecId]) // One pending review per node execution
|
||||
|
||||
@@ -29,4 +29,4 @@ NEXT_PUBLIC_CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
NEXT_PUBLIC_TURNSTILE=disabled
|
||||
|
||||
# PR previews
|
||||
NEXT_PUBLIC_PREVIEW_STEALING_DEV=
|
||||
NEXT_PUBLIC_PREVIEW_STEALING_DEV=
|
||||
@@ -175,8 +175,6 @@ While server components and actions are cool and cutting-edge, they introduce a
|
||||
|
||||
- Prefer [React Query](https://tanstack.com/query/latest/docs/framework/react/overview) for server state, colocated near consumers (see [state colocation](https://kentcdodds.com/blog/state-colocation-will-make-your-react-app-faster))
|
||||
- Co-locate UI state inside components/hooks; keep global state minimal
|
||||
- Avoid `useMemo` and `useCallback` unless you have a measured performance issue
|
||||
- Do not abuse `useEffect`; prefer state colocation and derive values directly when possible
|
||||
|
||||
### Styling and components
|
||||
|
||||
@@ -551,48 +549,9 @@ Files:
|
||||
Types:
|
||||
|
||||
- Prefer `interface` for object shapes
|
||||
- Component props should be `interface Props { ... }` (not exported)
|
||||
- Only use specific exported names (e.g., `export interface MyComponentProps`) when the interface needs to be used outside the component
|
||||
- Keep type definitions inline with the component - do not create separate `types.ts` files unless types are shared across multiple files
|
||||
- Component props should be `interface Props { ... }`
|
||||
- Use precise types; avoid `any` and unsafe casts
|
||||
|
||||
**Props naming examples:**
|
||||
|
||||
```tsx
|
||||
// ✅ Good - internal props, not exported
|
||||
interface Props {
|
||||
title: string;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function Modal({ title, onClose }: Props) {
|
||||
// ...
|
||||
}
|
||||
|
||||
// ✅ Good - exported when needed externally
|
||||
export interface ModalProps {
|
||||
title: string;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function Modal({ title, onClose }: ModalProps) {
|
||||
// ...
|
||||
}
|
||||
|
||||
// ❌ Bad - unnecessarily specific name for internal use
|
||||
interface ModalComponentProps {
|
||||
title: string;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
// ❌ Bad - separate types.ts file for single component
|
||||
// types.ts
|
||||
export interface ModalProps { ... }
|
||||
|
||||
// Modal.tsx
|
||||
import type { ModalProps } from './types';
|
||||
```
|
||||
|
||||
Parameters:
|
||||
|
||||
- If more than one parameter is needed, pass a single `Args` object for clarity
|
||||
|
||||
@@ -16,12 +16,6 @@ export default defineConfig({
|
||||
client: "react-query",
|
||||
httpClient: "fetch",
|
||||
indexFiles: false,
|
||||
mock: {
|
||||
type: "msw",
|
||||
baseUrl: "http://localhost:3000/api/proxy",
|
||||
generateEachHttpStatus: true,
|
||||
delay: 0,
|
||||
},
|
||||
override: {
|
||||
mutator: {
|
||||
path: "./mutators/custom-mutator.ts",
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
"types": "tsc --noEmit",
|
||||
"test": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test",
|
||||
"test-ui": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test --ui",
|
||||
"test:unit": "vitest run",
|
||||
"test:unit:watch": "vitest",
|
||||
"test:no-build": "playwright test",
|
||||
"gentests": "playwright codegen http://localhost:3000",
|
||||
"storybook": "storybook dev -p 6006",
|
||||
@@ -129,9 +127,6 @@
|
||||
"@storybook/nextjs": "9.1.5",
|
||||
"@tanstack/eslint-plugin-query": "5.91.2",
|
||||
"@tanstack/react-query-devtools": "5.90.2",
|
||||
"@testing-library/dom": "10.4.1",
|
||||
"@testing-library/jest-dom": "6.9.1",
|
||||
"@testing-library/react": "16.3.2",
|
||||
"@types/canvas-confetti": "1.9.0",
|
||||
"@types/lodash": "4.17.20",
|
||||
"@types/negotiator": "0.6.4",
|
||||
@@ -140,7 +135,6 @@
|
||||
"@types/react-dom": "18.3.5",
|
||||
"@types/react-modal": "3.16.3",
|
||||
"@types/react-window": "1.8.8",
|
||||
"@vitejs/plugin-react": "5.1.2",
|
||||
"axe-playwright": "2.2.2",
|
||||
"chromatic": "13.3.3",
|
||||
"concurrently": "9.2.1",
|
||||
@@ -148,7 +142,6 @@
|
||||
"eslint": "8.57.1",
|
||||
"eslint-config-next": "15.5.7",
|
||||
"eslint-plugin-storybook": "9.1.5",
|
||||
"happy-dom": "20.3.4",
|
||||
"import-in-the-middle": "2.0.2",
|
||||
"msw": "2.11.6",
|
||||
"msw-storybook-addon": "2.0.6",
|
||||
@@ -160,9 +153,7 @@
|
||||
"require-in-the-middle": "8.0.1",
|
||||
"storybook": "9.1.5",
|
||||
"tailwindcss": "3.4.17",
|
||||
"typescript": "5.9.3",
|
||||
"vite-tsconfig-paths": "6.0.4",
|
||||
"vitest": "4.0.17"
|
||||
"typescript": "5.9.3"
|
||||
},
|
||||
"msw": {
|
||||
"workerDirectory": [
|
||||
|
||||
1121
autogpt_platform/frontend/pnpm-lock.yaml
generated
1121
autogpt_platform/frontend/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -1,58 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect, useRef } from "react";
|
||||
|
||||
const LOGOUT_REDIRECT_DELAY_MS = 400;
|
||||
|
||||
function wait(ms: number): Promise<void> {
|
||||
return new Promise(function resolveAfterDelay(resolve) {
|
||||
setTimeout(resolve, ms);
|
||||
});
|
||||
}
|
||||
|
||||
export default function LogoutPage() {
|
||||
const { logOut } = useSupabase();
|
||||
const { toast } = useToast();
|
||||
const router = useRouter();
|
||||
const hasStartedRef = useRef(false);
|
||||
|
||||
useEffect(
|
||||
function handleLogoutEffect() {
|
||||
if (hasStartedRef.current) return;
|
||||
hasStartedRef.current = true;
|
||||
|
||||
async function runLogout() {
|
||||
try {
|
||||
await logOut();
|
||||
} catch {
|
||||
toast({
|
||||
title: "Failed to log out. Redirecting to login.",
|
||||
variant: "destructive",
|
||||
});
|
||||
} finally {
|
||||
await wait(LOGOUT_REDIRECT_DELAY_MS);
|
||||
router.replace("/login");
|
||||
}
|
||||
}
|
||||
|
||||
void runLogout();
|
||||
},
|
||||
[logOut, router, toast],
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex min-h-screen items-center justify-center px-4">
|
||||
<div className="flex flex-col items-center justify-center gap-4 py-8">
|
||||
<LoadingSpinner size="large" />
|
||||
<Text variant="body" className="text-center">
|
||||
Logging you out...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -9,7 +9,7 @@ export async function GET(request: Request) {
|
||||
const { searchParams, origin } = new URL(request.url);
|
||||
const code = searchParams.get("code");
|
||||
|
||||
let next = "/";
|
||||
let next = "/marketplace";
|
||||
|
||||
if (code) {
|
||||
const supabase = await getServerSupabase();
|
||||
|
||||
@@ -86,7 +86,6 @@ export function FloatingSafeModeToggle({
|
||||
const {
|
||||
currentHITLSafeMode,
|
||||
showHITLToggle,
|
||||
isHITLStateUndetermined,
|
||||
handleHITLToggle,
|
||||
currentSensitiveActionSafeMode,
|
||||
showSensitiveActionToggle,
|
||||
@@ -99,16 +98,9 @@ export function FloatingSafeModeToggle({
|
||||
return null;
|
||||
}
|
||||
|
||||
const showHITL = showHITLToggle && !isHITLStateUndetermined;
|
||||
const showSensitive = showSensitiveActionToggle;
|
||||
|
||||
if (!showHITL && !showSensitive) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("fixed z-50 flex flex-col gap-2", className)}>
|
||||
{showHITL && (
|
||||
{showHITLToggle && (
|
||||
<SafeModeButton
|
||||
isEnabled={currentHITLSafeMode}
|
||||
label="Human in the loop block approval"
|
||||
@@ -119,7 +111,7 @@ export function FloatingSafeModeToggle({
|
||||
fullWidth={fullWidth}
|
||||
/>
|
||||
)}
|
||||
{showSensitive && (
|
||||
{showSensitiveActionToggle && (
|
||||
<SafeModeButton
|
||||
isEnabled={currentSensitiveActionSafeMode}
|
||||
label="Sensitive actions blocks approval"
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { List } from "@phosphor-icons/react";
|
||||
import React, { useState } from "react";
|
||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||
import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState";
|
||||
import { ChatLoadingState } from "./components/ChatLoadingState/ChatLoadingState";
|
||||
import { SessionsDrawer } from "./components/SessionsDrawer/SessionsDrawer";
|
||||
import { useChat } from "./useChat";
|
||||
|
||||
export interface ChatProps {
|
||||
className?: string;
|
||||
headerTitle?: React.ReactNode;
|
||||
showHeader?: boolean;
|
||||
showSessionInfo?: boolean;
|
||||
showNewChatButton?: boolean;
|
||||
onNewChat?: () => void;
|
||||
headerActions?: React.ReactNode;
|
||||
}
|
||||
|
||||
export function Chat({
|
||||
className,
|
||||
headerTitle = "AutoGPT Copilot",
|
||||
showHeader = true,
|
||||
showSessionInfo = true,
|
||||
showNewChatButton = true,
|
||||
onNewChat,
|
||||
headerActions,
|
||||
}: ChatProps) {
|
||||
const {
|
||||
messages,
|
||||
isLoading,
|
||||
isCreating,
|
||||
error,
|
||||
sessionId,
|
||||
createSession,
|
||||
clearSession,
|
||||
loadSession,
|
||||
} = useChat();
|
||||
|
||||
const [isSessionsDrawerOpen, setIsSessionsDrawerOpen] = useState(false);
|
||||
|
||||
const handleNewChat = () => {
|
||||
clearSession();
|
||||
onNewChat?.();
|
||||
};
|
||||
|
||||
const handleSelectSession = async (sessionId: string) => {
|
||||
try {
|
||||
await loadSession(sessionId);
|
||||
} catch (err) {
|
||||
console.error("Failed to load session:", err);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={cn("flex h-full flex-col", className)}>
|
||||
{/* Header */}
|
||||
{showHeader && (
|
||||
<header className="shrink-0 border-t border-zinc-200 bg-white p-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<button
|
||||
aria-label="View sessions"
|
||||
onClick={() => setIsSessionsDrawerOpen(true)}
|
||||
className="flex size-8 items-center justify-center rounded hover:bg-zinc-100"
|
||||
>
|
||||
<List width="1.25rem" height="1.25rem" />
|
||||
</button>
|
||||
{typeof headerTitle === "string" ? (
|
||||
<Text variant="h2" className="text-lg font-semibold">
|
||||
{headerTitle}
|
||||
</Text>
|
||||
) : (
|
||||
headerTitle
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-3">
|
||||
{showSessionInfo && sessionId && (
|
||||
<>
|
||||
{showNewChatButton && (
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={handleNewChat}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
{headerActions}
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
)}
|
||||
|
||||
{/* Main Content */}
|
||||
<main className="flex min-h-0 flex-1 flex-col overflow-hidden">
|
||||
{/* Loading State - show when explicitly loading/creating OR when we don't have a session yet and no error */}
|
||||
{(isLoading || isCreating || (!sessionId && !error)) && (
|
||||
<ChatLoadingState
|
||||
message={isCreating ? "Creating session..." : "Loading..."}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Error State */}
|
||||
{error && !isLoading && (
|
||||
<ChatErrorState error={error} onRetry={createSession} />
|
||||
)}
|
||||
|
||||
{/* Session Content */}
|
||||
{sessionId && !isLoading && !error && (
|
||||
<ChatContainer
|
||||
sessionId={sessionId}
|
||||
initialMessages={messages}
|
||||
className="flex-1"
|
||||
/>
|
||||
)}
|
||||
</main>
|
||||
|
||||
{/* Sessions Drawer */}
|
||||
<SessionsDrawer
|
||||
isOpen={isSessionsDrawerOpen}
|
||||
onClose={() => setIsSessionsDrawerOpen(false)}
|
||||
onSelectSession={handleSelectSession}
|
||||
currentSessionId={sessionId}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -21,7 +21,7 @@ export function AuthPromptWidget({
|
||||
message,
|
||||
sessionId,
|
||||
agentInfo,
|
||||
returnUrl = "/copilot/chat",
|
||||
returnUrl = "/chat",
|
||||
className,
|
||||
}: AuthPromptWidgetProps) {
|
||||
const router = useRouter();
|
||||
@@ -0,0 +1,88 @@
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useCallback } from "react";
|
||||
import { usePageContext } from "../../usePageContext";
|
||||
import { ChatInput } from "../ChatInput/ChatInput";
|
||||
import { MessageList } from "../MessageList/MessageList";
|
||||
import { QuickActionsWelcome } from "../QuickActionsWelcome/QuickActionsWelcome";
|
||||
import { useChatContainer } from "./useChatContainer";
|
||||
|
||||
export interface ChatContainerProps {
|
||||
sessionId: string | null;
|
||||
initialMessages: SessionDetailResponse["messages"];
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
className,
|
||||
}: ChatContainerProps) {
|
||||
const { messages, streamingChunks, isStreaming, sendMessage } =
|
||||
useChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
});
|
||||
const { capturePageContext } = usePageContext();
|
||||
|
||||
// Wrap sendMessage to automatically capture page context
|
||||
const sendMessageWithContext = useCallback(
|
||||
async (content: string, isUserMessage: boolean = true) => {
|
||||
const context = capturePageContext();
|
||||
await sendMessage(content, isUserMessage, context);
|
||||
},
|
||||
[sendMessage, capturePageContext],
|
||||
);
|
||||
|
||||
const quickActions = [
|
||||
"Find agents for social media management",
|
||||
"Show me agents for content creation",
|
||||
"Help me automate my business",
|
||||
"What can you help me with?",
|
||||
];
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn("flex h-full min-h-0 flex-col", className)}
|
||||
style={{
|
||||
backgroundColor: "#ffffff",
|
||||
backgroundImage:
|
||||
"radial-gradient(#e5e5e5 0.5px, transparent 0.5px), radial-gradient(#e5e5e5 0.5px, #ffffff 0.5px)",
|
||||
backgroundSize: "20px 20px",
|
||||
backgroundPosition: "0 0, 10px 10px",
|
||||
}}
|
||||
>
|
||||
{/* Messages or Welcome Screen */}
|
||||
<div className="flex min-h-0 flex-1 flex-col overflow-hidden pb-24">
|
||||
{messages.length === 0 ? (
|
||||
<QuickActionsWelcome
|
||||
title="Welcome to AutoGPT Copilot"
|
||||
description="Start a conversation to discover and run AI agents."
|
||||
actions={quickActions}
|
||||
onActionClick={sendMessageWithContext}
|
||||
disabled={isStreaming || !sessionId}
|
||||
/>
|
||||
) : (
|
||||
<MessageList
|
||||
messages={messages}
|
||||
streamingChunks={streamingChunks}
|
||||
isStreaming={isStreaming}
|
||||
onSendMessage={sendMessageWithContext}
|
||||
className="flex-1"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Input - Always visible */}
|
||||
<div className="fixed bottom-0 left-0 right-0 z-50 border-t border-zinc-200 bg-white p-4">
|
||||
<ChatInput
|
||||
onSend={sendMessageWithContext}
|
||||
disabled={isStreaming || !sessionId}
|
||||
placeholder={
|
||||
sessionId ? "Type your message..." : "Creating session..."
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
import { toast } from "sonner";
|
||||
import { StreamChunk } from "../../useChatStream";
|
||||
import type { HandlerDependencies } from "./handlers";
|
||||
import type { HandlerDependencies } from "./useChatContainer.handlers";
|
||||
import {
|
||||
handleError,
|
||||
handleLoginNeeded,
|
||||
@@ -9,30 +9,12 @@ import {
|
||||
handleTextEnded,
|
||||
handleToolCallStart,
|
||||
handleToolResponse,
|
||||
isRegionBlockedError,
|
||||
} from "./handlers";
|
||||
} from "./useChatContainer.handlers";
|
||||
|
||||
export function createStreamEventDispatcher(
|
||||
deps: HandlerDependencies,
|
||||
): (chunk: StreamChunk) => void {
|
||||
return function dispatchStreamEvent(chunk: StreamChunk): void {
|
||||
if (
|
||||
chunk.type === "text_chunk" ||
|
||||
chunk.type === "tool_call_start" ||
|
||||
chunk.type === "tool_response" ||
|
||||
chunk.type === "login_needed" ||
|
||||
chunk.type === "need_login" ||
|
||||
chunk.type === "error"
|
||||
) {
|
||||
if (!deps.hasResponseRef.current) {
|
||||
console.info("[ChatStream] First response chunk:", {
|
||||
type: chunk.type,
|
||||
sessionId: deps.sessionId,
|
||||
});
|
||||
}
|
||||
deps.hasResponseRef.current = true;
|
||||
}
|
||||
|
||||
switch (chunk.type) {
|
||||
case "text_chunk":
|
||||
handleTextChunk(chunk, deps);
|
||||
@@ -56,23 +38,15 @@ export function createStreamEventDispatcher(
|
||||
break;
|
||||
|
||||
case "stream_end":
|
||||
console.info("[ChatStream] Stream ended:", {
|
||||
sessionId: deps.sessionId,
|
||||
hasResponse: deps.hasResponseRef.current,
|
||||
chunkCount: deps.streamingChunksRef.current.length,
|
||||
});
|
||||
handleStreamEnd(chunk, deps);
|
||||
break;
|
||||
|
||||
case "error":
|
||||
const isRegionBlocked = isRegionBlockedError(chunk);
|
||||
handleError(chunk, deps);
|
||||
// Show toast at dispatcher level to avoid circular dependencies
|
||||
if (!isRegionBlocked) {
|
||||
toast.error("Chat Error", {
|
||||
description: chunk.message || chunk.content || "An error occurred",
|
||||
});
|
||||
}
|
||||
toast.error("Chat Error", {
|
||||
description: chunk.message || chunk.content || "An error occurred",
|
||||
});
|
||||
break;
|
||||
|
||||
case "usage":
|
||||
@@ -1,33 +1,6 @@
|
||||
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
||||
import type { ToolResult } from "@/types/chat";
|
||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||
|
||||
export function hasSentInitialPrompt(sessionId: string): boolean {
|
||||
try {
|
||||
const sent = JSON.parse(
|
||||
sessionStorage.get(SessionKey.CHAT_SENT_INITIAL_PROMPTS) || "{}",
|
||||
);
|
||||
return sent[sessionId] === true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export function markInitialPromptSent(sessionId: string): void {
|
||||
try {
|
||||
const sent = JSON.parse(
|
||||
sessionStorage.get(SessionKey.CHAT_SENT_INITIAL_PROMPTS) || "{}",
|
||||
);
|
||||
sent[sessionId] = true;
|
||||
sessionStorage.set(
|
||||
SessionKey.CHAT_SENT_INITIAL_PROMPTS,
|
||||
JSON.stringify(sent),
|
||||
);
|
||||
} catch {
|
||||
// Ignore storage errors
|
||||
}
|
||||
}
|
||||
|
||||
export function removePageContext(content: string): string {
|
||||
// Remove "Page URL: ..." pattern at start of line (case insensitive, handles various formats)
|
||||
let cleaned = content.replace(/^\s*Page URL:\s*[^\n\r]*/gim, "");
|
||||
@@ -234,22 +207,12 @@ export function parseToolResponse(
|
||||
if (responseType === "setup_requirements") {
|
||||
return null;
|
||||
}
|
||||
if (responseType === "understanding_updated") {
|
||||
return {
|
||||
type: "tool_response",
|
||||
toolId,
|
||||
toolName,
|
||||
result: (parsedResult || result) as ToolResult,
|
||||
success: true,
|
||||
timestamp: timestamp || new Date(),
|
||||
};
|
||||
}
|
||||
}
|
||||
return {
|
||||
type: "tool_response",
|
||||
toolId,
|
||||
toolName,
|
||||
result: parsedResult ? (parsedResult as ToolResult) : result,
|
||||
result,
|
||||
success: true,
|
||||
timestamp: timestamp || new Date(),
|
||||
};
|
||||
@@ -7,30 +7,15 @@ import {
|
||||
parseToolResponse,
|
||||
} from "./helpers";
|
||||
|
||||
function isToolCallMessage(
|
||||
message: ChatMessageData,
|
||||
): message is Extract<ChatMessageData, { type: "tool_call" }> {
|
||||
return message.type === "tool_call";
|
||||
}
|
||||
|
||||
export interface HandlerDependencies {
|
||||
setHasTextChunks: Dispatch<SetStateAction<boolean>>;
|
||||
setStreamingChunks: Dispatch<SetStateAction<string[]>>;
|
||||
streamingChunksRef: MutableRefObject<string[]>;
|
||||
hasResponseRef: MutableRefObject<boolean>;
|
||||
setMessages: Dispatch<SetStateAction<ChatMessageData[]>>;
|
||||
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
||||
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
||||
sessionId: string;
|
||||
}
|
||||
|
||||
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
||||
if (chunk.code === "MODEL_NOT_AVAILABLE_REGION") return true;
|
||||
const message = chunk.message || chunk.content;
|
||||
if (typeof message !== "string") return false;
|
||||
return message.toLowerCase().includes("not available in your region");
|
||||
}
|
||||
|
||||
export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||
if (!chunk.content) return;
|
||||
deps.setHasTextChunks(true);
|
||||
@@ -45,17 +30,16 @@ export function handleTextEnded(
|
||||
_chunk: StreamChunk,
|
||||
deps: HandlerDependencies,
|
||||
) {
|
||||
console.log("[Text Ended] Saving streamed text as assistant message");
|
||||
const completedText = deps.streamingChunksRef.current.join("");
|
||||
if (completedText.trim()) {
|
||||
deps.setMessages((prev) => {
|
||||
const assistantMessage: ChatMessageData = {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: completedText,
|
||||
timestamp: new Date(),
|
||||
};
|
||||
return [...prev, assistantMessage];
|
||||
});
|
||||
const assistantMessage: ChatMessageData = {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: completedText,
|
||||
timestamp: new Date(),
|
||||
};
|
||||
deps.setMessages((prev) => [...prev, assistantMessage]);
|
||||
}
|
||||
deps.setStreamingChunks([]);
|
||||
deps.streamingChunksRef.current = [];
|
||||
@@ -66,45 +50,30 @@ export function handleToolCallStart(
|
||||
chunk: StreamChunk,
|
||||
deps: HandlerDependencies,
|
||||
) {
|
||||
const toolCallMessage: Extract<ChatMessageData, { type: "tool_call" }> = {
|
||||
const toolCallMessage: ChatMessageData = {
|
||||
type: "tool_call",
|
||||
toolId: chunk.tool_id || `tool-${Date.now()}-${chunk.idx || 0}`,
|
||||
toolName: chunk.tool_name || "Executing",
|
||||
toolName: chunk.tool_name || "Executing...",
|
||||
arguments: chunk.arguments || {},
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
function updateToolCallMessages(prev: ChatMessageData[]) {
|
||||
const existingIndex = prev.findIndex(function findToolCallIndex(msg) {
|
||||
return isToolCallMessage(msg) && msg.toolId === toolCallMessage.toolId;
|
||||
});
|
||||
if (existingIndex === -1) {
|
||||
return [...prev, toolCallMessage];
|
||||
}
|
||||
const nextMessages = [...prev];
|
||||
const existing = nextMessages[existingIndex];
|
||||
if (!isToolCallMessage(existing)) return prev;
|
||||
const nextArguments =
|
||||
toolCallMessage.arguments &&
|
||||
Object.keys(toolCallMessage.arguments).length > 0
|
||||
? toolCallMessage.arguments
|
||||
: existing.arguments;
|
||||
nextMessages[existingIndex] = {
|
||||
...existing,
|
||||
toolName: toolCallMessage.toolName || existing.toolName,
|
||||
arguments: nextArguments,
|
||||
timestamp: toolCallMessage.timestamp,
|
||||
};
|
||||
return nextMessages;
|
||||
}
|
||||
|
||||
deps.setMessages(updateToolCallMessages);
|
||||
deps.setMessages((prev) => [...prev, toolCallMessage]);
|
||||
console.log("[Tool Call Start]", {
|
||||
toolId: toolCallMessage.toolId,
|
||||
toolName: toolCallMessage.toolName,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
}
|
||||
|
||||
export function handleToolResponse(
|
||||
chunk: StreamChunk,
|
||||
deps: HandlerDependencies,
|
||||
) {
|
||||
console.log("[Tool Response] Received:", {
|
||||
toolId: chunk.tool_id,
|
||||
toolName: chunk.tool_name,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
let toolName = chunk.tool_name || "unknown";
|
||||
if (!chunk.tool_name || chunk.tool_name === "unknown") {
|
||||
deps.setMessages((prev) => {
|
||||
@@ -158,15 +127,22 @@ export function handleToolResponse(
|
||||
const toolCallIndex = prev.findIndex(
|
||||
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
||||
);
|
||||
const hasResponse = prev.some(
|
||||
(msg) => msg.type === "tool_response" && msg.toolId === chunk.tool_id,
|
||||
);
|
||||
if (hasResponse) return prev;
|
||||
if (toolCallIndex !== -1) {
|
||||
const newMessages = [...prev];
|
||||
newMessages.splice(toolCallIndex + 1, 0, responseMessage);
|
||||
newMessages[toolCallIndex] = responseMessage;
|
||||
console.log(
|
||||
"[Tool Response] Replaced tool_call with matching tool_id:",
|
||||
chunk.tool_id,
|
||||
"at index:",
|
||||
toolCallIndex,
|
||||
);
|
||||
return newMessages;
|
||||
}
|
||||
console.warn(
|
||||
"[Tool Response] No tool_call found with tool_id:",
|
||||
chunk.tool_id,
|
||||
"appending instead",
|
||||
);
|
||||
return [...prev, responseMessage];
|
||||
});
|
||||
}
|
||||
@@ -191,38 +167,55 @@ export function handleStreamEnd(
|
||||
deps: HandlerDependencies,
|
||||
) {
|
||||
const completedContent = deps.streamingChunksRef.current.join("");
|
||||
if (!completedContent.trim() && !deps.hasResponseRef.current) {
|
||||
deps.setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: "No response received. Please try again.",
|
||||
timestamp: new Date(),
|
||||
},
|
||||
]);
|
||||
}
|
||||
// Only save message if there are uncommitted chunks
|
||||
// (text_ended already saved if there were tool calls)
|
||||
if (completedContent.trim()) {
|
||||
console.log(
|
||||
"[Stream End] Saving remaining streamed text as assistant message",
|
||||
);
|
||||
const assistantMessage: ChatMessageData = {
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: completedContent,
|
||||
timestamp: new Date(),
|
||||
};
|
||||
deps.setMessages((prev) => [...prev, assistantMessage]);
|
||||
deps.setMessages((prev) => {
|
||||
const updated = [...prev, assistantMessage];
|
||||
console.log("[Stream End] Final state:", {
|
||||
localMessages: updated.map((m) => ({
|
||||
type: m.type,
|
||||
...(m.type === "message" && {
|
||||
role: m.role,
|
||||
contentLength: m.content.length,
|
||||
}),
|
||||
...(m.type === "tool_call" && {
|
||||
toolId: m.toolId,
|
||||
toolName: m.toolName,
|
||||
}),
|
||||
...(m.type === "tool_response" && {
|
||||
toolId: m.toolId,
|
||||
toolName: m.toolName,
|
||||
success: m.success,
|
||||
}),
|
||||
})),
|
||||
streamingChunks: deps.streamingChunksRef.current,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
return updated;
|
||||
});
|
||||
} else {
|
||||
console.log("[Stream End] No uncommitted chunks, message already saved");
|
||||
}
|
||||
deps.setStreamingChunks([]);
|
||||
deps.streamingChunksRef.current = [];
|
||||
deps.setHasTextChunks(false);
|
||||
deps.setIsStreamingInitiated(false);
|
||||
console.log("[Stream End] Stream complete, messages in local state");
|
||||
}
|
||||
|
||||
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||
const errorMessage = chunk.message || chunk.content || "An error occurred";
|
||||
console.error("Stream error:", errorMessage);
|
||||
if (isRegionBlockedError(chunk)) {
|
||||
deps.setIsRegionBlockedModalOpen(true);
|
||||
}
|
||||
deps.setIsStreamingInitiated(false);
|
||||
deps.setHasTextChunks(false);
|
||||
deps.setStreamingChunks([]);
|
||||
@@ -1,17 +1,14 @@
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useCallback, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { useChatStream } from "../../useChatStream";
|
||||
import { usePageContext } from "../../usePageContext";
|
||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
||||
import {
|
||||
createUserMessage,
|
||||
filterAuthMessages,
|
||||
hasSentInitialPrompt,
|
||||
isToolCallArray,
|
||||
isValidMessage,
|
||||
markInitialPromptSent,
|
||||
parseToolResponse,
|
||||
removePageContext,
|
||||
} from "./helpers";
|
||||
@@ -19,45 +16,20 @@ import {
|
||||
interface Args {
|
||||
sessionId: string | null;
|
||||
initialMessages: SessionDetailResponse["messages"];
|
||||
initialPrompt?: string;
|
||||
}
|
||||
|
||||
export function useChatContainer({
|
||||
sessionId,
|
||||
initialMessages,
|
||||
initialPrompt,
|
||||
}: Args) {
|
||||
export function useChatContainer({ sessionId, initialMessages }: Args) {
|
||||
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
||||
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
||||
const [hasTextChunks, setHasTextChunks] = useState(false);
|
||||
const [isStreamingInitiated, setIsStreamingInitiated] = useState(false);
|
||||
const [isRegionBlockedModalOpen, setIsRegionBlockedModalOpen] =
|
||||
useState(false);
|
||||
const hasResponseRef = useRef(false);
|
||||
const streamingChunksRef = useRef<string[]>([]);
|
||||
const previousSessionIdRef = useRef<string | null>(null);
|
||||
const {
|
||||
error,
|
||||
sendMessage: sendStreamMessage,
|
||||
stopStreaming,
|
||||
} = useChatStream();
|
||||
const { error, sendMessage: sendStreamMessage } = useChatStream();
|
||||
const isStreaming = isStreamingInitiated || hasTextChunks;
|
||||
|
||||
useEffect(() => {
|
||||
if (sessionId !== previousSessionIdRef.current) {
|
||||
stopStreaming(previousSessionIdRef.current ?? undefined, true);
|
||||
previousSessionIdRef.current = sessionId;
|
||||
setMessages([]);
|
||||
setStreamingChunks([]);
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
setIsStreamingInitiated(false);
|
||||
hasResponseRef.current = false;
|
||||
}
|
||||
}, [sessionId, stopStreaming]);
|
||||
|
||||
const allMessages = useMemo(() => {
|
||||
const processedInitialMessages: ChatMessageData[] = [];
|
||||
// Map to track tool calls by their ID so we can look up tool names for tool responses
|
||||
const toolCallMap = new Map<string, string>();
|
||||
|
||||
for (const msg of initialMessages) {
|
||||
@@ -73,9 +45,13 @@ export function useChatContainer({
|
||||
? new Date(msg.timestamp as string)
|
||||
: undefined;
|
||||
|
||||
// Remove page context from user messages when loading existing sessions
|
||||
if (role === "user") {
|
||||
content = removePageContext(content);
|
||||
if (!content.trim()) continue;
|
||||
// Skip user messages that become empty after removing page context
|
||||
if (!content.trim()) {
|
||||
continue;
|
||||
}
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
role: "user",
|
||||
@@ -85,15 +61,19 @@ export function useChatContainer({
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle assistant messages first (before tool messages) to build tool call map
|
||||
if (role === "assistant") {
|
||||
// Strip <thinking> tags from content
|
||||
content = content
|
||||
.replace(/<thinking>[\s\S]*?<\/thinking>/gi, "")
|
||||
.trim();
|
||||
|
||||
// If assistant has tool calls, create tool_call messages for each
|
||||
if (toolCalls && isToolCallArray(toolCalls) && toolCalls.length > 0) {
|
||||
for (const toolCall of toolCalls) {
|
||||
const toolName = toolCall.function.name;
|
||||
const toolId = toolCall.id;
|
||||
// Store tool name for later lookup
|
||||
toolCallMap.set(toolId, toolName);
|
||||
|
||||
try {
|
||||
@@ -116,6 +96,7 @@ export function useChatContainer({
|
||||
});
|
||||
}
|
||||
}
|
||||
// Only add assistant message if there's content after stripping thinking tags
|
||||
if (content.trim()) {
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
@@ -125,6 +106,7 @@ export function useChatContainer({
|
||||
});
|
||||
}
|
||||
} else if (content.trim()) {
|
||||
// Assistant message without tool calls, but with content
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
@@ -135,6 +117,7 @@ export function useChatContainer({
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle tool messages - look up tool name from tool call map
|
||||
if (role === "tool") {
|
||||
const toolCallId = (msg.tool_call_id as string) || "";
|
||||
const toolName = toolCallMap.get(toolCallId) || "unknown";
|
||||
@@ -150,6 +133,7 @@ export function useChatContainer({
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle other message types (system, etc.)
|
||||
if (content.trim()) {
|
||||
processedInitialMessages.push({
|
||||
type: "message",
|
||||
@@ -170,10 +154,9 @@ export function useChatContainer({
|
||||
context?: { url: string; content: string },
|
||||
) {
|
||||
if (!sessionId) {
|
||||
console.error("[useChatContainer] Cannot send message: no session ID");
|
||||
console.error("Cannot send message: no session ID");
|
||||
return;
|
||||
}
|
||||
setIsRegionBlockedModalOpen(false);
|
||||
if (isUserMessage) {
|
||||
const userMessage = createUserMessage(content);
|
||||
setMessages((prev) => [...filterAuthMessages(prev), userMessage]);
|
||||
@@ -184,19 +167,14 @@ export function useChatContainer({
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
setIsStreamingInitiated(true);
|
||||
hasResponseRef.current = false;
|
||||
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
hasResponseRef,
|
||||
setMessages,
|
||||
setIsRegionBlockedModalOpen,
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
});
|
||||
|
||||
try {
|
||||
await sendStreamMessage(
|
||||
sessionId,
|
||||
@@ -206,12 +184,8 @@ export function useChatContainer({
|
||||
context,
|
||||
);
|
||||
} catch (err) {
|
||||
console.error("[useChatContainer] Failed to send message:", err);
|
||||
console.error("Failed to send message:", err);
|
||||
setIsStreamingInitiated(false);
|
||||
|
||||
// Don't show error toast for AbortError (expected during cleanup)
|
||||
if (err instanceof Error && err.name === "AbortError") return;
|
||||
|
||||
const errorMessage =
|
||||
err instanceof Error ? err.message : "Failed to send message";
|
||||
toast.error("Failed to send message", {
|
||||
@@ -222,63 +196,11 @@ export function useChatContainer({
|
||||
[sessionId, sendStreamMessage],
|
||||
);
|
||||
|
||||
const handleStopStreaming = useCallback(() => {
|
||||
stopStreaming();
|
||||
setStreamingChunks([]);
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
setIsStreamingInitiated(false);
|
||||
}, [stopStreaming]);
|
||||
|
||||
const { capturePageContext } = usePageContext();
|
||||
|
||||
// Send initial prompt if provided (for new sessions from homepage)
|
||||
useEffect(
|
||||
function handleInitialPrompt() {
|
||||
if (!initialPrompt || !sessionId) return;
|
||||
if (initialMessages.length > 0) return;
|
||||
if (hasSentInitialPrompt(sessionId)) return;
|
||||
|
||||
markInitialPromptSent(sessionId);
|
||||
const context = capturePageContext();
|
||||
sendMessage(initialPrompt, true, context);
|
||||
},
|
||||
[
|
||||
initialPrompt,
|
||||
sessionId,
|
||||
initialMessages.length,
|
||||
sendMessage,
|
||||
capturePageContext,
|
||||
],
|
||||
);
|
||||
|
||||
async function sendMessageWithContext(
|
||||
content: string,
|
||||
isUserMessage: boolean = true,
|
||||
) {
|
||||
const context = capturePageContext();
|
||||
await sendMessage(content, isUserMessage, context);
|
||||
}
|
||||
|
||||
function handleRegionModalOpenChange(open: boolean) {
|
||||
setIsRegionBlockedModalOpen(open);
|
||||
}
|
||||
|
||||
function handleRegionModalClose() {
|
||||
setIsRegionBlockedModalOpen(false);
|
||||
}
|
||||
|
||||
return {
|
||||
messages: allMessages,
|
||||
streamingChunks,
|
||||
isStreaming,
|
||||
error,
|
||||
isRegionBlockedModalOpen,
|
||||
setIsRegionBlockedModalOpen,
|
||||
sendMessageWithContext,
|
||||
handleRegionModalOpenChange,
|
||||
handleRegionModalClose,
|
||||
sendMessage,
|
||||
stopStreaming: handleStopStreaming,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ArrowUpIcon } from "@phosphor-icons/react";
|
||||
import { useChatInput } from "./useChatInput";
|
||||
|
||||
export interface ChatInputProps {
|
||||
onSend: (message: string) => void;
|
||||
disabled?: boolean;
|
||||
placeholder?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ChatInput({
|
||||
onSend,
|
||||
disabled = false,
|
||||
placeholder = "Type your message...",
|
||||
className,
|
||||
}: ChatInputProps) {
|
||||
const inputId = "chat-input";
|
||||
const { value, setValue, handleKeyDown, handleSend } = useChatInput({
|
||||
onSend,
|
||||
disabled,
|
||||
maxRows: 5,
|
||||
inputId,
|
||||
});
|
||||
|
||||
return (
|
||||
<div className={cn("relative flex-1", className)}>
|
||||
<Input
|
||||
id={inputId}
|
||||
label="Chat message input"
|
||||
hideLabel
|
||||
type="textarea"
|
||||
value={value}
|
||||
onChange={(e) => setValue(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={placeholder}
|
||||
disabled={disabled}
|
||||
rows={1}
|
||||
wrapperClassName="mb-0 relative"
|
||||
className="pr-12"
|
||||
/>
|
||||
<span id="chat-input-hint" className="sr-only">
|
||||
Press Enter to send, Shift+Enter for new line
|
||||
</span>
|
||||
|
||||
<button
|
||||
onClick={handleSend}
|
||||
disabled={disabled || !value.trim()}
|
||||
className={cn(
|
||||
"absolute right-3 top-1/2 flex h-8 w-8 -translate-y-1/2 items-center justify-center rounded-full",
|
||||
"border border-zinc-800 bg-zinc-800 text-white",
|
||||
"hover:border-zinc-900 hover:bg-zinc-900",
|
||||
"disabled:border-zinc-200 disabled:bg-zinc-200 disabled:text-white disabled:opacity-50",
|
||||
"transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-neutral-950",
|
||||
"disabled:pointer-events-none",
|
||||
)}
|
||||
aria-label="Send message"
|
||||
>
|
||||
<ArrowUpIcon className="h-3 w-3" weight="bold" />
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
import { KeyboardEvent, useCallback, useEffect, useState } from "react";
|
||||
|
||||
interface UseChatInputArgs {
|
||||
onSend: (message: string) => void;
|
||||
disabled?: boolean;
|
||||
maxRows?: number;
|
||||
inputId?: string;
|
||||
}
|
||||
|
||||
export function useChatInput({
|
||||
onSend,
|
||||
disabled = false,
|
||||
maxRows = 5,
|
||||
inputId = "chat-input",
|
||||
}: UseChatInputArgs) {
|
||||
const [value, setValue] = useState("");
|
||||
|
||||
useEffect(() => {
|
||||
const textarea = document.getElementById(inputId) as HTMLTextAreaElement;
|
||||
if (!textarea) return;
|
||||
textarea.style.height = "auto";
|
||||
const lineHeight = parseInt(
|
||||
window.getComputedStyle(textarea).lineHeight,
|
||||
10,
|
||||
);
|
||||
const maxHeight = lineHeight * maxRows;
|
||||
const newHeight = Math.min(textarea.scrollHeight, maxHeight);
|
||||
textarea.style.height = `${newHeight}px`;
|
||||
textarea.style.overflowY =
|
||||
textarea.scrollHeight > maxHeight ? "auto" : "hidden";
|
||||
}, [value, maxRows, inputId]);
|
||||
|
||||
const handleSend = useCallback(() => {
|
||||
if (disabled || !value.trim()) return;
|
||||
onSend(value.trim());
|
||||
setValue("");
|
||||
const textarea = document.getElementById(inputId) as HTMLTextAreaElement;
|
||||
if (textarea) {
|
||||
textarea.style.height = "auto";
|
||||
}
|
||||
}, [value, onSend, disabled, inputId]);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(event: KeyboardEvent<HTMLInputElement | HTMLTextAreaElement>) => {
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
handleSend();
|
||||
}
|
||||
// Shift+Enter allows default behavior (new line) - no need to handle explicitly
|
||||
},
|
||||
[handleSend],
|
||||
);
|
||||
|
||||
return {
|
||||
value,
|
||||
setValue,
|
||||
handleKeyDown,
|
||||
handleSend,
|
||||
};
|
||||
}
|
||||
@@ -1,65 +1,48 @@
|
||||
"use client";
|
||||
|
||||
import { useGetV2GetUserProfile } from "@/app/api/__generated__/endpoints/store/store";
|
||||
import Avatar, {
|
||||
AvatarFallback,
|
||||
AvatarImage,
|
||||
} from "@/components/atoms/Avatar/Avatar";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
ArrowsClockwiseIcon,
|
||||
ArrowClockwise,
|
||||
CheckCircleIcon,
|
||||
CheckIcon,
|
||||
CopyIcon,
|
||||
RobotIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useCallback, useState } from "react";
|
||||
import { getToolActionPhrase } from "../../helpers";
|
||||
import { AgentCarouselMessage } from "../AgentCarouselMessage/AgentCarouselMessage";
|
||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
||||
import { AuthPromptWidget } from "../AuthPromptWidget/AuthPromptWidget";
|
||||
import { ChatCredentialsSetup } from "../ChatCredentialsSetup/ChatCredentialsSetup";
|
||||
import { ExecutionStartedMessage } from "../ExecutionStartedMessage/ExecutionStartedMessage";
|
||||
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
||||
import { MessageBubble } from "../MessageBubble/MessageBubble";
|
||||
import { NoResultsMessage } from "../NoResultsMessage/NoResultsMessage";
|
||||
import { ToolCallMessage } from "../ToolCallMessage/ToolCallMessage";
|
||||
import { ToolResponseMessage } from "../ToolResponseMessage/ToolResponseMessage";
|
||||
import { UserChatBubble } from "../UserChatBubble/UserChatBubble";
|
||||
import { useChatMessage, type ChatMessageData } from "./useChatMessage";
|
||||
|
||||
function stripInternalReasoning(content: string): string {
|
||||
const cleaned = content.replace(
|
||||
/<internal_reasoning>[\s\S]*?<\/internal_reasoning>/gi,
|
||||
"",
|
||||
);
|
||||
return cleaned.replace(/\n{3,}/g, "\n\n").trim();
|
||||
}
|
||||
|
||||
function getDisplayContent(message: ChatMessageData, isUser: boolean): string {
|
||||
if (message.type !== "message") return "";
|
||||
if (isUser) return message.content;
|
||||
return stripInternalReasoning(message.content);
|
||||
}
|
||||
|
||||
export interface ChatMessageProps {
|
||||
message: ChatMessageData;
|
||||
messages?: ChatMessageData[];
|
||||
index?: number;
|
||||
isStreaming?: boolean;
|
||||
className?: string;
|
||||
onDismissLogin?: () => void;
|
||||
onDismissCredentials?: () => void;
|
||||
onSendMessage?: (content: string, isUserMessage?: boolean) => void;
|
||||
agentOutput?: ChatMessageData;
|
||||
isFinalMessage?: boolean;
|
||||
}
|
||||
|
||||
export function ChatMessage({
|
||||
message,
|
||||
messages = [],
|
||||
index = -1,
|
||||
isStreaming = false,
|
||||
className,
|
||||
onDismissCredentials,
|
||||
onSendMessage,
|
||||
agentOutput,
|
||||
isFinalMessage = true,
|
||||
}: ChatMessageProps) {
|
||||
const { user } = useSupabase();
|
||||
const router = useRouter();
|
||||
@@ -71,7 +54,14 @@ export function ChatMessage({
|
||||
isLoginNeeded,
|
||||
isCredentialsNeeded,
|
||||
} = useChatMessage(message);
|
||||
const displayContent = getDisplayContent(message, isUser);
|
||||
|
||||
const { data: profile } = useGetV2GetUserProfile({
|
||||
query: {
|
||||
select: (res) => (res.status === 200 ? res.data : null),
|
||||
enabled: isUser && !!user,
|
||||
queryKey: ["/api/store/profile", user?.id],
|
||||
},
|
||||
});
|
||||
|
||||
const handleAllCredentialsComplete = useCallback(
|
||||
function handleAllCredentialsComplete() {
|
||||
@@ -97,25 +87,17 @@ export function ChatMessage({
|
||||
}
|
||||
}
|
||||
|
||||
const handleCopy = useCallback(
|
||||
async function handleCopy() {
|
||||
if (message.type !== "message") return;
|
||||
if (!displayContent) return;
|
||||
const handleCopy = useCallback(async () => {
|
||||
if (message.type !== "message") return;
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(displayContent);
|
||||
setCopied(true);
|
||||
setTimeout(() => setCopied(false), 2000);
|
||||
} catch (error) {
|
||||
console.error("Failed to copy:", error);
|
||||
}
|
||||
},
|
||||
[displayContent, message],
|
||||
);
|
||||
|
||||
function isLongResponse(content: string): boolean {
|
||||
return content.split("\n").length > 5;
|
||||
}
|
||||
try {
|
||||
await navigator.clipboard.writeText(message.content);
|
||||
setCopied(true);
|
||||
setTimeout(() => setCopied(false), 2000);
|
||||
} catch (error) {
|
||||
console.error("Failed to copy:", error);
|
||||
}
|
||||
}, [message]);
|
||||
|
||||
const handleTryAgain = useCallback(() => {
|
||||
if (message.type !== "message" || !onSendMessage) return;
|
||||
@@ -187,45 +169,9 @@ export function ChatMessage({
|
||||
|
||||
// Render tool call messages
|
||||
if (isToolCall && message.type === "tool_call") {
|
||||
// Check if this tool call is currently streaming
|
||||
// A tool call is streaming if:
|
||||
// 1. isStreaming is true
|
||||
// 2. This is the last tool_call message
|
||||
// 3. There's no tool_response for this tool call yet
|
||||
const isToolCallStreaming =
|
||||
isStreaming &&
|
||||
index >= 0 &&
|
||||
(() => {
|
||||
// Find the last tool_call index
|
||||
let lastToolCallIndex = -1;
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
if (messages[i].type === "tool_call") {
|
||||
lastToolCallIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Check if this is the last tool_call and there's no response yet
|
||||
if (index === lastToolCallIndex) {
|
||||
// Check if there's a tool_response for this tool call
|
||||
const hasResponse = messages
|
||||
.slice(index + 1)
|
||||
.some(
|
||||
(msg) =>
|
||||
msg.type === "tool_response" && msg.toolId === message.toolId,
|
||||
);
|
||||
return !hasResponse;
|
||||
}
|
||||
return false;
|
||||
})();
|
||||
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<ToolCallMessage
|
||||
toolId={message.toolId}
|
||||
toolName={message.toolName}
|
||||
arguments={message.arguments}
|
||||
isStreaming={isToolCallStreaming}
|
||||
/>
|
||||
<ToolCallMessage toolName={message.toolName} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -272,11 +218,27 @@ export function ChatMessage({
|
||||
|
||||
// Render tool response messages (but skip agent_output if it's being rendered inside assistant message)
|
||||
if (isToolResponse && message.type === "tool_response") {
|
||||
// Check if this is an agent_output that should be rendered inside assistant message
|
||||
if (message.result) {
|
||||
let parsedResult: Record<string, unknown> | null = null;
|
||||
try {
|
||||
parsedResult =
|
||||
typeof message.result === "string"
|
||||
? JSON.parse(message.result)
|
||||
: (message.result as Record<string, unknown>);
|
||||
} catch {
|
||||
parsedResult = null;
|
||||
}
|
||||
if (parsedResult?.type === "agent_output") {
|
||||
// Skip rendering - this will be rendered inside the assistant message
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<ToolResponseMessage
|
||||
toolId={message.toolId}
|
||||
toolName={message.toolName}
|
||||
toolName={getToolActionPhrase(message.toolName)}
|
||||
result={message.result}
|
||||
/>
|
||||
</div>
|
||||
@@ -294,33 +256,40 @@ export function ChatMessage({
|
||||
)}
|
||||
>
|
||||
<div className="flex w-full max-w-3xl gap-3">
|
||||
{!isUser && (
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-indigo-500">
|
||||
<RobotIcon className="h-4 w-4 text-indigo-50" />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div
|
||||
className={cn(
|
||||
"flex min-w-0 flex-1 flex-col",
|
||||
isUser && "items-end",
|
||||
)}
|
||||
>
|
||||
{isUser ? (
|
||||
<UserChatBubble>
|
||||
<MarkdownContent content={displayContent} />
|
||||
</UserChatBubble>
|
||||
) : (
|
||||
<AIChatBubble>
|
||||
<MarkdownContent content={displayContent} />
|
||||
{agentOutput && agentOutput.type === "tool_response" && (
|
||||
<MessageBubble variant={isUser ? "user" : "assistant"}>
|
||||
<MarkdownContent content={message.content} />
|
||||
{agentOutput &&
|
||||
agentOutput.type === "tool_response" &&
|
||||
!isUser && (
|
||||
<div className="mt-4">
|
||||
<ToolResponseMessage
|
||||
toolId={agentOutput.toolId}
|
||||
toolName={agentOutput.toolName || "Agent Output"}
|
||||
toolName={
|
||||
agentOutput.toolName
|
||||
? getToolActionPhrase(agentOutput.toolName)
|
||||
: "Agent Output"
|
||||
}
|
||||
result={agentOutput.result}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</AIChatBubble>
|
||||
)}
|
||||
</MessageBubble>
|
||||
<div
|
||||
className={cn(
|
||||
"flex gap-0",
|
||||
"mt-1 flex gap-1",
|
||||
isUser ? "justify-end" : "justify-start",
|
||||
)}
|
||||
>
|
||||
@@ -331,25 +300,37 @@ export function ChatMessage({
|
||||
onClick={handleTryAgain}
|
||||
aria-label="Try again"
|
||||
>
|
||||
<ArrowsClockwiseIcon className="size-4 text-zinc-600" />
|
||||
</Button>
|
||||
)}
|
||||
{!isUser && isFinalMessage && isLongResponse(displayContent) && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={handleCopy}
|
||||
aria-label="Copy message"
|
||||
>
|
||||
{copied ? (
|
||||
<CheckIcon className="size-4 text-green-600" />
|
||||
) : (
|
||||
<CopyIcon className="size-4 text-zinc-600" />
|
||||
)}
|
||||
<ArrowClockwise className="size-3 text-neutral-500" />
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={handleCopy}
|
||||
aria-label="Copy message"
|
||||
>
|
||||
{copied ? (
|
||||
<CheckIcon className="size-3 text-green-600" />
|
||||
) : (
|
||||
<CopyIcon className="size-3 text-neutral-500" />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{isUser && (
|
||||
<div className="flex-shrink-0">
|
||||
<Avatar className="h-7 w-7">
|
||||
<AvatarImage
|
||||
src={profile?.avatar_url ?? ""}
|
||||
alt={profile?.username ?? "User"}
|
||||
/>
|
||||
<AvatarFallback className="rounded-lg bg-neutral-200 text-neutral-600">
|
||||
{profile?.username?.charAt(0)?.toUpperCase() || "U"}
|
||||
</AvatarFallback>
|
||||
</Avatar>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
@@ -13,9 +13,10 @@ export function MessageBubble({
|
||||
className,
|
||||
}: MessageBubbleProps) {
|
||||
const userTheme = {
|
||||
bg: "bg-purple-100",
|
||||
border: "border-purple-100",
|
||||
text: "text-slate-900",
|
||||
bg: "bg-slate-900",
|
||||
border: "border-slate-800",
|
||||
gradient: "from-slate-900/30 via-slate-800/20 to-transparent",
|
||||
text: "text-slate-50",
|
||||
};
|
||||
|
||||
const assistantTheme = {
|
||||
@@ -39,7 +40,9 @@ export function MessageBubble({
|
||||
)}
|
||||
>
|
||||
{/* Gradient flare background */}
|
||||
<div className={cn("absolute inset-0 bg-gradient-to-br")} />
|
||||
<div
|
||||
className={cn("absolute inset-0 bg-gradient-to-br", theme.gradient)}
|
||||
/>
|
||||
<div
|
||||
className={cn(
|
||||
"relative z-10 transition-all duration-500 ease-in-out",
|
||||
@@ -0,0 +1,121 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ChatMessage } from "../ChatMessage/ChatMessage";
|
||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||
import { StreamingMessage } from "../StreamingMessage/StreamingMessage";
|
||||
import { ThinkingMessage } from "../ThinkingMessage/ThinkingMessage";
|
||||
import { useMessageList } from "./useMessageList";
|
||||
|
||||
export interface MessageListProps {
|
||||
messages: ChatMessageData[];
|
||||
streamingChunks?: string[];
|
||||
isStreaming?: boolean;
|
||||
className?: string;
|
||||
onStreamComplete?: () => void;
|
||||
onSendMessage?: (content: string) => void;
|
||||
}
|
||||
|
||||
export function MessageList({
|
||||
messages,
|
||||
streamingChunks = [],
|
||||
isStreaming = false,
|
||||
className,
|
||||
onStreamComplete,
|
||||
onSendMessage,
|
||||
}: MessageListProps) {
|
||||
const { messagesEndRef, messagesContainerRef } = useMessageList({
|
||||
messageCount: messages.length,
|
||||
isStreaming,
|
||||
});
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={messagesContainerRef}
|
||||
className={cn(
|
||||
"flex-1 overflow-y-auto",
|
||||
"scrollbar-thin scrollbar-track-transparent scrollbar-thumb-zinc-300",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="mx-auto flex max-w-3xl flex-col py-4">
|
||||
{/* Render all persisted messages */}
|
||||
{messages.map((message, index) => {
|
||||
// Check if current message is an agent_output tool_response
|
||||
// and if previous message is an assistant message
|
||||
let agentOutput: ChatMessageData | undefined;
|
||||
|
||||
if (message.type === "tool_response" && message.result) {
|
||||
let parsedResult: Record<string, unknown> | null = null;
|
||||
try {
|
||||
parsedResult =
|
||||
typeof message.result === "string"
|
||||
? JSON.parse(message.result)
|
||||
: (message.result as Record<string, unknown>);
|
||||
} catch {
|
||||
parsedResult = null;
|
||||
}
|
||||
if (parsedResult?.type === "agent_output") {
|
||||
const prevMessage = messages[index - 1];
|
||||
if (
|
||||
prevMessage &&
|
||||
prevMessage.type === "message" &&
|
||||
prevMessage.role === "assistant"
|
||||
) {
|
||||
// This agent output will be rendered inside the previous assistant message
|
||||
// Skip rendering this message separately
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if next message is an agent_output tool_response to include in current assistant message
|
||||
if (message.type === "message" && message.role === "assistant") {
|
||||
const nextMessage = messages[index + 1];
|
||||
if (
|
||||
nextMessage &&
|
||||
nextMessage.type === "tool_response" &&
|
||||
nextMessage.result
|
||||
) {
|
||||
let parsedResult: Record<string, unknown> | null = null;
|
||||
try {
|
||||
parsedResult =
|
||||
typeof nextMessage.result === "string"
|
||||
? JSON.parse(nextMessage.result)
|
||||
: (nextMessage.result as Record<string, unknown>);
|
||||
} catch {
|
||||
parsedResult = null;
|
||||
}
|
||||
if (parsedResult?.type === "agent_output") {
|
||||
agentOutput = nextMessage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<ChatMessage
|
||||
key={index}
|
||||
message={message}
|
||||
onSendMessage={onSendMessage}
|
||||
agentOutput={agentOutput}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Render thinking message when streaming but no chunks yet */}
|
||||
{isStreaming && streamingChunks.length === 0 && <ThinkingMessage />}
|
||||
|
||||
{/* Render streaming message if active */}
|
||||
{isStreaming && streamingChunks.length > 0 && (
|
||||
<StreamingMessage
|
||||
chunks={streamingChunks}
|
||||
onComplete={onStreamComplete}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Invisible div to scroll to */}
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -81,9 +81,9 @@ export function SessionsDrawer({
|
||||
</Text>
|
||||
</div>
|
||||
) : sessions.length === 0 ? (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<div className="flex items-center justify-center py-8">
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
You don't have previously started chats
|
||||
No sessions found
|
||||
</Text>
|
||||
</div>
|
||||
) : (
|
||||
@@ -1,6 +1,7 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
||||
import { RobotIcon } from "@phosphor-icons/react";
|
||||
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
||||
import { MessageBubble } from "../MessageBubble/MessageBubble";
|
||||
import { useStreamingMessage } from "./useStreamingMessage";
|
||||
|
||||
export interface StreamingMessageProps {
|
||||
@@ -24,10 +25,16 @@ export function StreamingMessage({
|
||||
)}
|
||||
>
|
||||
<div className="flex w-full max-w-3xl gap-3">
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-indigo-600">
|
||||
<RobotIcon className="h-4 w-4 text-indigo-50" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex min-w-0 flex-1 flex-col">
|
||||
<AIChatBubble>
|
||||
<MessageBubble variant="assistant">
|
||||
<MarkdownContent content={displayText} />
|
||||
</AIChatBubble>
|
||||
</MessageBubble>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -1,7 +1,7 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
import { RobotIcon } from "@phosphor-icons/react";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
||||
import { ChatLoader } from "../ChatLoader/ChatLoader";
|
||||
import { MessageBubble } from "../MessageBubble/MessageBubble";
|
||||
|
||||
export interface ThinkingMessageProps {
|
||||
className?: string;
|
||||
@@ -34,11 +34,22 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
||||
)}
|
||||
>
|
||||
<div className="flex w-full max-w-3xl gap-3">
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-indigo-500">
|
||||
<RobotIcon className="h-4 w-4 text-indigo-50" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex min-w-0 flex-1 flex-col">
|
||||
<AIChatBubble>
|
||||
<MessageBubble variant="assistant">
|
||||
<div className="transition-all duration-500 ease-in-out">
|
||||
{showSlowLoader ? (
|
||||
<ChatLoader />
|
||||
<div className="flex flex-col items-center gap-3 py-2">
|
||||
<div className="loader" style={{ flexShrink: 0 }} />
|
||||
<p className="text-sm text-slate-700">
|
||||
Taking a bit longer to think, wait a moment please
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<span
|
||||
className="inline-block bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-clip-text text-transparent"
|
||||
@@ -51,7 +62,7 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</AIChatBubble>
|
||||
</MessageBubble>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -0,0 +1,24 @@
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { WrenchIcon } from "@phosphor-icons/react";
|
||||
import { getToolActionPhrase } from "../../helpers";
|
||||
|
||||
export interface ToolCallMessageProps {
|
||||
toolName: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ToolCallMessage({ toolName, className }: ToolCallMessageProps) {
|
||||
return (
|
||||
<div className={cn("flex items-center justify-center gap-2", className)}>
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}...
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import "@/components/contextual/OutputRenderers";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { ToolResult } from "@/types/chat";
|
||||
import { WrenchIcon } from "@phosphor-icons/react";
|
||||
import { getToolActionPhrase } from "../../helpers";
|
||||
|
||||
export interface ToolResponseMessageProps {
|
||||
toolName: string;
|
||||
result?: ToolResult;
|
||||
success?: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ToolResponseMessage({
|
||||
toolName,
|
||||
result,
|
||||
success: _success = true,
|
||||
className,
|
||||
}: ToolResponseMessageProps) {
|
||||
if (!result) {
|
||||
return (
|
||||
<div className={cn("flex items-center justify-center gap-2", className)}>
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}...
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
let parsedResult: Record<string, unknown> | null = null;
|
||||
try {
|
||||
parsedResult =
|
||||
typeof result === "string"
|
||||
? JSON.parse(result)
|
||||
: (result as Record<string, unknown>);
|
||||
} catch {
|
||||
parsedResult = null;
|
||||
}
|
||||
|
||||
if (parsedResult && typeof parsedResult === "object") {
|
||||
const responseType = parsedResult.type as string | undefined;
|
||||
|
||||
if (responseType === "agent_output") {
|
||||
const execution = parsedResult.execution as
|
||||
| {
|
||||
outputs?: Record<string, unknown[]>;
|
||||
}
|
||||
| null
|
||||
| undefined;
|
||||
const outputs = execution?.outputs || {};
|
||||
const message = parsedResult.message as string | undefined;
|
||||
|
||||
return (
|
||||
<div className={cn("space-y-4 px-4 py-2", className)}>
|
||||
<div className="flex items-center gap-2">
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}
|
||||
</Text>
|
||||
</div>
|
||||
{message && (
|
||||
<div className="rounded border p-4">
|
||||
<Text variant="small" className="text-neutral-600">
|
||||
{message}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
{Object.keys(outputs).length > 0 && (
|
||||
<div className="space-y-4">
|
||||
{Object.entries(outputs).map(([outputName, values]) =>
|
||||
values.map((value, index) => {
|
||||
const renderer = globalRegistry.getRenderer(value);
|
||||
if (renderer) {
|
||||
return (
|
||||
<OutputItem
|
||||
key={`${outputName}-${index}`}
|
||||
value={value}
|
||||
renderer={renderer}
|
||||
label={outputName}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div
|
||||
key={`${outputName}-${index}`}
|
||||
className="rounded border p-4"
|
||||
>
|
||||
<Text variant="large-medium" className="mb-2 capitalize">
|
||||
{outputName}
|
||||
</Text>
|
||||
<pre className="overflow-auto text-sm">
|
||||
{JSON.stringify(value, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
);
|
||||
}),
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (responseType === "block_output" && parsedResult.outputs) {
|
||||
const outputs = parsedResult.outputs as Record<string, unknown[]>;
|
||||
|
||||
return (
|
||||
<div className={cn("space-y-4 px-4 py-2", className)}>
|
||||
<div className="flex items-center gap-2">
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="space-y-4">
|
||||
{Object.entries(outputs).map(([outputName, values]) =>
|
||||
values.map((value, index) => {
|
||||
const renderer = globalRegistry.getRenderer(value);
|
||||
if (renderer) {
|
||||
return (
|
||||
<OutputItem
|
||||
key={`${outputName}-${index}`}
|
||||
value={value}
|
||||
renderer={renderer}
|
||||
label={outputName}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div
|
||||
key={`${outputName}-${index}`}
|
||||
className="rounded border p-4"
|
||||
>
|
||||
<Text variant="large-medium" className="mb-2 capitalize">
|
||||
{outputName}
|
||||
</Text>
|
||||
<pre className="overflow-auto text-sm">
|
||||
{JSON.stringify(value, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
);
|
||||
}),
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Handle other response types with a message field (e.g., understanding_updated)
|
||||
if (parsedResult.message && typeof parsedResult.message === "string") {
|
||||
// Format tool name from snake_case to Title Case
|
||||
const formattedToolName = toolName
|
||||
.split("_")
|
||||
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join(" ");
|
||||
|
||||
// Clean up message - remove incomplete user_name references
|
||||
let cleanedMessage = parsedResult.message;
|
||||
// Remove "Updated understanding with: user_name" pattern if user_name is just a placeholder
|
||||
cleanedMessage = cleanedMessage.replace(
|
||||
/Updated understanding with:\s*user_name\.?\s*/gi,
|
||||
"",
|
||||
);
|
||||
// Remove standalone user_name references
|
||||
cleanedMessage = cleanedMessage.replace(/\buser_name\b\.?\s*/gi, "");
|
||||
cleanedMessage = cleanedMessage.trim();
|
||||
|
||||
// Only show message if it has content after cleaning
|
||||
if (!cleanedMessage) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center gap-2 px-4 py-2",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{formattedToolName}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("space-y-2 px-4 py-2", className)}>
|
||||
<div className="flex items-center justify-center gap-2">
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{formattedToolName}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="rounded border p-4">
|
||||
<Text variant="small" className="text-neutral-600">
|
||||
{cleanedMessage}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const renderer = globalRegistry.getRenderer(result);
|
||||
if (renderer) {
|
||||
return (
|
||||
<div className={cn("px-4 py-2", className)}>
|
||||
<div className="mb-2 flex items-center gap-2">
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}
|
||||
</Text>
|
||||
</div>
|
||||
<OutputItem value={result} renderer={renderer} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("flex items-center justify-center gap-2", className)}>
|
||||
<WrenchIcon
|
||||
size={14}
|
||||
weight="bold"
|
||||
className="flex-shrink-0 text-neutral-500"
|
||||
/>
|
||||
<Text variant="small" className="text-neutral-500">
|
||||
{getToolActionPhrase(toolName)}...
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
/**
|
||||
* Maps internal tool names to user-friendly display names with emojis.
|
||||
* @deprecated Use getToolActionPhrase or getToolCompletionPhrase for status messages
|
||||
*
|
||||
* @param toolName - The internal tool name from the backend
|
||||
* @returns A user-friendly display name with an emoji prefix
|
||||
*/
|
||||
export function getToolDisplayName(toolName: string): string {
|
||||
const toolDisplayNames: Record<string, string> = {
|
||||
find_agent: "🔍 Search Marketplace",
|
||||
get_agent_details: "📋 Get Agent Details",
|
||||
check_credentials: "🔑 Check Credentials",
|
||||
setup_agent: "⚙️ Setup Agent",
|
||||
run_agent: "▶️ Run Agent",
|
||||
get_required_setup_info: "📝 Get Setup Requirements",
|
||||
};
|
||||
return toolDisplayNames[toolName] || toolName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps internal tool names to human-friendly action phrases (present continuous).
|
||||
* Used for tool call messages to indicate what action is currently happening.
|
||||
*
|
||||
* @param toolName - The internal tool name from the backend
|
||||
* @returns A human-friendly action phrase in present continuous tense
|
||||
*/
|
||||
export function getToolActionPhrase(toolName: string): string {
|
||||
const toolActionPhrases: Record<string, string> = {
|
||||
find_agent: "Looking for agents in the marketplace",
|
||||
agent_carousel: "Looking for agents in the marketplace",
|
||||
get_agent_details: "Learning about the agent",
|
||||
check_credentials: "Checking your credentials",
|
||||
setup_agent: "Setting up the agent",
|
||||
execution_started: "Running the agent",
|
||||
run_agent: "Running the agent",
|
||||
get_required_setup_info: "Getting setup requirements",
|
||||
schedule_agent: "Scheduling the agent to run",
|
||||
};
|
||||
|
||||
// Return mapped phrase or generate human-friendly fallback
|
||||
return toolActionPhrases[toolName] || toolName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps internal tool names to human-friendly completion phrases (past tense).
|
||||
* Used for tool response messages to indicate what action was completed.
|
||||
*
|
||||
* @param toolName - The internal tool name from the backend
|
||||
* @returns A human-friendly completion phrase in past tense
|
||||
*/
|
||||
export function getToolCompletionPhrase(toolName: string): string {
|
||||
const toolCompletionPhrases: Record<string, string> = {
|
||||
find_agent: "Finished searching the marketplace",
|
||||
get_agent_details: "Got agent details",
|
||||
check_credentials: "Checked credentials",
|
||||
setup_agent: "Agent setup complete",
|
||||
run_agent: "Agent execution started",
|
||||
get_required_setup_info: "Got setup requirements",
|
||||
};
|
||||
|
||||
// Return mapped phrase or generate human-friendly fallback
|
||||
return (
|
||||
toolCompletionPhrases[toolName] ||
|
||||
`Finished ${toolName.replace(/_/g, " ").replace("...", "")}`
|
||||
);
|
||||
}
|
||||
@@ -1,20 +1,17 @@
|
||||
"use client";
|
||||
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { useChatSession } from "./useChatSession";
|
||||
import { useChatStream } from "./useChatStream";
|
||||
|
||||
interface UseChatArgs {
|
||||
urlSessionId?: string | null;
|
||||
}
|
||||
|
||||
export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
||||
export function useChat() {
|
||||
const hasCreatedSessionRef = useRef(false);
|
||||
const hasClaimedSessionRef = useRef(false);
|
||||
const { user } = useSupabase();
|
||||
const { sendMessage: sendStreamMessage } = useChatStream();
|
||||
const [showLoader, setShowLoader] = useState(false);
|
||||
|
||||
const {
|
||||
session,
|
||||
sessionId: sessionIdFromHook,
|
||||
@@ -22,16 +19,27 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
||||
isLoading,
|
||||
isCreating,
|
||||
error,
|
||||
isSessionNotFound,
|
||||
createSession,
|
||||
claimSession,
|
||||
clearSession: clearSessionBase,
|
||||
loadSession,
|
||||
} = useChatSession({
|
||||
urlSessionId,
|
||||
urlSessionId: null,
|
||||
autoCreate: false,
|
||||
});
|
||||
|
||||
useEffect(
|
||||
function autoCreateSession() {
|
||||
if (!hasCreatedSessionRef.current && !isCreating && !sessionIdFromHook) {
|
||||
hasCreatedSessionRef.current = true;
|
||||
createSession().catch((_err) => {
|
||||
hasCreatedSessionRef.current = false;
|
||||
});
|
||||
}
|
||||
},
|
||||
[isCreating, sessionIdFromHook, createSession],
|
||||
);
|
||||
|
||||
useEffect(
|
||||
function autoClaimSession() {
|
||||
if (
|
||||
@@ -67,17 +75,6 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
||||
],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoading || isCreating) {
|
||||
const timer = setTimeout(() => {
|
||||
setShowLoader(true);
|
||||
}, 300);
|
||||
return () => clearTimeout(timer);
|
||||
} else {
|
||||
setShowLoader(false);
|
||||
}
|
||||
}, [isLoading, isCreating]);
|
||||
|
||||
useEffect(function monitorNetworkStatus() {
|
||||
function handleOnline() {
|
||||
toast.success("Connection restored", {
|
||||
@@ -102,6 +99,7 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
||||
|
||||
function clearSession() {
|
||||
clearSessionBase();
|
||||
hasCreatedSessionRef.current = false;
|
||||
hasClaimedSessionRef.current = false;
|
||||
}
|
||||
|
||||
@@ -111,11 +109,9 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
||||
isLoading,
|
||||
isCreating,
|
||||
error,
|
||||
isSessionNotFound,
|
||||
createSession,
|
||||
clearSession,
|
||||
loadSession,
|
||||
sessionId: sessionIdFromHook,
|
||||
showLoader,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
import {
|
||||
getGetV2GetSessionQueryKey,
|
||||
getGetV2GetSessionQueryOptions,
|
||||
postV2CreateSession,
|
||||
useGetV2GetSession,
|
||||
usePatchV2SessionAssignUser,
|
||||
usePostV2CreateSession,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { isValidUUID } from "@/lib/utils";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
interface UseChatSessionArgs {
|
||||
urlSessionId?: string | null;
|
||||
autoCreate?: boolean;
|
||||
}
|
||||
|
||||
export function useChatSession({
|
||||
urlSessionId,
|
||||
autoCreate = false,
|
||||
}: UseChatSessionArgs = {}) {
|
||||
const queryClient = useQueryClient();
|
||||
const [sessionId, setSessionId] = useState<string | null>(null);
|
||||
const [error, setError] = useState<Error | null>(null);
|
||||
const justCreatedSessionIdRef = useRef<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (urlSessionId) {
|
||||
if (!isValidUUID(urlSessionId)) {
|
||||
console.error("Invalid session ID format:", urlSessionId);
|
||||
toast.error("Invalid session ID", {
|
||||
description:
|
||||
"The session ID in the URL is not valid. Starting a new session...",
|
||||
});
|
||||
setSessionId(null);
|
||||
storage.clean(Key.CHAT_SESSION_ID);
|
||||
return;
|
||||
}
|
||||
setSessionId(urlSessionId);
|
||||
storage.set(Key.CHAT_SESSION_ID, urlSessionId);
|
||||
} else {
|
||||
const storedSessionId = storage.get(Key.CHAT_SESSION_ID);
|
||||
if (storedSessionId) {
|
||||
if (!isValidUUID(storedSessionId)) {
|
||||
console.error("Invalid stored session ID:", storedSessionId);
|
||||
storage.clean(Key.CHAT_SESSION_ID);
|
||||
setSessionId(null);
|
||||
} else {
|
||||
setSessionId(storedSessionId);
|
||||
}
|
||||
} else if (autoCreate) {
|
||||
setSessionId(null);
|
||||
}
|
||||
}
|
||||
}, [urlSessionId, autoCreate]);
|
||||
|
||||
const {
|
||||
mutateAsync: createSessionMutation,
|
||||
isPending: isCreating,
|
||||
error: createError,
|
||||
} = usePostV2CreateSession();
|
||||
|
||||
const {
|
||||
data: sessionData,
|
||||
isLoading: isLoadingSession,
|
||||
error: loadError,
|
||||
refetch,
|
||||
} = useGetV2GetSession(sessionId || "", {
|
||||
query: {
|
||||
enabled: !!sessionId,
|
||||
select: okData,
|
||||
staleTime: Infinity, // Never mark as stale
|
||||
refetchOnMount: false, // Don't refetch on component mount
|
||||
refetchOnWindowFocus: false, // Don't refetch when window regains focus
|
||||
refetchOnReconnect: false, // Don't refetch when network reconnects
|
||||
retry: 1,
|
||||
},
|
||||
});
|
||||
|
||||
const { mutateAsync: claimSessionMutation } = usePatchV2SessionAssignUser();
|
||||
|
||||
const session = useMemo(() => {
|
||||
if (sessionData) return sessionData;
|
||||
|
||||
if (sessionId && justCreatedSessionIdRef.current === sessionId) {
|
||||
return {
|
||||
id: sessionId,
|
||||
user_id: null,
|
||||
messages: [],
|
||||
created_at: new Date().toISOString(),
|
||||
updated_at: new Date().toISOString(),
|
||||
} as SessionDetailResponse;
|
||||
}
|
||||
return null;
|
||||
}, [sessionData, sessionId]);
|
||||
|
||||
const messages = session?.messages || [];
|
||||
const isLoading = isCreating || isLoadingSession;
|
||||
|
||||
useEffect(() => {
|
||||
if (createError) {
|
||||
setError(
|
||||
createError instanceof Error
|
||||
? createError
|
||||
: new Error("Failed to create session"),
|
||||
);
|
||||
} else if (loadError) {
|
||||
setError(
|
||||
loadError instanceof Error
|
||||
? loadError
|
||||
: new Error("Failed to load session"),
|
||||
);
|
||||
} else {
|
||||
setError(null);
|
||||
}
|
||||
}, [createError, loadError]);
|
||||
|
||||
const createSession = useCallback(
|
||||
async function createSession() {
|
||||
try {
|
||||
setError(null);
|
||||
const response = await postV2CreateSession({
|
||||
body: JSON.stringify({}),
|
||||
});
|
||||
if (response.status !== 200) {
|
||||
throw new Error("Failed to create session");
|
||||
}
|
||||
const newSessionId = response.data.id;
|
||||
setSessionId(newSessionId);
|
||||
storage.set(Key.CHAT_SESSION_ID, newSessionId);
|
||||
justCreatedSessionIdRef.current = newSessionId;
|
||||
setTimeout(() => {
|
||||
if (justCreatedSessionIdRef.current === newSessionId) {
|
||||
justCreatedSessionIdRef.current = null;
|
||||
}
|
||||
}, 10000);
|
||||
return newSessionId;
|
||||
} catch (err) {
|
||||
const error =
|
||||
err instanceof Error ? err : new Error("Failed to create session");
|
||||
setError(error);
|
||||
toast.error("Failed to create chat session", {
|
||||
description: error.message,
|
||||
});
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[createSessionMutation],
|
||||
);
|
||||
|
||||
const loadSession = useCallback(
|
||||
async function loadSession(id: string) {
|
||||
try {
|
||||
setError(null);
|
||||
// Invalidate the query cache for this session to force a fresh fetch
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(id),
|
||||
});
|
||||
// Set sessionId after invalidation to ensure the hook refetches
|
||||
setSessionId(id);
|
||||
storage.set(Key.CHAT_SESSION_ID, id);
|
||||
// Force fetch with fresh data (bypass cache)
|
||||
const queryOptions = getGetV2GetSessionQueryOptions(id, {
|
||||
query: {
|
||||
staleTime: 0, // Force fresh fetch
|
||||
retry: 1,
|
||||
},
|
||||
});
|
||||
const result = await queryClient.fetchQuery(queryOptions);
|
||||
if (!result || ("status" in result && result.status !== 200)) {
|
||||
console.warn("Session not found on server, clearing local state");
|
||||
storage.clean(Key.CHAT_SESSION_ID);
|
||||
setSessionId(null);
|
||||
throw new Error("Session not found");
|
||||
}
|
||||
} catch (err) {
|
||||
const error =
|
||||
err instanceof Error ? err : new Error("Failed to load session");
|
||||
setError(error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[queryClient],
|
||||
);
|
||||
|
||||
const refreshSession = useCallback(
|
||||
async function refreshSession() {
|
||||
if (!sessionId) {
|
||||
console.log("[refreshSession] Skipping - no session ID");
|
||||
return;
|
||||
}
|
||||
try {
|
||||
setError(null);
|
||||
await refetch();
|
||||
} catch (err) {
|
||||
const error =
|
||||
err instanceof Error ? err : new Error("Failed to refresh session");
|
||||
setError(error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[sessionId, refetch],
|
||||
);
|
||||
|
||||
const claimSession = useCallback(
|
||||
async function claimSession(id: string) {
|
||||
try {
|
||||
setError(null);
|
||||
await claimSessionMutation({ sessionId: id });
|
||||
if (justCreatedSessionIdRef.current === id) {
|
||||
justCreatedSessionIdRef.current = null;
|
||||
}
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(id),
|
||||
});
|
||||
await refetch();
|
||||
toast.success("Session claimed successfully", {
|
||||
description: "Your chat history has been saved to your account",
|
||||
});
|
||||
} catch (err: unknown) {
|
||||
const error =
|
||||
err instanceof Error ? err : new Error("Failed to claim session");
|
||||
const is404 =
|
||||
(typeof err === "object" &&
|
||||
err !== null &&
|
||||
"status" in err &&
|
||||
err.status === 404) ||
|
||||
(typeof err === "object" &&
|
||||
err !== null &&
|
||||
"response" in err &&
|
||||
typeof err.response === "object" &&
|
||||
err.response !== null &&
|
||||
"status" in err.response &&
|
||||
err.response.status === 404);
|
||||
if (!is404) {
|
||||
setError(error);
|
||||
toast.error("Failed to claim session", {
|
||||
description: error.message || "Unable to claim session",
|
||||
});
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[claimSessionMutation, queryClient, refetch],
|
||||
);
|
||||
|
||||
const clearSession = useCallback(function clearSession() {
|
||||
setSessionId(null);
|
||||
setError(null);
|
||||
storage.clean(Key.CHAT_SESSION_ID);
|
||||
justCreatedSessionIdRef.current = null;
|
||||
}, []);
|
||||
|
||||
return {
|
||||
session,
|
||||
sessionId,
|
||||
messages,
|
||||
isLoading,
|
||||
isCreating,
|
||||
error,
|
||||
createSession,
|
||||
loadSession,
|
||||
refreshSession,
|
||||
claimSession,
|
||||
clearSession,
|
||||
};
|
||||
}
|
||||
@@ -21,8 +21,6 @@ export interface StreamChunk {
|
||||
timestamp?: string;
|
||||
content?: string;
|
||||
message?: string;
|
||||
code?: string;
|
||||
details?: Record<string, unknown>;
|
||||
tool_id?: string;
|
||||
tool_name?: string;
|
||||
arguments?: ToolArguments;
|
||||
@@ -140,18 +138,8 @@ function normalizeStreamChunk(
|
||||
return { type: "stream_end" };
|
||||
case "start":
|
||||
case "text-start":
|
||||
return null;
|
||||
case "tool-input-start":
|
||||
const toolInputStart = chunk as Extract<
|
||||
VercelStreamChunk,
|
||||
{ type: "tool-input-start" }
|
||||
>;
|
||||
return {
|
||||
type: "tool_call_start",
|
||||
tool_id: toolInputStart.toolCallId,
|
||||
tool_name: toolInputStart.toolName,
|
||||
arguments: {},
|
||||
};
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,103 +149,22 @@ export function useChatStream() {
|
||||
const retryCountRef = useRef<number>(0);
|
||||
const retryTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const currentSessionIdRef = useRef<string | null>(null);
|
||||
const requestStartTimeRef = useRef<number | null>(null);
|
||||
|
||||
const stopStreaming = useCallback(
|
||||
(sessionId?: string, force: boolean = false) => {
|
||||
console.log("[useChatStream] stopStreaming called", {
|
||||
hasAbortController: !!abortControllerRef.current,
|
||||
isAborted: abortControllerRef.current?.signal.aborted,
|
||||
currentSessionId: currentSessionIdRef.current,
|
||||
requestedSessionId: sessionId,
|
||||
requestStartTime: requestStartTimeRef.current,
|
||||
timeSinceStart: requestStartTimeRef.current
|
||||
? Date.now() - requestStartTimeRef.current
|
||||
: null,
|
||||
force,
|
||||
stack: new Error().stack,
|
||||
});
|
||||
|
||||
if (
|
||||
sessionId &&
|
||||
currentSessionIdRef.current &&
|
||||
currentSessionIdRef.current !== sessionId
|
||||
) {
|
||||
console.log(
|
||||
"[useChatStream] Session changed, aborting previous stream",
|
||||
{
|
||||
oldSessionId: currentSessionIdRef.current,
|
||||
newSessionId: sessionId,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
const controller = abortControllerRef.current;
|
||||
if (controller) {
|
||||
const timeSinceStart = requestStartTimeRef.current
|
||||
? Date.now() - requestStartTimeRef.current
|
||||
: null;
|
||||
|
||||
if (!force && timeSinceStart !== null && timeSinceStart < 100) {
|
||||
console.log(
|
||||
"[useChatStream] Request just started (<100ms), skipping abort to prevent race condition",
|
||||
{
|
||||
timeSinceStart,
|
||||
},
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const signal = controller.signal;
|
||||
|
||||
if (
|
||||
signal &&
|
||||
typeof signal.aborted === "boolean" &&
|
||||
!signal.aborted
|
||||
) {
|
||||
console.log("[useChatStream] Aborting stream");
|
||||
controller.abort();
|
||||
} else {
|
||||
console.log(
|
||||
"[useChatStream] Stream already aborted or signal invalid",
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === "AbortError") {
|
||||
console.log(
|
||||
"[useChatStream] AbortError caught (expected during cleanup)",
|
||||
);
|
||||
} else {
|
||||
console.warn("[useChatStream] Error aborting stream:", error);
|
||||
}
|
||||
} finally {
|
||||
abortControllerRef.current = null;
|
||||
requestStartTimeRef.current = null;
|
||||
}
|
||||
}
|
||||
if (retryTimeoutRef.current) {
|
||||
clearTimeout(retryTimeoutRef.current);
|
||||
retryTimeoutRef.current = null;
|
||||
}
|
||||
setIsStreaming(false);
|
||||
},
|
||||
[],
|
||||
);
|
||||
const stopStreaming = useCallback(() => {
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
if (retryTimeoutRef.current) {
|
||||
clearTimeout(retryTimeoutRef.current);
|
||||
retryTimeoutRef.current = null;
|
||||
}
|
||||
setIsStreaming(false);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
console.log("[useChatStream] Component mounted");
|
||||
return () => {
|
||||
const sessionIdAtUnmount = currentSessionIdRef.current;
|
||||
console.log(
|
||||
"[useChatStream] Component unmounting, calling stopStreaming",
|
||||
{
|
||||
sessionIdAtUnmount,
|
||||
},
|
||||
);
|
||||
stopStreaming(undefined, false);
|
||||
currentSessionIdRef.current = null;
|
||||
stopStreaming();
|
||||
};
|
||||
}, [stopStreaming]);
|
||||
|
||||
@@ -270,32 +177,12 @@ export function useChatStream() {
|
||||
context?: { url: string; content: string },
|
||||
isRetry: boolean = false,
|
||||
) => {
|
||||
console.log("[useChatStream] sendMessage called", {
|
||||
sessionId,
|
||||
message: message.substring(0, 50),
|
||||
isUserMessage,
|
||||
isRetry,
|
||||
stack: new Error().stack,
|
||||
});
|
||||
|
||||
const previousSessionId = currentSessionIdRef.current;
|
||||
stopStreaming(sessionId, true);
|
||||
currentSessionIdRef.current = sessionId;
|
||||
stopStreaming();
|
||||
|
||||
const abortController = new AbortController();
|
||||
abortControllerRef.current = abortController;
|
||||
requestStartTimeRef.current = Date.now();
|
||||
console.log("[useChatStream] Created new AbortController", {
|
||||
sessionId,
|
||||
previousSessionId,
|
||||
requestStartTime: requestStartTimeRef.current,
|
||||
});
|
||||
|
||||
if (abortController.signal.aborted) {
|
||||
console.warn(
|
||||
"[useChatStream] AbortController was aborted before request started",
|
||||
);
|
||||
requestStartTimeRef.current = null;
|
||||
return Promise.reject(new Error("Request aborted"));
|
||||
}
|
||||
|
||||
@@ -323,34 +210,18 @@ export function useChatStream() {
|
||||
signal: abortController.signal,
|
||||
});
|
||||
|
||||
console.info("[useChatStream] Stream response", {
|
||||
sessionId,
|
||||
status: response.status,
|
||||
ok: response.ok,
|
||||
contentType: response.headers.get("content-type"),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.warn("[useChatStream] Stream response error", {
|
||||
sessionId,
|
||||
status: response.status,
|
||||
errorText,
|
||||
});
|
||||
throw new Error(errorText || `HTTP ${response.status}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
console.warn("[useChatStream] Response body is null", { sessionId });
|
||||
throw new Error("Response body is null");
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
let receivedChunkCount = 0;
|
||||
let firstChunkAt: number | null = null;
|
||||
let loggedLineCount = 0;
|
||||
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
let didDispatchStreamEnd = false;
|
||||
@@ -374,13 +245,6 @@ export function useChatStream() {
|
||||
|
||||
if (done) {
|
||||
cleanup();
|
||||
console.info("[useChatStream] Stream closed", {
|
||||
sessionId,
|
||||
receivedChunkCount,
|
||||
timeSinceStart: requestStartTimeRef.current
|
||||
? Date.now() - requestStartTimeRef.current
|
||||
: null,
|
||||
});
|
||||
dispatchStreamEnd();
|
||||
retryCountRef.current = 0;
|
||||
stopStreaming();
|
||||
@@ -395,23 +259,8 @@ export function useChatStream() {
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ")) {
|
||||
const data = line.slice(6);
|
||||
if (loggedLineCount < 3) {
|
||||
console.info("[useChatStream] Raw stream line", {
|
||||
sessionId,
|
||||
data:
|
||||
data.length > 300 ? `${data.slice(0, 300)}...` : data,
|
||||
});
|
||||
loggedLineCount += 1;
|
||||
}
|
||||
if (data === "[DONE]") {
|
||||
cleanup();
|
||||
console.info("[useChatStream] Stream done marker", {
|
||||
sessionId,
|
||||
receivedChunkCount,
|
||||
timeSinceStart: requestStartTimeRef.current
|
||||
? Date.now() - requestStartTimeRef.current
|
||||
: null,
|
||||
});
|
||||
dispatchStreamEnd();
|
||||
retryCountRef.current = 0;
|
||||
stopStreaming();
|
||||
@@ -428,18 +277,6 @@ export function useChatStream() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!firstChunkAt) {
|
||||
firstChunkAt = Date.now();
|
||||
console.info("[useChatStream] First stream chunk", {
|
||||
sessionId,
|
||||
chunkType: chunk.type,
|
||||
timeSinceStart: requestStartTimeRef.current
|
||||
? firstChunkAt - requestStartTimeRef.current
|
||||
: null,
|
||||
});
|
||||
}
|
||||
receivedChunkCount += 1;
|
||||
|
||||
// Call the chunk handler
|
||||
onChunk(chunk);
|
||||
|
||||
@@ -447,13 +284,6 @@ export function useChatStream() {
|
||||
if (chunk.type === "stream_end") {
|
||||
didDispatchStreamEnd = true;
|
||||
cleanup();
|
||||
console.info("[useChatStream] Stream end chunk", {
|
||||
sessionId,
|
||||
receivedChunkCount,
|
||||
timeSinceStart: requestStartTimeRef.current
|
||||
? Date.now() - requestStartTimeRef.current
|
||||
: null,
|
||||
});
|
||||
retryCountRef.current = 0;
|
||||
stopStreaming();
|
||||
resolve();
|
||||
@@ -477,9 +307,6 @@ export function useChatStream() {
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
cleanup();
|
||||
dispatchStreamEnd();
|
||||
stopStreaming();
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -525,10 +352,6 @@ export function useChatStream() {
|
||||
readStream();
|
||||
});
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
setIsStreaming(false);
|
||||
return Promise.resolve();
|
||||
}
|
||||
const streamError =
|
||||
err instanceof Error ? err : new Error("Failed to start stream");
|
||||
setError(streamError);
|
||||
27
autogpt_platform/frontend/src/app/(platform)/chat/page.tsx
Normal file
27
autogpt_platform/frontend/src/app/(platform)/chat/page.tsx
Normal file
@@ -0,0 +1,27 @@
|
||||
"use client";
|
||||
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
import { Chat } from "./components/Chat/Chat";
|
||||
|
||||
export default function ChatPage() {
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const router = useRouter();
|
||||
|
||||
useEffect(() => {
|
||||
if (isChatEnabled === false) {
|
||||
router.push("/marketplace");
|
||||
}
|
||||
}, [isChatEnabled, router]);
|
||||
|
||||
if (isChatEnabled === null || isChatEnabled === false) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
<Chat className="flex-1" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
||||
import type { ReactNode } from "react";
|
||||
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
|
||||
import { LoadingState } from "./components/LoadingState/LoadingState";
|
||||
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
||||
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
||||
import { useCopilotShell } from "./useCopilotShell";
|
||||
|
||||
interface Props {
|
||||
children: ReactNode;
|
||||
}
|
||||
|
||||
export function CopilotShell({ children }: Props) {
|
||||
const {
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
isLoading,
|
||||
isLoggedIn,
|
||||
hasActiveSession,
|
||||
sessions,
|
||||
currentSessionId,
|
||||
handleSelectSession,
|
||||
handleOpenDrawer,
|
||||
handleCloseDrawer,
|
||||
handleDrawerOpenChange,
|
||||
handleNewChat,
|
||||
hasNextPage,
|
||||
isFetchingNextPage,
|
||||
fetchNextPage,
|
||||
isReadyToShowContent,
|
||||
} = useCopilotShell();
|
||||
|
||||
if (!isLoggedIn) {
|
||||
return (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<LoadingSpinner size="large" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex overflow-hidden bg-[#EFEFF0]"
|
||||
style={{ height: `calc(100vh - ${NAVBAR_HEIGHT_PX}px)` }}
|
||||
>
|
||||
{!isMobile && (
|
||||
<DesktopSidebar
|
||||
sessions={sessions}
|
||||
currentSessionId={currentSessionId}
|
||||
isLoading={isLoading}
|
||||
hasNextPage={hasNextPage}
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
onSelectSession={handleSelectSession}
|
||||
onFetchNextPage={fetchNextPage}
|
||||
onNewChat={handleNewChat}
|
||||
hasActiveSession={Boolean(hasActiveSession)}
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className="relative flex min-h-0 flex-1 flex-col">
|
||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||
<div className="flex min-h-0 flex-1 flex-col">
|
||||
{isReadyToShowContent ? children : <LoadingState />}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{isMobile && (
|
||||
<MobileDrawer
|
||||
isOpen={isDrawerOpen}
|
||||
sessions={sessions}
|
||||
currentSessionId={currentSessionId}
|
||||
isLoading={isLoading}
|
||||
hasNextPage={hasNextPage}
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
onSelectSession={handleSelectSession}
|
||||
onFetchNextPage={fetchNextPage}
|
||||
onNewChat={handleNewChat}
|
||||
onClose={handleCloseDrawer}
|
||||
onOpenChange={handleDrawerOpenChange}
|
||||
hasActiveSession={Boolean(hasActiveSession)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Plus } from "@phosphor-icons/react";
|
||||
import { SessionsList } from "../SessionsList/SessionsList";
|
||||
|
||||
interface Props {
|
||||
sessions: SessionSummaryResponse[];
|
||||
currentSessionId: string | null;
|
||||
isLoading: boolean;
|
||||
hasNextPage: boolean;
|
||||
isFetchingNextPage: boolean;
|
||||
onSelectSession: (sessionId: string) => void;
|
||||
onFetchNextPage: () => void;
|
||||
onNewChat: () => void;
|
||||
hasActiveSession: boolean;
|
||||
}
|
||||
|
||||
export function DesktopSidebar({
|
||||
sessions,
|
||||
currentSessionId,
|
||||
isLoading,
|
||||
hasNextPage,
|
||||
isFetchingNextPage,
|
||||
onSelectSession,
|
||||
onFetchNextPage,
|
||||
onNewChat,
|
||||
hasActiveSession,
|
||||
}: Props) {
|
||||
return (
|
||||
<aside className="flex h-full w-80 flex-col border-r border-zinc-100 bg-zinc-50">
|
||||
<div className="shrink-0 px-6 py-4">
|
||||
<Text variant="h3" size="body-medium">
|
||||
Your chats
|
||||
</Text>
|
||||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
|
||||
scrollbarStyles,
|
||||
)}
|
||||
>
|
||||
<SessionsList
|
||||
sessions={sessions}
|
||||
currentSessionId={currentSessionId}
|
||||
isLoading={isLoading}
|
||||
hasNextPage={hasNextPage}
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
onSelectSession={onSelectSession}
|
||||
onFetchNextPage={onFetchNextPage}
|
||||
/>
|
||||
</div>
|
||||
{hasActiveSession && (
|
||||
<div className="shrink-0 bg-zinc-50 p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={onNewChat}
|
||||
className="w-full"
|
||||
leftIcon={<Plus width="1rem" height="1rem" />}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</aside>
|
||||
);
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||
|
||||
export function LoadingState() {
|
||||
return (
|
||||
<div className="flex flex-1 items-center justify-center">
|
||||
<div className="flex flex-col items-center gap-4">
|
||||
<ChatLoader />
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
Loading your chats...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { PlusIcon, X } from "@phosphor-icons/react";
|
||||
import { Drawer } from "vaul";
|
||||
import { SessionsList } from "../SessionsList/SessionsList";
|
||||
|
||||
interface Props {
|
||||
isOpen: boolean;
|
||||
sessions: SessionSummaryResponse[];
|
||||
currentSessionId: string | null;
|
||||
isLoading: boolean;
|
||||
hasNextPage: boolean;
|
||||
isFetchingNextPage: boolean;
|
||||
onSelectSession: (sessionId: string) => void;
|
||||
onFetchNextPage: () => void;
|
||||
onNewChat: () => void;
|
||||
onClose: () => void;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
hasActiveSession: boolean;
|
||||
}
|
||||
|
||||
export function MobileDrawer({
|
||||
isOpen,
|
||||
sessions,
|
||||
currentSessionId,
|
||||
isLoading,
|
||||
hasNextPage,
|
||||
isFetchingNextPage,
|
||||
onSelectSession,
|
||||
onFetchNextPage,
|
||||
onNewChat,
|
||||
onClose,
|
||||
onOpenChange,
|
||||
hasActiveSession,
|
||||
}: Props) {
|
||||
return (
|
||||
<Drawer.Root open={isOpen} onOpenChange={onOpenChange} direction="left">
|
||||
<Drawer.Portal>
|
||||
<Drawer.Overlay className="fixed inset-0 z-[60] bg-black/10 backdrop-blur-sm" />
|
||||
<Drawer.Content className="fixed left-0 top-0 z-[70] flex h-full w-80 flex-col border-r border-zinc-200 bg-zinc-50">
|
||||
<div className="shrink-0 border-b border-zinc-200 p-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<Drawer.Title className="text-lg font-semibold text-zinc-800">
|
||||
Your chats
|
||||
</Drawer.Title>
|
||||
<Button
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label="Close sessions"
|
||||
onClick={onClose}
|
||||
>
|
||||
<X width="1.25rem" height="1.25rem" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
className={cn(
|
||||
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
|
||||
scrollbarStyles,
|
||||
)}
|
||||
>
|
||||
<SessionsList
|
||||
sessions={sessions}
|
||||
currentSessionId={currentSessionId}
|
||||
isLoading={isLoading}
|
||||
hasNextPage={hasNextPage}
|
||||
isFetchingNextPage={isFetchingNextPage}
|
||||
onSelectSession={onSelectSession}
|
||||
onFetchNextPage={onFetchNextPage}
|
||||
/>
|
||||
</div>
|
||||
{hasActiveSession && (
|
||||
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||
<Button
|
||||
variant="primary"
|
||||
size="small"
|
||||
onClick={onNewChat}
|
||||
className="w-full"
|
||||
leftIcon={<PlusIcon width="1rem" height="1rem" />}
|
||||
>
|
||||
New Chat
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</Drawer.Content>
|
||||
</Drawer.Portal>
|
||||
</Drawer.Root>
|
||||
);
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
import { useState } from "react";
|
||||
|
||||
export function useMobileDrawer() {
|
||||
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
||||
|
||||
function handleOpenDrawer() {
|
||||
setIsDrawerOpen(true);
|
||||
}
|
||||
|
||||
function handleCloseDrawer() {
|
||||
setIsDrawerOpen(false);
|
||||
}
|
||||
|
||||
function handleDrawerOpenChange(open: boolean) {
|
||||
setIsDrawerOpen(open);
|
||||
}
|
||||
|
||||
return {
|
||||
isDrawerOpen,
|
||||
handleOpenDrawer,
|
||||
handleCloseDrawer,
|
||||
handleDrawerOpenChange,
|
||||
};
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
||||
import { ListIcon } from "@phosphor-icons/react";
|
||||
|
||||
interface Props {
|
||||
onOpenDrawer: () => void;
|
||||
}
|
||||
|
||||
export function MobileHeader({ onOpenDrawer }: Props) {
|
||||
return (
|
||||
<Button
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label="Open sessions"
|
||||
onClick={onOpenDrawer}
|
||||
className="fixed z-50 bg-white shadow-md"
|
||||
style={{ left: "1rem", top: `${NAVBAR_HEIGHT_PX + 20}px` }}
|
||||
>
|
||||
<ListIcon width="1.25rem" height="1.25rem" />
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { InfiniteList } from "@/components/molecules/InfiniteList/InfiniteList";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { getSessionTitle } from "../../helpers";
|
||||
|
||||
interface Props {
|
||||
sessions: SessionSummaryResponse[];
|
||||
currentSessionId: string | null;
|
||||
isLoading: boolean;
|
||||
hasNextPage: boolean;
|
||||
isFetchingNextPage: boolean;
|
||||
onSelectSession: (sessionId: string) => void;
|
||||
onFetchNextPage: () => void;
|
||||
}
|
||||
|
||||
export function SessionsList({
|
||||
sessions,
|
||||
currentSessionId,
|
||||
isLoading,
|
||||
hasNextPage,
|
||||
isFetchingNextPage,
|
||||
onSelectSession,
|
||||
onFetchNextPage,
|
||||
}: Props) {
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="space-y-1">
|
||||
{Array.from({ length: 5 }).map((_, i) => (
|
||||
<div key={i} className="rounded-lg px-3 py-2.5">
|
||||
<Skeleton className="h-5 w-full" />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (sessions.length === 0) {
|
||||
return (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
You don't have previous chats
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<InfiniteList
|
||||
items={sessions}
|
||||
hasMore={hasNextPage}
|
||||
isFetchingMore={isFetchingNextPage}
|
||||
onEndReached={onFetchNextPage}
|
||||
className="space-y-1"
|
||||
renderItem={(session) => {
|
||||
const isActive = session.id === currentSessionId;
|
||||
return (
|
||||
<button
|
||||
onClick={() => onSelectSession(session.id)}
|
||||
className={cn(
|
||||
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
|
||||
isActive ? "bg-zinc-100" : "hover:bg-zinc-50",
|
||||
)}
|
||||
>
|
||||
<Text
|
||||
variant="body"
|
||||
className={cn(
|
||||
"font-normal",
|
||||
isActive ? "text-zinc-600" : "text-zinc-800",
|
||||
)}
|
||||
>
|
||||
{getSessionTitle(session)}
|
||||
</Text>
|
||||
</button>
|
||||
);
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
|
||||
const PAGE_SIZE = 50;
|
||||
|
||||
export interface UseSessionsPaginationArgs {
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
||||
const [offset, setOffset] = useState(0);
|
||||
const [accumulatedSessions, setAccumulatedSessions] = useState<
|
||||
SessionSummaryResponse[]
|
||||
>([]);
|
||||
const [totalCount, setTotalCount] = useState<number | null>(null);
|
||||
|
||||
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
|
||||
{ limit: PAGE_SIZE, offset },
|
||||
{
|
||||
query: {
|
||||
enabled: enabled && offset >= 0,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const responseData = okData(data);
|
||||
if (responseData) {
|
||||
const newSessions = responseData.sessions;
|
||||
const total = responseData.total;
|
||||
setTotalCount(total);
|
||||
|
||||
if (offset === 0) {
|
||||
setAccumulatedSessions(newSessions);
|
||||
} else {
|
||||
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
|
||||
}
|
||||
} else if (!enabled) {
|
||||
setAccumulatedSessions([]);
|
||||
setTotalCount(null);
|
||||
}
|
||||
}, [data, offset, enabled]);
|
||||
|
||||
const hasNextPage = useMemo(() => {
|
||||
if (totalCount === null) return false;
|
||||
return accumulatedSessions.length < totalCount;
|
||||
}, [accumulatedSessions.length, totalCount]);
|
||||
|
||||
const areAllSessionsLoaded = useMemo(() => {
|
||||
if (totalCount === null) return false;
|
||||
return (
|
||||
accumulatedSessions.length >= totalCount && !isFetching && !isLoading
|
||||
);
|
||||
}, [accumulatedSessions.length, totalCount, isFetching, isLoading]);
|
||||
|
||||
useEffect(() => {
|
||||
if (
|
||||
hasNextPage &&
|
||||
!isFetching &&
|
||||
!isLoading &&
|
||||
!isError &&
|
||||
totalCount !== null
|
||||
) {
|
||||
setOffset((prev) => prev + PAGE_SIZE);
|
||||
}
|
||||
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
|
||||
|
||||
function fetchNextPage() {
|
||||
if (hasNextPage && !isFetching) {
|
||||
setOffset((prev) => prev + PAGE_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
function reset() {
|
||||
setOffset(0);
|
||||
setAccumulatedSessions([]);
|
||||
setTotalCount(null);
|
||||
}
|
||||
|
||||
return {
|
||||
sessions: accumulatedSessions,
|
||||
isLoading,
|
||||
isFetching,
|
||||
hasNextPage,
|
||||
areAllSessionsLoaded,
|
||||
totalCount,
|
||||
fetchNextPage,
|
||||
reset,
|
||||
};
|
||||
}
|
||||
@@ -1,165 +0,0 @@
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||
import { format, formatDistanceToNow, isToday } from "date-fns";
|
||||
|
||||
export function convertSessionDetailToSummary(
|
||||
session: SessionDetailResponse,
|
||||
): SessionSummaryResponse {
|
||||
return {
|
||||
id: session.id,
|
||||
created_at: session.created_at,
|
||||
updated_at: session.updated_at,
|
||||
title: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
export function filterVisibleSessions(
|
||||
sessions: SessionSummaryResponse[],
|
||||
): SessionSummaryResponse[] {
|
||||
return sessions.filter(
|
||||
(session) => session.updated_at !== session.created_at,
|
||||
);
|
||||
}
|
||||
|
||||
export function getSessionTitle(session: SessionSummaryResponse): string {
|
||||
if (session.title) return session.title;
|
||||
const isNewSession = session.updated_at === session.created_at;
|
||||
if (isNewSession) {
|
||||
const createdDate = new Date(session.created_at);
|
||||
if (isToday(createdDate)) {
|
||||
return "Today";
|
||||
}
|
||||
return format(createdDate, "MMM d, yyyy");
|
||||
}
|
||||
return "Untitled Chat";
|
||||
}
|
||||
|
||||
export function getSessionUpdatedLabel(
|
||||
session: SessionSummaryResponse,
|
||||
): string {
|
||||
if (!session.updated_at) return "";
|
||||
return formatDistanceToNow(new Date(session.updated_at), { addSuffix: true });
|
||||
}
|
||||
|
||||
export function mergeCurrentSessionIntoList(
|
||||
accumulatedSessions: SessionSummaryResponse[],
|
||||
currentSessionId: string | null,
|
||||
currentSessionData: SessionDetailResponse | null | undefined,
|
||||
): SessionSummaryResponse[] {
|
||||
const filteredSessions: SessionSummaryResponse[] = [];
|
||||
|
||||
if (accumulatedSessions.length > 0) {
|
||||
const visibleSessions = filterVisibleSessions(accumulatedSessions);
|
||||
|
||||
if (currentSessionId) {
|
||||
const currentInAll = accumulatedSessions.find(
|
||||
(s) => s.id === currentSessionId,
|
||||
);
|
||||
if (currentInAll) {
|
||||
const isInVisible = visibleSessions.some(
|
||||
(s) => s.id === currentSessionId,
|
||||
);
|
||||
if (!isInVisible) {
|
||||
filteredSessions.push(currentInAll);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
filteredSessions.push(...visibleSessions);
|
||||
}
|
||||
|
||||
if (currentSessionId && currentSessionData) {
|
||||
const isCurrentInList = filteredSessions.some(
|
||||
(s) => s.id === currentSessionId,
|
||||
);
|
||||
if (!isCurrentInList) {
|
||||
const summarySession = convertSessionDetailToSummary(currentSessionData);
|
||||
filteredSessions.unshift(summarySession);
|
||||
}
|
||||
}
|
||||
|
||||
return filteredSessions;
|
||||
}
|
||||
|
||||
export function getCurrentSessionId(
|
||||
searchParams: URLSearchParams,
|
||||
): string | null {
|
||||
return searchParams.get("sessionId");
|
||||
}
|
||||
|
||||
export function shouldAutoSelectSession(
|
||||
areAllSessionsLoaded: boolean,
|
||||
hasAutoSelectedSession: boolean,
|
||||
paramSessionId: string | null,
|
||||
visibleSessions: SessionSummaryResponse[],
|
||||
accumulatedSessions: SessionSummaryResponse[],
|
||||
isLoading: boolean,
|
||||
totalCount: number | null,
|
||||
): {
|
||||
shouldSelect: boolean;
|
||||
sessionIdToSelect: string | null;
|
||||
shouldCreate: boolean;
|
||||
} {
|
||||
if (!areAllSessionsLoaded || hasAutoSelectedSession) {
|
||||
return {
|
||||
shouldSelect: false,
|
||||
sessionIdToSelect: null,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (paramSessionId) {
|
||||
return {
|
||||
shouldSelect: false,
|
||||
sessionIdToSelect: null,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (visibleSessions.length > 0) {
|
||||
return {
|
||||
shouldSelect: true,
|
||||
sessionIdToSelect: visibleSessions[0].id,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (accumulatedSessions.length === 0 && !isLoading && totalCount === 0) {
|
||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: true };
|
||||
}
|
||||
|
||||
if (totalCount === 0) {
|
||||
return {
|
||||
shouldSelect: false,
|
||||
sessionIdToSelect: null,
|
||||
shouldCreate: false,
|
||||
};
|
||||
}
|
||||
|
||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: false };
|
||||
}
|
||||
|
||||
export function checkReadyToShowContent(
|
||||
areAllSessionsLoaded: boolean,
|
||||
paramSessionId: string | null,
|
||||
accumulatedSessions: SessionSummaryResponse[],
|
||||
isCurrentSessionLoading: boolean,
|
||||
currentSessionData: SessionDetailResponse | null | undefined,
|
||||
hasAutoSelectedSession: boolean,
|
||||
): boolean {
|
||||
if (!areAllSessionsLoaded) return false;
|
||||
|
||||
if (paramSessionId) {
|
||||
const sessionFound = accumulatedSessions.some(
|
||||
(s) => s.id === paramSessionId,
|
||||
);
|
||||
return (
|
||||
sessionFound ||
|
||||
(!isCurrentSessionLoading &&
|
||||
currentSessionData !== undefined &&
|
||||
currentSessionData !== null)
|
||||
);
|
||||
}
|
||||
|
||||
return hasAutoSelectedSession;
|
||||
}
|
||||
@@ -1,170 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
getGetV2ListSessionsQueryKey,
|
||||
useGetV2GetSession,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { usePathname, useRouter, useSearchParams } from "next/navigation";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
||||
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
|
||||
import {
|
||||
checkReadyToShowContent,
|
||||
filterVisibleSessions,
|
||||
getCurrentSessionId,
|
||||
mergeCurrentSessionIntoList,
|
||||
} from "./helpers";
|
||||
|
||||
export function useCopilotShell() {
|
||||
const router = useRouter();
|
||||
const pathname = usePathname();
|
||||
const searchParams = useSearchParams();
|
||||
const queryClient = useQueryClient();
|
||||
const breakpoint = useBreakpoint();
|
||||
const { isLoggedIn } = useSupabase();
|
||||
const isMobile =
|
||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||
|
||||
const isOnHomepage = pathname === "/copilot";
|
||||
const paramSessionId = searchParams.get("sessionId");
|
||||
|
||||
const {
|
||||
isDrawerOpen,
|
||||
handleOpenDrawer,
|
||||
handleCloseDrawer,
|
||||
handleDrawerOpenChange,
|
||||
} = useMobileDrawer();
|
||||
|
||||
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
|
||||
|
||||
const {
|
||||
sessions: accumulatedSessions,
|
||||
isLoading: isSessionsLoading,
|
||||
isFetching: isSessionsFetching,
|
||||
hasNextPage,
|
||||
areAllSessionsLoaded,
|
||||
fetchNextPage,
|
||||
reset: resetPagination,
|
||||
} = useSessionsPagination({
|
||||
enabled: paginationEnabled,
|
||||
});
|
||||
|
||||
const currentSessionId = getCurrentSessionId(searchParams);
|
||||
|
||||
const { data: currentSessionData, isLoading: isCurrentSessionLoading } =
|
||||
useGetV2GetSession(currentSessionId || "", {
|
||||
query: {
|
||||
enabled: !!currentSessionId,
|
||||
select: okData,
|
||||
},
|
||||
});
|
||||
|
||||
const [hasAutoSelectedSession, setHasAutoSelectedSession] = useState(false);
|
||||
const hasAutoSelectedRef = useRef(false);
|
||||
|
||||
// Mark as auto-selected when sessionId is in URL
|
||||
useEffect(() => {
|
||||
if (paramSessionId && !hasAutoSelectedRef.current) {
|
||||
hasAutoSelectedRef.current = true;
|
||||
setHasAutoSelectedSession(true);
|
||||
}
|
||||
}, [paramSessionId]);
|
||||
|
||||
// On homepage without sessionId, mark as ready immediately
|
||||
useEffect(() => {
|
||||
if (isOnHomepage && !paramSessionId && !hasAutoSelectedRef.current) {
|
||||
hasAutoSelectedRef.current = true;
|
||||
setHasAutoSelectedSession(true);
|
||||
}
|
||||
}, [isOnHomepage, paramSessionId]);
|
||||
|
||||
// Invalidate sessions list when navigating to homepage (to show newly created sessions)
|
||||
useEffect(() => {
|
||||
if (isOnHomepage && !paramSessionId) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
}
|
||||
}, [isOnHomepage, paramSessionId, queryClient]);
|
||||
|
||||
// Reset pagination when query becomes disabled
|
||||
const prevPaginationEnabledRef = useRef(paginationEnabled);
|
||||
useEffect(() => {
|
||||
if (prevPaginationEnabledRef.current && !paginationEnabled) {
|
||||
resetPagination();
|
||||
resetAutoSelect();
|
||||
}
|
||||
prevPaginationEnabledRef.current = paginationEnabled;
|
||||
}, [paginationEnabled, resetPagination]);
|
||||
|
||||
const sessions = mergeCurrentSessionIntoList(
|
||||
accumulatedSessions,
|
||||
currentSessionId,
|
||||
currentSessionData,
|
||||
);
|
||||
|
||||
const visibleSessions = filterVisibleSessions(sessions);
|
||||
|
||||
const sidebarSelectedSessionId =
|
||||
isOnHomepage && !paramSessionId ? null : currentSessionId;
|
||||
|
||||
const isReadyToShowContent = isOnHomepage
|
||||
? true
|
||||
: checkReadyToShowContent(
|
||||
areAllSessionsLoaded,
|
||||
paramSessionId,
|
||||
accumulatedSessions,
|
||||
isCurrentSessionLoading,
|
||||
currentSessionData,
|
||||
hasAutoSelectedSession,
|
||||
);
|
||||
|
||||
function handleSelectSession(sessionId: string) {
|
||||
// Navigate using replaceState to avoid full page reload
|
||||
window.history.replaceState(null, "", `/copilot?sessionId=${sessionId}`);
|
||||
// Force a re-render by updating the URL through router
|
||||
router.replace(`/copilot?sessionId=${sessionId}`);
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function handleNewChat() {
|
||||
resetAutoSelect();
|
||||
resetPagination();
|
||||
// Invalidate and refetch sessions list to ensure newly created sessions appear
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
});
|
||||
window.history.replaceState(null, "", "/copilot");
|
||||
router.replace("/copilot");
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function resetAutoSelect() {
|
||||
hasAutoSelectedRef.current = false;
|
||||
setHasAutoSelectedSession(false);
|
||||
}
|
||||
|
||||
return {
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
isLoggedIn,
|
||||
hasActiveSession:
|
||||
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
|
||||
isLoading: isSessionsLoading || !areAllSessionsLoaded,
|
||||
sessions: visibleSessions,
|
||||
currentSessionId: sidebarSelectedSessionId,
|
||||
handleSelectSession,
|
||||
handleOpenDrawer,
|
||||
handleCloseDrawer,
|
||||
handleDrawerOpenChange,
|
||||
handleNewChat,
|
||||
hasNextPage,
|
||||
isFetchingNextPage: isSessionsFetching,
|
||||
fetchNextPage,
|
||||
isReadyToShowContent,
|
||||
};
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
import type { User } from "@supabase/supabase-js";
|
||||
|
||||
export function getGreetingName(user?: User | null): string {
|
||||
if (!user) return "there";
|
||||
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
||||
const fullName = metadata?.full_name;
|
||||
const name = metadata?.name;
|
||||
if (typeof fullName === "string" && fullName.trim()) {
|
||||
return fullName.split(" ")[0];
|
||||
}
|
||||
if (typeof name === "string" && name.trim()) {
|
||||
return name.split(" ")[0];
|
||||
}
|
||||
if (user.email) {
|
||||
return user.email.split("@")[0];
|
||||
}
|
||||
return "there";
|
||||
}
|
||||
|
||||
export function buildCopilotChatUrl(prompt: string): string {
|
||||
const trimmed = prompt.trim();
|
||||
if (!trimmed) return "/copilot/chat";
|
||||
const encoded = encodeURIComponent(trimmed);
|
||||
return `/copilot/chat?prompt=${encoded}`;
|
||||
}
|
||||
|
||||
export function getQuickActions(): string[] {
|
||||
return [
|
||||
"Show me what I can automate",
|
||||
"Design a custom workflow",
|
||||
"Help me with content creation",
|
||||
];
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
import type { ReactNode } from "react";
|
||||
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
|
||||
|
||||
export default function CopilotLayout({ children }: { children: ReactNode }) {
|
||||
return <CopilotShell>{children}</CopilotShell>;
|
||||
}
|
||||
@@ -1,228 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { postV2CreateSession } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import {
|
||||
Flag,
|
||||
type FlagValues,
|
||||
useGetFlag,
|
||||
} from "@/services/feature-flags/use-get-flag";
|
||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
import { getGreetingName, getQuickActions } from "./helpers";
|
||||
|
||||
type PageState =
|
||||
| { type: "welcome" }
|
||||
| { type: "creating"; prompt: string }
|
||||
| { type: "chat"; sessionId: string; initialPrompt?: string };
|
||||
|
||||
export default function CopilotPage() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
||||
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const flags = useFlags<FlagValues>();
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
||||
const isFlagReady =
|
||||
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
||||
|
||||
const [pageState, setPageState] = useState<PageState>({ type: "welcome" });
|
||||
const initialPromptRef = useRef<Map<string, string>>(new Map());
|
||||
|
||||
const urlSessionId = searchParams.get("sessionId");
|
||||
|
||||
// Sync with URL sessionId (preserve initialPrompt from ref)
|
||||
useEffect(
|
||||
function syncSessionFromUrl() {
|
||||
if (urlSessionId) {
|
||||
// If we're already in chat state with this sessionId, don't overwrite
|
||||
if (pageState.type === "chat" && pageState.sessionId === urlSessionId) {
|
||||
return;
|
||||
}
|
||||
// Get initialPrompt from ref or current state
|
||||
const storedInitialPrompt = initialPromptRef.current.get(urlSessionId);
|
||||
const currentInitialPrompt =
|
||||
storedInitialPrompt ||
|
||||
(pageState.type === "creating"
|
||||
? pageState.prompt
|
||||
: pageState.type === "chat"
|
||||
? pageState.initialPrompt
|
||||
: undefined);
|
||||
if (currentInitialPrompt) {
|
||||
initialPromptRef.current.set(urlSessionId, currentInitialPrompt);
|
||||
}
|
||||
setPageState({
|
||||
type: "chat",
|
||||
sessionId: urlSessionId,
|
||||
initialPrompt: currentInitialPrompt,
|
||||
});
|
||||
} else if (pageState.type === "chat") {
|
||||
setPageState({ type: "welcome" });
|
||||
}
|
||||
},
|
||||
[urlSessionId],
|
||||
);
|
||||
|
||||
useEffect(
|
||||
function ensureAccess() {
|
||||
if (!isFlagReady) return;
|
||||
if (isChatEnabled === false) {
|
||||
router.replace(homepageRoute);
|
||||
}
|
||||
},
|
||||
[homepageRoute, isChatEnabled, isFlagReady, router],
|
||||
);
|
||||
|
||||
const greetingName = useMemo(
|
||||
function getName() {
|
||||
return getGreetingName(user);
|
||||
},
|
||||
[user],
|
||||
);
|
||||
|
||||
const quickActions = useMemo(function getActions() {
|
||||
return getQuickActions();
|
||||
}, []);
|
||||
|
||||
async function startChatWithPrompt(prompt: string) {
|
||||
if (!prompt?.trim()) return;
|
||||
if (pageState.type === "creating") return;
|
||||
|
||||
const trimmedPrompt = prompt.trim();
|
||||
setPageState({ type: "creating", prompt: trimmedPrompt });
|
||||
|
||||
try {
|
||||
// Create session
|
||||
const sessionResponse = await postV2CreateSession({
|
||||
body: JSON.stringify({}),
|
||||
});
|
||||
|
||||
if (sessionResponse.status !== 200 || !sessionResponse.data?.id) {
|
||||
throw new Error("Failed to create session");
|
||||
}
|
||||
|
||||
const sessionId = sessionResponse.data.id;
|
||||
|
||||
// Store initialPrompt in ref so it persists across re-renders
|
||||
initialPromptRef.current.set(sessionId, trimmedPrompt);
|
||||
|
||||
// Update URL and show Chat with initial prompt
|
||||
// Chat will handle sending the message and streaming
|
||||
window.history.replaceState(null, "", `/copilot?sessionId=${sessionId}`);
|
||||
setPageState({ type: "chat", sessionId, initialPrompt: trimmedPrompt });
|
||||
} catch (error) {
|
||||
console.error("[CopilotPage] Failed to start chat:", error);
|
||||
setPageState({ type: "welcome" });
|
||||
}
|
||||
}
|
||||
|
||||
function handleQuickAction(action: string) {
|
||||
startChatWithPrompt(action);
|
||||
}
|
||||
|
||||
function handleSessionNotFound() {
|
||||
router.replace("/copilot");
|
||||
}
|
||||
|
||||
if (!isFlagReady || isChatEnabled === false || !isLoggedIn) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Show Chat when we have an active session
|
||||
if (pageState.type === "chat") {
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
<Chat
|
||||
key={pageState.sessionId ?? "welcome"}
|
||||
className="flex-1"
|
||||
urlSessionId={pageState.sessionId}
|
||||
initialPrompt={pageState.initialPrompt}
|
||||
onSessionNotFound={handleSessionNotFound}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Show loading state while creating session and sending first message
|
||||
if (pageState.type === "creating") {
|
||||
return (
|
||||
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9] px-6 py-10">
|
||||
<LoadingSpinner size="large" />
|
||||
<Text variant="body" className="mt-4 text-zinc-500">
|
||||
Starting your chat...
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Show Welcome screen
|
||||
const isLoading = isUserLoading;
|
||||
|
||||
return (
|
||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
||||
<div className="w-full text-center">
|
||||
{isLoading ? (
|
||||
<div className="mx-auto max-w-2xl">
|
||||
<Skeleton className="mx-auto mb-3 h-8 w-64" />
|
||||
<Skeleton className="mx-auto mb-8 h-6 w-80" />
|
||||
<div className="mb-8">
|
||||
<Skeleton className="mx-auto h-14 w-full rounded-lg" />
|
||||
</div>
|
||||
<div className="flex flex-wrap items-center justify-center gap-3">
|
||||
{Array.from({ length: 4 }).map((_, i) => (
|
||||
<Skeleton key={i} className="h-9 w-48 rounded-md" />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="mx-auto max-w-2xl">
|
||||
<Text
|
||||
variant="h3"
|
||||
className="mb-3 !text-[1.375rem] text-zinc-700"
|
||||
>
|
||||
Hey, <span className="text-violet-600">{greetingName}</span>
|
||||
</Text>
|
||||
<Text variant="h3" className="mb-8 !font-normal">
|
||||
What do you want to automate?
|
||||
</Text>
|
||||
|
||||
<div className="mb-6">
|
||||
<ChatInput
|
||||
onSend={startChatWithPrompt}
|
||||
placeholder='You can search or just ask - e.g. "create a blog post outline"'
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-nowrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||
{quickActions.map((action) => (
|
||||
<Button
|
||||
key={action}
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => handleQuickAction(action)}
|
||||
className="h-auto shrink-0 border-zinc-600 !px-4 !py-2 text-[1rem] text-zinc-600"
|
||||
>
|
||||
{action}
|
||||
</Button>
|
||||
))}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,8 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { Suspense } from "react";
|
||||
import { getErrorDetails } from "./helpers";
|
||||
@@ -11,8 +9,6 @@ function ErrorPageContent() {
|
||||
const searchParams = useSearchParams();
|
||||
const errorMessage = searchParams.get("message");
|
||||
const errorDetails = getErrorDetails(errorMessage);
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
function handleRetry() {
|
||||
// Auth-related errors should redirect to login
|
||||
@@ -29,8 +25,8 @@ function ErrorPageContent() {
|
||||
window.location.reload();
|
||||
}, 2000);
|
||||
} else {
|
||||
// For server/network errors, go to home
|
||||
window.location.href = homepageRoute;
|
||||
// For server/network errors, go to marketplace
|
||||
window.location.href = "/marketplace";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,10 @@ import {
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { ScheduleAgentModal } from "../ScheduleAgentModal/ScheduleAgentModal";
|
||||
import {
|
||||
AIAgentSafetyPopup,
|
||||
useAIAgentSafetyPopup,
|
||||
} from "./components/AIAgentSafetyPopup/AIAgentSafetyPopup";
|
||||
import { ModalHeader } from "./components/ModalHeader/ModalHeader";
|
||||
import { ModalRunSection } from "./components/ModalRunSection/ModalRunSection";
|
||||
import { RunActions } from "./components/RunActions/RunActions";
|
||||
@@ -83,8 +87,18 @@ export function RunAgentModal({
|
||||
|
||||
const [isScheduleModalOpen, setIsScheduleModalOpen] = useState(false);
|
||||
const [hasOverflow, setHasOverflow] = useState(false);
|
||||
const [isSafetyPopupOpen, setIsSafetyPopupOpen] = useState(false);
|
||||
const [pendingRunAction, setPendingRunAction] = useState<(() => void) | null>(
|
||||
null,
|
||||
);
|
||||
const contentRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const { shouldShowPopup, dismissPopup } = useAIAgentSafetyPopup(
|
||||
agent.id,
|
||||
agent.has_sensitive_action,
|
||||
agent.has_human_in_the_loop,
|
||||
);
|
||||
|
||||
const hasAnySetupFields =
|
||||
Object.keys(agentInputFields || {}).length > 0 ||
|
||||
Object.keys(agentCredentialsInputFields || {}).length > 0;
|
||||
@@ -165,6 +179,24 @@ export function RunAgentModal({
|
||||
onScheduleCreated?.(schedule);
|
||||
}
|
||||
|
||||
function handleRunWithSafetyCheck() {
|
||||
if (shouldShowPopup) {
|
||||
setPendingRunAction(() => handleRun);
|
||||
setIsSafetyPopupOpen(true);
|
||||
} else {
|
||||
handleRun();
|
||||
}
|
||||
}
|
||||
|
||||
function handleSafetyPopupAcknowledge() {
|
||||
setIsSafetyPopupOpen(false);
|
||||
dismissPopup();
|
||||
if (pendingRunAction) {
|
||||
pendingRunAction();
|
||||
setPendingRunAction(null);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Dialog
|
||||
@@ -180,7 +212,7 @@ export function RunAgentModal({
|
||||
|
||||
{/* Content */}
|
||||
{hasAnySetupFields ? (
|
||||
<div className="mt-4 pb-10">
|
||||
<div className="mt-10 pb-32">
|
||||
<RunAgentModalContextProvider
|
||||
value={{
|
||||
agent,
|
||||
@@ -248,7 +280,7 @@ export function RunAgentModal({
|
||||
)}
|
||||
<RunActions
|
||||
defaultRunType={defaultRunType}
|
||||
onRun={handleRun}
|
||||
onRun={handleRunWithSafetyCheck}
|
||||
isExecuting={isExecuting}
|
||||
isSettingUpTrigger={isSettingUpTrigger}
|
||||
isRunReady={allRequiredInputsAreSet}
|
||||
@@ -266,6 +298,12 @@ export function RunAgentModal({
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
|
||||
<AIAgentSafetyPopup
|
||||
agentId={agent.id}
|
||||
isOpen={isSafetyPopupOpen}
|
||||
onAcknowledge={handleSafetyPopupAcknowledge}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { ShieldCheckIcon } from "@phosphor-icons/react";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
interface Props {
|
||||
agentId: string;
|
||||
onAcknowledge: () => void;
|
||||
isOpen: boolean;
|
||||
}
|
||||
|
||||
export function AIAgentSafetyPopup({ agentId, onAcknowledge, isOpen }: Props) {
|
||||
function handleAcknowledge() {
|
||||
// Add this agent to the list of agents for which popup has been shown
|
||||
const seenAgentsJson = storage.get(Key.AI_AGENT_SAFETY_POPUP_SHOWN);
|
||||
const seenAgents: string[] = seenAgentsJson
|
||||
? JSON.parse(seenAgentsJson)
|
||||
: [];
|
||||
|
||||
if (!seenAgents.includes(agentId)) {
|
||||
seenAgents.push(agentId);
|
||||
storage.set(Key.AI_AGENT_SAFETY_POPUP_SHOWN, JSON.stringify(seenAgents));
|
||||
}
|
||||
|
||||
onAcknowledge();
|
||||
}
|
||||
|
||||
if (!isOpen) return null;
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
controlled={{ isOpen, set: () => {} }}
|
||||
styling={{ maxWidth: "480px" }}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="flex flex-col items-center p-6 text-center">
|
||||
<div className="mb-6 flex h-16 w-16 items-center justify-center rounded-full bg-blue-50">
|
||||
<ShieldCheckIcon
|
||||
weight="fill"
|
||||
size={32}
|
||||
className="text-blue-600"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Text variant="h3" className="mb-4">
|
||||
Safety Checks Enabled
|
||||
</Text>
|
||||
|
||||
<Text variant="body" className="mb-2 text-zinc-700">
|
||||
AI-generated agents may take actions that affect your data or
|
||||
external systems.
|
||||
</Text>
|
||||
|
||||
<Text variant="body" className="mb-8 text-zinc-700">
|
||||
AutoGPT includes safety checks so you'll always have the
|
||||
opportunity to review and approve sensitive actions before they
|
||||
happen.
|
||||
</Text>
|
||||
|
||||
<Button
|
||||
variant="primary"
|
||||
size="large"
|
||||
className="w-full"
|
||||
onClick={handleAcknowledge}
|
||||
>
|
||||
Got it
|
||||
</Button>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
export function useAIAgentSafetyPopup(
|
||||
agentId: string,
|
||||
hasSensitiveAction: boolean,
|
||||
hasHumanInTheLoop: boolean,
|
||||
) {
|
||||
const [shouldShowPopup, setShouldShowPopup] = useState(false);
|
||||
const [hasChecked, setHasChecked] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (hasChecked) return;
|
||||
|
||||
const seenAgentsJson = storage.get(Key.AI_AGENT_SAFETY_POPUP_SHOWN);
|
||||
const seenAgents: string[] = seenAgentsJson
|
||||
? JSON.parse(seenAgentsJson)
|
||||
: [];
|
||||
const hasSeenPopupForThisAgent = seenAgents.includes(agentId);
|
||||
const isRelevantAgent = hasSensitiveAction || hasHumanInTheLoop;
|
||||
|
||||
setShouldShowPopup(!hasSeenPopupForThisAgent && isRelevantAgent);
|
||||
setHasChecked(true);
|
||||
}, [agentId, hasSensitiveAction, hasHumanInTheLoop, hasChecked]);
|
||||
|
||||
const dismissPopup = useCallback(() => {
|
||||
setShouldShowPopup(false);
|
||||
}, []);
|
||||
|
||||
return {
|
||||
shouldShowPopup,
|
||||
dismissPopup,
|
||||
};
|
||||
}
|
||||
@@ -29,7 +29,7 @@ export function ModalHeader({ agent }: ModalHeaderProps) {
|
||||
<ShowMoreText
|
||||
previewLimit={400}
|
||||
variant="small"
|
||||
className="mb-2 mt-4 !text-zinc-700"
|
||||
className="mt-4 !text-zinc-700"
|
||||
>
|
||||
{agent.description}
|
||||
</ShowMoreText>
|
||||
@@ -40,8 +40,6 @@ export function ModalHeader({ agent }: ModalHeaderProps) {
|
||||
<Text variant="lead-semibold" className="text-blue-600">
|
||||
Tip
|
||||
</Text>
|
||||
<div className="h-px w-full bg-blue-100" />
|
||||
|
||||
<Text variant="body">
|
||||
For best results, run this agent{" "}
|
||||
{humanizeCronExpression(
|
||||
@@ -52,7 +50,7 @@ export function ModalHeader({ agent }: ModalHeaderProps) {
|
||||
) : null}
|
||||
|
||||
{agent.instructions ? (
|
||||
<div className="mt-4 flex flex-col gap-4 rounded-medium border border-purple-100 bg-[#f1ebfe80] p-4">
|
||||
<div className="flex flex-col gap-4 rounded-medium border border-purple-100 bg-[#F1EBFE/5] p-4">
|
||||
<Text variant="lead-semibold" className="text-purple-600">
|
||||
Instructions
|
||||
</Text>
|
||||
|
||||
@@ -69,7 +69,6 @@ export function SafeModeToggle({ graph, className }: Props) {
|
||||
const {
|
||||
currentHITLSafeMode,
|
||||
showHITLToggle,
|
||||
isHITLStateUndetermined,
|
||||
handleHITLToggle,
|
||||
currentSensitiveActionSafeMode,
|
||||
showSensitiveActionToggle,
|
||||
@@ -78,20 +77,13 @@ export function SafeModeToggle({ graph, className }: Props) {
|
||||
shouldShowToggle,
|
||||
} = useAgentSafeMode(graph);
|
||||
|
||||
if (!shouldShowToggle || isHITLStateUndetermined) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const showHITL = showHITLToggle && !isHITLStateUndetermined;
|
||||
const showSensitive = showSensitiveActionToggle;
|
||||
|
||||
if (!showHITL && !showSensitive) {
|
||||
if (!shouldShowToggle) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("flex gap-1", className)}>
|
||||
{showHITL && (
|
||||
{showHITLToggle && (
|
||||
<SafeModeIconButton
|
||||
isEnabled={currentHITLSafeMode}
|
||||
label="Human-in-the-loop"
|
||||
@@ -101,7 +93,7 @@ export function SafeModeToggle({ graph, className }: Props) {
|
||||
isPending={isPending}
|
||||
/>
|
||||
)}
|
||||
{showSensitive && (
|
||||
{showSensitiveActionToggle && (
|
||||
<SafeModeIconButton
|
||||
isEnabled={currentSensitiveActionSafeMode}
|
||||
label="Sensitive actions"
|
||||
|
||||
@@ -8,8 +8,6 @@ import { useGetV2GetUserProfile } from "@/app/api/__generated__/endpoints/store/
|
||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||
import { okData } from "@/app/api/helpers";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { isLogoutInProgress } from "@/lib/autogpt-server-api/helpers";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { updateFavoriteInQueries } from "./helpers";
|
||||
|
||||
interface Props {
|
||||
@@ -25,14 +23,10 @@ export function useLibraryAgentCard({ agent }: Props) {
|
||||
const { toast } = useToast();
|
||||
const queryClient = getQueryClient();
|
||||
const { mutateAsync: updateLibraryAgent } = usePatchV2UpdateLibraryAgent();
|
||||
const { user, isLoggedIn } = useSupabase();
|
||||
const logoutInProgress = isLogoutInProgress();
|
||||
|
||||
const { data: profile } = useGetV2GetUserProfile({
|
||||
query: {
|
||||
select: okData,
|
||||
enabled: isLoggedIn && !!user && !logoutInProgress,
|
||||
queryKey: ["/api/store/profile", user?.id],
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { environment } from "@/services/environment";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
@@ -22,17 +20,15 @@ export function useLoginPage() {
|
||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||
const isCloudEnv = environment.isCloud();
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
// Get redirect destination from 'next' query parameter
|
||||
const nextUrl = searchParams.get("next");
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoggedIn && !isLoggingIn) {
|
||||
router.push(nextUrl || homepageRoute);
|
||||
router.push(nextUrl || "/marketplace");
|
||||
}
|
||||
}, [homepageRoute, isLoggedIn, isLoggingIn, nextUrl, router]);
|
||||
}, [isLoggedIn, isLoggingIn, nextUrl, router]);
|
||||
|
||||
const form = useForm<z.infer<typeof loginFormSchema>>({
|
||||
resolver: zodResolver(loginFormSchema),
|
||||
@@ -102,7 +98,7 @@ export function useLoginPage() {
|
||||
} else if (result.onboarding) {
|
||||
router.replace("/onboarding");
|
||||
} else {
|
||||
router.replace(homepageRoute);
|
||||
router.replace("/marketplace");
|
||||
}
|
||||
} catch (error) {
|
||||
toast({
|
||||
|
||||
@@ -80,7 +80,6 @@ export const AgentInfo = ({
|
||||
const allVersions = storeData?.versions
|
||||
? storeData.versions
|
||||
.map((versionStr: string) => parseInt(versionStr, 10))
|
||||
.filter((versionNum: number) => !isNaN(versionNum))
|
||||
.sort((a: number, b: number) => b - a)
|
||||
.map((versionNum: number) => ({
|
||||
version: versionNum,
|
||||
|
||||
@@ -1,173 +0,0 @@
|
||||
import { describe, expect, test, afterEach } from "vitest";
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
waitFor,
|
||||
act,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import { MainAgentPage } from "../MainAgentPage";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
import { getGetV2GetSpecificAgentMockHandler422 } from "@/app/api/__generated__/endpoints/store/store.msw";
|
||||
import { create500Handler } from "@/tests/integrations/helpers/create-500-handler";
|
||||
import {
|
||||
mockAuthenticatedUser,
|
||||
mockUnauthenticatedUser,
|
||||
resetAuthState,
|
||||
} from "@/tests/integrations/helpers/mock-supabase-auth";
|
||||
|
||||
const defaultParams = {
|
||||
creator: "test-creator",
|
||||
slug: "test-agent",
|
||||
};
|
||||
|
||||
describe("MainAgentPage", () => {
|
||||
afterEach(() => {
|
||||
resetAuthState();
|
||||
});
|
||||
|
||||
describe("rendering", () => {
|
||||
test("renders agent info with title", async () => {
|
||||
render(<MainAgentPage params={defaultParams} />);
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("agent-title")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders agent creator info", async () => {
|
||||
render(<MainAgentPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("agent-creator")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders agent description", async () => {
|
||||
render(<MainAgentPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("agent-description")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders breadcrumbs with marketplace link", async () => {
|
||||
render(<MainAgentPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByRole("link", { name: /marketplace/i }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders download button", async () => {
|
||||
render(<MainAgentPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("agent-download-button")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders similar agents section", async () => {
|
||||
render(<MainAgentPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Similar agents", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("auth state", () => {
|
||||
test("shows add to library button when authenticated", async () => {
|
||||
mockAuthenticatedUser();
|
||||
render(<MainAgentPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("agent-add-library-button"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("hides add to library button when not authenticated", async () => {
|
||||
mockUnauthenticatedUser();
|
||||
render(<MainAgentPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("agent-title")).toBeInTheDocument();
|
||||
});
|
||||
expect(
|
||||
screen.queryByTestId("agent-add-library-button"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders page correctly when logged out", async () => {
|
||||
mockUnauthenticatedUser();
|
||||
render(<MainAgentPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("agent-title")).toBeInTheDocument();
|
||||
});
|
||||
expect(screen.getByTestId("agent-download-button")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders page correctly when logged in", async () => {
|
||||
mockAuthenticatedUser();
|
||||
render(<MainAgentPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("agent-title")).toBeInTheDocument();
|
||||
});
|
||||
expect(
|
||||
screen.getByTestId("agent-add-library-button"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("error handling", () => {
|
||||
test("displays error when agent API returns 422", async () => {
|
||||
server.use(getGetV2GetSpecificAgentMockHandler422());
|
||||
|
||||
render(<MainAgentPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Failed to load agent data", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
await act(async () => {});
|
||||
});
|
||||
|
||||
test("displays error when API returns 500", async () => {
|
||||
server.use(
|
||||
create500Handler("get", "*/api/store/agents/test-creator/test-agent"),
|
||||
);
|
||||
|
||||
render(<MainAgentPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Failed to load agent data", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
await act(async () => {});
|
||||
});
|
||||
|
||||
test("retry button is visible on error", async () => {
|
||||
server.use(getGetV2GetSpecificAgentMockHandler422());
|
||||
|
||||
render(<MainAgentPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByRole("button", { name: /try again/i }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
await act(async () => {});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,131 +0,0 @@
|
||||
import { describe, expect, test, afterEach } from "vitest";
|
||||
import { cleanup, render, screen, waitFor } from "@/tests/integrations/test-utils";
|
||||
import { MainCreatorPage } from "../MainCreatorPage";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
import {
|
||||
getGetV2GetCreatorDetailsMockHandler422,
|
||||
getGetV2ListStoreAgentsMockHandler422,
|
||||
} from "@/app/api/__generated__/endpoints/store/store.msw";
|
||||
import { create500Handler } from "@/tests/integrations/helpers/create-500-handler";
|
||||
import {
|
||||
mockAuthenticatedUser,
|
||||
mockUnauthenticatedUser,
|
||||
resetAuthState,
|
||||
} from "@/tests/integrations/helpers/mock-supabase-auth";
|
||||
|
||||
const defaultParams = {
|
||||
creator: "test-creator",
|
||||
};
|
||||
|
||||
describe("MainCreatorPage", () => {
|
||||
afterEach(() => {
|
||||
resetAuthState();
|
||||
});
|
||||
|
||||
describe("rendering", () => {
|
||||
test("renders creator info card", async () => {
|
||||
render(<MainCreatorPage params={defaultParams} />);
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("creator-description")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders breadcrumbs with marketplace link", async () => {
|
||||
render(<MainCreatorPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByRole("link", { name: /marketplace/i }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders about section", async () => {
|
||||
render(<MainCreatorPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("About")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders agents by creator section", async () => {
|
||||
render(<MainCreatorPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/Agents by/i, { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("auth state", () => {
|
||||
test("renders page correctly when logged out", async () => {
|
||||
mockUnauthenticatedUser();
|
||||
render(<MainCreatorPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("creator-description")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders page correctly when logged in", async () => {
|
||||
mockAuthenticatedUser();
|
||||
render(<MainCreatorPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("creator-description")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("error handling", () => {
|
||||
test("displays error when creator details API returns 422", async () => {
|
||||
server.use(getGetV2GetCreatorDetailsMockHandler422());
|
||||
|
||||
render(<MainCreatorPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Failed to load creator data", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("displays error when creator agents API returns 422", async () => {
|
||||
server.use(getGetV2ListStoreAgentsMockHandler422());
|
||||
|
||||
render(<MainCreatorPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Failed to load creator data", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("displays error when API returns 500", async () => {
|
||||
server.use(create500Handler("get", "*/api/store/creator/test-creator"));
|
||||
|
||||
render(<MainCreatorPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Failed to load creator data", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("retry button is visible on error", async () => {
|
||||
server.use(getGetV2GetCreatorDetailsMockHandler422());
|
||||
|
||||
render(<MainCreatorPage params={defaultParams} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByRole("button", { name: /try again/i }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,132 +0,0 @@
|
||||
import { describe, expect, test, afterEach } from "vitest";
|
||||
import { render, screen, waitFor } from "@/tests/integrations/test-utils";
|
||||
import { MainMarkeplacePage } from "../MainMarketplacePage";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
import {
|
||||
getGetV2ListStoreAgentsMockHandler422,
|
||||
getGetV2ListStoreCreatorsMockHandler422,
|
||||
} from "@/app/api/__generated__/endpoints/store/store.msw";
|
||||
import { create500Handler } from "@/tests/integrations/helpers/create-500-handler";
|
||||
import {
|
||||
mockAuthenticatedUser,
|
||||
mockUnauthenticatedUser,
|
||||
resetAuthState,
|
||||
} from "@/tests/integrations/helpers/mock-supabase-auth";
|
||||
|
||||
describe("MainMarketplacePage", () => {
|
||||
afterEach(() => {
|
||||
resetAuthState();
|
||||
});
|
||||
|
||||
describe("rendering", () => {
|
||||
test("renders hero section with search bar", async () => {
|
||||
render(<MainMarkeplacePage />);
|
||||
|
||||
expect(
|
||||
await screen.findByText("Featured agents", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
|
||||
expect(screen.getByPlaceholderText(/search/i)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders featured agents section", async () => {
|
||||
render(<MainMarkeplacePage />);
|
||||
|
||||
expect(
|
||||
await screen.findByText("Featured agents", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders top agents section", async () => {
|
||||
render(<MainMarkeplacePage />);
|
||||
|
||||
expect(
|
||||
await screen.findByText("Top Agents", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders featured creators section", async () => {
|
||||
render(<MainMarkeplacePage />);
|
||||
|
||||
expect(
|
||||
await screen.findByText("Featured creators", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("auth state", () => {
|
||||
test("renders page correctly when logged out", async () => {
|
||||
mockUnauthenticatedUser();
|
||||
render(<MainMarkeplacePage />);
|
||||
|
||||
expect(
|
||||
await screen.findByText("Featured agents", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Top Agents", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders page correctly when logged in", async () => {
|
||||
mockAuthenticatedUser();
|
||||
render(<MainMarkeplacePage />);
|
||||
|
||||
expect(
|
||||
await screen.findByText("Featured agents", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Top Agents", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("error handling", () => {
|
||||
test("displays error when featured agents API returns 422", async () => {
|
||||
server.use(getGetV2ListStoreAgentsMockHandler422());
|
||||
|
||||
render(<MainMarkeplacePage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Failed to load marketplace data", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("displays error when creators API returns 422", async () => {
|
||||
server.use(getGetV2ListStoreCreatorsMockHandler422());
|
||||
|
||||
render(<MainMarkeplacePage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Failed to load marketplace data", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("displays error when API returns 500", async () => {
|
||||
server.use(create500Handler("get", "*/api/store/agents*"));
|
||||
|
||||
render(<MainMarkeplacePage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Failed to load marketplace data", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("retry button is visible on error", async () => {
|
||||
server.use(getGetV2ListStoreAgentsMockHandler422());
|
||||
|
||||
render(<MainMarkeplacePage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByRole("button", { name: /try again/i }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,114 +0,0 @@
|
||||
import { describe, expect, test, afterEach } from "vitest";
|
||||
import { render, screen, waitFor } from "@/tests/integrations/test-utils";
|
||||
import { MainSearchResultPage } from "../MainSearchResultPage";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
import {
|
||||
getGetV2ListStoreAgentsMockHandler422,
|
||||
getGetV2ListStoreCreatorsMockHandler422,
|
||||
} from "@/app/api/__generated__/endpoints/store/store.msw";
|
||||
import { create500Handler } from "@/tests/integrations/helpers/create-500-handler";
|
||||
import {
|
||||
mockAuthenticatedUser,
|
||||
mockUnauthenticatedUser,
|
||||
resetAuthState,
|
||||
} from "@/tests/integrations/helpers/mock-supabase-auth";
|
||||
|
||||
const defaultProps = {
|
||||
searchTerm: "test-search",
|
||||
sort: undefined as undefined,
|
||||
};
|
||||
|
||||
describe("MainSearchResultPage", () => {
|
||||
afterEach(() => {
|
||||
resetAuthState();
|
||||
});
|
||||
|
||||
describe("rendering", () => {
|
||||
test("renders search results header with search term", async () => {
|
||||
render(<MainSearchResultPage {...defaultProps} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Results for:")).toBeInTheDocument();
|
||||
});
|
||||
expect(screen.getByText("test-search")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders search bar", async () => {
|
||||
render(<MainSearchResultPage {...defaultProps} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText(/search/i)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("auth state", () => {
|
||||
test("renders page correctly when logged out", async () => {
|
||||
mockUnauthenticatedUser();
|
||||
render(<MainSearchResultPage {...defaultProps} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Results for:")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders page correctly when logged in", async () => {
|
||||
mockAuthenticatedUser();
|
||||
render(<MainSearchResultPage {...defaultProps} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Results for:")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("error handling", () => {
|
||||
test("displays error when agents API returns 422", async () => {
|
||||
server.use(getGetV2ListStoreAgentsMockHandler422());
|
||||
|
||||
render(<MainSearchResultPage {...defaultProps} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Failed to load marketplace data", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("displays error when creators API returns 422", async () => {
|
||||
server.use(getGetV2ListStoreCreatorsMockHandler422());
|
||||
|
||||
render(<MainSearchResultPage {...defaultProps} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Failed to load marketplace data", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("displays error when API returns 500", async () => {
|
||||
server.use(create500Handler("get", "*/api/store/agents*"));
|
||||
|
||||
render(<MainSearchResultPage {...defaultProps} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Failed to load marketplace data", { exact: false }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("retry button is visible on error", async () => {
|
||||
server.use(getGetV2ListStoreAgentsMockHandler422());
|
||||
|
||||
render(<MainSearchResultPage {...defaultProps} />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByRole("button", { name: /try again/i }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -3,14 +3,12 @@
|
||||
import { useGetV2GetUserProfile } from "@/app/api/__generated__/endpoints/store/store";
|
||||
import { ProfileInfoForm } from "@/components/__legacy__/ProfileInfoForm";
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { isLogoutInProgress } from "@/lib/autogpt-server-api/helpers";
|
||||
import { ProfileDetails } from "@/lib/autogpt-server-api/types";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { ProfileLoading } from "./ProfileLoading";
|
||||
|
||||
export default function UserProfilePage() {
|
||||
const { user } = useSupabase();
|
||||
const logoutInProgress = isLogoutInProgress();
|
||||
|
||||
const {
|
||||
data: profile,
|
||||
@@ -20,7 +18,7 @@ export default function UserProfilePage() {
|
||||
refetch,
|
||||
} = useGetV2GetUserProfile<ProfileDetails | null>({
|
||||
query: {
|
||||
enabled: !!user && !logoutInProgress,
|
||||
enabled: !!user,
|
||||
select: (res) => {
|
||||
if (res.status === 200) {
|
||||
return {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"use server";
|
||||
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import { signupFormSchema } from "@/types/auth";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
@@ -12,7 +11,6 @@ export async function signup(
|
||||
password: string,
|
||||
confirmPassword: string,
|
||||
agreeToTerms: boolean,
|
||||
isChatEnabled: boolean,
|
||||
) {
|
||||
try {
|
||||
const parsed = signupFormSchema.safeParse({
|
||||
@@ -60,9 +58,7 @@ export async function signup(
|
||||
}
|
||||
|
||||
const isOnboardingEnabled = await shouldShowOnboarding();
|
||||
const next = isOnboardingEnabled
|
||||
? "/onboarding"
|
||||
: getHomepageRoute(isChatEnabled);
|
||||
const next = isOnboardingEnabled ? "/onboarding" : "/";
|
||||
|
||||
return { success: true, next };
|
||||
} catch (err) {
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { environment } from "@/services/environment";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { LoginProvider, signupFormSchema } from "@/types/auth";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
@@ -22,17 +20,15 @@ export function useSignupPage() {
|
||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||
const isCloudEnv = environment.isCloud();
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
// Get redirect destination from 'next' query parameter
|
||||
const nextUrl = searchParams.get("next");
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoggedIn && !isSigningUp) {
|
||||
router.push(nextUrl || homepageRoute);
|
||||
router.push(nextUrl || "/marketplace");
|
||||
}
|
||||
}, [homepageRoute, isLoggedIn, isSigningUp, nextUrl, router]);
|
||||
}, [isLoggedIn, isSigningUp, nextUrl, router]);
|
||||
|
||||
const form = useForm<z.infer<typeof signupFormSchema>>({
|
||||
resolver: zodResolver(signupFormSchema),
|
||||
@@ -108,7 +104,6 @@ export function useSignupPage() {
|
||||
data.password,
|
||||
data.confirmPassword,
|
||||
data.agreeToTerms,
|
||||
isChatEnabled === true,
|
||||
);
|
||||
|
||||
setIsLoading(false);
|
||||
@@ -134,7 +129,7 @@ export function useSignupPage() {
|
||||
}
|
||||
|
||||
// Prefer the URL's next parameter, then result.next (for onboarding), then default
|
||||
const redirectTo = nextUrl || result.next || homepageRoute;
|
||||
const redirectTo = nextUrl || result.next || "/";
|
||||
router.replace(redirectTo);
|
||||
} catch (error) {
|
||||
setIsLoading(false);
|
||||
|
||||
@@ -4,12 +4,12 @@ import {
|
||||
getServerAuthToken,
|
||||
} from "@/lib/autogpt-server-api/helpers";
|
||||
|
||||
import { transformDates } from "./date-transformer";
|
||||
import { environment } from "@/services/environment";
|
||||
import {
|
||||
IMPERSONATION_HEADER_NAME,
|
||||
IMPERSONATION_STORAGE_KEY,
|
||||
} from "@/lib/constants";
|
||||
import { environment } from "@/services/environment";
|
||||
import { transformDates } from "./date-transformer";
|
||||
|
||||
const FRONTEND_BASE_URL =
|
||||
process.env.NEXT_PUBLIC_FRONTEND_BASE_URL || "http://localhost:3000";
|
||||
@@ -136,19 +136,16 @@ export const customMutator = async <
|
||||
response.statusText ||
|
||||
`HTTP ${response.status}`;
|
||||
|
||||
const isTestEnv = process.env.NODE_ENV === 'test';
|
||||
if (!isTestEnv) {
|
||||
console.error(
|
||||
`Request failed ${environment.isServerSide() ? "on server" : "on client"}`,
|
||||
{
|
||||
status: response.status,
|
||||
method,
|
||||
url: fullUrl.replace(baseUrl, ""), // Show relative URL for cleaner logs
|
||||
errorMessage,
|
||||
responseData: responseData || "No response data",
|
||||
},
|
||||
);
|
||||
}
|
||||
console.error(
|
||||
`Request failed ${environment.isServerSide() ? "on server" : "on client"}`,
|
||||
{
|
||||
status: response.status,
|
||||
method,
|
||||
url: fullUrl.replace(baseUrl, ""), // Show relative URL for cleaner logs
|
||||
errorMessage,
|
||||
responseData: responseData || "No response data",
|
||||
},
|
||||
);
|
||||
|
||||
throw new ApiError(errorMessage, response.status, responseData);
|
||||
}
|
||||
|
||||
@@ -1022,7 +1022,7 @@
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Session",
|
||||
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.",
|
||||
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.",
|
||||
"operationId": "getV2GetSession",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
@@ -9411,6 +9411,12 @@
|
||||
],
|
||||
"title": "Reviewed Data",
|
||||
"description": "Optional edited data (ignored if approved=False)"
|
||||
},
|
||||
"auto_approve_future": {
|
||||
"type": "boolean",
|
||||
"title": "Auto Approve Future",
|
||||
"description": "If true and this review is approved, future executions of this same block (node) will be automatically approved. This only affects approved reviews.",
|
||||
"default": false
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -9430,7 +9436,7 @@
|
||||
"type": "object",
|
||||
"required": ["reviews"],
|
||||
"title": "ReviewRequest",
|
||||
"description": "Request model for processing ALL pending reviews for an execution.\n\nThis request must include ALL pending reviews for a graph execution.\nEach review will be either approved (with optional data modifications)\nor rejected (data ignored). The execution will resume only after ALL reviews are processed."
|
||||
"description": "Request model for processing ALL pending reviews for an execution.\n\nThis request must include ALL pending reviews for a graph execution.\nEach review will be either approved (with optional data modifications)\nor rejected (data ignored). The execution will resume only after ALL reviews are processed.\n\nEach review item can individually specify whether to auto-approve future executions\nof the same block via the `auto_approve_future` field on ReviewItem."
|
||||
},
|
||||
"ReviewResponse": {
|
||||
"properties": {
|
||||
|
||||
@@ -141,6 +141,52 @@
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes shimmer {
|
||||
0% {
|
||||
background-position: -200% 0;
|
||||
}
|
||||
100% {
|
||||
background-position: 200% 0;
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes l3 {
|
||||
25% {
|
||||
background-position:
|
||||
0 0,
|
||||
100% 100%,
|
||||
100% calc(100% - 5px);
|
||||
}
|
||||
50% {
|
||||
background-position:
|
||||
0 100%,
|
||||
100% 100%,
|
||||
0 calc(100% - 5px);
|
||||
}
|
||||
75% {
|
||||
background-position:
|
||||
0 100%,
|
||||
100% 0,
|
||||
100% 5px;
|
||||
}
|
||||
}
|
||||
|
||||
.loader {
|
||||
width: 80px;
|
||||
height: 70px;
|
||||
border: 5px solid rgb(241 245 249);
|
||||
padding: 0 8px;
|
||||
box-sizing: border-box;
|
||||
background:
|
||||
linear-gradient(rgb(15 23 42) 0 0) 0 0/8px 20px,
|
||||
linear-gradient(rgb(15 23 42) 0 0) 100% 0/8px 20px,
|
||||
radial-gradient(farthest-side, rgb(15 23 42) 90%, #0000) 0 5px/8px 8px
|
||||
content-box,
|
||||
transparent;
|
||||
background-repeat: no-repeat;
|
||||
animation: l3 2s infinite linear;
|
||||
}
|
||||
|
||||
input[type="number"]::-webkit-outer-spin-button,
|
||||
input[type="number"]::-webkit-inner-spin-button {
|
||||
-webkit-appearance: none;
|
||||
|
||||
@@ -1,27 +1,5 @@
|
||||
"use client";
|
||||
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
import { redirect } from "next/navigation";
|
||||
|
||||
export default function Page() {
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const router = useRouter();
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
||||
const isFlagReady =
|
||||
!isLaunchDarklyConfigured || typeof isChatEnabled === "boolean";
|
||||
|
||||
useEffect(
|
||||
function redirectToHomepage() {
|
||||
if (!isFlagReady) return;
|
||||
router.replace(homepageRoute);
|
||||
},
|
||||
[homepageRoute, isFlagReady, router],
|
||||
);
|
||||
|
||||
return null;
|
||||
redirect("/marketplace");
|
||||
}
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
// import { render, screen } from "@testing-library/react";
|
||||
// import { describe, expect, it } from "vitest";
|
||||
// import { Badge } from "./Badge";
|
||||
|
||||
// describe("Badge Component", () => {
|
||||
// it("renders badge with content", () => {
|
||||
// render(<Badge variant="success">Success</Badge>);
|
||||
|
||||
// expect(screen.getByText("Success")).toBeInTheDocument();
|
||||
// });
|
||||
|
||||
// it("applies correct variant styles", () => {
|
||||
// const { rerender } = render(<Badge variant="success">Success</Badge>);
|
||||
// let badge = screen.getByText("Success");
|
||||
// expect(badge).toHaveClass("bg-green-100", "text-green-800");
|
||||
|
||||
// rerender(<Badge variant="error">Error</Badge>);
|
||||
// badge = screen.getByText("Error");
|
||||
// expect(badge).toHaveClass("bg-red-100", "text-red-800");
|
||||
|
||||
// rerender(<Badge variant="info">Info</Badge>);
|
||||
// badge = screen.getByText("Info");
|
||||
// expect(badge).toHaveClass("bg-slate-100", "text-slate-800");
|
||||
// });
|
||||
|
||||
// it("applies custom className", () => {
|
||||
// render(
|
||||
// <Badge variant="success" className="custom-class">
|
||||
// Success
|
||||
// </Badge>,
|
||||
// );
|
||||
|
||||
// const badge = screen.getByText("Success");
|
||||
// expect(badge).toHaveClass("custom-class");
|
||||
// });
|
||||
|
||||
// it("renders as span element", () => {
|
||||
// render(<Badge variant="success">Success</Badge>);
|
||||
|
||||
// const badge = screen.getByText("Success");
|
||||
// expect(badge.tagName).toBe("SPAN");
|
||||
// });
|
||||
|
||||
// it("renders children correctly", () => {
|
||||
// render(
|
||||
// <Badge variant="success">
|
||||
// <span>Custom</span> Content
|
||||
// </Badge>,
|
||||
// );
|
||||
|
||||
// expect(screen.getByText("Custom")).toBeInTheDocument();
|
||||
// expect(screen.getByText("Content")).toBeInTheDocument();
|
||||
// });
|
||||
|
||||
// it("supports all badge variants", () => {
|
||||
// const variants = ["success", "error", "info"] as const;
|
||||
|
||||
// variants.forEach((variant) => {
|
||||
// const { unmount } = render(
|
||||
// <Badge variant={variant} data-testid={`badge-${variant}`}>
|
||||
// {variant}
|
||||
// </Badge>,
|
||||
// );
|
||||
|
||||
// expect(screen.getByTestId(`badge-${variant}`)).toBeInTheDocument();
|
||||
// unmount();
|
||||
// });
|
||||
// });
|
||||
|
||||
// it("handles long text content", () => {
|
||||
// render(
|
||||
// <Badge variant="info">
|
||||
// Very long text that should be handled properly by the component
|
||||
// </Badge>,
|
||||
// );
|
||||
|
||||
// const badge = screen.getByText(/Very long text/);
|
||||
// expect(badge).toBeInTheDocument();
|
||||
// expect(badge).toHaveClass("overflow-hidden", "text-ellipsis");
|
||||
// });
|
||||
// });
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user