mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-22 13:38:10 -05:00
Compare commits
4 Commits
dev
...
feat/agent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da9c4a4adf | ||
|
|
0ca73004e5 | ||
|
|
9a786ed8d9 | ||
|
|
0a435e2ffb |
38
.github/workflows/platform-frontend-ci.yml
vendored
38
.github/workflows/platform-frontend-ci.yml
vendored
@@ -128,7 +128,7 @@ jobs:
|
|||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
exitOnceUploaded: true
|
exitOnceUploaded: true
|
||||||
|
|
||||||
e2e_test:
|
test:
|
||||||
runs-on: big-boi
|
runs-on: big-boi
|
||||||
needs: setup
|
needs: setup
|
||||||
strategy:
|
strategy:
|
||||||
@@ -258,39 +258,3 @@ jobs:
|
|||||||
- name: Print Final Docker Compose logs
|
- name: Print Final Docker Compose logs
|
||||||
if: always()
|
if: always()
|
||||||
run: docker compose -f ../docker-compose.yml logs
|
run: docker compose -f ../docker-compose.yml logs
|
||||||
|
|
||||||
integration_test:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: setup
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
submodules: recursive
|
|
||||||
|
|
||||||
- name: Set up Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: "22.18.0"
|
|
||||||
|
|
||||||
- name: Enable corepack
|
|
||||||
run: corepack enable
|
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: ~/.pnpm-store
|
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
|
||||||
${{ runner.os }}-pnpm-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: pnpm install --frozen-lockfile
|
|
||||||
|
|
||||||
- name: Generate API client
|
|
||||||
run: pnpm generate:api
|
|
||||||
|
|
||||||
- name: Run Integration Tests
|
|
||||||
run: pnpm test:unit
|
|
||||||
|
|||||||
26
AGENTS.md
26
AGENTS.md
@@ -16,32 +16,6 @@ See `docs/content/platform/getting-started.md` for setup instructions.
|
|||||||
- Format Python code with `poetry run format`.
|
- Format Python code with `poetry run format`.
|
||||||
- Format frontend code using `pnpm format`.
|
- Format frontend code using `pnpm format`.
|
||||||
|
|
||||||
|
|
||||||
## Frontend guidelines:
|
|
||||||
|
|
||||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
|
||||||
|
|
||||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
|
||||||
- Add `usePageName.ts` hook for logic
|
|
||||||
- Put sub-components in local `components/` folder
|
|
||||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
|
||||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
|
||||||
- Never use `src/components/__legacy__/*`
|
|
||||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
|
||||||
- Regenerate with `pnpm generate:api`
|
|
||||||
- Pattern: `use{Method}{Version}{OperationName}`
|
|
||||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
|
||||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
|
||||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
|
||||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
|
||||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
|
||||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
|
||||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
|
||||||
- Use function declarations for components, arrow functions only for callbacks
|
|
||||||
- No barrel files or `index.ts` re-exports
|
|
||||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
|
||||||
- Avoid comments at all times unless the code is very complex
|
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
|
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
|
||||||
|
|||||||
@@ -201,7 +201,7 @@ If you get any pushback or hit complex block conditions check the new_blocks gui
|
|||||||
3. Write tests alongside the route file
|
3. Write tests alongside the route file
|
||||||
4. Run `poetry run test` to verify
|
4. Run `poetry run test` to verify
|
||||||
|
|
||||||
### Frontend guidelines:
|
**Frontend feature development:**
|
||||||
|
|
||||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||||
|
|
||||||
@@ -217,14 +217,6 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
|||||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
|
||||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
|
||||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
|
||||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
|
||||||
- Use function declarations for components, arrow functions only for callbacks
|
|
||||||
- No barrel files or `index.ts` re-exports
|
|
||||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
|
||||||
- Avoid comments at all times unless the code is very complex
|
|
||||||
|
|
||||||
### Security Implementation
|
### Security Implementation
|
||||||
|
|
||||||
|
|||||||
@@ -290,11 +290,6 @@ async def _cache_session(session: ChatSession) -> None:
|
|||||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||||
|
|
||||||
|
|
||||||
async def cache_chat_session(session: ChatSession) -> None:
|
|
||||||
"""Cache a chat session without persisting to the database."""
|
|
||||||
await _cache_session(session)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||||
"""Get a chat session from the database."""
|
"""Get a chat session from the database."""
|
||||||
prisma_session = await chat_db.get_chat_session(session_id)
|
prisma_session = await chat_db.get_chat_session(session_id)
|
||||||
|
|||||||
@@ -172,12 +172,12 @@ async def get_session(
|
|||||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SessionDetailResponse: Details for the requested session, or None if not found.
|
SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
if not session:
|
if not session:
|
||||||
raise NotFoundError(f"Session {session_id} not found.")
|
raise NotFoundError(f"Session {session_id} not found")
|
||||||
|
|
||||||
messages = [message.model_dump() for message in session.messages]
|
messages = [message.model_dump() for message in session.messages]
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -222,8 +222,6 @@ async def stream_chat_post(
|
|||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
chunk_count = 0
|
|
||||||
first_chunk_type: str | None = None
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
session_id,
|
session_id,
|
||||||
request.message,
|
request.message,
|
||||||
@@ -232,26 +230,7 @@ async def stream_chat_post(
|
|||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
context=request.context,
|
context=request.context,
|
||||||
):
|
):
|
||||||
if chunk_count < 3:
|
|
||||||
logger.info(
|
|
||||||
"Chat stream chunk",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_type": str(chunk.type),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if not first_chunk_type:
|
|
||||||
first_chunk_type = str(chunk.type)
|
|
||||||
chunk_count += 1
|
|
||||||
yield chunk.to_sse()
|
yield chunk.to_sse()
|
||||||
logger.info(
|
|
||||||
"Chat stream completed",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"first_chunk_type": first_chunk_type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination
|
# AI SDK protocol termination
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
@@ -296,8 +275,6 @@ async def stream_chat_get(
|
|||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
chunk_count = 0
|
|
||||||
first_chunk_type: str | None = None
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
session_id,
|
session_id,
|
||||||
message,
|
message,
|
||||||
@@ -305,26 +282,7 @@ async def stream_chat_get(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
):
|
):
|
||||||
if chunk_count < 3:
|
|
||||||
logger.info(
|
|
||||||
"Chat stream chunk",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_type": str(chunk.type),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if not first_chunk_type:
|
|
||||||
first_chunk_type = str(chunk.type)
|
|
||||||
chunk_count += 1
|
|
||||||
yield chunk.to_sse()
|
yield chunk.to_sse()
|
||||||
logger.info(
|
|
||||||
"Chat stream completed",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"first_chunk_type": first_chunk_type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination
|
# AI SDK protocol termination
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from asyncio import CancelledError
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from langfuse import get_client, propagate_attributes
|
from langfuse import get_client, propagate_attributes
|
||||||
from langfuse.openai import openai # type: ignore
|
from langfuse.openai import openai # type: ignore
|
||||||
from openai import (
|
from openai import APIConnectionError, APIError, APIStatusError, RateLimitError
|
||||||
APIConnectionError,
|
|
||||||
APIError,
|
|
||||||
APIStatusError,
|
|
||||||
PermissionDeniedError,
|
|
||||||
RateLimitError,
|
|
||||||
)
|
|
||||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||||
|
|
||||||
from backend.data.understanding import (
|
from backend.data.understanding import (
|
||||||
@@ -29,7 +21,6 @@ from .model import (
|
|||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatSession,
|
ChatSession,
|
||||||
Usage,
|
Usage,
|
||||||
cache_chat_session,
|
|
||||||
get_chat_session,
|
get_chat_session,
|
||||||
update_session_title,
|
update_session_title,
|
||||||
upsert_chat_session,
|
upsert_chat_session,
|
||||||
@@ -305,10 +296,6 @@ async def stream_chat_completion(
|
|||||||
content="",
|
content="",
|
||||||
)
|
)
|
||||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||||
has_saved_assistant_message = False
|
|
||||||
has_appended_streaming_message = False
|
|
||||||
last_cache_time = 0.0
|
|
||||||
last_cache_content_len = 0
|
|
||||||
|
|
||||||
# Wrap main logic in try/finally to ensure Langfuse observations are always ended
|
# Wrap main logic in try/finally to ensure Langfuse observations are always ended
|
||||||
has_yielded_end = False
|
has_yielded_end = False
|
||||||
@@ -345,23 +332,6 @@ async def stream_chat_completion(
|
|||||||
assert assistant_response.content is not None
|
assert assistant_response.content is not None
|
||||||
assistant_response.content += delta
|
assistant_response.content += delta
|
||||||
has_received_text = True
|
has_received_text = True
|
||||||
if not has_appended_streaming_message:
|
|
||||||
session.messages.append(assistant_response)
|
|
||||||
has_appended_streaming_message = True
|
|
||||||
current_time = time.monotonic()
|
|
||||||
content_len = len(assistant_response.content)
|
|
||||||
if (
|
|
||||||
current_time - last_cache_time >= 1.0
|
|
||||||
and content_len > last_cache_content_len
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
await cache_chat_session(session)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to cache partial session {session.session_id}: {e}"
|
|
||||||
)
|
|
||||||
last_cache_time = current_time
|
|
||||||
last_cache_content_len = content_len
|
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(chunk, StreamTextEnd):
|
elif isinstance(chunk, StreamTextEnd):
|
||||||
# Emit text-end after text completes
|
# Emit text-end after text completes
|
||||||
@@ -420,42 +390,10 @@ async def stream_chat_completion(
|
|||||||
if has_received_text and not text_streaming_ended:
|
if has_received_text and not text_streaming_ended:
|
||||||
yield StreamTextEnd(id=text_block_id)
|
yield StreamTextEnd(id=text_block_id)
|
||||||
text_streaming_ended = True
|
text_streaming_ended = True
|
||||||
|
|
||||||
# Save assistant message before yielding finish to ensure it's persisted
|
|
||||||
# even if client disconnects immediately after receiving StreamFinish
|
|
||||||
if not has_saved_assistant_message:
|
|
||||||
messages_to_save_early: list[ChatMessage] = []
|
|
||||||
if accumulated_tool_calls:
|
|
||||||
assistant_response.tool_calls = (
|
|
||||||
accumulated_tool_calls
|
|
||||||
)
|
|
||||||
if not has_appended_streaming_message and (
|
|
||||||
assistant_response.content
|
|
||||||
or assistant_response.tool_calls
|
|
||||||
):
|
|
||||||
messages_to_save_early.append(assistant_response)
|
|
||||||
messages_to_save_early.extend(tool_response_messages)
|
|
||||||
|
|
||||||
if messages_to_save_early:
|
|
||||||
session.messages.extend(messages_to_save_early)
|
|
||||||
logger.info(
|
|
||||||
f"Saving assistant message before StreamFinish: "
|
|
||||||
f"content_len={len(assistant_response.content or '')}, "
|
|
||||||
f"tool_calls={len(assistant_response.tool_calls or [])}, "
|
|
||||||
f"tool_responses={len(tool_response_messages)}"
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
messages_to_save_early
|
|
||||||
or has_appended_streaming_message
|
|
||||||
):
|
|
||||||
await upsert_chat_session(session)
|
|
||||||
has_saved_assistant_message = True
|
|
||||||
|
|
||||||
has_yielded_end = True
|
has_yielded_end = True
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(chunk, StreamError):
|
elif isinstance(chunk, StreamError):
|
||||||
has_yielded_error = True
|
has_yielded_error = True
|
||||||
yield chunk
|
|
||||||
elif isinstance(chunk, StreamUsage):
|
elif isinstance(chunk, StreamUsage):
|
||||||
session.usage.append(
|
session.usage.append(
|
||||||
Usage(
|
Usage(
|
||||||
@@ -475,27 +413,6 @@ async def stream_chat_completion(
|
|||||||
langfuse.update_current_trace(output=str(tool_response_messages))
|
langfuse.update_current_trace(output=str(tool_response_messages))
|
||||||
langfuse.update_current_span(output=str(tool_response_messages))
|
langfuse.update_current_span(output=str(tool_response_messages))
|
||||||
|
|
||||||
except CancelledError:
|
|
||||||
if not has_saved_assistant_message:
|
|
||||||
if accumulated_tool_calls:
|
|
||||||
assistant_response.tool_calls = accumulated_tool_calls
|
|
||||||
if assistant_response.content:
|
|
||||||
assistant_response.content = (
|
|
||||||
f"{assistant_response.content}\n\n[interrupted]"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assistant_response.content = "[interrupted]"
|
|
||||||
if not has_appended_streaming_message:
|
|
||||||
session.messages.append(assistant_response)
|
|
||||||
if tool_response_messages:
|
|
||||||
session.messages.extend(tool_response_messages)
|
|
||||||
try:
|
|
||||||
await upsert_chat_session(session)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to save interrupted session {session.session_id}: {e}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during stream: {e!s}", exc_info=True)
|
logger.error(f"Error during stream: {e!s}", exc_info=True)
|
||||||
|
|
||||||
@@ -517,19 +434,14 @@ async def stream_chat_completion(
|
|||||||
# Add assistant message if it has content or tool calls
|
# Add assistant message if it has content or tool calls
|
||||||
if accumulated_tool_calls:
|
if accumulated_tool_calls:
|
||||||
assistant_response.tool_calls = accumulated_tool_calls
|
assistant_response.tool_calls = accumulated_tool_calls
|
||||||
if not has_appended_streaming_message and (
|
if assistant_response.content or assistant_response.tool_calls:
|
||||||
assistant_response.content or assistant_response.tool_calls
|
|
||||||
):
|
|
||||||
messages_to_save.append(assistant_response)
|
messages_to_save.append(assistant_response)
|
||||||
|
|
||||||
# Add tool response messages after assistant message
|
# Add tool response messages after assistant message
|
||||||
messages_to_save.extend(tool_response_messages)
|
messages_to_save.extend(tool_response_messages)
|
||||||
|
|
||||||
if not has_saved_assistant_message:
|
session.messages.extend(messages_to_save)
|
||||||
if messages_to_save:
|
await upsert_chat_session(session)
|
||||||
session.messages.extend(messages_to_save)
|
|
||||||
if messages_to_save or has_appended_streaming_message:
|
|
||||||
await upsert_chat_session(session)
|
|
||||||
|
|
||||||
if not has_yielded_error:
|
if not has_yielded_error:
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
@@ -560,49 +472,38 @@ async def stream_chat_completion(
|
|||||||
return # Exit after retry to avoid double-saving in finally block
|
return # Exit after retry to avoid double-saving in finally block
|
||||||
|
|
||||||
# Normal completion path - save session and handle tool call continuation
|
# Normal completion path - save session and handle tool call continuation
|
||||||
# Only save if we haven't already saved when StreamFinish was received
|
logger.info(
|
||||||
if not has_saved_assistant_message:
|
f"Normal completion path: session={session.session_id}, "
|
||||||
|
f"current message_count={len(session.messages)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build the messages list in the correct order
|
||||||
|
messages_to_save: list[ChatMessage] = []
|
||||||
|
|
||||||
|
# Add assistant message with tool_calls if any
|
||||||
|
if accumulated_tool_calls:
|
||||||
|
assistant_response.tool_calls = accumulated_tool_calls
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Normal completion path: session={session.session_id}, "
|
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||||
f"current message_count={len(session.messages)}"
|
)
|
||||||
|
if assistant_response.content or assistant_response.tool_calls:
|
||||||
|
messages_to_save.append(assistant_response)
|
||||||
|
logger.info(
|
||||||
|
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build the messages list in the correct order
|
# Add tool response messages after assistant message
|
||||||
messages_to_save: list[ChatMessage] = []
|
messages_to_save.extend(tool_response_messages)
|
||||||
|
logger.info(
|
||||||
|
f"Saving {len(tool_response_messages)} tool response messages, "
|
||||||
|
f"total_to_save={len(messages_to_save)}"
|
||||||
|
)
|
||||||
|
|
||||||
# Add assistant message with tool_calls if any
|
session.messages.extend(messages_to_save)
|
||||||
if accumulated_tool_calls:
|
logger.info(
|
||||||
assistant_response.tool_calls = accumulated_tool_calls
|
f"Extended session messages, new message_count={len(session.messages)}"
|
||||||
logger.info(
|
)
|
||||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
await upsert_chat_session(session)
|
||||||
)
|
|
||||||
if not has_appended_streaming_message and (
|
|
||||||
assistant_response.content or assistant_response.tool_calls
|
|
||||||
):
|
|
||||||
messages_to_save.append(assistant_response)
|
|
||||||
logger.info(
|
|
||||||
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add tool response messages after assistant message
|
|
||||||
messages_to_save.extend(tool_response_messages)
|
|
||||||
logger.info(
|
|
||||||
f"Saving {len(tool_response_messages)} tool response messages, "
|
|
||||||
f"total_to_save={len(messages_to_save)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if messages_to_save:
|
|
||||||
session.messages.extend(messages_to_save)
|
|
||||||
logger.info(
|
|
||||||
f"Extended session messages, new message_count={len(session.messages)}"
|
|
||||||
)
|
|
||||||
if messages_to_save or has_appended_streaming_message:
|
|
||||||
await upsert_chat_session(session)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Assistant message already saved when StreamFinish was received, "
|
|
||||||
"skipping duplicate save"
|
|
||||||
)
|
|
||||||
|
|
||||||
# If we did a tool call, stream the chat completion again to get the next response
|
# If we did a tool call, stream the chat completion again to get the next response
|
||||||
if has_done_tool_call:
|
if has_done_tool_call:
|
||||||
@@ -644,12 +545,6 @@ def _is_retryable_error(error: Exception) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _is_region_blocked_error(error: Exception) -> bool:
|
|
||||||
if isinstance(error, PermissionDeniedError):
|
|
||||||
return "not available in your region" in str(error).lower()
|
|
||||||
return "not available in your region" in str(error).lower()
|
|
||||||
|
|
||||||
|
|
||||||
async def _stream_chat_chunks(
|
async def _stream_chat_chunks(
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
tools: list[ChatCompletionToolParam],
|
tools: list[ChatCompletionToolParam],
|
||||||
@@ -842,18 +737,7 @@ async def _stream_chat_chunks(
|
|||||||
f"Error in stream (not retrying): {e!s}",
|
f"Error in stream (not retrying): {e!s}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
error_code = None
|
error_response = StreamError(errorText=str(e))
|
||||||
error_text = str(e)
|
|
||||||
if _is_region_blocked_error(e):
|
|
||||||
error_code = "MODEL_NOT_AVAILABLE_REGION"
|
|
||||||
error_text = (
|
|
||||||
"This model is not available in your region. "
|
|
||||||
"Please connect via VPN and try again."
|
|
||||||
)
|
|
||||||
error_response = StreamError(
|
|
||||||
errorText=error_text,
|
|
||||||
code=error_code,
|
|
||||||
)
|
|
||||||
yield error_response
|
yield error_response
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,29 +1,28 @@
|
|||||||
"""Agent generator package - Creates agents from natural language."""
|
"""Agent generator package - Creates agents from natural language."""
|
||||||
|
|
||||||
from .core import (
|
from .core import (
|
||||||
apply_agent_patch,
|
AgentGeneratorNotConfiguredError,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
generate_agent,
|
generate_agent,
|
||||||
generate_agent_patch,
|
generate_agent_patch,
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
|
json_to_graph,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
)
|
)
|
||||||
from .fixer import apply_all_fixes
|
from .service import health_check as check_external_service_health
|
||||||
from .utils import get_blocks_info
|
from .service import is_external_service_configured
|
||||||
from .validator import validate_agent
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Core functions
|
# Core functions
|
||||||
"decompose_goal",
|
"decompose_goal",
|
||||||
"generate_agent",
|
"generate_agent",
|
||||||
"generate_agent_patch",
|
"generate_agent_patch",
|
||||||
"apply_agent_patch",
|
|
||||||
"save_agent_to_library",
|
"save_agent_to_library",
|
||||||
"get_agent_as_json",
|
"get_agent_as_json",
|
||||||
# Fixer
|
"json_to_graph",
|
||||||
"apply_all_fixes",
|
# Exceptions
|
||||||
# Validator
|
"AgentGeneratorNotConfiguredError",
|
||||||
"validate_agent",
|
# Service
|
||||||
# Utils
|
"is_external_service_configured",
|
||||||
"get_blocks_info",
|
"check_external_service_health",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
"""OpenRouter client configuration for agent generation."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
# Configuration - use OPEN_ROUTER_API_KEY for consistency with chat/config.py
|
|
||||||
OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY")
|
|
||||||
AGENT_GENERATOR_MODEL = os.getenv("AGENT_GENERATOR_MODEL", "anthropic/claude-opus-4.5")
|
|
||||||
|
|
||||||
# OpenRouter client (OpenAI-compatible API)
|
|
||||||
_client: AsyncOpenAI | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_client() -> AsyncOpenAI:
|
|
||||||
"""Get or create the OpenRouter client."""
|
|
||||||
global _client
|
|
||||||
if _client is None:
|
|
||||||
if not OPENROUTER_API_KEY:
|
|
||||||
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
|
||||||
_client = AsyncOpenAI(
|
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
api_key=OPENROUTER_API_KEY,
|
|
||||||
)
|
|
||||||
return _client
|
|
||||||
@@ -1,7 +1,5 @@
|
|||||||
"""Core agent generation functions."""
|
"""Core agent generation functions."""
|
||||||
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -9,13 +7,35 @@ from typing import Any
|
|||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.data.graph import Graph, Link, Node, create_graph
|
from backend.data.graph import Graph, Link, Node, create_graph
|
||||||
|
|
||||||
from .client import AGENT_GENERATOR_MODEL, get_client
|
from .service import (
|
||||||
from .prompts import DECOMPOSITION_PROMPT, GENERATION_PROMPT, PATCH_PROMPT
|
decompose_goal_external,
|
||||||
from .utils import get_block_summaries, parse_json_from_llm
|
generate_agent_external,
|
||||||
|
generate_agent_patch_external,
|
||||||
|
is_external_service_configured,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentGeneratorNotConfiguredError(Exception):
|
||||||
|
"""Raised when the external Agent Generator service is not configured."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _check_service_configured() -> None:
|
||||||
|
"""Check if the external Agent Generator service is configured.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the service is not configured.
|
||||||
|
"""
|
||||||
|
if not is_external_service_configured():
|
||||||
|
raise AgentGeneratorNotConfiguredError(
|
||||||
|
"Agent Generator service is not configured. "
|
||||||
|
"Set AGENTGENERATOR_HOST environment variable to enable agent generation."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
||||||
"""Break down a goal into steps or return clarifying questions.
|
"""Break down a goal into steps or return clarifying questions.
|
||||||
|
|
||||||
@@ -28,40 +48,13 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
|
|||||||
- {"type": "clarifying_questions", "questions": [...]}
|
- {"type": "clarifying_questions", "questions": [...]}
|
||||||
- {"type": "instructions", "steps": [...]}
|
- {"type": "instructions", "steps": [...]}
|
||||||
Or None on error
|
Or None on error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
client = get_client()
|
_check_service_configured()
|
||||||
prompt = DECOMPOSITION_PROMPT.format(block_summaries=get_block_summaries())
|
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||||
|
return await decompose_goal_external(description, context)
|
||||||
full_description = description
|
|
||||||
if context:
|
|
||||||
full_description = f"{description}\n\nAdditional context:\n{context}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.chat.completions.create(
|
|
||||||
model=AGENT_GENERATOR_MODEL,
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": prompt},
|
|
||||||
{"role": "user", "content": full_description},
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
if content is None:
|
|
||||||
logger.error("LLM returned empty content for decomposition")
|
|
||||||
return None
|
|
||||||
|
|
||||||
result = parse_json_from_llm(content)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
logger.error(f"Failed to parse decomposition response: {content[:200]}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error decomposing goal: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
@@ -72,31 +65,14 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict or None on error
|
Agent JSON dict or None on error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
client = get_client()
|
_check_service_configured()
|
||||||
prompt = GENERATION_PROMPT.format(block_summaries=get_block_summaries())
|
logger.info("Calling external Agent Generator service for generate_agent")
|
||||||
|
result = await generate_agent_external(instructions)
|
||||||
try:
|
if result:
|
||||||
response = await client.chat.completions.create(
|
|
||||||
model=AGENT_GENERATOR_MODEL,
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": prompt},
|
|
||||||
{"role": "user", "content": json.dumps(instructions, indent=2)},
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
if content is None:
|
|
||||||
logger.error("LLM returned empty content for agent generation")
|
|
||||||
return None
|
|
||||||
|
|
||||||
result = parse_json_from_llm(content)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
logger.error(f"Failed to parse agent JSON: {content[:200]}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Ensure required fields
|
# Ensure required fields
|
||||||
if "id" not in result:
|
if "id" not in result:
|
||||||
result["id"] = str(uuid.uuid4())
|
result["id"] = str(uuid.uuid4())
|
||||||
@@ -104,12 +80,7 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
|||||||
result["version"] = 1
|
result["version"] = 1
|
||||||
if "is_active" not in result:
|
if "is_active" not in result:
|
||||||
result["is_active"] = True
|
result["is_active"] = True
|
||||||
|
return result
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating agent: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||||
@@ -284,108 +255,23 @@ async def get_agent_as_json(
|
|||||||
async def generate_agent_patch(
|
async def generate_agent_patch(
|
||||||
update_request: str, current_agent: dict[str, Any]
|
update_request: str, current_agent: dict[str, Any]
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Generate a patch to update an existing agent.
|
"""Update an existing agent using natural language.
|
||||||
|
|
||||||
|
The external Agent Generator service handles:
|
||||||
|
- Generating the patch
|
||||||
|
- Applying the patch
|
||||||
|
- Fixing and validating the result
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Patch dict or clarifying questions, or None on error
|
Updated agent JSON, clarifying questions dict, or None on error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
client = get_client()
|
_check_service_configured()
|
||||||
prompt = PATCH_PROMPT.format(
|
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||||
current_agent=json.dumps(current_agent, indent=2),
|
return await generate_agent_patch_external(update_request, current_agent)
|
||||||
block_summaries=get_block_summaries(),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.chat.completions.create(
|
|
||||||
model=AGENT_GENERATOR_MODEL,
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": prompt},
|
|
||||||
{"role": "user", "content": update_request},
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
if content is None:
|
|
||||||
logger.error("LLM returned empty content for patch generation")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return parse_json_from_llm(content)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating patch: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def apply_agent_patch(
|
|
||||||
current_agent: dict[str, Any], patch: dict[str, Any]
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Apply a patch to an existing agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
current_agent: Current agent JSON
|
|
||||||
patch: Patch dict with operations
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated agent JSON
|
|
||||||
"""
|
|
||||||
agent = copy.deepcopy(current_agent)
|
|
||||||
patches = patch.get("patches", [])
|
|
||||||
|
|
||||||
for p in patches:
|
|
||||||
patch_type = p.get("type")
|
|
||||||
|
|
||||||
if patch_type == "modify":
|
|
||||||
node_id = p.get("node_id")
|
|
||||||
changes = p.get("changes", {})
|
|
||||||
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
if node["id"] == node_id:
|
|
||||||
_deep_update(node, changes)
|
|
||||||
logger.debug(f"Modified node {node_id}")
|
|
||||||
break
|
|
||||||
|
|
||||||
elif patch_type == "add":
|
|
||||||
new_nodes = p.get("new_nodes", [])
|
|
||||||
new_links = p.get("new_links", [])
|
|
||||||
|
|
||||||
agent["nodes"] = agent.get("nodes", []) + new_nodes
|
|
||||||
agent["links"] = agent.get("links", []) + new_links
|
|
||||||
logger.debug(f"Added {len(new_nodes)} nodes, {len(new_links)} links")
|
|
||||||
|
|
||||||
elif patch_type == "remove":
|
|
||||||
node_ids_to_remove = set(p.get("node_ids", []))
|
|
||||||
link_ids_to_remove = set(p.get("link_ids", []))
|
|
||||||
|
|
||||||
# Remove nodes
|
|
||||||
agent["nodes"] = [
|
|
||||||
n for n in agent.get("nodes", []) if n["id"] not in node_ids_to_remove
|
|
||||||
]
|
|
||||||
|
|
||||||
# Remove links (both explicit and those referencing removed nodes)
|
|
||||||
agent["links"] = [
|
|
||||||
link
|
|
||||||
for link in agent.get("links", [])
|
|
||||||
if link["id"] not in link_ids_to_remove
|
|
||||||
and link["source_id"] not in node_ids_to_remove
|
|
||||||
and link["sink_id"] not in node_ids_to_remove
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Removed {len(node_ids_to_remove)} nodes, {len(link_ids_to_remove)} links"
|
|
||||||
)
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def _deep_update(target: dict, source: dict) -> None:
|
|
||||||
"""Recursively update a dict with another dict."""
|
|
||||||
for key, value in source.items():
|
|
||||||
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
|
|
||||||
_deep_update(target[key], value)
|
|
||||||
else:
|
|
||||||
target[key] = value
|
|
||||||
|
|||||||
@@ -1,606 +0,0 @@
|
|||||||
"""Agent fixer - Fixes common LLM generation errors."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from .utils import (
|
|
||||||
ADDTODICTIONARY_BLOCK_ID,
|
|
||||||
ADDTOLIST_BLOCK_ID,
|
|
||||||
CODE_EXECUTION_BLOCK_ID,
|
|
||||||
CONDITION_BLOCK_ID,
|
|
||||||
CREATEDICT_BLOCK_ID,
|
|
||||||
CREATELIST_BLOCK_ID,
|
|
||||||
DATA_SAMPLING_BLOCK_ID,
|
|
||||||
DOUBLE_CURLY_BRACES_BLOCK_IDS,
|
|
||||||
GET_CURRENT_DATE_BLOCK_ID,
|
|
||||||
STORE_VALUE_BLOCK_ID,
|
|
||||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
|
||||||
get_blocks_info,
|
|
||||||
is_valid_uuid,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def fix_agent_ids(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix invalid UUIDs in agent and link IDs."""
|
|
||||||
# Fix agent ID
|
|
||||||
if not is_valid_uuid(agent.get("id", "")):
|
|
||||||
agent["id"] = str(uuid.uuid4())
|
|
||||||
logger.debug(f"Fixed agent ID: {agent['id']}")
|
|
||||||
|
|
||||||
# Fix node IDs
|
|
||||||
id_mapping = {} # Old ID -> New ID
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
if not is_valid_uuid(node.get("id", "")):
|
|
||||||
old_id = node.get("id", "")
|
|
||||||
new_id = str(uuid.uuid4())
|
|
||||||
id_mapping[old_id] = new_id
|
|
||||||
node["id"] = new_id
|
|
||||||
logger.debug(f"Fixed node ID: {old_id} -> {new_id}")
|
|
||||||
|
|
||||||
# Fix link IDs and update references
|
|
||||||
for link in agent.get("links", []):
|
|
||||||
if not is_valid_uuid(link.get("id", "")):
|
|
||||||
link["id"] = str(uuid.uuid4())
|
|
||||||
logger.debug(f"Fixed link ID: {link['id']}")
|
|
||||||
|
|
||||||
# Update source/sink IDs if they were remapped
|
|
||||||
if link.get("source_id") in id_mapping:
|
|
||||||
link["source_id"] = id_mapping[link["source_id"]]
|
|
||||||
if link.get("sink_id") in id_mapping:
|
|
||||||
link["sink_id"] = id_mapping[link["sink_id"]]
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_double_curly_braces(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix single curly braces to double in template blocks."""
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
if node.get("block_id") not in DOUBLE_CURLY_BRACES_BLOCK_IDS:
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_data = node.get("input_default", {})
|
|
||||||
for key in ("prompt", "format"):
|
|
||||||
if key in input_data and isinstance(input_data[key], str):
|
|
||||||
original = input_data[key]
|
|
||||||
# Fix simple variable references: {var} -> {{var}}
|
|
||||||
fixed = re.sub(
|
|
||||||
r"(?<!\{)\{([a-zA-Z_][a-zA-Z0-9_]*)\}(?!\})",
|
|
||||||
r"{{\1}}",
|
|
||||||
original,
|
|
||||||
)
|
|
||||||
if fixed != original:
|
|
||||||
input_data[key] = fixed
|
|
||||||
logger.debug(f"Fixed curly braces in {key}")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_storevalue_before_condition(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Add StoreValueBlock before ConditionBlock if needed for value2."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
|
|
||||||
# Find all ConditionBlock nodes
|
|
||||||
condition_node_ids = {
|
|
||||||
node["id"] for node in nodes if node.get("block_id") == CONDITION_BLOCK_ID
|
|
||||||
}
|
|
||||||
|
|
||||||
if not condition_node_ids:
|
|
||||||
return agent
|
|
||||||
|
|
||||||
new_nodes = []
|
|
||||||
new_links = []
|
|
||||||
processed_conditions = set()
|
|
||||||
|
|
||||||
for link in links:
|
|
||||||
sink_id = link.get("sink_id")
|
|
||||||
sink_name = link.get("sink_name")
|
|
||||||
|
|
||||||
# Check if this link goes to a ConditionBlock's value2
|
|
||||||
if sink_id in condition_node_ids and sink_name == "value2":
|
|
||||||
source_node = next(
|
|
||||||
(n for n in nodes if n["id"] == link.get("source_id")), None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Skip if source is already a StoreValueBlock
|
|
||||||
if source_node and source_node.get("block_id") == STORE_VALUE_BLOCK_ID:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Skip if we already processed this condition
|
|
||||||
if sink_id in processed_conditions:
|
|
||||||
continue
|
|
||||||
|
|
||||||
processed_conditions.add(sink_id)
|
|
||||||
|
|
||||||
# Create StoreValueBlock
|
|
||||||
store_node_id = str(uuid.uuid4())
|
|
||||||
store_node = {
|
|
||||||
"id": store_node_id,
|
|
||||||
"block_id": STORE_VALUE_BLOCK_ID,
|
|
||||||
"input_default": {"data": None},
|
|
||||||
"metadata": {"position": {"x": 0, "y": -100}},
|
|
||||||
}
|
|
||||||
new_nodes.append(store_node)
|
|
||||||
|
|
||||||
# Create link: original source -> StoreValueBlock
|
|
||||||
new_links.append(
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": link["source_id"],
|
|
||||||
"source_name": link["source_name"],
|
|
||||||
"sink_id": store_node_id,
|
|
||||||
"sink_name": "input",
|
|
||||||
"is_static": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update original link: StoreValueBlock -> ConditionBlock
|
|
||||||
link["source_id"] = store_node_id
|
|
||||||
link["source_name"] = "output"
|
|
||||||
|
|
||||||
logger.debug(f"Added StoreValueBlock before ConditionBlock {sink_id}")
|
|
||||||
|
|
||||||
if new_nodes:
|
|
||||||
agent["nodes"] = nodes + new_nodes
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_addtolist_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix AddToList blocks by adding prerequisite empty AddToList block.
|
|
||||||
|
|
||||||
When an AddToList block is found:
|
|
||||||
1. Checks if there's a CreateListBlock before it
|
|
||||||
2. Removes CreateListBlock if linked directly to AddToList
|
|
||||||
3. Adds an empty AddToList block before the original
|
|
||||||
4. Ensures the original has a self-referencing link
|
|
||||||
"""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
new_nodes = []
|
|
||||||
original_addtolist_ids = set()
|
|
||||||
nodes_to_remove = set()
|
|
||||||
links_to_remove = []
|
|
||||||
|
|
||||||
# First pass: identify CreateListBlock nodes to remove
|
|
||||||
for link in links:
|
|
||||||
source_node = next(
|
|
||||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
|
||||||
)
|
|
||||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
|
||||||
|
|
||||||
if (
|
|
||||||
source_node
|
|
||||||
and sink_node
|
|
||||||
and source_node.get("block_id") == CREATELIST_BLOCK_ID
|
|
||||||
and sink_node.get("block_id") == ADDTOLIST_BLOCK_ID
|
|
||||||
):
|
|
||||||
nodes_to_remove.add(source_node.get("id"))
|
|
||||||
links_to_remove.append(link)
|
|
||||||
logger.debug(f"Removing CreateListBlock {source_node.get('id')}")
|
|
||||||
|
|
||||||
# Second pass: process AddToList blocks
|
|
||||||
filtered_nodes = []
|
|
||||||
for node in nodes:
|
|
||||||
if node.get("id") in nodes_to_remove:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if node.get("block_id") == ADDTOLIST_BLOCK_ID:
|
|
||||||
original_addtolist_ids.add(node.get("id"))
|
|
||||||
node_id = node.get("id")
|
|
||||||
pos = node.get("metadata", {}).get("position", {"x": 0, "y": 0})
|
|
||||||
|
|
||||||
# Check if already has prerequisite
|
|
||||||
has_prereq = any(
|
|
||||||
link.get("sink_id") == node_id
|
|
||||||
and link.get("sink_name") == "list"
|
|
||||||
and link.get("source_name") == "updated_list"
|
|
||||||
for link in links
|
|
||||||
)
|
|
||||||
|
|
||||||
if not has_prereq:
|
|
||||||
# Remove links to "list" input (except self-reference)
|
|
||||||
for link in links:
|
|
||||||
if (
|
|
||||||
link.get("sink_id") == node_id
|
|
||||||
and link.get("sink_name") == "list"
|
|
||||||
and link.get("source_id") != node_id
|
|
||||||
and link not in links_to_remove
|
|
||||||
):
|
|
||||||
links_to_remove.append(link)
|
|
||||||
|
|
||||||
# Create prerequisite AddToList block
|
|
||||||
prereq_id = str(uuid.uuid4())
|
|
||||||
prereq_node = {
|
|
||||||
"id": prereq_id,
|
|
||||||
"block_id": ADDTOLIST_BLOCK_ID,
|
|
||||||
"input_default": {"list": [], "entry": None, "entries": []},
|
|
||||||
"metadata": {
|
|
||||||
"position": {"x": pos.get("x", 0) - 800, "y": pos.get("y", 0)}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
new_nodes.append(prereq_node)
|
|
||||||
|
|
||||||
# Link prerequisite to original
|
|
||||||
links.append(
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": prereq_id,
|
|
||||||
"source_name": "updated_list",
|
|
||||||
"sink_id": node_id,
|
|
||||||
"sink_name": "list",
|
|
||||||
"is_static": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
logger.debug(f"Added prerequisite AddToList block for {node_id}")
|
|
||||||
|
|
||||||
filtered_nodes.append(node)
|
|
||||||
|
|
||||||
# Remove marked links
|
|
||||||
filtered_links = [link for link in links if link not in links_to_remove]
|
|
||||||
|
|
||||||
# Add self-referencing links for original AddToList blocks
|
|
||||||
for node in filtered_nodes + new_nodes:
|
|
||||||
if (
|
|
||||||
node.get("block_id") == ADDTOLIST_BLOCK_ID
|
|
||||||
and node.get("id") in original_addtolist_ids
|
|
||||||
):
|
|
||||||
node_id = node.get("id")
|
|
||||||
has_self_ref = any(
|
|
||||||
link["source_id"] == node_id
|
|
||||||
and link["sink_id"] == node_id
|
|
||||||
and link["source_name"] == "updated_list"
|
|
||||||
and link["sink_name"] == "list"
|
|
||||||
for link in filtered_links
|
|
||||||
)
|
|
||||||
if not has_self_ref:
|
|
||||||
filtered_links.append(
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": node_id,
|
|
||||||
"source_name": "updated_list",
|
|
||||||
"sink_id": node_id,
|
|
||||||
"sink_name": "list",
|
|
||||||
"is_static": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
logger.debug(f"Added self-reference for AddToList {node_id}")
|
|
||||||
|
|
||||||
agent["nodes"] = filtered_nodes + new_nodes
|
|
||||||
agent["links"] = filtered_links
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_addtodictionary_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix AddToDictionary blocks by removing empty CreateDictionary nodes."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
nodes_to_remove = set()
|
|
||||||
links_to_remove = []
|
|
||||||
|
|
||||||
for link in links:
|
|
||||||
source_node = next(
|
|
||||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
|
||||||
)
|
|
||||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
|
||||||
|
|
||||||
if (
|
|
||||||
source_node
|
|
||||||
and sink_node
|
|
||||||
and source_node.get("block_id") == CREATEDICT_BLOCK_ID
|
|
||||||
and sink_node.get("block_id") == ADDTODICTIONARY_BLOCK_ID
|
|
||||||
):
|
|
||||||
nodes_to_remove.add(source_node.get("id"))
|
|
||||||
links_to_remove.append(link)
|
|
||||||
logger.debug(f"Removing CreateDictionary {source_node.get('id')}")
|
|
||||||
|
|
||||||
agent["nodes"] = [n for n in nodes if n.get("id") not in nodes_to_remove]
|
|
||||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_code_execution_output(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix CodeExecutionBlock output: change 'response' to 'stdout_logs'."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
|
|
||||||
for link in links:
|
|
||||||
source_node = next(
|
|
||||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
source_node
|
|
||||||
and source_node.get("block_id") == CODE_EXECUTION_BLOCK_ID
|
|
||||||
and link.get("source_name") == "response"
|
|
||||||
):
|
|
||||||
link["source_name"] = "stdout_logs"
|
|
||||||
logger.debug("Fixed CodeExecutionBlock output: response -> stdout_logs")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_data_sampling_sample_size(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix DataSamplingBlock by setting sample_size to 1 as default."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
links_to_remove = []
|
|
||||||
|
|
||||||
for node in nodes:
|
|
||||||
if node.get("block_id") == DATA_SAMPLING_BLOCK_ID:
|
|
||||||
node_id = node.get("id")
|
|
||||||
input_default = node.get("input_default", {})
|
|
||||||
|
|
||||||
# Remove links to sample_size
|
|
||||||
for link in links:
|
|
||||||
if (
|
|
||||||
link.get("sink_id") == node_id
|
|
||||||
and link.get("sink_name") == "sample_size"
|
|
||||||
):
|
|
||||||
links_to_remove.append(link)
|
|
||||||
|
|
||||||
# Set default
|
|
||||||
input_default["sample_size"] = 1
|
|
||||||
node["input_default"] = input_default
|
|
||||||
logger.debug(f"Fixed DataSamplingBlock {node_id} sample_size to 1")
|
|
||||||
|
|
||||||
if links_to_remove:
|
|
||||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_node_x_coordinates(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix node x-coordinates to ensure 800+ unit spacing between linked nodes."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
node_lookup = {n.get("id"): n for n in nodes}
|
|
||||||
|
|
||||||
for link in links:
|
|
||||||
source_id = link.get("source_id")
|
|
||||||
sink_id = link.get("sink_id")
|
|
||||||
|
|
||||||
source_node = node_lookup.get(source_id)
|
|
||||||
sink_node = node_lookup.get(sink_id)
|
|
||||||
|
|
||||||
if not source_node or not sink_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_pos = source_node.get("metadata", {}).get("position", {})
|
|
||||||
sink_pos = sink_node.get("metadata", {}).get("position", {})
|
|
||||||
|
|
||||||
source_x = source_pos.get("x", 0)
|
|
||||||
sink_x = sink_pos.get("x", 0)
|
|
||||||
|
|
||||||
if abs(sink_x - source_x) < 800:
|
|
||||||
new_x = source_x + 800
|
|
||||||
if "metadata" not in sink_node:
|
|
||||||
sink_node["metadata"] = {}
|
|
||||||
if "position" not in sink_node["metadata"]:
|
|
||||||
sink_node["metadata"]["position"] = {}
|
|
||||||
sink_node["metadata"]["position"]["x"] = new_x
|
|
||||||
logger.debug(f"Fixed node {sink_id} x: {sink_x} -> {new_x}")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_getcurrentdate_offset(agent: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix GetCurrentDateBlock offset to ensure it's positive."""
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
if node.get("block_id") == GET_CURRENT_DATE_BLOCK_ID:
|
|
||||||
input_default = node.get("input_default", {})
|
|
||||||
if "offset" in input_default:
|
|
||||||
offset = input_default["offset"]
|
|
||||||
if isinstance(offset, (int, float)) and offset < 0:
|
|
||||||
input_default["offset"] = abs(offset)
|
|
||||||
logger.debug(f"Fixed offset: {offset} -> {abs(offset)}")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_ai_model_parameter(
|
|
||||||
agent: dict[str, Any],
|
|
||||||
blocks_info: list[dict[str, Any]],
|
|
||||||
default_model: str = "gpt-4o",
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Add default model parameter to AI blocks if missing."""
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
block_id = node.get("block_id")
|
|
||||||
block = block_map.get(block_id)
|
|
||||||
|
|
||||||
if not block:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if block has AI category
|
|
||||||
categories = block.get("categories", [])
|
|
||||||
is_ai_block = any(
|
|
||||||
cat.get("category") == "AI" for cat in categories if isinstance(cat, dict)
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_ai_block:
|
|
||||||
input_default = node.get("input_default", {})
|
|
||||||
if "model" not in input_default:
|
|
||||||
input_default["model"] = default_model
|
|
||||||
node["input_default"] = input_default
|
|
||||||
logger.debug(
|
|
||||||
f"Added model '{default_model}' to AI block {node.get('id')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_link_static_properties(
|
|
||||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Fix is_static property based on source block's staticOutput."""
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
|
||||||
|
|
||||||
for link in agent.get("links", []):
|
|
||||||
source_node = node_lookup.get(link.get("source_id"))
|
|
||||||
if not source_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_block = block_map.get(source_node.get("block_id"))
|
|
||||||
if not source_block:
|
|
||||||
continue
|
|
||||||
|
|
||||||
static_output = source_block.get("staticOutput", False)
|
|
||||||
if link.get("is_static") != static_output:
|
|
||||||
link["is_static"] = static_output
|
|
||||||
logger.debug(f"Fixed link {link.get('id')} is_static to {static_output}")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def fix_data_type_mismatch(
|
|
||||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Fix data type mismatches by inserting UniversalTypeConverterBlock."""
|
|
||||||
nodes = agent.get("nodes", [])
|
|
||||||
links = agent.get("links", [])
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
node_lookup = {n.get("id"): n for n in nodes}
|
|
||||||
|
|
||||||
def get_property_type(schema: dict, name: str) -> str | None:
|
|
||||||
if "_#_" in name:
|
|
||||||
parent, child = name.split("_#_", 1)
|
|
||||||
parent_schema = schema.get(parent, {})
|
|
||||||
if "properties" in parent_schema:
|
|
||||||
return parent_schema["properties"].get(child, {}).get("type")
|
|
||||||
return None
|
|
||||||
return schema.get(name, {}).get("type")
|
|
||||||
|
|
||||||
def are_types_compatible(src: str, sink: str) -> bool:
|
|
||||||
if {src, sink} <= {"integer", "number"}:
|
|
||||||
return True
|
|
||||||
return src == sink
|
|
||||||
|
|
||||||
type_mapping = {
|
|
||||||
"string": "string",
|
|
||||||
"text": "string",
|
|
||||||
"integer": "number",
|
|
||||||
"number": "number",
|
|
||||||
"float": "number",
|
|
||||||
"boolean": "boolean",
|
|
||||||
"bool": "boolean",
|
|
||||||
"array": "list",
|
|
||||||
"list": "list",
|
|
||||||
"object": "dictionary",
|
|
||||||
"dict": "dictionary",
|
|
||||||
"dictionary": "dictionary",
|
|
||||||
}
|
|
||||||
|
|
||||||
new_links = []
|
|
||||||
nodes_to_add = []
|
|
||||||
|
|
||||||
for link in links:
|
|
||||||
source_node = node_lookup.get(link.get("source_id"))
|
|
||||||
sink_node = node_lookup.get(link.get("sink_id"))
|
|
||||||
|
|
||||||
if not source_node or not sink_node:
|
|
||||||
new_links.append(link)
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_block = block_map.get(source_node.get("block_id"))
|
|
||||||
sink_block = block_map.get(sink_node.get("block_id"))
|
|
||||||
|
|
||||||
if not source_block or not sink_block:
|
|
||||||
new_links.append(link)
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
|
||||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
|
||||||
|
|
||||||
source_type = get_property_type(source_outputs, link.get("source_name", ""))
|
|
||||||
sink_type = get_property_type(sink_inputs, link.get("sink_name", ""))
|
|
||||||
|
|
||||||
if (
|
|
||||||
source_type
|
|
||||||
and sink_type
|
|
||||||
and not are_types_compatible(source_type, sink_type)
|
|
||||||
):
|
|
||||||
# Insert type converter
|
|
||||||
converter_id = str(uuid.uuid4())
|
|
||||||
target_type = type_mapping.get(sink_type, sink_type)
|
|
||||||
|
|
||||||
converter_node = {
|
|
||||||
"id": converter_id,
|
|
||||||
"block_id": UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
|
||||||
"input_default": {"type": target_type},
|
|
||||||
"metadata": {"position": {"x": 0, "y": 100}},
|
|
||||||
}
|
|
||||||
nodes_to_add.append(converter_node)
|
|
||||||
|
|
||||||
# source -> converter
|
|
||||||
new_links.append(
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": link["source_id"],
|
|
||||||
"source_name": link["source_name"],
|
|
||||||
"sink_id": converter_id,
|
|
||||||
"sink_name": "value",
|
|
||||||
"is_static": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# converter -> sink
|
|
||||||
new_links.append(
|
|
||||||
{
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"source_id": converter_id,
|
|
||||||
"source_name": "value",
|
|
||||||
"sink_id": link["sink_id"],
|
|
||||||
"sink_name": link["sink_name"],
|
|
||||||
"is_static": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Inserted type converter: {source_type} -> {target_type}")
|
|
||||||
else:
|
|
||||||
new_links.append(link)
|
|
||||||
|
|
||||||
if nodes_to_add:
|
|
||||||
agent["nodes"] = nodes + nodes_to_add
|
|
||||||
agent["links"] = new_links
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def apply_all_fixes(
|
|
||||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Apply all fixes to an agent JSON.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent: Agent JSON dict
|
|
||||||
blocks_info: Optional list of block info dicts for advanced fixes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Fixed agent JSON
|
|
||||||
"""
|
|
||||||
# Basic fixes (no block info needed)
|
|
||||||
agent = fix_agent_ids(agent)
|
|
||||||
agent = fix_double_curly_braces(agent)
|
|
||||||
agent = fix_storevalue_before_condition(agent)
|
|
||||||
agent = fix_addtolist_blocks(agent)
|
|
||||||
agent = fix_addtodictionary_blocks(agent)
|
|
||||||
agent = fix_code_execution_output(agent)
|
|
||||||
agent = fix_data_sampling_sample_size(agent)
|
|
||||||
agent = fix_node_x_coordinates(agent)
|
|
||||||
agent = fix_getcurrentdate_offset(agent)
|
|
||||||
|
|
||||||
# Advanced fixes (require block info)
|
|
||||||
if blocks_info is None:
|
|
||||||
blocks_info = get_blocks_info()
|
|
||||||
|
|
||||||
agent = fix_ai_model_parameter(agent, blocks_info)
|
|
||||||
agent = fix_link_static_properties(agent, blocks_info)
|
|
||||||
agent = fix_data_type_mismatch(agent, blocks_info)
|
|
||||||
|
|
||||||
return agent
|
|
||||||
@@ -1,225 +0,0 @@
|
|||||||
"""Prompt templates for agent generation."""
|
|
||||||
|
|
||||||
DECOMPOSITION_PROMPT = """
|
|
||||||
You are an expert AutoGPT Workflow Decomposer. Your task is to analyze a user's high-level goal and break it down into a clear, step-by-step plan using the available blocks.
|
|
||||||
|
|
||||||
Each step should represent a distinct, automatable action suitable for execution by an AI automation system.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
FIRST: Analyze the user's goal and determine:
|
|
||||||
1) Design-time configuration (fixed settings that won't change per run)
|
|
||||||
2) Runtime inputs (values the agent's end-user will provide each time it runs)
|
|
||||||
|
|
||||||
For anything that can vary per run (email addresses, names, dates, search terms, etc.):
|
|
||||||
- DO NOT ask for the actual value
|
|
||||||
- Instead, define it as an Agent Input with a clear name, type, and description
|
|
||||||
|
|
||||||
Only ask clarifying questions about design-time config that affects how you build the workflow:
|
|
||||||
- Which external service to use (e.g., "Gmail vs Outlook", "Notion vs Google Docs")
|
|
||||||
- Required formats or structures (e.g., "CSV, JSON, or PDF output?")
|
|
||||||
- Business rules that must be hard-coded
|
|
||||||
|
|
||||||
IMPORTANT CLARIFICATIONS POLICY:
|
|
||||||
- Ask no more than five essential questions
|
|
||||||
- Do not ask for concrete values that can be provided at runtime as Agent Inputs
|
|
||||||
- Do not ask for API keys or credentials; the platform handles those directly
|
|
||||||
- If there is enough information to infer reasonable defaults, prefer to propose defaults
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
GUIDELINES:
|
|
||||||
1. List each step as a numbered item
|
|
||||||
2. Describe the action clearly and specify inputs/outputs
|
|
||||||
3. Ensure steps are in logical, sequential order
|
|
||||||
4. Mention block names naturally (e.g., "Use GetWeatherByLocationBlock to...")
|
|
||||||
5. Help the user reach their goal efficiently
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
RULES:
|
|
||||||
1. OUTPUT FORMAT: Only output either clarifying questions OR step-by-step instructions, not both
|
|
||||||
2. USE ONLY THE BLOCKS PROVIDED
|
|
||||||
3. ALL required_input fields must be provided
|
|
||||||
4. Data types of linked properties must match
|
|
||||||
5. Write expert-level prompts for AI-related blocks
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
CRITICAL BLOCK RESTRICTIONS:
|
|
||||||
1. AddToListBlock: Outputs updated list EVERY addition, not after all additions
|
|
||||||
2. SendEmailBlock: Draft the email for user review; set SMTP config based on email type
|
|
||||||
3. ConditionBlock: value2 is reference, value1 is contrast
|
|
||||||
4. CodeExecutionBlock: DO NOT USE - use AI blocks instead
|
|
||||||
5. ReadCsvBlock: Only use the 'rows' output, not 'row'
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
OUTPUT FORMAT:
|
|
||||||
|
|
||||||
If more information is needed:
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"type": "clarifying_questions",
|
|
||||||
"questions": [
|
|
||||||
{{
|
|
||||||
"question": "Which email provider should be used? (Gmail, Outlook, custom SMTP)",
|
|
||||||
"keyword": "email_provider",
|
|
||||||
"example": "Gmail"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
|
|
||||||
If ready to proceed:
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"type": "instructions",
|
|
||||||
"steps": [
|
|
||||||
{{
|
|
||||||
"step_number": 1,
|
|
||||||
"block_name": "AgentShortTextInputBlock",
|
|
||||||
"description": "Get the URL of the content to analyze.",
|
|
||||||
"inputs": [{{"name": "name", "value": "URL"}}],
|
|
||||||
"outputs": [{{"name": "result", "description": "The URL entered by user"}}]
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
AVAILABLE BLOCKS:
|
|
||||||
{block_summaries}
|
|
||||||
"""
|
|
||||||
|
|
||||||
GENERATION_PROMPT = """
|
|
||||||
You are an expert AI workflow builder. Generate a valid agent JSON from the given instructions.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
NODES:
|
|
||||||
Each node must include:
|
|
||||||
- `id`: Unique UUID v4 (e.g. `a8f5b1e2-c3d4-4e5f-8a9b-0c1d2e3f4a5b`)
|
|
||||||
- `block_id`: The block identifier (must match an Allowed Block)
|
|
||||||
- `input_default`: Dict of inputs (can be empty if no static inputs needed)
|
|
||||||
- `metadata`: Must contain:
|
|
||||||
- `position`: {{"x": number, "y": number}} - adjacent nodes should differ by 800+ in X
|
|
||||||
- `customized_name`: Clear name describing this block's purpose in the workflow
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
LINKS:
|
|
||||||
Each link connects a source node's output to a sink node's input:
|
|
||||||
- `id`: MUST be UUID v4 (NOT "link-1", "link-2", etc.)
|
|
||||||
- `source_id`: ID of the source node
|
|
||||||
- `source_name`: Output field name from the source block
|
|
||||||
- `sink_id`: ID of the sink node
|
|
||||||
- `sink_name`: Input field name on the sink block
|
|
||||||
- `is_static`: true only if source block has static_output: true
|
|
||||||
|
|
||||||
CRITICAL: All IDs must be valid UUID v4 format!
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
AGENT (GRAPH):
|
|
||||||
Wrap nodes and links in:
|
|
||||||
- `id`: UUID of the agent
|
|
||||||
- `name`: Short, generic name (avoid specific company names, URLs)
|
|
||||||
- `description`: Short, generic description
|
|
||||||
- `nodes`: List of all nodes
|
|
||||||
- `links`: List of all links
|
|
||||||
- `version`: 1
|
|
||||||
- `is_active`: true
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
TIPS:
|
|
||||||
- All required_input fields must be provided via input_default or a valid link
|
|
||||||
- Ensure consistent source_id and sink_id references
|
|
||||||
- Avoid dangling links
|
|
||||||
- Input/output pins must match block schemas
|
|
||||||
- Do not invent unknown block_ids
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
ALLOWED BLOCKS:
|
|
||||||
{block_summaries}
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
Generate the complete agent JSON. Output ONLY valid JSON, no explanation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATCH_PROMPT = """
|
|
||||||
You are an expert at modifying AutoGPT agent workflows. Given the current agent and a modification request, generate a JSON patch to update the agent.
|
|
||||||
|
|
||||||
CURRENT AGENT:
|
|
||||||
{current_agent}
|
|
||||||
|
|
||||||
AVAILABLE BLOCKS:
|
|
||||||
{block_summaries}
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
PATCH FORMAT:
|
|
||||||
Return a JSON object with the following structure:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"type": "patch",
|
|
||||||
"intent": "Brief description of what the patch does",
|
|
||||||
"patches": [
|
|
||||||
{{
|
|
||||||
"type": "modify",
|
|
||||||
"node_id": "uuid-of-node-to-modify",
|
|
||||||
"changes": {{
|
|
||||||
"input_default": {{"field": "new_value"}},
|
|
||||||
"metadata": {{"customized_name": "New Name"}}
|
|
||||||
}}
|
|
||||||
}},
|
|
||||||
{{
|
|
||||||
"type": "add",
|
|
||||||
"new_nodes": [
|
|
||||||
{{
|
|
||||||
"id": "new-uuid",
|
|
||||||
"block_id": "block-uuid",
|
|
||||||
"input_default": {{}},
|
|
||||||
"metadata": {{"position": {{"x": 0, "y": 0}}, "customized_name": "Name"}}
|
|
||||||
}}
|
|
||||||
],
|
|
||||||
"new_links": [
|
|
||||||
{{
|
|
||||||
"id": "link-uuid",
|
|
||||||
"source_id": "source-node-id",
|
|
||||||
"source_name": "output_field",
|
|
||||||
"sink_id": "sink-node-id",
|
|
||||||
"sink_name": "input_field"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}},
|
|
||||||
{{
|
|
||||||
"type": "remove",
|
|
||||||
"node_ids": ["uuid-of-node-to-remove"],
|
|
||||||
"link_ids": ["uuid-of-link-to-remove"]
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
|
|
||||||
If you need more information, return:
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"type": "clarifying_questions",
|
|
||||||
"questions": [
|
|
||||||
{{
|
|
||||||
"question": "What specific change do you want?",
|
|
||||||
"keyword": "change_type",
|
|
||||||
"example": "Add error handling"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
|
|
||||||
Generate the minimal patch needed. Output ONLY valid JSON.
|
|
||||||
"""
|
|
||||||
@@ -0,0 +1,269 @@
|
|||||||
|
"""External Agent Generator service client.
|
||||||
|
|
||||||
|
This module provides a client for communicating with the external Agent Generator
|
||||||
|
microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions
|
||||||
|
will delegate to the external service instead of using the built-in LLM-based implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_client: httpx.AsyncClient | None = None
|
||||||
|
_settings: Settings | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_settings() -> Settings:
|
||||||
|
"""Get or create settings singleton."""
|
||||||
|
global _settings
|
||||||
|
if _settings is None:
|
||||||
|
_settings = Settings()
|
||||||
|
return _settings
|
||||||
|
|
||||||
|
|
||||||
|
def is_external_service_configured() -> bool:
|
||||||
|
"""Check if external Agent Generator service is configured."""
|
||||||
|
settings = _get_settings()
|
||||||
|
return bool(settings.config.agentgenerator_host)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_base_url() -> str:
|
||||||
|
"""Get the base URL for the external service."""
|
||||||
|
settings = _get_settings()
|
||||||
|
host = settings.config.agentgenerator_host
|
||||||
|
port = settings.config.agentgenerator_port
|
||||||
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client() -> httpx.AsyncClient:
|
||||||
|
"""Get or create the HTTP client for the external service."""
|
||||||
|
global _client
|
||||||
|
if _client is None:
|
||||||
|
settings = _get_settings()
|
||||||
|
_client = httpx.AsyncClient(
|
||||||
|
base_url=_get_base_url(),
|
||||||
|
timeout=httpx.Timeout(settings.config.agentgenerator_timeout),
|
||||||
|
)
|
||||||
|
return _client
|
||||||
|
|
||||||
|
|
||||||
|
async def decompose_goal_external(
|
||||||
|
description: str, context: str = ""
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Call the external service to decompose a goal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description: Natural language goal description
|
||||||
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with either:
|
||||||
|
- {"type": "clarifying_questions", "questions": [...]}
|
||||||
|
- {"type": "instructions", "steps": [...]}
|
||||||
|
- {"type": "unachievable_goal", ...}
|
||||||
|
- {"type": "vague_goal", ...}
|
||||||
|
Or None on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
# Build the request payload
|
||||||
|
payload: dict[str, Any] = {"description": description}
|
||||||
|
if context:
|
||||||
|
# The external service uses user_instruction for additional context
|
||||||
|
payload["user_instruction"] = context
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post("/api/decompose-description", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
logger.error(f"External service returned error: {data.get('error')}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Map the response to the expected format
|
||||||
|
response_type = data.get("type")
|
||||||
|
if response_type == "instructions":
|
||||||
|
return {"type": "instructions", "steps": data.get("steps", [])}
|
||||||
|
elif response_type == "clarifying_questions":
|
||||||
|
return {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": data.get("questions", []),
|
||||||
|
}
|
||||||
|
elif response_type == "unachievable_goal":
|
||||||
|
return {
|
||||||
|
"type": "unachievable_goal",
|
||||||
|
"reason": data.get("reason"),
|
||||||
|
"suggested_goal": data.get("suggested_goal"),
|
||||||
|
}
|
||||||
|
elif response_type == "vague_goal":
|
||||||
|
return {
|
||||||
|
"type": "vague_goal",
|
||||||
|
"suggested_goal": data.get("suggested_goal"),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Unknown response type from external service: {response_type}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Request error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_agent_external(
|
||||||
|
instructions: dict[str, Any]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Call the external service to generate an agent from instructions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instructions: Structured instructions from decompose_goal
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent JSON dict or None on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/generate-agent", json={"instructions": instructions}
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
logger.error(f"External service returned error: {data.get('error')}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data.get("agent_json")
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Request error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_agent_patch_external(
|
||||||
|
update_request: str, current_agent: dict[str, Any]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Call the external service to generate a patch for an existing agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update_request: Natural language description of changes
|
||||||
|
current_agent: Current agent JSON
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated agent JSON, clarifying questions dict, or None on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/update-agent",
|
||||||
|
json={
|
||||||
|
"update_request": update_request,
|
||||||
|
"current_agent_json": current_agent,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
logger.error(f"External service returned error: {data.get('error')}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if it's clarifying questions
|
||||||
|
if data.get("type") == "clarifying_questions":
|
||||||
|
return {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": data.get("questions", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Otherwise return the updated agent JSON
|
||||||
|
return data.get("agent_json")
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Request error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||||
|
"""Get available blocks from the external service.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of block info dicts or None on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get("/api/blocks")
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
logger.error("External service returned error getting blocks")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data.get("blocks", [])
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"HTTP error getting blocks from external service: {e}")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Request error getting blocks from external service: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error getting blocks from external service: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def health_check() -> bool:
|
||||||
|
"""Check if the external service is healthy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if healthy, False otherwise
|
||||||
|
"""
|
||||||
|
if not is_external_service_configured():
|
||||||
|
return False
|
||||||
|
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get("/health")
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return data.get("status") == "healthy" and data.get("blocks_loaded", False)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"External agent generator health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def close_client() -> None:
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
global _client
|
||||||
|
if _client is not None:
|
||||||
|
await _client.aclose()
|
||||||
|
_client = None
|
||||||
@@ -1,213 +0,0 @@
|
|||||||
"""Utilities for agent generation."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.data.block import get_blocks
|
|
||||||
|
|
||||||
# UUID validation regex
|
|
||||||
UUID_REGEX = re.compile(
|
|
||||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Block IDs for various fixes
|
|
||||||
STORE_VALUE_BLOCK_ID = "1ff065e9-88e8-4358-9d82-8dc91f622ba9"
|
|
||||||
CONDITION_BLOCK_ID = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
|
|
||||||
ADDTOLIST_BLOCK_ID = "aeb08fc1-2fc1-4141-bc8e-f758f183a822"
|
|
||||||
ADDTODICTIONARY_BLOCK_ID = "31d1064e-7446-4693-a7d4-65e5ca1180d1"
|
|
||||||
CREATELIST_BLOCK_ID = "a912d5c7-6e00-4542-b2a9-8034136930e4"
|
|
||||||
CREATEDICT_BLOCK_ID = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
|
|
||||||
CODE_EXECUTION_BLOCK_ID = "0b02b072-abe7-11ef-8372-fb5d162dd712"
|
|
||||||
DATA_SAMPLING_BLOCK_ID = "4a448883-71fa-49cf-91cf-70d793bd7d87"
|
|
||||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID = "95d1b990-ce13-4d88-9737-ba5c2070c97b"
|
|
||||||
GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
|
|
||||||
|
|
||||||
DOUBLE_CURLY_BRACES_BLOCK_IDS = [
|
|
||||||
"44f6c8ad-d75c-4ae1-8209-aad1c0326928", # FillTextTemplateBlock
|
|
||||||
"6ab085e2-20b3-4055-bc3e-08036e01eca6",
|
|
||||||
"90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
|
|
||||||
"363ae599-353e-4804-937e-b2ee3cef3da4", # AgentOutputBlock
|
|
||||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
|
||||||
"db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
|
|
||||||
"3a7c4b8d-6e2f-4a5d-b9c1-f8d23c5a9b0e",
|
|
||||||
"ed1ae7a0-b770-4089-b520-1f0005fad19a",
|
|
||||||
"a892b8d9-3e4e-4e9c-9c1e-75f8efcf1bfa",
|
|
||||||
"b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1",
|
|
||||||
"716a67b3-6760-42e7-86dc-18645c6e00fc",
|
|
||||||
"530cf046-2ce0-4854-ae2c-659db17c7a46",
|
|
||||||
"ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
|
||||||
"1f292d4a-41a4-4977-9684-7c8d560b9f91", # LLM blocks
|
|
||||||
"32a87eab-381e-4dd4-bdb8-4c47151be35a",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_uuid(value: str) -> bool:
|
|
||||||
"""Check if a string is a valid UUID v4."""
|
|
||||||
return isinstance(value, str) and UUID_REGEX.match(value) is not None
|
|
||||||
|
|
||||||
|
|
||||||
def _compact_schema(schema: dict) -> dict[str, str]:
|
|
||||||
"""Extract compact type info from a JSON schema properties dict.
|
|
||||||
|
|
||||||
Returns a dict of {field_name: type_string} for essential info only.
|
|
||||||
"""
|
|
||||||
props = schema.get("properties", {})
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
for name, prop in props.items():
|
|
||||||
# Skip internal/complex fields
|
|
||||||
if name.startswith("_"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get type string
|
|
||||||
type_str = prop.get("type", "any")
|
|
||||||
|
|
||||||
# Handle anyOf/oneOf (optional types)
|
|
||||||
if "anyOf" in prop:
|
|
||||||
types = [t.get("type", "?") for t in prop["anyOf"] if t.get("type")]
|
|
||||||
type_str = "|".join(types) if types else "any"
|
|
||||||
elif "allOf" in prop:
|
|
||||||
type_str = "object"
|
|
||||||
|
|
||||||
# Add array item type if present
|
|
||||||
if type_str == "array" and "items" in prop:
|
|
||||||
items = prop["items"]
|
|
||||||
if isinstance(items, dict):
|
|
||||||
item_type = items.get("type", "any")
|
|
||||||
type_str = f"array[{item_type}]"
|
|
||||||
|
|
||||||
result[name] = type_str
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_block_summaries(include_schemas: bool = True) -> str:
|
|
||||||
"""Generate compact block summaries for prompts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
include_schemas: Whether to include input/output type info
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted string of block summaries (compact format)
|
|
||||||
"""
|
|
||||||
blocks = get_blocks()
|
|
||||||
summaries = []
|
|
||||||
|
|
||||||
for block_id, block_cls in blocks.items():
|
|
||||||
block = block_cls()
|
|
||||||
name = block.name
|
|
||||||
desc = getattr(block, "description", "") or ""
|
|
||||||
|
|
||||||
# Truncate description
|
|
||||||
if len(desc) > 150:
|
|
||||||
desc = desc[:147] + "..."
|
|
||||||
|
|
||||||
if not include_schemas:
|
|
||||||
summaries.append(f"- {name} (id: {block_id}): {desc}")
|
|
||||||
else:
|
|
||||||
# Compact format with type info only
|
|
||||||
inputs = {}
|
|
||||||
outputs = {}
|
|
||||||
required = []
|
|
||||||
|
|
||||||
if hasattr(block, "input_schema"):
|
|
||||||
try:
|
|
||||||
schema = block.input_schema.jsonschema()
|
|
||||||
inputs = _compact_schema(schema)
|
|
||||||
required = schema.get("required", [])
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if hasattr(block, "output_schema"):
|
|
||||||
try:
|
|
||||||
schema = block.output_schema.jsonschema()
|
|
||||||
outputs = _compact_schema(schema)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Build compact line format
|
|
||||||
# Format: NAME (id): desc | in: {field:type, ...} [required] | out: {field:type}
|
|
||||||
in_str = ", ".join(f"{k}:{v}" for k, v in inputs.items())
|
|
||||||
out_str = ", ".join(f"{k}:{v}" for k, v in outputs.items())
|
|
||||||
req_str = f" req=[{','.join(required)}]" if required else ""
|
|
||||||
|
|
||||||
static = " [static]" if getattr(block, "static_output", False) else ""
|
|
||||||
|
|
||||||
line = f"- {name} (id: {block_id}): {desc}"
|
|
||||||
if in_str:
|
|
||||||
line += f"\n in: {{{in_str}}}{req_str}"
|
|
||||||
if out_str:
|
|
||||||
line += f"\n out: {{{out_str}}}{static}"
|
|
||||||
|
|
||||||
summaries.append(line)
|
|
||||||
|
|
||||||
return "\n".join(summaries)
|
|
||||||
|
|
||||||
|
|
||||||
def get_blocks_info() -> list[dict[str, Any]]:
|
|
||||||
"""Get block information with schemas for validation and fixing."""
|
|
||||||
blocks = get_blocks()
|
|
||||||
blocks_info = []
|
|
||||||
for block_id, block_cls in blocks.items():
|
|
||||||
block = block_cls()
|
|
||||||
blocks_info.append(
|
|
||||||
{
|
|
||||||
"id": block_id,
|
|
||||||
"name": block.name,
|
|
||||||
"description": getattr(block, "description", ""),
|
|
||||||
"categories": getattr(block, "categories", []),
|
|
||||||
"staticOutput": getattr(block, "static_output", False),
|
|
||||||
"inputSchema": (
|
|
||||||
block.input_schema.jsonschema()
|
|
||||||
if hasattr(block, "input_schema")
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
"outputSchema": (
|
|
||||||
block.output_schema.jsonschema()
|
|
||||||
if hasattr(block, "output_schema")
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return blocks_info
|
|
||||||
|
|
||||||
|
|
||||||
def parse_json_from_llm(text: str) -> dict[str, Any] | None:
|
|
||||||
"""Extract JSON from LLM response (handles markdown code blocks)."""
|
|
||||||
if not text:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Try fenced code block
|
|
||||||
match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE)
|
|
||||||
if match:
|
|
||||||
try:
|
|
||||||
return json.loads(match.group(1).strip())
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Try raw text
|
|
||||||
try:
|
|
||||||
return json.loads(text.strip())
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Try finding {...} span
|
|
||||||
start = text.find("{")
|
|
||||||
end = text.rfind("}")
|
|
||||||
if start != -1 and end > start:
|
|
||||||
try:
|
|
||||||
return json.loads(text[start : end + 1])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Try finding [...] span
|
|
||||||
start = text.find("[")
|
|
||||||
end = text.rfind("]")
|
|
||||||
if start != -1 and end > start:
|
|
||||||
try:
|
|
||||||
return json.loads(text[start : end + 1])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return None
|
|
||||||
@@ -1,279 +0,0 @@
|
|||||||
"""Agent validator - Validates agent structure and connections."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from .utils import get_blocks_info
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentValidator:
|
|
||||||
"""Validator for AutoGPT agents with detailed error reporting."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.errors: list[str] = []
|
|
||||||
|
|
||||||
def add_error(self, error: str) -> None:
|
|
||||||
"""Add an error message."""
|
|
||||||
self.errors.append(error)
|
|
||||||
|
|
||||||
def validate_block_existence(
|
|
||||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> bool:
|
|
||||||
"""Validate all block IDs exist in the blocks library."""
|
|
||||||
valid = True
|
|
||||||
valid_block_ids = {b.get("id") for b in blocks_info if b.get("id")}
|
|
||||||
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
block_id = node.get("block_id")
|
|
||||||
node_id = node.get("id")
|
|
||||||
|
|
||||||
if not block_id:
|
|
||||||
self.add_error(f"Node '{node_id}' is missing 'block_id' field.")
|
|
||||||
valid = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
if block_id not in valid_block_ids:
|
|
||||||
self.add_error(
|
|
||||||
f"Node '{node_id}' references block_id '{block_id}' which does not exist."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate_link_node_references(self, agent: dict[str, Any]) -> bool:
|
|
||||||
"""Validate all node IDs referenced in links exist."""
|
|
||||||
valid = True
|
|
||||||
valid_node_ids = {n.get("id") for n in agent.get("nodes", []) if n.get("id")}
|
|
||||||
|
|
||||||
for link in agent.get("links", []):
|
|
||||||
link_id = link.get("id", "Unknown")
|
|
||||||
source_id = link.get("source_id")
|
|
||||||
sink_id = link.get("sink_id")
|
|
||||||
|
|
||||||
if not source_id:
|
|
||||||
self.add_error(f"Link '{link_id}' is missing 'source_id'.")
|
|
||||||
valid = False
|
|
||||||
elif source_id not in valid_node_ids:
|
|
||||||
self.add_error(
|
|
||||||
f"Link '{link_id}' references non-existent source_id '{source_id}'."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
if not sink_id:
|
|
||||||
self.add_error(f"Link '{link_id}' is missing 'sink_id'.")
|
|
||||||
valid = False
|
|
||||||
elif sink_id not in valid_node_ids:
|
|
||||||
self.add_error(
|
|
||||||
f"Link '{link_id}' references non-existent sink_id '{sink_id}'."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate_required_inputs(
|
|
||||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> bool:
|
|
||||||
"""Validate required inputs are provided."""
|
|
||||||
valid = True
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
block_id = node.get("block_id")
|
|
||||||
block = block_map.get(block_id)
|
|
||||||
|
|
||||||
if not block:
|
|
||||||
continue
|
|
||||||
|
|
||||||
required_inputs = block.get("inputSchema", {}).get("required", [])
|
|
||||||
input_defaults = node.get("input_default", {})
|
|
||||||
node_id = node.get("id")
|
|
||||||
|
|
||||||
# Get linked inputs
|
|
||||||
linked_inputs = {
|
|
||||||
link["sink_name"]
|
|
||||||
for link in agent.get("links", [])
|
|
||||||
if link.get("sink_id") == node_id
|
|
||||||
}
|
|
||||||
|
|
||||||
for req_input in required_inputs:
|
|
||||||
if (
|
|
||||||
req_input not in input_defaults
|
|
||||||
and req_input not in linked_inputs
|
|
||||||
and req_input != "credentials"
|
|
||||||
):
|
|
||||||
block_name = block.get("name", "Unknown Block")
|
|
||||||
self.add_error(
|
|
||||||
f"Node '{node_id}' ({block_name}) is missing required input '{req_input}'."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate_data_type_compatibility(
|
|
||||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> bool:
|
|
||||||
"""Validate linked data types are compatible."""
|
|
||||||
valid = True
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
|
||||||
|
|
||||||
def get_type(schema: dict, name: str) -> str | None:
|
|
||||||
if "_#_" in name:
|
|
||||||
parent, child = name.split("_#_", 1)
|
|
||||||
parent_schema = schema.get(parent, {})
|
|
||||||
if "properties" in parent_schema:
|
|
||||||
return parent_schema["properties"].get(child, {}).get("type")
|
|
||||||
return None
|
|
||||||
return schema.get(name, {}).get("type")
|
|
||||||
|
|
||||||
def are_compatible(src: str, sink: str) -> bool:
|
|
||||||
if {src, sink} <= {"integer", "number"}:
|
|
||||||
return True
|
|
||||||
return src == sink
|
|
||||||
|
|
||||||
for link in agent.get("links", []):
|
|
||||||
source_node = node_lookup.get(link.get("source_id"))
|
|
||||||
sink_node = node_lookup.get(link.get("sink_id"))
|
|
||||||
|
|
||||||
if not source_node or not sink_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_block = block_map.get(source_node.get("block_id"))
|
|
||||||
sink_block = block_map.get(sink_node.get("block_id"))
|
|
||||||
|
|
||||||
if not source_block or not sink_block:
|
|
||||||
continue
|
|
||||||
|
|
||||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
|
||||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
|
||||||
|
|
||||||
source_type = get_type(source_outputs, link.get("source_name", ""))
|
|
||||||
sink_type = get_type(sink_inputs, link.get("sink_name", ""))
|
|
||||||
|
|
||||||
if source_type and sink_type and not are_compatible(source_type, sink_type):
|
|
||||||
self.add_error(
|
|
||||||
f"Type mismatch: {source_block.get('name')} output '{link['source_name']}' "
|
|
||||||
f"({source_type}) -> {sink_block.get('name')} input '{link['sink_name']}' ({sink_type})."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate_nested_sink_links(
|
|
||||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
|
||||||
) -> bool:
|
|
||||||
"""Validate nested sink links (with _#_ notation)."""
|
|
||||||
valid = True
|
|
||||||
block_map = {b.get("id"): b for b in blocks_info}
|
|
||||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
|
||||||
|
|
||||||
for link in agent.get("links", []):
|
|
||||||
sink_name = link.get("sink_name", "")
|
|
||||||
|
|
||||||
if "_#_" in sink_name:
|
|
||||||
parent, child = sink_name.split("_#_", 1)
|
|
||||||
|
|
||||||
sink_node = node_lookup.get(link.get("sink_id"))
|
|
||||||
if not sink_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
block = block_map.get(sink_node.get("block_id"))
|
|
||||||
if not block:
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_props = block.get("inputSchema", {}).get("properties", {})
|
|
||||||
parent_schema = input_props.get(parent)
|
|
||||||
|
|
||||||
if not parent_schema:
|
|
||||||
self.add_error(
|
|
||||||
f"Invalid nested link '{sink_name}': parent '{parent}' not found."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not parent_schema.get("additionalProperties"):
|
|
||||||
if not (
|
|
||||||
isinstance(parent_schema, dict)
|
|
||||||
and "properties" in parent_schema
|
|
||||||
and child in parent_schema.get("properties", {})
|
|
||||||
):
|
|
||||||
self.add_error(
|
|
||||||
f"Invalid nested link '{sink_name}': child '{child}' not found in '{parent}'."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate_prompt_spaces(self, agent: dict[str, Any]) -> bool:
|
|
||||||
"""Validate prompts don't have spaces in template variables."""
|
|
||||||
valid = True
|
|
||||||
|
|
||||||
for node in agent.get("nodes", []):
|
|
||||||
input_default = node.get("input_default", {})
|
|
||||||
prompt = input_default.get("prompt", "")
|
|
||||||
|
|
||||||
if not isinstance(prompt, str):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Find {{...}} with spaces
|
|
||||||
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt)
|
|
||||||
for match in matches:
|
|
||||||
content = match.group(1)
|
|
||||||
if " " in content:
|
|
||||||
self.add_error(
|
|
||||||
f"Node '{node.get('id')}' has spaces in template variable: "
|
|
||||||
f"'{{{{{content}}}}}' should be '{{{{{content.replace(' ', '_')}}}}}'."
|
|
||||||
)
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def validate(
|
|
||||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
|
||||||
) -> tuple[bool, str | None]:
|
|
||||||
"""Run all validations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_valid, error_message)
|
|
||||||
"""
|
|
||||||
self.errors = []
|
|
||||||
|
|
||||||
if blocks_info is None:
|
|
||||||
blocks_info = get_blocks_info()
|
|
||||||
|
|
||||||
checks = [
|
|
||||||
self.validate_block_existence(agent, blocks_info),
|
|
||||||
self.validate_link_node_references(agent),
|
|
||||||
self.validate_required_inputs(agent, blocks_info),
|
|
||||||
self.validate_data_type_compatibility(agent, blocks_info),
|
|
||||||
self.validate_nested_sink_links(agent, blocks_info),
|
|
||||||
self.validate_prompt_spaces(agent),
|
|
||||||
]
|
|
||||||
|
|
||||||
all_passed = all(checks)
|
|
||||||
|
|
||||||
if all_passed:
|
|
||||||
logger.info("Agent validation successful")
|
|
||||||
return True, None
|
|
||||||
|
|
||||||
error_message = "Agent validation failed:\n"
|
|
||||||
for i, error in enumerate(self.errors, 1):
|
|
||||||
error_message += f"{i}. {error}\n"
|
|
||||||
|
|
||||||
logger.warning(f"Agent validation failed with {len(self.errors)} errors")
|
|
||||||
return False, error_message
|
|
||||||
|
|
||||||
|
|
||||||
def validate_agent(
|
|
||||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
|
||||||
) -> tuple[bool, str | None]:
|
|
||||||
"""Convenience function to validate an agent.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_valid, error_message)
|
|
||||||
"""
|
|
||||||
validator = AgentValidator()
|
|
||||||
return validator.validate(agent, blocks_info)
|
|
||||||
@@ -8,12 +8,10 @@ from langfuse import observe
|
|||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
apply_all_fixes,
|
AgentGeneratorNotConfiguredError,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
generate_agent,
|
generate_agent,
|
||||||
get_blocks_info,
|
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
validate_agent,
|
|
||||||
)
|
)
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .models import (
|
from .models import (
|
||||||
@@ -27,9 +25,6 @@ from .models import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Maximum retries for agent generation with validation feedback
|
|
||||||
MAX_GENERATION_RETRIES = 2
|
|
||||||
|
|
||||||
|
|
||||||
class CreateAgentTool(BaseTool):
|
class CreateAgentTool(BaseTool):
|
||||||
"""Tool for creating agents from natural language descriptions."""
|
"""Tool for creating agents from natural language descriptions."""
|
||||||
@@ -91,9 +86,8 @@ class CreateAgentTool(BaseTool):
|
|||||||
|
|
||||||
Flow:
|
Flow:
|
||||||
1. Decompose the description into steps (may return clarifying questions)
|
1. Decompose the description into steps (may return clarifying questions)
|
||||||
2. Generate agent JSON from the steps
|
2. Generate agent JSON (external service handles fixing and validation)
|
||||||
3. Apply fixes to correct common LLM errors
|
3. Preview or save based on the save parameter
|
||||||
4. Preview or save based on the save parameter
|
|
||||||
"""
|
"""
|
||||||
description = kwargs.get("description", "").strip()
|
description = kwargs.get("description", "").strip()
|
||||||
context = kwargs.get("context", "")
|
context = kwargs.get("context", "")
|
||||||
@@ -110,11 +104,13 @@ class CreateAgentTool(BaseTool):
|
|||||||
# Step 1: Decompose goal into steps
|
# Step 1: Decompose goal into steps
|
||||||
try:
|
try:
|
||||||
decomposition_result = await decompose_goal(description, context)
|
decomposition_result = await decompose_goal(description, context)
|
||||||
except ValueError as e:
|
except AgentGeneratorNotConfiguredError:
|
||||||
# Handle missing API key or configuration errors
|
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=f"Agent generation is not configured: {str(e)}",
|
message=(
|
||||||
error="configuration_error",
|
"Agent generation is not available. "
|
||||||
|
"The Agent Generator service is not configured."
|
||||||
|
),
|
||||||
|
error="service_not_configured",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -171,72 +167,32 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 2: Generate agent JSON with retry on validation failure
|
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
||||||
blocks_info = get_blocks_info()
|
try:
|
||||||
agent_json = None
|
agent_json = await generate_agent(decomposition_result)
|
||||||
validation_errors = None
|
except AgentGeneratorNotConfiguredError:
|
||||||
|
return ErrorResponse(
|
||||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
message=(
|
||||||
# Generate agent (include validation errors from previous attempt)
|
"Agent generation is not available. "
|
||||||
if attempt == 0:
|
"The Agent Generator service is not configured."
|
||||||
agent_json = await generate_agent(decomposition_result)
|
),
|
||||||
else:
|
error="service_not_configured",
|
||||||
# Retry with validation error feedback
|
session_id=session_id,
|
||||||
logger.info(
|
|
||||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
|
||||||
)
|
|
||||||
retry_instructions = {
|
|
||||||
**decomposition_result,
|
|
||||||
"previous_errors": validation_errors,
|
|
||||||
"retry_instructions": (
|
|
||||||
"The previous generation had validation errors. "
|
|
||||||
"Please fix these issues in the new generation:\n"
|
|
||||||
f"{validation_errors}"
|
|
||||||
),
|
|
||||||
}
|
|
||||||
agent_json = await generate_agent(retry_instructions)
|
|
||||||
|
|
||||||
if agent_json is None:
|
|
||||||
if attempt == MAX_GENERATION_RETRIES:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to generate the agent. Please try again.",
|
|
||||||
error="Generation failed",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Step 3: Apply fixes to correct common errors
|
|
||||||
agent_json = apply_all_fixes(agent_json, blocks_info)
|
|
||||||
|
|
||||||
# Step 4: Validate the agent
|
|
||||||
is_valid, validation_errors = validate_agent(agent_json, blocks_info)
|
|
||||||
|
|
||||||
if is_valid:
|
|
||||||
logger.info(f"Agent generated successfully on attempt {attempt + 1}")
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if attempt == MAX_GENERATION_RETRIES:
|
if agent_json is None:
|
||||||
# Return error with validation details
|
return ErrorResponse(
|
||||||
return ErrorResponse(
|
message="Failed to generate the agent. Please try again.",
|
||||||
message=(
|
error="Generation failed",
|
||||||
f"Generated agent has validation errors after {MAX_GENERATION_RETRIES + 1} attempts. "
|
session_id=session_id,
|
||||||
f"Please try rephrasing your request or simplify the workflow."
|
)
|
||||||
),
|
|
||||||
error="validation_failed",
|
|
||||||
details={"validation_errors": validation_errors},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_name = agent_json.get("name", "Generated Agent")
|
agent_name = agent_json.get("name", "Generated Agent")
|
||||||
agent_description = agent_json.get("description", "")
|
agent_description = agent_json.get("description", "")
|
||||||
node_count = len(agent_json.get("nodes", []))
|
node_count = len(agent_json.get("nodes", []))
|
||||||
link_count = len(agent_json.get("links", []))
|
link_count = len(agent_json.get("links", []))
|
||||||
|
|
||||||
# Step 4: Preview or save
|
# Step 3: Preview or save
|
||||||
if not save:
|
if not save:
|
||||||
return AgentPreviewResponse(
|
return AgentPreviewResponse(
|
||||||
message=(
|
message=(
|
||||||
|
|||||||
@@ -8,13 +8,10 @@ from langfuse import observe
|
|||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
apply_agent_patch,
|
AgentGeneratorNotConfiguredError,
|
||||||
apply_all_fixes,
|
|
||||||
generate_agent_patch,
|
generate_agent_patch,
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
get_blocks_info,
|
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
validate_agent,
|
|
||||||
)
|
)
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .models import (
|
from .models import (
|
||||||
@@ -28,9 +25,6 @@ from .models import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Maximum retries for patch generation with validation feedback
|
|
||||||
MAX_GENERATION_RETRIES = 2
|
|
||||||
|
|
||||||
|
|
||||||
class EditAgentTool(BaseTool):
|
class EditAgentTool(BaseTool):
|
||||||
"""Tool for editing existing agents using natural language."""
|
"""Tool for editing existing agents using natural language."""
|
||||||
@@ -43,7 +37,7 @@ class EditAgentTool(BaseTool):
|
|||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Edit an existing agent from the user's library using natural language. "
|
"Edit an existing agent from the user's library using natural language. "
|
||||||
"Generates a patch to update the agent while preserving unchanged parts."
|
"Generates updates to the agent while preserving unchanged parts."
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -98,9 +92,8 @@ class EditAgentTool(BaseTool):
|
|||||||
|
|
||||||
Flow:
|
Flow:
|
||||||
1. Fetch the current agent
|
1. Fetch the current agent
|
||||||
2. Generate a patch based on the requested changes
|
2. Generate updated agent (external service handles fixing and validation)
|
||||||
3. Apply the patch to create an updated agent
|
3. Preview or save based on the save parameter
|
||||||
4. Preview or save based on the save parameter
|
|
||||||
"""
|
"""
|
||||||
agent_id = kwargs.get("agent_id", "").strip()
|
agent_id = kwargs.get("agent_id", "").strip()
|
||||||
changes = kwargs.get("changes", "").strip()
|
changes = kwargs.get("changes", "").strip()
|
||||||
@@ -137,121 +130,58 @@ class EditAgentTool(BaseTool):
|
|||||||
if context:
|
if context:
|
||||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||||
|
|
||||||
# Step 2: Generate patch with retry on validation failure
|
# Step 2: Generate updated agent (external service handles fixing and validation)
|
||||||
blocks_info = get_blocks_info()
|
try:
|
||||||
updated_agent = None
|
result = await generate_agent_patch(update_request, current_agent)
|
||||||
validation_errors = None
|
except AgentGeneratorNotConfiguredError:
|
||||||
intent = "Applied requested changes"
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
"Agent editing is not available. "
|
||||||
# Generate patch (include validation errors from previous attempt)
|
"The Agent Generator service is not configured."
|
||||||
try:
|
),
|
||||||
if attempt == 0:
|
error="service_not_configured",
|
||||||
patch_result = await generate_agent_patch(
|
session_id=session_id,
|
||||||
update_request, current_agent
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Retry with validation error feedback
|
|
||||||
logger.info(
|
|
||||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
|
||||||
)
|
|
||||||
retry_request = (
|
|
||||||
f"{update_request}\n\n"
|
|
||||||
f"IMPORTANT: The previous edit had validation errors. "
|
|
||||||
f"Please fix these issues:\n{validation_errors}"
|
|
||||||
)
|
|
||||||
patch_result = await generate_agent_patch(
|
|
||||||
retry_request, current_agent
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
# Handle missing API key or configuration errors
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Agent generation is not configured: {str(e)}",
|
|
||||||
error="configuration_error",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if patch_result is None:
|
|
||||||
if attempt == MAX_GENERATION_RETRIES:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to generate changes. Please try rephrasing.",
|
|
||||||
error="Patch generation failed",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if LLM returned clarifying questions
|
|
||||||
if patch_result.get("type") == "clarifying_questions":
|
|
||||||
questions = patch_result.get("questions", [])
|
|
||||||
return ClarificationNeededResponse(
|
|
||||||
message=(
|
|
||||||
"I need some more information about the changes. "
|
|
||||||
"Please answer the following questions:"
|
|
||||||
),
|
|
||||||
questions=[
|
|
||||||
ClarifyingQuestion(
|
|
||||||
question=q.get("question", ""),
|
|
||||||
keyword=q.get("keyword", ""),
|
|
||||||
example=q.get("example"),
|
|
||||||
)
|
|
||||||
for q in questions
|
|
||||||
],
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 3: Apply patch and fixes
|
|
||||||
try:
|
|
||||||
updated_agent = apply_agent_patch(current_agent, patch_result)
|
|
||||||
updated_agent = apply_all_fixes(updated_agent, blocks_info)
|
|
||||||
except Exception as e:
|
|
||||||
if attempt == MAX_GENERATION_RETRIES:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to apply changes: {str(e)}",
|
|
||||||
error="patch_apply_failed",
|
|
||||||
details={"exception": str(e)},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
validation_errors = str(e)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Step 4: Validate the updated agent
|
|
||||||
is_valid, validation_errors = validate_agent(updated_agent, blocks_info)
|
|
||||||
|
|
||||||
if is_valid:
|
|
||||||
logger.info(f"Agent edited successfully on attempt {attempt + 1}")
|
|
||||||
intent = patch_result.get("intent", "Applied requested changes")
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if attempt == MAX_GENERATION_RETRIES:
|
if result is None:
|
||||||
# Return error with validation details
|
return ErrorResponse(
|
||||||
return ErrorResponse(
|
message="Failed to generate changes. Please try rephrasing.",
|
||||||
message=(
|
error="Update generation failed",
|
||||||
f"Updated agent has validation errors after "
|
session_id=session_id,
|
||||||
f"{MAX_GENERATION_RETRIES + 1} attempts. "
|
)
|
||||||
f"Please try rephrasing your request or simplify the changes."
|
|
||||||
),
|
|
||||||
error="validation_failed",
|
|
||||||
details={"validation_errors": validation_errors},
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# At this point, updated_agent is guaranteed to be set (we return on all failure paths)
|
# Check if LLM returned clarifying questions
|
||||||
assert updated_agent is not None
|
if result.get("type") == "clarifying_questions":
|
||||||
|
questions = result.get("questions", [])
|
||||||
|
return ClarificationNeededResponse(
|
||||||
|
message=(
|
||||||
|
"I need some more information about the changes. "
|
||||||
|
"Please answer the following questions:"
|
||||||
|
),
|
||||||
|
questions=[
|
||||||
|
ClarifyingQuestion(
|
||||||
|
question=q.get("question", ""),
|
||||||
|
keyword=q.get("keyword", ""),
|
||||||
|
example=q.get("example"),
|
||||||
|
)
|
||||||
|
for q in questions
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Result is the updated agent JSON
|
||||||
|
updated_agent = result
|
||||||
|
|
||||||
agent_name = updated_agent.get("name", "Updated Agent")
|
agent_name = updated_agent.get("name", "Updated Agent")
|
||||||
agent_description = updated_agent.get("description", "")
|
agent_description = updated_agent.get("description", "")
|
||||||
node_count = len(updated_agent.get("nodes", []))
|
node_count = len(updated_agent.get("nodes", []))
|
||||||
link_count = len(updated_agent.get("links", []))
|
link_count = len(updated_agent.get("links", []))
|
||||||
|
|
||||||
# Step 5: Preview or save
|
# Step 3: Preview or save
|
||||||
if not save:
|
if not save:
|
||||||
return AgentPreviewResponse(
|
return AgentPreviewResponse(
|
||||||
message=(
|
message=(
|
||||||
f"I've updated the agent. Changes: {intent}. "
|
f"I've updated the agent. "
|
||||||
f"The agent now has {node_count} blocks. "
|
f"The agent now has {node_count} blocks. "
|
||||||
f"Review it and call edit_agent with save=true to save the changes."
|
f"Review it and call edit_agent with save=true to save the changes."
|
||||||
),
|
),
|
||||||
@@ -277,10 +207,7 @@ class EditAgentTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return AgentSavedResponse(
|
return AgentSavedResponse(
|
||||||
message=(
|
message=f"Updated agent '{created_graph.name}' has been saved to your library!",
|
||||||
f"Updated agent '{created_graph.name}' has been saved to your library! "
|
|
||||||
f"Changes: {intent}"
|
|
||||||
),
|
|
||||||
agent_id=created_graph.id,
|
agent_id=created_graph.id,
|
||||||
agent_name=created_graph.name,
|
agent_name=created_graph.name,
|
||||||
library_agent_id=library_agent.id,
|
library_agent_id=library_agent.id,
|
||||||
|
|||||||
@@ -350,6 +350,19 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
description="Whether to mark failed scans as clean or not",
|
description="Whether to mark failed scans as clean or not",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agentgenerator_host: str = Field(
|
||||||
|
default="",
|
||||||
|
description="The host for the Agent Generator service (empty to use built-in)",
|
||||||
|
)
|
||||||
|
agentgenerator_port: int = Field(
|
||||||
|
default=8000,
|
||||||
|
description="The port for the Agent Generator service",
|
||||||
|
)
|
||||||
|
agentgenerator_timeout: int = Field(
|
||||||
|
default=120,
|
||||||
|
description="The timeout in seconds for Agent Generator service requests",
|
||||||
|
)
|
||||||
|
|
||||||
enable_example_blocks: bool = Field(
|
enable_example_blocks: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to enable example blocks in production",
|
description="Whether to enable example blocks in production",
|
||||||
|
|||||||
@@ -1,37 +1,12 @@
|
|||||||
-- CreateExtension
|
-- CreateExtension
|
||||||
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
||||||
-- Ensures vector extension is in the current schema (from DATABASE_URL ?schema= param)
|
-- Creates extension in current schema (determined by search_path from DATABASE_URL ?schema= param)
|
||||||
-- If it exists in a different schema (e.g., public), we drop and recreate it in the current schema
|
|
||||||
-- This ensures vector type is in the same schema as tables, making ::vector work without explicit qualification
|
-- This ensures vector type is in the same schema as tables, making ::vector work without explicit qualification
|
||||||
DO $$
|
DO $$
|
||||||
DECLARE
|
|
||||||
current_schema_name text;
|
|
||||||
vector_schema text;
|
|
||||||
BEGIN
|
BEGIN
|
||||||
-- Get the current schema from search_path
|
CREATE EXTENSION IF NOT EXISTS "vector";
|
||||||
SELECT current_schema() INTO current_schema_name;
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'vector extension not available or already exists, skipping';
|
||||||
-- Check if vector extension exists and which schema it's in
|
|
||||||
SELECT n.nspname INTO vector_schema
|
|
||||||
FROM pg_extension e
|
|
||||||
JOIN pg_namespace n ON e.extnamespace = n.oid
|
|
||||||
WHERE e.extname = 'vector';
|
|
||||||
|
|
||||||
-- Handle removal if in wrong schema
|
|
||||||
IF vector_schema IS NOT NULL AND vector_schema != current_schema_name THEN
|
|
||||||
BEGIN
|
|
||||||
-- Vector exists in a different schema, drop it first
|
|
||||||
RAISE WARNING 'pgvector found in schema "%" but need it in "%". Dropping and reinstalling...',
|
|
||||||
vector_schema, current_schema_name;
|
|
||||||
EXECUTE 'DROP EXTENSION IF EXISTS vector CASCADE';
|
|
||||||
EXCEPTION WHEN OTHERS THEN
|
|
||||||
RAISE EXCEPTION 'Failed to drop pgvector from schema "%": %. You may need to drop it manually.',
|
|
||||||
vector_schema, SQLERRM;
|
|
||||||
END;
|
|
||||||
END IF;
|
|
||||||
|
|
||||||
-- Create extension in current schema (let it fail naturally if not available)
|
|
||||||
EXECUTE format('CREATE EXTENSION IF NOT EXISTS vector SCHEMA %I', current_schema_name);
|
|
||||||
END $$;
|
END $$;
|
||||||
|
|
||||||
-- CreateEnum
|
-- CreateEnum
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
-- Acknowledge Supabase-managed extensions to prevent drift warnings
|
||||||
|
-- These extensions are pre-installed by Supabase in specific schemas
|
||||||
|
-- This migration ensures they exist where available (Supabase) or skips gracefully (CI)
|
||||||
|
|
||||||
|
-- Create schemas (safe in both CI and Supabase)
|
||||||
|
CREATE SCHEMA IF NOT EXISTS "extensions";
|
||||||
|
|
||||||
|
-- Extensions that exist in both CI and Supabase
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'pgcrypto extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'uuid-ossp extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
-- Supabase-specific extensions (skip gracefully in CI)
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'pg_stat_statements extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "pg_net" WITH SCHEMA "extensions";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'pg_net extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'pgjwt extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE SCHEMA IF NOT EXISTS "graphql";
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'pg_graphql extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE SCHEMA IF NOT EXISTS "pgsodium";
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'pgsodium extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
CREATE SCHEMA IF NOT EXISTS "vault";
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault";
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RAISE NOTICE 'supabase_vault extension not available, skipping';
|
||||||
|
END $$;
|
||||||
|
|
||||||
|
|
||||||
|
-- Return to platform
|
||||||
|
CREATE SCHEMA IF NOT EXISTS "platform";
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for agent generator module."""
|
||||||
@@ -0,0 +1,273 @@
|
|||||||
|
"""
|
||||||
|
Tests for the Agent Generator core module.
|
||||||
|
|
||||||
|
This test suite verifies that the core functions correctly delegate to
|
||||||
|
the external Agent Generator service.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.agent_generator import core
|
||||||
|
from backend.api.features.chat.tools.agent_generator.core import (
|
||||||
|
AgentGeneratorNotConfiguredError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestServiceNotConfigured:
|
||||||
|
"""Test that functions raise AgentGeneratorNotConfiguredError when service is not configured."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_raises_when_not_configured(self):
|
||||||
|
"""Test that decompose_goal raises error when service not configured."""
|
||||||
|
with patch.object(core, "is_external_service_configured", return_value=False):
|
||||||
|
with pytest.raises(AgentGeneratorNotConfiguredError):
|
||||||
|
await core.decompose_goal("Build a chatbot")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_agent_raises_when_not_configured(self):
|
||||||
|
"""Test that generate_agent raises error when service not configured."""
|
||||||
|
with patch.object(core, "is_external_service_configured", return_value=False):
|
||||||
|
with pytest.raises(AgentGeneratorNotConfiguredError):
|
||||||
|
await core.generate_agent({"steps": []})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_agent_patch_raises_when_not_configured(self):
|
||||||
|
"""Test that generate_agent_patch raises error when service not configured."""
|
||||||
|
with patch.object(core, "is_external_service_configured", return_value=False):
|
||||||
|
with pytest.raises(AgentGeneratorNotConfiguredError):
|
||||||
|
await core.generate_agent_patch("Add a node", {"nodes": []})
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecomposeGoal:
|
||||||
|
"""Test decompose_goal function service delegation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calls_external_service(self):
|
||||||
|
"""Test that decompose_goal calls the external service."""
|
||||||
|
expected_result = {"type": "instructions", "steps": ["Step 1"]}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "decompose_goal_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = expected_result
|
||||||
|
|
||||||
|
result = await core.decompose_goal("Build a chatbot")
|
||||||
|
|
||||||
|
mock_external.assert_called_once_with("Build a chatbot", "")
|
||||||
|
assert result == expected_result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_passes_context_to_external_service(self):
|
||||||
|
"""Test that decompose_goal passes context to external service."""
|
||||||
|
expected_result = {"type": "instructions", "steps": ["Step 1"]}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "decompose_goal_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = expected_result
|
||||||
|
|
||||||
|
await core.decompose_goal("Build a chatbot", "Use Python")
|
||||||
|
|
||||||
|
mock_external.assert_called_once_with("Build a chatbot", "Use Python")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_on_service_failure(self):
|
||||||
|
"""Test that decompose_goal returns None when external service fails."""
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "decompose_goal_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = None
|
||||||
|
|
||||||
|
result = await core.decompose_goal("Build a chatbot")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateAgent:
|
||||||
|
"""Test generate_agent function service delegation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calls_external_service(self):
|
||||||
|
"""Test that generate_agent calls the external service."""
|
||||||
|
expected_result = {"name": "Test Agent", "nodes": [], "links": []}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "generate_agent_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = expected_result
|
||||||
|
|
||||||
|
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||||
|
result = await core.generate_agent(instructions)
|
||||||
|
|
||||||
|
mock_external.assert_called_once_with(instructions)
|
||||||
|
# Result should have id, version, is_active added if not present
|
||||||
|
assert result is not None
|
||||||
|
assert result["name"] == "Test Agent"
|
||||||
|
assert "id" in result
|
||||||
|
assert result["version"] == 1
|
||||||
|
assert result["is_active"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserves_existing_id_and_version(self):
|
||||||
|
"""Test that external service result preserves existing id and version."""
|
||||||
|
expected_result = {
|
||||||
|
"id": "existing-id",
|
||||||
|
"version": 3,
|
||||||
|
"is_active": False,
|
||||||
|
"name": "Test Agent",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "generate_agent_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = expected_result.copy()
|
||||||
|
|
||||||
|
result = await core.generate_agent({"steps": []})
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["id"] == "existing-id"
|
||||||
|
assert result["version"] == 3
|
||||||
|
assert result["is_active"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_external_service_fails(self):
|
||||||
|
"""Test that generate_agent returns None when external service fails."""
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "generate_agent_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = None
|
||||||
|
|
||||||
|
result = await core.generate_agent({"steps": []})
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateAgentPatch:
|
||||||
|
"""Test generate_agent_patch function service delegation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calls_external_service(self):
|
||||||
|
"""Test that generate_agent_patch calls the external service."""
|
||||||
|
expected_result = {"name": "Updated Agent", "nodes": [], "links": []}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "generate_agent_patch_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = expected_result
|
||||||
|
|
||||||
|
current_agent = {"nodes": [], "links": []}
|
||||||
|
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||||
|
|
||||||
|
mock_external.assert_called_once_with("Add a node", current_agent)
|
||||||
|
assert result == expected_result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_clarifying_questions(self):
|
||||||
|
"""Test that generate_agent_patch returns clarifying questions."""
|
||||||
|
expected_result = {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": [{"question": "What type of node?"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "generate_agent_patch_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = expected_result
|
||||||
|
|
||||||
|
result = await core.generate_agent_patch("Add a node", {"nodes": []})
|
||||||
|
|
||||||
|
assert result == expected_result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_external_service_fails(self):
|
||||||
|
"""Test that generate_agent_patch returns None when service fails."""
|
||||||
|
with patch.object(
|
||||||
|
core, "is_external_service_configured", return_value=True
|
||||||
|
), patch.object(
|
||||||
|
core, "generate_agent_patch_external", new_callable=AsyncMock
|
||||||
|
) as mock_external:
|
||||||
|
mock_external.return_value = None
|
||||||
|
|
||||||
|
result = await core.generate_agent_patch("Add a node", {"nodes": []})
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestJsonToGraph:
|
||||||
|
"""Test json_to_graph function."""
|
||||||
|
|
||||||
|
def test_converts_agent_json_to_graph(self):
|
||||||
|
"""Test conversion of agent JSON to Graph model."""
|
||||||
|
agent_json = {
|
||||||
|
"id": "test-id",
|
||||||
|
"version": 2,
|
||||||
|
"is_active": True,
|
||||||
|
"name": "Test Agent",
|
||||||
|
"description": "A test agent",
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"id": "node1",
|
||||||
|
"block_id": "block1",
|
||||||
|
"input_default": {"key": "value"},
|
||||||
|
"metadata": {"x": 100},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"links": [
|
||||||
|
{
|
||||||
|
"id": "link1",
|
||||||
|
"source_id": "node1",
|
||||||
|
"sink_id": "output",
|
||||||
|
"source_name": "result",
|
||||||
|
"sink_name": "input",
|
||||||
|
"is_static": False,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = core.json_to_graph(agent_json)
|
||||||
|
|
||||||
|
assert graph.id == "test-id"
|
||||||
|
assert graph.version == 2
|
||||||
|
assert graph.is_active is True
|
||||||
|
assert graph.name == "Test Agent"
|
||||||
|
assert graph.description == "A test agent"
|
||||||
|
assert len(graph.nodes) == 1
|
||||||
|
assert graph.nodes[0].id == "node1"
|
||||||
|
assert graph.nodes[0].block_id == "block1"
|
||||||
|
assert len(graph.links) == 1
|
||||||
|
assert graph.links[0].source_id == "node1"
|
||||||
|
|
||||||
|
def test_generates_ids_if_missing(self):
|
||||||
|
"""Test that missing IDs are generated."""
|
||||||
|
agent_json = {
|
||||||
|
"name": "Test Agent",
|
||||||
|
"nodes": [{"block_id": "block1"}],
|
||||||
|
"links": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = core.json_to_graph(agent_json)
|
||||||
|
|
||||||
|
assert graph.id is not None
|
||||||
|
assert graph.nodes[0].id is not None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
422
autogpt_platform/backend/test/agent_generator/test_service.py
Normal file
422
autogpt_platform/backend/test/agent_generator/test_service.py
Normal file
@@ -0,0 +1,422 @@
|
|||||||
|
"""
|
||||||
|
Tests for the Agent Generator external service client.
|
||||||
|
|
||||||
|
This test suite verifies the external Agent Generator service integration,
|
||||||
|
including service detection, API calls, and error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.agent_generator import service
|
||||||
|
|
||||||
|
|
||||||
|
class TestServiceConfiguration:
|
||||||
|
"""Test service configuration detection."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset settings singleton before each test."""
|
||||||
|
service._settings = None
|
||||||
|
service._client = None
|
||||||
|
|
||||||
|
def test_external_service_not_configured_when_host_empty(self):
|
||||||
|
"""Test that external service is not configured when host is empty."""
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.config.agentgenerator_host = ""
|
||||||
|
|
||||||
|
with patch.object(service, "_get_settings", return_value=mock_settings):
|
||||||
|
assert service.is_external_service_configured() is False
|
||||||
|
|
||||||
|
def test_external_service_configured_when_host_set(self):
|
||||||
|
"""Test that external service is configured when host is set."""
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.config.agentgenerator_host = "agent-generator.local"
|
||||||
|
|
||||||
|
with patch.object(service, "_get_settings", return_value=mock_settings):
|
||||||
|
assert service.is_external_service_configured() is True
|
||||||
|
|
||||||
|
def test_get_base_url(self):
|
||||||
|
"""Test base URL construction."""
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.config.agentgenerator_host = "agent-generator.local"
|
||||||
|
mock_settings.config.agentgenerator_port = 8000
|
||||||
|
|
||||||
|
with patch.object(service, "_get_settings", return_value=mock_settings):
|
||||||
|
url = service._get_base_url()
|
||||||
|
assert url == "http://agent-generator.local:8000"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecomposeGoalExternal:
|
||||||
|
"""Test decompose_goal_external function."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset client singleton before each test."""
|
||||||
|
service._settings = None
|
||||||
|
service._client = None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_returns_instructions(self):
|
||||||
|
"""Test successful decomposition returning instructions."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": ["Step 1", "Step 2"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.decompose_goal_external("Build a chatbot")
|
||||||
|
|
||||||
|
assert result == {"type": "instructions", "steps": ["Step 1", "Step 2"]}
|
||||||
|
mock_client.post.assert_called_once_with(
|
||||||
|
"/api/decompose-description", json={"description": "Build a chatbot"}
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_returns_clarifying_questions(self):
|
||||||
|
"""Test decomposition returning clarifying questions."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": ["What platform?", "What language?"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.decompose_goal_external("Build something")
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": ["What platform?", "What language?"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_with_context(self):
|
||||||
|
"""Test decomposition with additional context."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": ["Step 1"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
await service.decompose_goal_external(
|
||||||
|
"Build a chatbot", context="Use Python"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.post.assert_called_once_with(
|
||||||
|
"/api/decompose-description",
|
||||||
|
json={"description": "Build a chatbot", "user_instruction": "Use Python"},
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_returns_unachievable_goal(self):
|
||||||
|
"""Test decomposition returning unachievable goal response."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "unachievable_goal",
|
||||||
|
"reason": "Cannot do X",
|
||||||
|
"suggested_goal": "Try Y instead",
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.decompose_goal_external("Do something impossible")
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"type": "unachievable_goal",
|
||||||
|
"reason": "Cannot do X",
|
||||||
|
"suggested_goal": "Try Y instead",
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_handles_http_error(self):
|
||||||
|
"""Test decomposition handles HTTP errors gracefully."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.side_effect = httpx.HTTPStatusError(
|
||||||
|
"Server error", request=MagicMock(), response=MagicMock()
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.decompose_goal_external("Build a chatbot")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_handles_request_error(self):
|
||||||
|
"""Test decomposition handles request errors gracefully."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.side_effect = httpx.RequestError("Connection failed")
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.decompose_goal_external("Build a chatbot")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_handles_service_error(self):
|
||||||
|
"""Test decomposition handles service returning error."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": False,
|
||||||
|
"error": "Internal error",
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.decompose_goal_external("Build a chatbot")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateAgentExternal:
|
||||||
|
"""Test generate_agent_external function."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset client singleton before each test."""
|
||||||
|
service._settings = None
|
||||||
|
service._client = None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_agent_success(self):
|
||||||
|
"""Test successful agent generation."""
|
||||||
|
agent_json = {
|
||||||
|
"name": "Test Agent",
|
||||||
|
"nodes": [],
|
||||||
|
"links": [],
|
||||||
|
}
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"agent_json": agent_json,
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.generate_agent_external(instructions)
|
||||||
|
|
||||||
|
assert result == agent_json
|
||||||
|
mock_client.post.assert_called_once_with(
|
||||||
|
"/api/generate-agent", json={"instructions": instructions}
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_agent_handles_error(self):
|
||||||
|
"""Test agent generation handles errors gracefully."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.side_effect = httpx.RequestError("Connection failed")
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.generate_agent_external({"steps": []})
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateAgentPatchExternal:
|
||||||
|
"""Test generate_agent_patch_external function."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset client singleton before each test."""
|
||||||
|
service._settings = None
|
||||||
|
service._client = None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_patch_returns_updated_agent(self):
|
||||||
|
"""Test successful patch generation returning updated agent."""
|
||||||
|
updated_agent = {
|
||||||
|
"name": "Updated Agent",
|
||||||
|
"nodes": [{"id": "1", "block_id": "test"}],
|
||||||
|
"links": [],
|
||||||
|
}
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"agent_json": updated_agent,
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
current_agent = {"name": "Old Agent", "nodes": [], "links": []}
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.generate_agent_patch_external(
|
||||||
|
"Add a new node", current_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == updated_agent
|
||||||
|
mock_client.post.assert_called_once_with(
|
||||||
|
"/api/update-agent",
|
||||||
|
json={
|
||||||
|
"update_request": "Add a new node",
|
||||||
|
"current_agent_json": current_agent,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_patch_returns_clarifying_questions(self):
|
||||||
|
"""Test patch generation returning clarifying questions."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": ["What type of node?"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.generate_agent_patch_external(
|
||||||
|
"Add something", {"nodes": []}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": ["What type of node?"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthCheck:
|
||||||
|
"""Test health_check function."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset singletons before each test."""
|
||||||
|
service._settings = None
|
||||||
|
service._client = None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_returns_false_when_not_configured(self):
|
||||||
|
"""Test health check returns False when service not configured."""
|
||||||
|
with patch.object(
|
||||||
|
service, "is_external_service_configured", return_value=False
|
||||||
|
):
|
||||||
|
result = await service.health_check()
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_returns_true_when_healthy(self):
|
||||||
|
"""Test health check returns True when service is healthy."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"status": "healthy",
|
||||||
|
"blocks_loaded": True,
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.health_check()
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_client.get.assert_called_once_with("/health")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_returns_false_when_not_healthy(self):
|
||||||
|
"""Test health check returns False when service is not healthy."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"blocks_loaded": False,
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.health_check()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_returns_false_on_error(self):
|
||||||
|
"""Test health check returns False on connection error."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get.side_effect = httpx.RequestError("Connection failed")
|
||||||
|
|
||||||
|
with patch.object(service, "is_external_service_configured", return_value=True):
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.health_check()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetBlocksExternal:
|
||||||
|
"""Test get_blocks_external function."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset client singleton before each test."""
|
||||||
|
service._settings = None
|
||||||
|
service._client = None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_blocks_success(self):
|
||||||
|
"""Test successful blocks retrieval."""
|
||||||
|
blocks = [
|
||||||
|
{"id": "block1", "name": "Block 1"},
|
||||||
|
{"id": "block2", "name": "Block 2"},
|
||||||
|
]
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"blocks": blocks,
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.get_blocks_external()
|
||||||
|
|
||||||
|
assert result == blocks
|
||||||
|
mock_client.get.assert_called_once_with("/api/blocks")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_blocks_handles_error(self):
|
||||||
|
"""Test blocks retrieval handles errors gracefully."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get.side_effect = httpx.RequestError("Connection failed")
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
result = await service.get_blocks_external()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
@@ -175,8 +175,6 @@ While server components and actions are cool and cutting-edge, they introduce a
|
|||||||
|
|
||||||
- Prefer [React Query](https://tanstack.com/query/latest/docs/framework/react/overview) for server state, colocated near consumers (see [state colocation](https://kentcdodds.com/blog/state-colocation-will-make-your-react-app-faster))
|
- Prefer [React Query](https://tanstack.com/query/latest/docs/framework/react/overview) for server state, colocated near consumers (see [state colocation](https://kentcdodds.com/blog/state-colocation-will-make-your-react-app-faster))
|
||||||
- Co-locate UI state inside components/hooks; keep global state minimal
|
- Co-locate UI state inside components/hooks; keep global state minimal
|
||||||
- Avoid `useMemo` and `useCallback` unless you have a measured performance issue
|
|
||||||
- Do not abuse `useEffect`; prefer state colocation and derive values directly when possible
|
|
||||||
|
|
||||||
### Styling and components
|
### Styling and components
|
||||||
|
|
||||||
@@ -551,48 +549,9 @@ Files:
|
|||||||
Types:
|
Types:
|
||||||
|
|
||||||
- Prefer `interface` for object shapes
|
- Prefer `interface` for object shapes
|
||||||
- Component props should be `interface Props { ... }` (not exported)
|
- Component props should be `interface Props { ... }`
|
||||||
- Only use specific exported names (e.g., `export interface MyComponentProps`) when the interface needs to be used outside the component
|
|
||||||
- Keep type definitions inline with the component - do not create separate `types.ts` files unless types are shared across multiple files
|
|
||||||
- Use precise types; avoid `any` and unsafe casts
|
- Use precise types; avoid `any` and unsafe casts
|
||||||
|
|
||||||
**Props naming examples:**
|
|
||||||
|
|
||||||
```tsx
|
|
||||||
// ✅ Good - internal props, not exported
|
|
||||||
interface Props {
|
|
||||||
title: string;
|
|
||||||
onClose: () => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function Modal({ title, onClose }: Props) {
|
|
||||||
// ...
|
|
||||||
}
|
|
||||||
|
|
||||||
// ✅ Good - exported when needed externally
|
|
||||||
export interface ModalProps {
|
|
||||||
title: string;
|
|
||||||
onClose: () => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function Modal({ title, onClose }: ModalProps) {
|
|
||||||
// ...
|
|
||||||
}
|
|
||||||
|
|
||||||
// ❌ Bad - unnecessarily specific name for internal use
|
|
||||||
interface ModalComponentProps {
|
|
||||||
title: string;
|
|
||||||
onClose: () => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ❌ Bad - separate types.ts file for single component
|
|
||||||
// types.ts
|
|
||||||
export interface ModalProps { ... }
|
|
||||||
|
|
||||||
// Modal.tsx
|
|
||||||
import type { ModalProps } from './types';
|
|
||||||
```
|
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
|
|
||||||
- If more than one parameter is needed, pass a single `Args` object for clarity
|
- If more than one parameter is needed, pass a single `Args` object for clarity
|
||||||
|
|||||||
@@ -16,12 +16,6 @@ export default defineConfig({
|
|||||||
client: "react-query",
|
client: "react-query",
|
||||||
httpClient: "fetch",
|
httpClient: "fetch",
|
||||||
indexFiles: false,
|
indexFiles: false,
|
||||||
mock: {
|
|
||||||
type: "msw",
|
|
||||||
baseUrl: "http://localhost:3000/api/proxy",
|
|
||||||
generateEachHttpStatus: true,
|
|
||||||
delay: 0,
|
|
||||||
},
|
|
||||||
override: {
|
override: {
|
||||||
mutator: {
|
mutator: {
|
||||||
path: "./mutators/custom-mutator.ts",
|
path: "./mutators/custom-mutator.ts",
|
||||||
|
|||||||
@@ -15,8 +15,6 @@
|
|||||||
"types": "tsc --noEmit",
|
"types": "tsc --noEmit",
|
||||||
"test": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test",
|
"test": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test",
|
||||||
"test-ui": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test --ui",
|
"test-ui": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test --ui",
|
||||||
"test:unit": "vitest run",
|
|
||||||
"test:unit:watch": "vitest",
|
|
||||||
"test:no-build": "playwright test",
|
"test:no-build": "playwright test",
|
||||||
"gentests": "playwright codegen http://localhost:3000",
|
"gentests": "playwright codegen http://localhost:3000",
|
||||||
"storybook": "storybook dev -p 6006",
|
"storybook": "storybook dev -p 6006",
|
||||||
@@ -120,7 +118,6 @@
|
|||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@chromatic-com/storybook": "4.1.2",
|
"@chromatic-com/storybook": "4.1.2",
|
||||||
"happy-dom": "20.3.4",
|
|
||||||
"@opentelemetry/instrumentation": "0.209.0",
|
"@opentelemetry/instrumentation": "0.209.0",
|
||||||
"@playwright/test": "1.56.1",
|
"@playwright/test": "1.56.1",
|
||||||
"@storybook/addon-a11y": "9.1.5",
|
"@storybook/addon-a11y": "9.1.5",
|
||||||
@@ -130,8 +127,6 @@
|
|||||||
"@storybook/nextjs": "9.1.5",
|
"@storybook/nextjs": "9.1.5",
|
||||||
"@tanstack/eslint-plugin-query": "5.91.2",
|
"@tanstack/eslint-plugin-query": "5.91.2",
|
||||||
"@tanstack/react-query-devtools": "5.90.2",
|
"@tanstack/react-query-devtools": "5.90.2",
|
||||||
"@testing-library/dom": "10.4.1",
|
|
||||||
"@testing-library/react": "16.3.2",
|
|
||||||
"@types/canvas-confetti": "1.9.0",
|
"@types/canvas-confetti": "1.9.0",
|
||||||
"@types/lodash": "4.17.20",
|
"@types/lodash": "4.17.20",
|
||||||
"@types/negotiator": "0.6.4",
|
"@types/negotiator": "0.6.4",
|
||||||
@@ -140,7 +135,6 @@
|
|||||||
"@types/react-dom": "18.3.5",
|
"@types/react-dom": "18.3.5",
|
||||||
"@types/react-modal": "3.16.3",
|
"@types/react-modal": "3.16.3",
|
||||||
"@types/react-window": "1.8.8",
|
"@types/react-window": "1.8.8",
|
||||||
"@vitejs/plugin-react": "5.1.2",
|
|
||||||
"axe-playwright": "2.2.2",
|
"axe-playwright": "2.2.2",
|
||||||
"chromatic": "13.3.3",
|
"chromatic": "13.3.3",
|
||||||
"concurrently": "9.2.1",
|
"concurrently": "9.2.1",
|
||||||
@@ -159,9 +153,7 @@
|
|||||||
"require-in-the-middle": "8.0.1",
|
"require-in-the-middle": "8.0.1",
|
||||||
"storybook": "9.1.5",
|
"storybook": "9.1.5",
|
||||||
"tailwindcss": "3.4.17",
|
"tailwindcss": "3.4.17",
|
||||||
"typescript": "5.9.3",
|
"typescript": "5.9.3"
|
||||||
"vite-tsconfig-paths": "6.0.4",
|
|
||||||
"vitest": "4.0.17"
|
|
||||||
},
|
},
|
||||||
"msw": {
|
"msw": {
|
||||||
"workerDirectory": [
|
"workerDirectory": [
|
||||||
|
|||||||
1118
autogpt_platform/frontend/pnpm-lock.yaml
generated
1118
autogpt_platform/frontend/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -1,58 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
|
||||||
import { useRouter } from "next/navigation";
|
|
||||||
import { useEffect, useRef } from "react";
|
|
||||||
|
|
||||||
const LOGOUT_REDIRECT_DELAY_MS = 400;
|
|
||||||
|
|
||||||
function wait(ms: number): Promise<void> {
|
|
||||||
return new Promise(function resolveAfterDelay(resolve) {
|
|
||||||
setTimeout(resolve, ms);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
export default function LogoutPage() {
|
|
||||||
const { logOut } = useSupabase();
|
|
||||||
const { toast } = useToast();
|
|
||||||
const router = useRouter();
|
|
||||||
const hasStartedRef = useRef(false);
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function handleLogoutEffect() {
|
|
||||||
if (hasStartedRef.current) return;
|
|
||||||
hasStartedRef.current = true;
|
|
||||||
|
|
||||||
async function runLogout() {
|
|
||||||
try {
|
|
||||||
await logOut();
|
|
||||||
} catch {
|
|
||||||
toast({
|
|
||||||
title: "Failed to log out. Redirecting to login.",
|
|
||||||
variant: "destructive",
|
|
||||||
});
|
|
||||||
} finally {
|
|
||||||
await wait(LOGOUT_REDIRECT_DELAY_MS);
|
|
||||||
router.replace("/login");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void runLogout();
|
|
||||||
},
|
|
||||||
[logOut, router, toast],
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="flex min-h-screen items-center justify-center px-4">
|
|
||||||
<div className="flex flex-col items-center justify-center gap-4 py-8">
|
|
||||||
<LoadingSpinner size="large" />
|
|
||||||
<Text variant="body" className="text-center">
|
|
||||||
Logging you out...
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -9,7 +9,7 @@ export async function GET(request: Request) {
|
|||||||
const { searchParams, origin } = new URL(request.url);
|
const { searchParams, origin } = new URL(request.url);
|
||||||
const code = searchParams.get("code");
|
const code = searchParams.get("code");
|
||||||
|
|
||||||
let next = "/";
|
let next = "/marketplace";
|
||||||
|
|
||||||
if (code) {
|
if (code) {
|
||||||
const supabase = await getServerSupabase();
|
const supabase = await getServerSupabase();
|
||||||
|
|||||||
@@ -0,0 +1,134 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { List } from "@phosphor-icons/react";
|
||||||
|
import React, { useState } from "react";
|
||||||
|
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||||
|
import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState";
|
||||||
|
import { ChatLoadingState } from "./components/ChatLoadingState/ChatLoadingState";
|
||||||
|
import { SessionsDrawer } from "./components/SessionsDrawer/SessionsDrawer";
|
||||||
|
import { useChat } from "./useChat";
|
||||||
|
|
||||||
|
export interface ChatProps {
|
||||||
|
className?: string;
|
||||||
|
headerTitle?: React.ReactNode;
|
||||||
|
showHeader?: boolean;
|
||||||
|
showSessionInfo?: boolean;
|
||||||
|
showNewChatButton?: boolean;
|
||||||
|
onNewChat?: () => void;
|
||||||
|
headerActions?: React.ReactNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function Chat({
|
||||||
|
className,
|
||||||
|
headerTitle = "AutoGPT Copilot",
|
||||||
|
showHeader = true,
|
||||||
|
showSessionInfo = true,
|
||||||
|
showNewChatButton = true,
|
||||||
|
onNewChat,
|
||||||
|
headerActions,
|
||||||
|
}: ChatProps) {
|
||||||
|
const {
|
||||||
|
messages,
|
||||||
|
isLoading,
|
||||||
|
isCreating,
|
||||||
|
error,
|
||||||
|
sessionId,
|
||||||
|
createSession,
|
||||||
|
clearSession,
|
||||||
|
loadSession,
|
||||||
|
} = useChat();
|
||||||
|
|
||||||
|
const [isSessionsDrawerOpen, setIsSessionsDrawerOpen] = useState(false);
|
||||||
|
|
||||||
|
const handleNewChat = () => {
|
||||||
|
clearSession();
|
||||||
|
onNewChat?.();
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleSelectSession = async (sessionId: string) => {
|
||||||
|
try {
|
||||||
|
await loadSession(sessionId);
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Failed to load session:", err);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={cn("flex h-full flex-col", className)}>
|
||||||
|
{/* Header */}
|
||||||
|
{showHeader && (
|
||||||
|
<header className="shrink-0 border-t border-zinc-200 bg-white p-3">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<div className="flex items-center gap-3">
|
||||||
|
<button
|
||||||
|
aria-label="View sessions"
|
||||||
|
onClick={() => setIsSessionsDrawerOpen(true)}
|
||||||
|
className="flex size-8 items-center justify-center rounded hover:bg-zinc-100"
|
||||||
|
>
|
||||||
|
<List width="1.25rem" height="1.25rem" />
|
||||||
|
</button>
|
||||||
|
{typeof headerTitle === "string" ? (
|
||||||
|
<Text variant="h2" className="text-lg font-semibold">
|
||||||
|
{headerTitle}
|
||||||
|
</Text>
|
||||||
|
) : (
|
||||||
|
headerTitle
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-3">
|
||||||
|
{showSessionInfo && sessionId && (
|
||||||
|
<>
|
||||||
|
{showNewChatButton && (
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
size="small"
|
||||||
|
onClick={handleNewChat}
|
||||||
|
>
|
||||||
|
New Chat
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{headerActions}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</header>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Main Content */}
|
||||||
|
<main className="flex min-h-0 flex-1 flex-col overflow-hidden">
|
||||||
|
{/* Loading State - show when explicitly loading/creating OR when we don't have a session yet and no error */}
|
||||||
|
{(isLoading || isCreating || (!sessionId && !error)) && (
|
||||||
|
<ChatLoadingState
|
||||||
|
message={isCreating ? "Creating session..." : "Loading..."}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Error State */}
|
||||||
|
{error && !isLoading && (
|
||||||
|
<ChatErrorState error={error} onRetry={createSession} />
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Session Content */}
|
||||||
|
{sessionId && !isLoading && !error && (
|
||||||
|
<ChatContainer
|
||||||
|
sessionId={sessionId}
|
||||||
|
initialMessages={messages}
|
||||||
|
className="flex-1"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</main>
|
||||||
|
|
||||||
|
{/* Sessions Drawer */}
|
||||||
|
<SessionsDrawer
|
||||||
|
isOpen={isSessionsDrawerOpen}
|
||||||
|
onClose={() => setIsSessionsDrawerOpen(false)}
|
||||||
|
onSelectSession={handleSelectSession}
|
||||||
|
currentSessionId={sessionId}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -21,7 +21,7 @@ export function AuthPromptWidget({
|
|||||||
message,
|
message,
|
||||||
sessionId,
|
sessionId,
|
||||||
agentInfo,
|
agentInfo,
|
||||||
returnUrl = "/copilot/chat",
|
returnUrl = "/chat",
|
||||||
className,
|
className,
|
||||||
}: AuthPromptWidgetProps) {
|
}: AuthPromptWidgetProps) {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { useCallback } from "react";
|
||||||
|
import { usePageContext } from "../../usePageContext";
|
||||||
|
import { ChatInput } from "../ChatInput/ChatInput";
|
||||||
|
import { MessageList } from "../MessageList/MessageList";
|
||||||
|
import { QuickActionsWelcome } from "../QuickActionsWelcome/QuickActionsWelcome";
|
||||||
|
import { useChatContainer } from "./useChatContainer";
|
||||||
|
|
||||||
|
export interface ChatContainerProps {
|
||||||
|
sessionId: string | null;
|
||||||
|
initialMessages: SessionDetailResponse["messages"];
|
||||||
|
className?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ChatContainer({
|
||||||
|
sessionId,
|
||||||
|
initialMessages,
|
||||||
|
className,
|
||||||
|
}: ChatContainerProps) {
|
||||||
|
const { messages, streamingChunks, isStreaming, sendMessage } =
|
||||||
|
useChatContainer({
|
||||||
|
sessionId,
|
||||||
|
initialMessages,
|
||||||
|
});
|
||||||
|
const { capturePageContext } = usePageContext();
|
||||||
|
|
||||||
|
// Wrap sendMessage to automatically capture page context
|
||||||
|
const sendMessageWithContext = useCallback(
|
||||||
|
async (content: string, isUserMessage: boolean = true) => {
|
||||||
|
const context = capturePageContext();
|
||||||
|
await sendMessage(content, isUserMessage, context);
|
||||||
|
},
|
||||||
|
[sendMessage, capturePageContext],
|
||||||
|
);
|
||||||
|
|
||||||
|
const quickActions = [
|
||||||
|
"Find agents for social media management",
|
||||||
|
"Show me agents for content creation",
|
||||||
|
"Help me automate my business",
|
||||||
|
"What can you help me with?",
|
||||||
|
];
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn("flex h-full min-h-0 flex-col", className)}
|
||||||
|
style={{
|
||||||
|
backgroundColor: "#ffffff",
|
||||||
|
backgroundImage:
|
||||||
|
"radial-gradient(#e5e5e5 0.5px, transparent 0.5px), radial-gradient(#e5e5e5 0.5px, #ffffff 0.5px)",
|
||||||
|
backgroundSize: "20px 20px",
|
||||||
|
backgroundPosition: "0 0, 10px 10px",
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{/* Messages or Welcome Screen */}
|
||||||
|
<div className="flex min-h-0 flex-1 flex-col overflow-hidden pb-24">
|
||||||
|
{messages.length === 0 ? (
|
||||||
|
<QuickActionsWelcome
|
||||||
|
title="Welcome to AutoGPT Copilot"
|
||||||
|
description="Start a conversation to discover and run AI agents."
|
||||||
|
actions={quickActions}
|
||||||
|
onActionClick={sendMessageWithContext}
|
||||||
|
disabled={isStreaming || !sessionId}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<MessageList
|
||||||
|
messages={messages}
|
||||||
|
streamingChunks={streamingChunks}
|
||||||
|
isStreaming={isStreaming}
|
||||||
|
onSendMessage={sendMessageWithContext}
|
||||||
|
className="flex-1"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Input - Always visible */}
|
||||||
|
<div className="fixed bottom-0 left-0 right-0 z-50 border-t border-zinc-200 bg-white p-4">
|
||||||
|
<ChatInput
|
||||||
|
onSend={sendMessageWithContext}
|
||||||
|
disabled={isStreaming || !sessionId}
|
||||||
|
placeholder={
|
||||||
|
sessionId ? "Type your message..." : "Creating session..."
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { StreamChunk } from "../../useChatStream";
|
import { StreamChunk } from "../../useChatStream";
|
||||||
import type { HandlerDependencies } from "./handlers";
|
import type { HandlerDependencies } from "./useChatContainer.handlers";
|
||||||
import {
|
import {
|
||||||
handleError,
|
handleError,
|
||||||
handleLoginNeeded,
|
handleLoginNeeded,
|
||||||
@@ -9,30 +9,12 @@ import {
|
|||||||
handleTextEnded,
|
handleTextEnded,
|
||||||
handleToolCallStart,
|
handleToolCallStart,
|
||||||
handleToolResponse,
|
handleToolResponse,
|
||||||
isRegionBlockedError,
|
} from "./useChatContainer.handlers";
|
||||||
} from "./handlers";
|
|
||||||
|
|
||||||
export function createStreamEventDispatcher(
|
export function createStreamEventDispatcher(
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
): (chunk: StreamChunk) => void {
|
): (chunk: StreamChunk) => void {
|
||||||
return function dispatchStreamEvent(chunk: StreamChunk): void {
|
return function dispatchStreamEvent(chunk: StreamChunk): void {
|
||||||
if (
|
|
||||||
chunk.type === "text_chunk" ||
|
|
||||||
chunk.type === "tool_call_start" ||
|
|
||||||
chunk.type === "tool_response" ||
|
|
||||||
chunk.type === "login_needed" ||
|
|
||||||
chunk.type === "need_login" ||
|
|
||||||
chunk.type === "error"
|
|
||||||
) {
|
|
||||||
if (!deps.hasResponseRef.current) {
|
|
||||||
console.info("[ChatStream] First response chunk:", {
|
|
||||||
type: chunk.type,
|
|
||||||
sessionId: deps.sessionId,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
deps.hasResponseRef.current = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (chunk.type) {
|
switch (chunk.type) {
|
||||||
case "text_chunk":
|
case "text_chunk":
|
||||||
handleTextChunk(chunk, deps);
|
handleTextChunk(chunk, deps);
|
||||||
@@ -56,23 +38,15 @@ export function createStreamEventDispatcher(
|
|||||||
break;
|
break;
|
||||||
|
|
||||||
case "stream_end":
|
case "stream_end":
|
||||||
console.info("[ChatStream] Stream ended:", {
|
|
||||||
sessionId: deps.sessionId,
|
|
||||||
hasResponse: deps.hasResponseRef.current,
|
|
||||||
chunkCount: deps.streamingChunksRef.current.length,
|
|
||||||
});
|
|
||||||
handleStreamEnd(chunk, deps);
|
handleStreamEnd(chunk, deps);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "error":
|
case "error":
|
||||||
const isRegionBlocked = isRegionBlockedError(chunk);
|
|
||||||
handleError(chunk, deps);
|
handleError(chunk, deps);
|
||||||
// Show toast at dispatcher level to avoid circular dependencies
|
// Show toast at dispatcher level to avoid circular dependencies
|
||||||
if (!isRegionBlocked) {
|
toast.error("Chat Error", {
|
||||||
toast.error("Chat Error", {
|
description: chunk.message || chunk.content || "An error occurred",
|
||||||
description: chunk.message || chunk.content || "An error occurred",
|
});
|
||||||
});
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "usage":
|
case "usage":
|
||||||
@@ -1,33 +1,6 @@
|
|||||||
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
|
||||||
import type { ToolResult } from "@/types/chat";
|
import type { ToolResult } from "@/types/chat";
|
||||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||||
|
|
||||||
export function hasSentInitialPrompt(sessionId: string): boolean {
|
|
||||||
try {
|
|
||||||
const sent = JSON.parse(
|
|
||||||
sessionStorage.get(SessionKey.CHAT_SENT_INITIAL_PROMPTS) || "{}",
|
|
||||||
);
|
|
||||||
return sent[sessionId] === true;
|
|
||||||
} catch {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export function markInitialPromptSent(sessionId: string): void {
|
|
||||||
try {
|
|
||||||
const sent = JSON.parse(
|
|
||||||
sessionStorage.get(SessionKey.CHAT_SENT_INITIAL_PROMPTS) || "{}",
|
|
||||||
);
|
|
||||||
sent[sessionId] = true;
|
|
||||||
sessionStorage.set(
|
|
||||||
SessionKey.CHAT_SENT_INITIAL_PROMPTS,
|
|
||||||
JSON.stringify(sent),
|
|
||||||
);
|
|
||||||
} catch {
|
|
||||||
// Ignore storage errors
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export function removePageContext(content: string): string {
|
export function removePageContext(content: string): string {
|
||||||
// Remove "Page URL: ..." pattern at start of line (case insensitive, handles various formats)
|
// Remove "Page URL: ..." pattern at start of line (case insensitive, handles various formats)
|
||||||
let cleaned = content.replace(/^\s*Page URL:\s*[^\n\r]*/gim, "");
|
let cleaned = content.replace(/^\s*Page URL:\s*[^\n\r]*/gim, "");
|
||||||
@@ -234,22 +207,12 @@ export function parseToolResponse(
|
|||||||
if (responseType === "setup_requirements") {
|
if (responseType === "setup_requirements") {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
if (responseType === "understanding_updated") {
|
|
||||||
return {
|
|
||||||
type: "tool_response",
|
|
||||||
toolId,
|
|
||||||
toolName,
|
|
||||||
result: (parsedResult || result) as ToolResult,
|
|
||||||
success: true,
|
|
||||||
timestamp: timestamp || new Date(),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
type: "tool_response",
|
type: "tool_response",
|
||||||
toolId,
|
toolId,
|
||||||
toolName,
|
toolName,
|
||||||
result: parsedResult ? (parsedResult as ToolResult) : result,
|
result,
|
||||||
success: true,
|
success: true,
|
||||||
timestamp: timestamp || new Date(),
|
timestamp: timestamp || new Date(),
|
||||||
};
|
};
|
||||||
@@ -7,30 +7,15 @@ import {
|
|||||||
parseToolResponse,
|
parseToolResponse,
|
||||||
} from "./helpers";
|
} from "./helpers";
|
||||||
|
|
||||||
function isToolCallMessage(
|
|
||||||
message: ChatMessageData,
|
|
||||||
): message is Extract<ChatMessageData, { type: "tool_call" }> {
|
|
||||||
return message.type === "tool_call";
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface HandlerDependencies {
|
export interface HandlerDependencies {
|
||||||
setHasTextChunks: Dispatch<SetStateAction<boolean>>;
|
setHasTextChunks: Dispatch<SetStateAction<boolean>>;
|
||||||
setStreamingChunks: Dispatch<SetStateAction<string[]>>;
|
setStreamingChunks: Dispatch<SetStateAction<string[]>>;
|
||||||
streamingChunksRef: MutableRefObject<string[]>;
|
streamingChunksRef: MutableRefObject<string[]>;
|
||||||
hasResponseRef: MutableRefObject<boolean>;
|
|
||||||
setMessages: Dispatch<SetStateAction<ChatMessageData[]>>;
|
setMessages: Dispatch<SetStateAction<ChatMessageData[]>>;
|
||||||
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
||||||
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
|
||||||
if (chunk.code === "MODEL_NOT_AVAILABLE_REGION") return true;
|
|
||||||
const message = chunk.message || chunk.content;
|
|
||||||
if (typeof message !== "string") return false;
|
|
||||||
return message.toLowerCase().includes("not available in your region");
|
|
||||||
}
|
|
||||||
|
|
||||||
export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) {
|
export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||||
if (!chunk.content) return;
|
if (!chunk.content) return;
|
||||||
deps.setHasTextChunks(true);
|
deps.setHasTextChunks(true);
|
||||||
@@ -45,17 +30,16 @@ export function handleTextEnded(
|
|||||||
_chunk: StreamChunk,
|
_chunk: StreamChunk,
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
) {
|
) {
|
||||||
|
console.log("[Text Ended] Saving streamed text as assistant message");
|
||||||
const completedText = deps.streamingChunksRef.current.join("");
|
const completedText = deps.streamingChunksRef.current.join("");
|
||||||
if (completedText.trim()) {
|
if (completedText.trim()) {
|
||||||
deps.setMessages((prev) => {
|
const assistantMessage: ChatMessageData = {
|
||||||
const assistantMessage: ChatMessageData = {
|
type: "message",
|
||||||
type: "message",
|
role: "assistant",
|
||||||
role: "assistant",
|
content: completedText,
|
||||||
content: completedText,
|
timestamp: new Date(),
|
||||||
timestamp: new Date(),
|
};
|
||||||
};
|
deps.setMessages((prev) => [...prev, assistantMessage]);
|
||||||
return [...prev, assistantMessage];
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
deps.setStreamingChunks([]);
|
deps.setStreamingChunks([]);
|
||||||
deps.streamingChunksRef.current = [];
|
deps.streamingChunksRef.current = [];
|
||||||
@@ -66,45 +50,30 @@ export function handleToolCallStart(
|
|||||||
chunk: StreamChunk,
|
chunk: StreamChunk,
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
) {
|
) {
|
||||||
const toolCallMessage: Extract<ChatMessageData, { type: "tool_call" }> = {
|
const toolCallMessage: ChatMessageData = {
|
||||||
type: "tool_call",
|
type: "tool_call",
|
||||||
toolId: chunk.tool_id || `tool-${Date.now()}-${chunk.idx || 0}`,
|
toolId: chunk.tool_id || `tool-${Date.now()}-${chunk.idx || 0}`,
|
||||||
toolName: chunk.tool_name || "Executing",
|
toolName: chunk.tool_name || "Executing...",
|
||||||
arguments: chunk.arguments || {},
|
arguments: chunk.arguments || {},
|
||||||
timestamp: new Date(),
|
timestamp: new Date(),
|
||||||
};
|
};
|
||||||
|
deps.setMessages((prev) => [...prev, toolCallMessage]);
|
||||||
function updateToolCallMessages(prev: ChatMessageData[]) {
|
console.log("[Tool Call Start]", {
|
||||||
const existingIndex = prev.findIndex(function findToolCallIndex(msg) {
|
toolId: toolCallMessage.toolId,
|
||||||
return isToolCallMessage(msg) && msg.toolId === toolCallMessage.toolId;
|
toolName: toolCallMessage.toolName,
|
||||||
});
|
timestamp: new Date().toISOString(),
|
||||||
if (existingIndex === -1) {
|
});
|
||||||
return [...prev, toolCallMessage];
|
|
||||||
}
|
|
||||||
const nextMessages = [...prev];
|
|
||||||
const existing = nextMessages[existingIndex];
|
|
||||||
if (!isToolCallMessage(existing)) return prev;
|
|
||||||
const nextArguments =
|
|
||||||
toolCallMessage.arguments &&
|
|
||||||
Object.keys(toolCallMessage.arguments).length > 0
|
|
||||||
? toolCallMessage.arguments
|
|
||||||
: existing.arguments;
|
|
||||||
nextMessages[existingIndex] = {
|
|
||||||
...existing,
|
|
||||||
toolName: toolCallMessage.toolName || existing.toolName,
|
|
||||||
arguments: nextArguments,
|
|
||||||
timestamp: toolCallMessage.timestamp,
|
|
||||||
};
|
|
||||||
return nextMessages;
|
|
||||||
}
|
|
||||||
|
|
||||||
deps.setMessages(updateToolCallMessages);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function handleToolResponse(
|
export function handleToolResponse(
|
||||||
chunk: StreamChunk,
|
chunk: StreamChunk,
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
) {
|
) {
|
||||||
|
console.log("[Tool Response] Received:", {
|
||||||
|
toolId: chunk.tool_id,
|
||||||
|
toolName: chunk.tool_name,
|
||||||
|
timestamp: new Date().toISOString(),
|
||||||
|
});
|
||||||
let toolName = chunk.tool_name || "unknown";
|
let toolName = chunk.tool_name || "unknown";
|
||||||
if (!chunk.tool_name || chunk.tool_name === "unknown") {
|
if (!chunk.tool_name || chunk.tool_name === "unknown") {
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => {
|
||||||
@@ -158,15 +127,22 @@ export function handleToolResponse(
|
|||||||
const toolCallIndex = prev.findIndex(
|
const toolCallIndex = prev.findIndex(
|
||||||
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
||||||
);
|
);
|
||||||
const hasResponse = prev.some(
|
|
||||||
(msg) => msg.type === "tool_response" && msg.toolId === chunk.tool_id,
|
|
||||||
);
|
|
||||||
if (hasResponse) return prev;
|
|
||||||
if (toolCallIndex !== -1) {
|
if (toolCallIndex !== -1) {
|
||||||
const newMessages = [...prev];
|
const newMessages = [...prev];
|
||||||
newMessages.splice(toolCallIndex + 1, 0, responseMessage);
|
newMessages[toolCallIndex] = responseMessage;
|
||||||
|
console.log(
|
||||||
|
"[Tool Response] Replaced tool_call with matching tool_id:",
|
||||||
|
chunk.tool_id,
|
||||||
|
"at index:",
|
||||||
|
toolCallIndex,
|
||||||
|
);
|
||||||
return newMessages;
|
return newMessages;
|
||||||
}
|
}
|
||||||
|
console.warn(
|
||||||
|
"[Tool Response] No tool_call found with tool_id:",
|
||||||
|
chunk.tool_id,
|
||||||
|
"appending instead",
|
||||||
|
);
|
||||||
return [...prev, responseMessage];
|
return [...prev, responseMessage];
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -191,38 +167,55 @@ export function handleStreamEnd(
|
|||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
) {
|
) {
|
||||||
const completedContent = deps.streamingChunksRef.current.join("");
|
const completedContent = deps.streamingChunksRef.current.join("");
|
||||||
if (!completedContent.trim() && !deps.hasResponseRef.current) {
|
// Only save message if there are uncommitted chunks
|
||||||
deps.setMessages((prev) => [
|
// (text_ended already saved if there were tool calls)
|
||||||
...prev,
|
|
||||||
{
|
|
||||||
type: "message",
|
|
||||||
role: "assistant",
|
|
||||||
content: "No response received. Please try again.",
|
|
||||||
timestamp: new Date(),
|
|
||||||
},
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
if (completedContent.trim()) {
|
if (completedContent.trim()) {
|
||||||
|
console.log(
|
||||||
|
"[Stream End] Saving remaining streamed text as assistant message",
|
||||||
|
);
|
||||||
const assistantMessage: ChatMessageData = {
|
const assistantMessage: ChatMessageData = {
|
||||||
type: "message",
|
type: "message",
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: completedContent,
|
content: completedContent,
|
||||||
timestamp: new Date(),
|
timestamp: new Date(),
|
||||||
};
|
};
|
||||||
deps.setMessages((prev) => [...prev, assistantMessage]);
|
deps.setMessages((prev) => {
|
||||||
|
const updated = [...prev, assistantMessage];
|
||||||
|
console.log("[Stream End] Final state:", {
|
||||||
|
localMessages: updated.map((m) => ({
|
||||||
|
type: m.type,
|
||||||
|
...(m.type === "message" && {
|
||||||
|
role: m.role,
|
||||||
|
contentLength: m.content.length,
|
||||||
|
}),
|
||||||
|
...(m.type === "tool_call" && {
|
||||||
|
toolId: m.toolId,
|
||||||
|
toolName: m.toolName,
|
||||||
|
}),
|
||||||
|
...(m.type === "tool_response" && {
|
||||||
|
toolId: m.toolId,
|
||||||
|
toolName: m.toolName,
|
||||||
|
success: m.success,
|
||||||
|
}),
|
||||||
|
})),
|
||||||
|
streamingChunks: deps.streamingChunksRef.current,
|
||||||
|
timestamp: new Date().toISOString(),
|
||||||
|
});
|
||||||
|
return updated;
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
console.log("[Stream End] No uncommitted chunks, message already saved");
|
||||||
}
|
}
|
||||||
deps.setStreamingChunks([]);
|
deps.setStreamingChunks([]);
|
||||||
deps.streamingChunksRef.current = [];
|
deps.streamingChunksRef.current = [];
|
||||||
deps.setHasTextChunks(false);
|
deps.setHasTextChunks(false);
|
||||||
deps.setIsStreamingInitiated(false);
|
deps.setIsStreamingInitiated(false);
|
||||||
|
console.log("[Stream End] Stream complete, messages in local state");
|
||||||
}
|
}
|
||||||
|
|
||||||
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||||
const errorMessage = chunk.message || chunk.content || "An error occurred";
|
const errorMessage = chunk.message || chunk.content || "An error occurred";
|
||||||
console.error("Stream error:", errorMessage);
|
console.error("Stream error:", errorMessage);
|
||||||
if (isRegionBlockedError(chunk)) {
|
|
||||||
deps.setIsRegionBlockedModalOpen(true);
|
|
||||||
}
|
|
||||||
deps.setIsStreamingInitiated(false);
|
deps.setIsStreamingInitiated(false);
|
||||||
deps.setHasTextChunks(false);
|
deps.setHasTextChunks(false);
|
||||||
deps.setStreamingChunks([]);
|
deps.setStreamingChunks([]);
|
||||||
@@ -1,17 +1,14 @@
|
|||||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
import { useCallback, useMemo, useRef, useState } from "react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { useChatStream } from "../../useChatStream";
|
import { useChatStream } from "../../useChatStream";
|
||||||
import { usePageContext } from "../../usePageContext";
|
|
||||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||||
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
||||||
import {
|
import {
|
||||||
createUserMessage,
|
createUserMessage,
|
||||||
filterAuthMessages,
|
filterAuthMessages,
|
||||||
hasSentInitialPrompt,
|
|
||||||
isToolCallArray,
|
isToolCallArray,
|
||||||
isValidMessage,
|
isValidMessage,
|
||||||
markInitialPromptSent,
|
|
||||||
parseToolResponse,
|
parseToolResponse,
|
||||||
removePageContext,
|
removePageContext,
|
||||||
} from "./helpers";
|
} from "./helpers";
|
||||||
@@ -19,45 +16,20 @@ import {
|
|||||||
interface Args {
|
interface Args {
|
||||||
sessionId: string | null;
|
sessionId: string | null;
|
||||||
initialMessages: SessionDetailResponse["messages"];
|
initialMessages: SessionDetailResponse["messages"];
|
||||||
initialPrompt?: string;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useChatContainer({
|
export function useChatContainer({ sessionId, initialMessages }: Args) {
|
||||||
sessionId,
|
|
||||||
initialMessages,
|
|
||||||
initialPrompt,
|
|
||||||
}: Args) {
|
|
||||||
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
||||||
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
||||||
const [hasTextChunks, setHasTextChunks] = useState(false);
|
const [hasTextChunks, setHasTextChunks] = useState(false);
|
||||||
const [isStreamingInitiated, setIsStreamingInitiated] = useState(false);
|
const [isStreamingInitiated, setIsStreamingInitiated] = useState(false);
|
||||||
const [isRegionBlockedModalOpen, setIsRegionBlockedModalOpen] =
|
|
||||||
useState(false);
|
|
||||||
const hasResponseRef = useRef(false);
|
|
||||||
const streamingChunksRef = useRef<string[]>([]);
|
const streamingChunksRef = useRef<string[]>([]);
|
||||||
const previousSessionIdRef = useRef<string | null>(null);
|
const { error, sendMessage: sendStreamMessage } = useChatStream();
|
||||||
const {
|
|
||||||
error,
|
|
||||||
sendMessage: sendStreamMessage,
|
|
||||||
stopStreaming,
|
|
||||||
} = useChatStream();
|
|
||||||
const isStreaming = isStreamingInitiated || hasTextChunks;
|
const isStreaming = isStreamingInitiated || hasTextChunks;
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (sessionId !== previousSessionIdRef.current) {
|
|
||||||
stopStreaming(previousSessionIdRef.current ?? undefined, true);
|
|
||||||
previousSessionIdRef.current = sessionId;
|
|
||||||
setMessages([]);
|
|
||||||
setStreamingChunks([]);
|
|
||||||
streamingChunksRef.current = [];
|
|
||||||
setHasTextChunks(false);
|
|
||||||
setIsStreamingInitiated(false);
|
|
||||||
hasResponseRef.current = false;
|
|
||||||
}
|
|
||||||
}, [sessionId, stopStreaming]);
|
|
||||||
|
|
||||||
const allMessages = useMemo(() => {
|
const allMessages = useMemo(() => {
|
||||||
const processedInitialMessages: ChatMessageData[] = [];
|
const processedInitialMessages: ChatMessageData[] = [];
|
||||||
|
// Map to track tool calls by their ID so we can look up tool names for tool responses
|
||||||
const toolCallMap = new Map<string, string>();
|
const toolCallMap = new Map<string, string>();
|
||||||
|
|
||||||
for (const msg of initialMessages) {
|
for (const msg of initialMessages) {
|
||||||
@@ -73,9 +45,13 @@ export function useChatContainer({
|
|||||||
? new Date(msg.timestamp as string)
|
? new Date(msg.timestamp as string)
|
||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
|
// Remove page context from user messages when loading existing sessions
|
||||||
if (role === "user") {
|
if (role === "user") {
|
||||||
content = removePageContext(content);
|
content = removePageContext(content);
|
||||||
if (!content.trim()) continue;
|
// Skip user messages that become empty after removing page context
|
||||||
|
if (!content.trim()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
processedInitialMessages.push({
|
processedInitialMessages.push({
|
||||||
type: "message",
|
type: "message",
|
||||||
role: "user",
|
role: "user",
|
||||||
@@ -85,15 +61,19 @@ export function useChatContainer({
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle assistant messages first (before tool messages) to build tool call map
|
||||||
if (role === "assistant") {
|
if (role === "assistant") {
|
||||||
|
// Strip <thinking> tags from content
|
||||||
content = content
|
content = content
|
||||||
.replace(/<thinking>[\s\S]*?<\/thinking>/gi, "")
|
.replace(/<thinking>[\s\S]*?<\/thinking>/gi, "")
|
||||||
.trim();
|
.trim();
|
||||||
|
|
||||||
|
// If assistant has tool calls, create tool_call messages for each
|
||||||
if (toolCalls && isToolCallArray(toolCalls) && toolCalls.length > 0) {
|
if (toolCalls && isToolCallArray(toolCalls) && toolCalls.length > 0) {
|
||||||
for (const toolCall of toolCalls) {
|
for (const toolCall of toolCalls) {
|
||||||
const toolName = toolCall.function.name;
|
const toolName = toolCall.function.name;
|
||||||
const toolId = toolCall.id;
|
const toolId = toolCall.id;
|
||||||
|
// Store tool name for later lookup
|
||||||
toolCallMap.set(toolId, toolName);
|
toolCallMap.set(toolId, toolName);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@@ -116,6 +96,7 @@ export function useChatContainer({
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Only add assistant message if there's content after stripping thinking tags
|
||||||
if (content.trim()) {
|
if (content.trim()) {
|
||||||
processedInitialMessages.push({
|
processedInitialMessages.push({
|
||||||
type: "message",
|
type: "message",
|
||||||
@@ -125,6 +106,7 @@ export function useChatContainer({
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else if (content.trim()) {
|
} else if (content.trim()) {
|
||||||
|
// Assistant message without tool calls, but with content
|
||||||
processedInitialMessages.push({
|
processedInitialMessages.push({
|
||||||
type: "message",
|
type: "message",
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
@@ -135,6 +117,7 @@ export function useChatContainer({
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle tool messages - look up tool name from tool call map
|
||||||
if (role === "tool") {
|
if (role === "tool") {
|
||||||
const toolCallId = (msg.tool_call_id as string) || "";
|
const toolCallId = (msg.tool_call_id as string) || "";
|
||||||
const toolName = toolCallMap.get(toolCallId) || "unknown";
|
const toolName = toolCallMap.get(toolCallId) || "unknown";
|
||||||
@@ -150,6 +133,7 @@ export function useChatContainer({
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle other message types (system, etc.)
|
||||||
if (content.trim()) {
|
if (content.trim()) {
|
||||||
processedInitialMessages.push({
|
processedInitialMessages.push({
|
||||||
type: "message",
|
type: "message",
|
||||||
@@ -170,10 +154,9 @@ export function useChatContainer({
|
|||||||
context?: { url: string; content: string },
|
context?: { url: string; content: string },
|
||||||
) {
|
) {
|
||||||
if (!sessionId) {
|
if (!sessionId) {
|
||||||
console.error("[useChatContainer] Cannot send message: no session ID");
|
console.error("Cannot send message: no session ID");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
setIsRegionBlockedModalOpen(false);
|
|
||||||
if (isUserMessage) {
|
if (isUserMessage) {
|
||||||
const userMessage = createUserMessage(content);
|
const userMessage = createUserMessage(content);
|
||||||
setMessages((prev) => [...filterAuthMessages(prev), userMessage]);
|
setMessages((prev) => [...filterAuthMessages(prev), userMessage]);
|
||||||
@@ -184,19 +167,14 @@ export function useChatContainer({
|
|||||||
streamingChunksRef.current = [];
|
streamingChunksRef.current = [];
|
||||||
setHasTextChunks(false);
|
setHasTextChunks(false);
|
||||||
setIsStreamingInitiated(true);
|
setIsStreamingInitiated(true);
|
||||||
hasResponseRef.current = false;
|
|
||||||
|
|
||||||
const dispatcher = createStreamEventDispatcher({
|
const dispatcher = createStreamEventDispatcher({
|
||||||
setHasTextChunks,
|
setHasTextChunks,
|
||||||
setStreamingChunks,
|
setStreamingChunks,
|
||||||
streamingChunksRef,
|
streamingChunksRef,
|
||||||
hasResponseRef,
|
|
||||||
setMessages,
|
setMessages,
|
||||||
setIsRegionBlockedModalOpen,
|
|
||||||
sessionId,
|
sessionId,
|
||||||
setIsStreamingInitiated,
|
setIsStreamingInitiated,
|
||||||
});
|
});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await sendStreamMessage(
|
await sendStreamMessage(
|
||||||
sessionId,
|
sessionId,
|
||||||
@@ -206,12 +184,8 @@ export function useChatContainer({
|
|||||||
context,
|
context,
|
||||||
);
|
);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error("[useChatContainer] Failed to send message:", err);
|
console.error("Failed to send message:", err);
|
||||||
setIsStreamingInitiated(false);
|
setIsStreamingInitiated(false);
|
||||||
|
|
||||||
// Don't show error toast for AbortError (expected during cleanup)
|
|
||||||
if (err instanceof Error && err.name === "AbortError") return;
|
|
||||||
|
|
||||||
const errorMessage =
|
const errorMessage =
|
||||||
err instanceof Error ? err.message : "Failed to send message";
|
err instanceof Error ? err.message : "Failed to send message";
|
||||||
toast.error("Failed to send message", {
|
toast.error("Failed to send message", {
|
||||||
@@ -222,63 +196,11 @@ export function useChatContainer({
|
|||||||
[sessionId, sendStreamMessage],
|
[sessionId, sendStreamMessage],
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleStopStreaming = useCallback(() => {
|
|
||||||
stopStreaming();
|
|
||||||
setStreamingChunks([]);
|
|
||||||
streamingChunksRef.current = [];
|
|
||||||
setHasTextChunks(false);
|
|
||||||
setIsStreamingInitiated(false);
|
|
||||||
}, [stopStreaming]);
|
|
||||||
|
|
||||||
const { capturePageContext } = usePageContext();
|
|
||||||
|
|
||||||
// Send initial prompt if provided (for new sessions from homepage)
|
|
||||||
useEffect(
|
|
||||||
function handleInitialPrompt() {
|
|
||||||
if (!initialPrompt || !sessionId) return;
|
|
||||||
if (initialMessages.length > 0) return;
|
|
||||||
if (hasSentInitialPrompt(sessionId)) return;
|
|
||||||
|
|
||||||
markInitialPromptSent(sessionId);
|
|
||||||
const context = capturePageContext();
|
|
||||||
sendMessage(initialPrompt, true, context);
|
|
||||||
},
|
|
||||||
[
|
|
||||||
initialPrompt,
|
|
||||||
sessionId,
|
|
||||||
initialMessages.length,
|
|
||||||
sendMessage,
|
|
||||||
capturePageContext,
|
|
||||||
],
|
|
||||||
);
|
|
||||||
|
|
||||||
async function sendMessageWithContext(
|
|
||||||
content: string,
|
|
||||||
isUserMessage: boolean = true,
|
|
||||||
) {
|
|
||||||
const context = capturePageContext();
|
|
||||||
await sendMessage(content, isUserMessage, context);
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleRegionModalOpenChange(open: boolean) {
|
|
||||||
setIsRegionBlockedModalOpen(open);
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleRegionModalClose() {
|
|
||||||
setIsRegionBlockedModalOpen(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
messages: allMessages,
|
messages: allMessages,
|
||||||
streamingChunks,
|
streamingChunks,
|
||||||
isStreaming,
|
isStreaming,
|
||||||
error,
|
error,
|
||||||
isRegionBlockedModalOpen,
|
|
||||||
setIsRegionBlockedModalOpen,
|
|
||||||
sendMessageWithContext,
|
|
||||||
handleRegionModalOpenChange,
|
|
||||||
handleRegionModalClose,
|
|
||||||
sendMessage,
|
sendMessage,
|
||||||
stopStreaming: handleStopStreaming,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
import { Input } from "@/components/atoms/Input/Input";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { ArrowUpIcon } from "@phosphor-icons/react";
|
||||||
|
import { useChatInput } from "./useChatInput";
|
||||||
|
|
||||||
|
export interface ChatInputProps {
|
||||||
|
onSend: (message: string) => void;
|
||||||
|
disabled?: boolean;
|
||||||
|
placeholder?: string;
|
||||||
|
className?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ChatInput({
|
||||||
|
onSend,
|
||||||
|
disabled = false,
|
||||||
|
placeholder = "Type your message...",
|
||||||
|
className,
|
||||||
|
}: ChatInputProps) {
|
||||||
|
const inputId = "chat-input";
|
||||||
|
const { value, setValue, handleKeyDown, handleSend } = useChatInput({
|
||||||
|
onSend,
|
||||||
|
disabled,
|
||||||
|
maxRows: 5,
|
||||||
|
inputId,
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={cn("relative flex-1", className)}>
|
||||||
|
<Input
|
||||||
|
id={inputId}
|
||||||
|
label="Chat message input"
|
||||||
|
hideLabel
|
||||||
|
type="textarea"
|
||||||
|
value={value}
|
||||||
|
onChange={(e) => setValue(e.target.value)}
|
||||||
|
onKeyDown={handleKeyDown}
|
||||||
|
placeholder={placeholder}
|
||||||
|
disabled={disabled}
|
||||||
|
rows={1}
|
||||||
|
wrapperClassName="mb-0 relative"
|
||||||
|
className="pr-12"
|
||||||
|
/>
|
||||||
|
<span id="chat-input-hint" className="sr-only">
|
||||||
|
Press Enter to send, Shift+Enter for new line
|
||||||
|
</span>
|
||||||
|
|
||||||
|
<button
|
||||||
|
onClick={handleSend}
|
||||||
|
disabled={disabled || !value.trim()}
|
||||||
|
className={cn(
|
||||||
|
"absolute right-3 top-1/2 flex h-8 w-8 -translate-y-1/2 items-center justify-center rounded-full",
|
||||||
|
"border border-zinc-800 bg-zinc-800 text-white",
|
||||||
|
"hover:border-zinc-900 hover:bg-zinc-900",
|
||||||
|
"disabled:border-zinc-200 disabled:bg-zinc-200 disabled:text-white disabled:opacity-50",
|
||||||
|
"transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-neutral-950",
|
||||||
|
"disabled:pointer-events-none",
|
||||||
|
)}
|
||||||
|
aria-label="Send message"
|
||||||
|
>
|
||||||
|
<ArrowUpIcon className="h-3 w-3" weight="bold" />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
import { KeyboardEvent, useCallback, useEffect, useState } from "react";
|
||||||
|
|
||||||
|
interface UseChatInputArgs {
|
||||||
|
onSend: (message: string) => void;
|
||||||
|
disabled?: boolean;
|
||||||
|
maxRows?: number;
|
||||||
|
inputId?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useChatInput({
|
||||||
|
onSend,
|
||||||
|
disabled = false,
|
||||||
|
maxRows = 5,
|
||||||
|
inputId = "chat-input",
|
||||||
|
}: UseChatInputArgs) {
|
||||||
|
const [value, setValue] = useState("");
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const textarea = document.getElementById(inputId) as HTMLTextAreaElement;
|
||||||
|
if (!textarea) return;
|
||||||
|
textarea.style.height = "auto";
|
||||||
|
const lineHeight = parseInt(
|
||||||
|
window.getComputedStyle(textarea).lineHeight,
|
||||||
|
10,
|
||||||
|
);
|
||||||
|
const maxHeight = lineHeight * maxRows;
|
||||||
|
const newHeight = Math.min(textarea.scrollHeight, maxHeight);
|
||||||
|
textarea.style.height = `${newHeight}px`;
|
||||||
|
textarea.style.overflowY =
|
||||||
|
textarea.scrollHeight > maxHeight ? "auto" : "hidden";
|
||||||
|
}, [value, maxRows, inputId]);
|
||||||
|
|
||||||
|
const handleSend = useCallback(() => {
|
||||||
|
if (disabled || !value.trim()) return;
|
||||||
|
onSend(value.trim());
|
||||||
|
setValue("");
|
||||||
|
const textarea = document.getElementById(inputId) as HTMLTextAreaElement;
|
||||||
|
if (textarea) {
|
||||||
|
textarea.style.height = "auto";
|
||||||
|
}
|
||||||
|
}, [value, onSend, disabled, inputId]);
|
||||||
|
|
||||||
|
const handleKeyDown = useCallback(
|
||||||
|
(event: KeyboardEvent<HTMLInputElement | HTMLTextAreaElement>) => {
|
||||||
|
if (event.key === "Enter" && !event.shiftKey) {
|
||||||
|
event.preventDefault();
|
||||||
|
handleSend();
|
||||||
|
}
|
||||||
|
// Shift+Enter allows default behavior (new line) - no need to handle explicitly
|
||||||
|
},
|
||||||
|
[handleSend],
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
value,
|
||||||
|
setValue,
|
||||||
|
handleKeyDown,
|
||||||
|
handleSend,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -1,65 +1,48 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
|
import { useGetV2GetUserProfile } from "@/app/api/__generated__/endpoints/store/store";
|
||||||
|
import Avatar, {
|
||||||
|
AvatarFallback,
|
||||||
|
AvatarImage,
|
||||||
|
} from "@/components/atoms/Avatar/Avatar";
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import {
|
import {
|
||||||
ArrowsClockwiseIcon,
|
ArrowClockwise,
|
||||||
CheckCircleIcon,
|
CheckCircleIcon,
|
||||||
CheckIcon,
|
CheckIcon,
|
||||||
CopyIcon,
|
CopyIcon,
|
||||||
|
RobotIcon,
|
||||||
} from "@phosphor-icons/react";
|
} from "@phosphor-icons/react";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { useCallback, useState } from "react";
|
import { useCallback, useState } from "react";
|
||||||
|
import { getToolActionPhrase } from "../../helpers";
|
||||||
import { AgentCarouselMessage } from "../AgentCarouselMessage/AgentCarouselMessage";
|
import { AgentCarouselMessage } from "../AgentCarouselMessage/AgentCarouselMessage";
|
||||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
|
||||||
import { AuthPromptWidget } from "../AuthPromptWidget/AuthPromptWidget";
|
import { AuthPromptWidget } from "../AuthPromptWidget/AuthPromptWidget";
|
||||||
import { ChatCredentialsSetup } from "../ChatCredentialsSetup/ChatCredentialsSetup";
|
import { ChatCredentialsSetup } from "../ChatCredentialsSetup/ChatCredentialsSetup";
|
||||||
import { ExecutionStartedMessage } from "../ExecutionStartedMessage/ExecutionStartedMessage";
|
import { ExecutionStartedMessage } from "../ExecutionStartedMessage/ExecutionStartedMessage";
|
||||||
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
||||||
|
import { MessageBubble } from "../MessageBubble/MessageBubble";
|
||||||
import { NoResultsMessage } from "../NoResultsMessage/NoResultsMessage";
|
import { NoResultsMessage } from "../NoResultsMessage/NoResultsMessage";
|
||||||
import { ToolCallMessage } from "../ToolCallMessage/ToolCallMessage";
|
import { ToolCallMessage } from "../ToolCallMessage/ToolCallMessage";
|
||||||
import { ToolResponseMessage } from "../ToolResponseMessage/ToolResponseMessage";
|
import { ToolResponseMessage } from "../ToolResponseMessage/ToolResponseMessage";
|
||||||
import { UserChatBubble } from "../UserChatBubble/UserChatBubble";
|
|
||||||
import { useChatMessage, type ChatMessageData } from "./useChatMessage";
|
import { useChatMessage, type ChatMessageData } from "./useChatMessage";
|
||||||
|
|
||||||
function stripInternalReasoning(content: string): string {
|
|
||||||
const cleaned = content.replace(
|
|
||||||
/<internal_reasoning>[\s\S]*?<\/internal_reasoning>/gi,
|
|
||||||
"",
|
|
||||||
);
|
|
||||||
return cleaned.replace(/\n{3,}/g, "\n\n").trim();
|
|
||||||
}
|
|
||||||
|
|
||||||
function getDisplayContent(message: ChatMessageData, isUser: boolean): string {
|
|
||||||
if (message.type !== "message") return "";
|
|
||||||
if (isUser) return message.content;
|
|
||||||
return stripInternalReasoning(message.content);
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface ChatMessageProps {
|
export interface ChatMessageProps {
|
||||||
message: ChatMessageData;
|
message: ChatMessageData;
|
||||||
messages?: ChatMessageData[];
|
|
||||||
index?: number;
|
|
||||||
isStreaming?: boolean;
|
|
||||||
className?: string;
|
className?: string;
|
||||||
onDismissLogin?: () => void;
|
onDismissLogin?: () => void;
|
||||||
onDismissCredentials?: () => void;
|
onDismissCredentials?: () => void;
|
||||||
onSendMessage?: (content: string, isUserMessage?: boolean) => void;
|
onSendMessage?: (content: string, isUserMessage?: boolean) => void;
|
||||||
agentOutput?: ChatMessageData;
|
agentOutput?: ChatMessageData;
|
||||||
isFinalMessage?: boolean;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ChatMessage({
|
export function ChatMessage({
|
||||||
message,
|
message,
|
||||||
messages = [],
|
|
||||||
index = -1,
|
|
||||||
isStreaming = false,
|
|
||||||
className,
|
className,
|
||||||
onDismissCredentials,
|
onDismissCredentials,
|
||||||
onSendMessage,
|
onSendMessage,
|
||||||
agentOutput,
|
agentOutput,
|
||||||
isFinalMessage = true,
|
|
||||||
}: ChatMessageProps) {
|
}: ChatMessageProps) {
|
||||||
const { user } = useSupabase();
|
const { user } = useSupabase();
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
@@ -71,7 +54,14 @@ export function ChatMessage({
|
|||||||
isLoginNeeded,
|
isLoginNeeded,
|
||||||
isCredentialsNeeded,
|
isCredentialsNeeded,
|
||||||
} = useChatMessage(message);
|
} = useChatMessage(message);
|
||||||
const displayContent = getDisplayContent(message, isUser);
|
|
||||||
|
const { data: profile } = useGetV2GetUserProfile({
|
||||||
|
query: {
|
||||||
|
select: (res) => (res.status === 200 ? res.data : null),
|
||||||
|
enabled: isUser && !!user,
|
||||||
|
queryKey: ["/api/store/profile", user?.id],
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
const handleAllCredentialsComplete = useCallback(
|
const handleAllCredentialsComplete = useCallback(
|
||||||
function handleAllCredentialsComplete() {
|
function handleAllCredentialsComplete() {
|
||||||
@@ -97,25 +87,17 @@ export function ChatMessage({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleCopy = useCallback(
|
const handleCopy = useCallback(async () => {
|
||||||
async function handleCopy() {
|
if (message.type !== "message") return;
|
||||||
if (message.type !== "message") return;
|
|
||||||
if (!displayContent) return;
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await navigator.clipboard.writeText(displayContent);
|
await navigator.clipboard.writeText(message.content);
|
||||||
setCopied(true);
|
setCopied(true);
|
||||||
setTimeout(() => setCopied(false), 2000);
|
setTimeout(() => setCopied(false), 2000);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Failed to copy:", error);
|
console.error("Failed to copy:", error);
|
||||||
}
|
}
|
||||||
},
|
}, [message]);
|
||||||
[displayContent, message],
|
|
||||||
);
|
|
||||||
|
|
||||||
function isLongResponse(content: string): boolean {
|
|
||||||
return content.split("\n").length > 5;
|
|
||||||
}
|
|
||||||
|
|
||||||
const handleTryAgain = useCallback(() => {
|
const handleTryAgain = useCallback(() => {
|
||||||
if (message.type !== "message" || !onSendMessage) return;
|
if (message.type !== "message" || !onSendMessage) return;
|
||||||
@@ -187,45 +169,9 @@ export function ChatMessage({
|
|||||||
|
|
||||||
// Render tool call messages
|
// Render tool call messages
|
||||||
if (isToolCall && message.type === "tool_call") {
|
if (isToolCall && message.type === "tool_call") {
|
||||||
// Check if this tool call is currently streaming
|
|
||||||
// A tool call is streaming if:
|
|
||||||
// 1. isStreaming is true
|
|
||||||
// 2. This is the last tool_call message
|
|
||||||
// 3. There's no tool_response for this tool call yet
|
|
||||||
const isToolCallStreaming =
|
|
||||||
isStreaming &&
|
|
||||||
index >= 0 &&
|
|
||||||
(() => {
|
|
||||||
// Find the last tool_call index
|
|
||||||
let lastToolCallIndex = -1;
|
|
||||||
for (let i = messages.length - 1; i >= 0; i--) {
|
|
||||||
if (messages[i].type === "tool_call") {
|
|
||||||
lastToolCallIndex = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Check if this is the last tool_call and there's no response yet
|
|
||||||
if (index === lastToolCallIndex) {
|
|
||||||
// Check if there's a tool_response for this tool call
|
|
||||||
const hasResponse = messages
|
|
||||||
.slice(index + 1)
|
|
||||||
.some(
|
|
||||||
(msg) =>
|
|
||||||
msg.type === "tool_response" && msg.toolId === message.toolId,
|
|
||||||
);
|
|
||||||
return !hasResponse;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
})();
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={cn("px-4 py-2", className)}>
|
<div className={cn("px-4 py-2", className)}>
|
||||||
<ToolCallMessage
|
<ToolCallMessage toolName={message.toolName} />
|
||||||
toolId={message.toolId}
|
|
||||||
toolName={message.toolName}
|
|
||||||
arguments={message.arguments}
|
|
||||||
isStreaming={isToolCallStreaming}
|
|
||||||
/>
|
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -272,11 +218,27 @@ export function ChatMessage({
|
|||||||
|
|
||||||
// Render tool response messages (but skip agent_output if it's being rendered inside assistant message)
|
// Render tool response messages (but skip agent_output if it's being rendered inside assistant message)
|
||||||
if (isToolResponse && message.type === "tool_response") {
|
if (isToolResponse && message.type === "tool_response") {
|
||||||
|
// Check if this is an agent_output that should be rendered inside assistant message
|
||||||
|
if (message.result) {
|
||||||
|
let parsedResult: Record<string, unknown> | null = null;
|
||||||
|
try {
|
||||||
|
parsedResult =
|
||||||
|
typeof message.result === "string"
|
||||||
|
? JSON.parse(message.result)
|
||||||
|
: (message.result as Record<string, unknown>);
|
||||||
|
} catch {
|
||||||
|
parsedResult = null;
|
||||||
|
}
|
||||||
|
if (parsedResult?.type === "agent_output") {
|
||||||
|
// Skip rendering - this will be rendered inside the assistant message
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={cn("px-4 py-2", className)}>
|
<div className={cn("px-4 py-2", className)}>
|
||||||
<ToolResponseMessage
|
<ToolResponseMessage
|
||||||
toolId={message.toolId}
|
toolName={getToolActionPhrase(message.toolName)}
|
||||||
toolName={message.toolName}
|
|
||||||
result={message.result}
|
result={message.result}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
@@ -294,33 +256,40 @@ export function ChatMessage({
|
|||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<div className="flex w-full max-w-3xl gap-3">
|
<div className="flex w-full max-w-3xl gap-3">
|
||||||
|
{!isUser && (
|
||||||
|
<div className="flex-shrink-0">
|
||||||
|
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-indigo-500">
|
||||||
|
<RobotIcon className="h-4 w-4 text-indigo-50" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
"flex min-w-0 flex-1 flex-col",
|
"flex min-w-0 flex-1 flex-col",
|
||||||
isUser && "items-end",
|
isUser && "items-end",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{isUser ? (
|
<MessageBubble variant={isUser ? "user" : "assistant"}>
|
||||||
<UserChatBubble>
|
<MarkdownContent content={message.content} />
|
||||||
<MarkdownContent content={displayContent} />
|
{agentOutput &&
|
||||||
</UserChatBubble>
|
agentOutput.type === "tool_response" &&
|
||||||
) : (
|
!isUser && (
|
||||||
<AIChatBubble>
|
|
||||||
<MarkdownContent content={displayContent} />
|
|
||||||
{agentOutput && agentOutput.type === "tool_response" && (
|
|
||||||
<div className="mt-4">
|
<div className="mt-4">
|
||||||
<ToolResponseMessage
|
<ToolResponseMessage
|
||||||
toolId={agentOutput.toolId}
|
toolName={
|
||||||
toolName={agentOutput.toolName || "Agent Output"}
|
agentOutput.toolName
|
||||||
|
? getToolActionPhrase(agentOutput.toolName)
|
||||||
|
: "Agent Output"
|
||||||
|
}
|
||||||
result={agentOutput.result}
|
result={agentOutput.result}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</AIChatBubble>
|
</MessageBubble>
|
||||||
)}
|
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
"flex gap-0",
|
"mt-1 flex gap-1",
|
||||||
isUser ? "justify-end" : "justify-start",
|
isUser ? "justify-end" : "justify-start",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
@@ -331,25 +300,37 @@ export function ChatMessage({
|
|||||||
onClick={handleTryAgain}
|
onClick={handleTryAgain}
|
||||||
aria-label="Try again"
|
aria-label="Try again"
|
||||||
>
|
>
|
||||||
<ArrowsClockwiseIcon className="size-4 text-zinc-600" />
|
<ArrowClockwise className="size-3 text-neutral-500" />
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
{!isUser && isFinalMessage && isLongResponse(displayContent) && (
|
|
||||||
<Button
|
|
||||||
variant="ghost"
|
|
||||||
size="icon"
|
|
||||||
onClick={handleCopy}
|
|
||||||
aria-label="Copy message"
|
|
||||||
>
|
|
||||||
{copied ? (
|
|
||||||
<CheckIcon className="size-4 text-green-600" />
|
|
||||||
) : (
|
|
||||||
<CopyIcon className="size-4 text-zinc-600" />
|
|
||||||
)}
|
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
onClick={handleCopy}
|
||||||
|
aria-label="Copy message"
|
||||||
|
>
|
||||||
|
{copied ? (
|
||||||
|
<CheckIcon className="size-3 text-green-600" />
|
||||||
|
) : (
|
||||||
|
<CopyIcon className="size-3 text-neutral-500" />
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{isUser && (
|
||||||
|
<div className="flex-shrink-0">
|
||||||
|
<Avatar className="h-7 w-7">
|
||||||
|
<AvatarImage
|
||||||
|
src={profile?.avatar_url ?? ""}
|
||||||
|
alt={profile?.username ?? "User"}
|
||||||
|
/>
|
||||||
|
<AvatarFallback className="rounded-lg bg-neutral-200 text-neutral-600">
|
||||||
|
{profile?.username?.charAt(0)?.toUpperCase() || "U"}
|
||||||
|
</AvatarFallback>
|
||||||
|
</Avatar>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
@@ -13,9 +13,10 @@ export function MessageBubble({
|
|||||||
className,
|
className,
|
||||||
}: MessageBubbleProps) {
|
}: MessageBubbleProps) {
|
||||||
const userTheme = {
|
const userTheme = {
|
||||||
bg: "bg-purple-100",
|
bg: "bg-slate-900",
|
||||||
border: "border-purple-100",
|
border: "border-slate-800",
|
||||||
text: "text-slate-900",
|
gradient: "from-slate-900/30 via-slate-800/20 to-transparent",
|
||||||
|
text: "text-slate-50",
|
||||||
};
|
};
|
||||||
|
|
||||||
const assistantTheme = {
|
const assistantTheme = {
|
||||||
@@ -39,7 +40,9 @@ export function MessageBubble({
|
|||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{/* Gradient flare background */}
|
{/* Gradient flare background */}
|
||||||
<div className={cn("absolute inset-0 bg-gradient-to-br")} />
|
<div
|
||||||
|
className={cn("absolute inset-0 bg-gradient-to-br", theme.gradient)}
|
||||||
|
/>
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
"relative z-10 transition-all duration-500 ease-in-out",
|
"relative z-10 transition-all duration-500 ease-in-out",
|
||||||
@@ -0,0 +1,121 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { ChatMessage } from "../ChatMessage/ChatMessage";
|
||||||
|
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||||
|
import { StreamingMessage } from "../StreamingMessage/StreamingMessage";
|
||||||
|
import { ThinkingMessage } from "../ThinkingMessage/ThinkingMessage";
|
||||||
|
import { useMessageList } from "./useMessageList";
|
||||||
|
|
||||||
|
export interface MessageListProps {
|
||||||
|
messages: ChatMessageData[];
|
||||||
|
streamingChunks?: string[];
|
||||||
|
isStreaming?: boolean;
|
||||||
|
className?: string;
|
||||||
|
onStreamComplete?: () => void;
|
||||||
|
onSendMessage?: (content: string) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function MessageList({
|
||||||
|
messages,
|
||||||
|
streamingChunks = [],
|
||||||
|
isStreaming = false,
|
||||||
|
className,
|
||||||
|
onStreamComplete,
|
||||||
|
onSendMessage,
|
||||||
|
}: MessageListProps) {
|
||||||
|
const { messagesEndRef, messagesContainerRef } = useMessageList({
|
||||||
|
messageCount: messages.length,
|
||||||
|
isStreaming,
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
ref={messagesContainerRef}
|
||||||
|
className={cn(
|
||||||
|
"flex-1 overflow-y-auto",
|
||||||
|
"scrollbar-thin scrollbar-track-transparent scrollbar-thumb-zinc-300",
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<div className="mx-auto flex max-w-3xl flex-col py-4">
|
||||||
|
{/* Render all persisted messages */}
|
||||||
|
{messages.map((message, index) => {
|
||||||
|
// Check if current message is an agent_output tool_response
|
||||||
|
// and if previous message is an assistant message
|
||||||
|
let agentOutput: ChatMessageData | undefined;
|
||||||
|
|
||||||
|
if (message.type === "tool_response" && message.result) {
|
||||||
|
let parsedResult: Record<string, unknown> | null = null;
|
||||||
|
try {
|
||||||
|
parsedResult =
|
||||||
|
typeof message.result === "string"
|
||||||
|
? JSON.parse(message.result)
|
||||||
|
: (message.result as Record<string, unknown>);
|
||||||
|
} catch {
|
||||||
|
parsedResult = null;
|
||||||
|
}
|
||||||
|
if (parsedResult?.type === "agent_output") {
|
||||||
|
const prevMessage = messages[index - 1];
|
||||||
|
if (
|
||||||
|
prevMessage &&
|
||||||
|
prevMessage.type === "message" &&
|
||||||
|
prevMessage.role === "assistant"
|
||||||
|
) {
|
||||||
|
// This agent output will be rendered inside the previous assistant message
|
||||||
|
// Skip rendering this message separately
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if next message is an agent_output tool_response to include in current assistant message
|
||||||
|
if (message.type === "message" && message.role === "assistant") {
|
||||||
|
const nextMessage = messages[index + 1];
|
||||||
|
if (
|
||||||
|
nextMessage &&
|
||||||
|
nextMessage.type === "tool_response" &&
|
||||||
|
nextMessage.result
|
||||||
|
) {
|
||||||
|
let parsedResult: Record<string, unknown> | null = null;
|
||||||
|
try {
|
||||||
|
parsedResult =
|
||||||
|
typeof nextMessage.result === "string"
|
||||||
|
? JSON.parse(nextMessage.result)
|
||||||
|
: (nextMessage.result as Record<string, unknown>);
|
||||||
|
} catch {
|
||||||
|
parsedResult = null;
|
||||||
|
}
|
||||||
|
if (parsedResult?.type === "agent_output") {
|
||||||
|
agentOutput = nextMessage;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ChatMessage
|
||||||
|
key={index}
|
||||||
|
message={message}
|
||||||
|
onSendMessage={onSendMessage}
|
||||||
|
agentOutput={agentOutput}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
|
||||||
|
{/* Render thinking message when streaming but no chunks yet */}
|
||||||
|
{isStreaming && streamingChunks.length === 0 && <ThinkingMessage />}
|
||||||
|
|
||||||
|
{/* Render streaming message if active */}
|
||||||
|
{isStreaming && streamingChunks.length > 0 && (
|
||||||
|
<StreamingMessage
|
||||||
|
chunks={streamingChunks}
|
||||||
|
onComplete={onStreamComplete}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Invisible div to scroll to */}
|
||||||
|
<div ref={messagesEndRef} />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -81,9 +81,9 @@ export function SessionsDrawer({
|
|||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
) : sessions.length === 0 ? (
|
) : sessions.length === 0 ? (
|
||||||
<div className="flex h-full items-center justify-center">
|
<div className="flex items-center justify-center py-8">
|
||||||
<Text variant="body" className="text-zinc-500">
|
<Text variant="body" className="text-zinc-500">
|
||||||
You don't have previously started chats
|
No sessions found
|
||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
import { RobotIcon } from "@phosphor-icons/react";
|
||||||
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
||||||
|
import { MessageBubble } from "../MessageBubble/MessageBubble";
|
||||||
import { useStreamingMessage } from "./useStreamingMessage";
|
import { useStreamingMessage } from "./useStreamingMessage";
|
||||||
|
|
||||||
export interface StreamingMessageProps {
|
export interface StreamingMessageProps {
|
||||||
@@ -24,10 +25,16 @@ export function StreamingMessage({
|
|||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<div className="flex w-full max-w-3xl gap-3">
|
<div className="flex w-full max-w-3xl gap-3">
|
||||||
|
<div className="flex-shrink-0">
|
||||||
|
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-indigo-600">
|
||||||
|
<RobotIcon className="h-4 w-4 text-indigo-50" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div className="flex min-w-0 flex-1 flex-col">
|
<div className="flex min-w-0 flex-1 flex-col">
|
||||||
<AIChatBubble>
|
<MessageBubble variant="assistant">
|
||||||
<MarkdownContent content={displayText} />
|
<MarkdownContent content={displayText} />
|
||||||
</AIChatBubble>
|
</MessageBubble>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
import { RobotIcon } from "@phosphor-icons/react";
|
||||||
import { useEffect, useRef, useState } from "react";
|
import { useEffect, useRef, useState } from "react";
|
||||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
import { MessageBubble } from "../MessageBubble/MessageBubble";
|
||||||
import { ChatLoader } from "../ChatLoader/ChatLoader";
|
|
||||||
|
|
||||||
export interface ThinkingMessageProps {
|
export interface ThinkingMessageProps {
|
||||||
className?: string;
|
className?: string;
|
||||||
@@ -34,11 +34,22 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
|||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<div className="flex w-full max-w-3xl gap-3">
|
<div className="flex w-full max-w-3xl gap-3">
|
||||||
|
<div className="flex-shrink-0">
|
||||||
|
<div className="flex h-7 w-7 items-center justify-center rounded-lg bg-indigo-500">
|
||||||
|
<RobotIcon className="h-4 w-4 text-indigo-50" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div className="flex min-w-0 flex-1 flex-col">
|
<div className="flex min-w-0 flex-1 flex-col">
|
||||||
<AIChatBubble>
|
<MessageBubble variant="assistant">
|
||||||
<div className="transition-all duration-500 ease-in-out">
|
<div className="transition-all duration-500 ease-in-out">
|
||||||
{showSlowLoader ? (
|
{showSlowLoader ? (
|
||||||
<ChatLoader />
|
<div className="flex flex-col items-center gap-3 py-2">
|
||||||
|
<div className="loader" style={{ flexShrink: 0 }} />
|
||||||
|
<p className="text-sm text-slate-700">
|
||||||
|
Taking a bit longer to think, wait a moment please
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<span
|
<span
|
||||||
className="inline-block bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-clip-text text-transparent"
|
className="inline-block bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-clip-text text-transparent"
|
||||||
@@ -51,7 +62,7 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
|||||||
</span>
|
</span>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
</AIChatBubble>
|
</MessageBubble>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { WrenchIcon } from "@phosphor-icons/react";
|
||||||
|
import { getToolActionPhrase } from "../../helpers";
|
||||||
|
|
||||||
|
export interface ToolCallMessageProps {
|
||||||
|
toolName: string;
|
||||||
|
className?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ToolCallMessage({ toolName, className }: ToolCallMessageProps) {
|
||||||
|
return (
|
||||||
|
<div className={cn("flex items-center justify-center gap-2", className)}>
|
||||||
|
<WrenchIcon
|
||||||
|
size={14}
|
||||||
|
weight="bold"
|
||||||
|
className="flex-shrink-0 text-neutral-500"
|
||||||
|
/>
|
||||||
|
<Text variant="small" className="text-neutral-500">
|
||||||
|
{getToolActionPhrase(toolName)}...
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,260 @@
|
|||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
|
import "@/components/contextual/OutputRenderers";
|
||||||
|
import {
|
||||||
|
globalRegistry,
|
||||||
|
OutputItem,
|
||||||
|
} from "@/components/contextual/OutputRenderers";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import type { ToolResult } from "@/types/chat";
|
||||||
|
import { WrenchIcon } from "@phosphor-icons/react";
|
||||||
|
import { getToolActionPhrase } from "../../helpers";
|
||||||
|
|
||||||
|
export interface ToolResponseMessageProps {
|
||||||
|
toolName: string;
|
||||||
|
result?: ToolResult;
|
||||||
|
success?: boolean;
|
||||||
|
className?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ToolResponseMessage({
|
||||||
|
toolName,
|
||||||
|
result,
|
||||||
|
success: _success = true,
|
||||||
|
className,
|
||||||
|
}: ToolResponseMessageProps) {
|
||||||
|
if (!result) {
|
||||||
|
return (
|
||||||
|
<div className={cn("flex items-center justify-center gap-2", className)}>
|
||||||
|
<WrenchIcon
|
||||||
|
size={14}
|
||||||
|
weight="bold"
|
||||||
|
className="flex-shrink-0 text-neutral-500"
|
||||||
|
/>
|
||||||
|
<Text variant="small" className="text-neutral-500">
|
||||||
|
{getToolActionPhrase(toolName)}...
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let parsedResult: Record<string, unknown> | null = null;
|
||||||
|
try {
|
||||||
|
parsedResult =
|
||||||
|
typeof result === "string"
|
||||||
|
? JSON.parse(result)
|
||||||
|
: (result as Record<string, unknown>);
|
||||||
|
} catch {
|
||||||
|
parsedResult = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parsedResult && typeof parsedResult === "object") {
|
||||||
|
const responseType = parsedResult.type as string | undefined;
|
||||||
|
|
||||||
|
if (responseType === "agent_output") {
|
||||||
|
const execution = parsedResult.execution as
|
||||||
|
| {
|
||||||
|
outputs?: Record<string, unknown[]>;
|
||||||
|
}
|
||||||
|
| null
|
||||||
|
| undefined;
|
||||||
|
const outputs = execution?.outputs || {};
|
||||||
|
const message = parsedResult.message as string | undefined;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={cn("space-y-4 px-4 py-2", className)}>
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<WrenchIcon
|
||||||
|
size={14}
|
||||||
|
weight="bold"
|
||||||
|
className="flex-shrink-0 text-neutral-500"
|
||||||
|
/>
|
||||||
|
<Text variant="small" className="text-neutral-500">
|
||||||
|
{getToolActionPhrase(toolName)}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
{message && (
|
||||||
|
<div className="rounded border p-4">
|
||||||
|
<Text variant="small" className="text-neutral-600">
|
||||||
|
{message}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{Object.keys(outputs).length > 0 && (
|
||||||
|
<div className="space-y-4">
|
||||||
|
{Object.entries(outputs).map(([outputName, values]) =>
|
||||||
|
values.map((value, index) => {
|
||||||
|
const renderer = globalRegistry.getRenderer(value);
|
||||||
|
if (renderer) {
|
||||||
|
return (
|
||||||
|
<OutputItem
|
||||||
|
key={`${outputName}-${index}`}
|
||||||
|
value={value}
|
||||||
|
renderer={renderer}
|
||||||
|
label={outputName}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
key={`${outputName}-${index}`}
|
||||||
|
className="rounded border p-4"
|
||||||
|
>
|
||||||
|
<Text variant="large-medium" className="mb-2 capitalize">
|
||||||
|
{outputName}
|
||||||
|
</Text>
|
||||||
|
<pre className="overflow-auto text-sm">
|
||||||
|
{JSON.stringify(value, null, 2)}
|
||||||
|
</pre>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}),
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (responseType === "block_output" && parsedResult.outputs) {
|
||||||
|
const outputs = parsedResult.outputs as Record<string, unknown[]>;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={cn("space-y-4 px-4 py-2", className)}>
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<WrenchIcon
|
||||||
|
size={14}
|
||||||
|
weight="bold"
|
||||||
|
className="flex-shrink-0 text-neutral-500"
|
||||||
|
/>
|
||||||
|
<Text variant="small" className="text-neutral-500">
|
||||||
|
{getToolActionPhrase(toolName)}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
<div className="space-y-4">
|
||||||
|
{Object.entries(outputs).map(([outputName, values]) =>
|
||||||
|
values.map((value, index) => {
|
||||||
|
const renderer = globalRegistry.getRenderer(value);
|
||||||
|
if (renderer) {
|
||||||
|
return (
|
||||||
|
<OutputItem
|
||||||
|
key={`${outputName}-${index}`}
|
||||||
|
value={value}
|
||||||
|
renderer={renderer}
|
||||||
|
label={outputName}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
key={`${outputName}-${index}`}
|
||||||
|
className="rounded border p-4"
|
||||||
|
>
|
||||||
|
<Text variant="large-medium" className="mb-2 capitalize">
|
||||||
|
{outputName}
|
||||||
|
</Text>
|
||||||
|
<pre className="overflow-auto text-sm">
|
||||||
|
{JSON.stringify(value, null, 2)}
|
||||||
|
</pre>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}),
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle other response types with a message field (e.g., understanding_updated)
|
||||||
|
if (parsedResult.message && typeof parsedResult.message === "string") {
|
||||||
|
// Format tool name from snake_case to Title Case
|
||||||
|
const formattedToolName = toolName
|
||||||
|
.split("_")
|
||||||
|
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
|
||||||
|
.join(" ");
|
||||||
|
|
||||||
|
// Clean up message - remove incomplete user_name references
|
||||||
|
let cleanedMessage = parsedResult.message;
|
||||||
|
// Remove "Updated understanding with: user_name" pattern if user_name is just a placeholder
|
||||||
|
cleanedMessage = cleanedMessage.replace(
|
||||||
|
/Updated understanding with:\s*user_name\.?\s*/gi,
|
||||||
|
"",
|
||||||
|
);
|
||||||
|
// Remove standalone user_name references
|
||||||
|
cleanedMessage = cleanedMessage.replace(/\buser_name\b\.?\s*/gi, "");
|
||||||
|
cleanedMessage = cleanedMessage.trim();
|
||||||
|
|
||||||
|
// Only show message if it has content after cleaning
|
||||||
|
if (!cleanedMessage) {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex items-center justify-center gap-2 px-4 py-2",
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<WrenchIcon
|
||||||
|
size={14}
|
||||||
|
weight="bold"
|
||||||
|
className="flex-shrink-0 text-neutral-500"
|
||||||
|
/>
|
||||||
|
<Text variant="small" className="text-neutral-500">
|
||||||
|
{formattedToolName}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={cn("space-y-2 px-4 py-2", className)}>
|
||||||
|
<div className="flex items-center justify-center gap-2">
|
||||||
|
<WrenchIcon
|
||||||
|
size={14}
|
||||||
|
weight="bold"
|
||||||
|
className="flex-shrink-0 text-neutral-500"
|
||||||
|
/>
|
||||||
|
<Text variant="small" className="text-neutral-500">
|
||||||
|
{formattedToolName}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
<div className="rounded border p-4">
|
||||||
|
<Text variant="small" className="text-neutral-600">
|
||||||
|
{cleanedMessage}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const renderer = globalRegistry.getRenderer(result);
|
||||||
|
if (renderer) {
|
||||||
|
return (
|
||||||
|
<div className={cn("px-4 py-2", className)}>
|
||||||
|
<div className="mb-2 flex items-center gap-2">
|
||||||
|
<WrenchIcon
|
||||||
|
size={14}
|
||||||
|
weight="bold"
|
||||||
|
className="flex-shrink-0 text-neutral-500"
|
||||||
|
/>
|
||||||
|
<Text variant="small" className="text-neutral-500">
|
||||||
|
{getToolActionPhrase(toolName)}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
<OutputItem value={result} renderer={renderer} />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={cn("flex items-center justify-center gap-2", className)}>
|
||||||
|
<WrenchIcon
|
||||||
|
size={14}
|
||||||
|
weight="bold"
|
||||||
|
className="flex-shrink-0 text-neutral-500"
|
||||||
|
/>
|
||||||
|
<Text variant="small" className="text-neutral-500">
|
||||||
|
{getToolActionPhrase(toolName)}...
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
/**
|
||||||
|
* Maps internal tool names to user-friendly display names with emojis.
|
||||||
|
* @deprecated Use getToolActionPhrase or getToolCompletionPhrase for status messages
|
||||||
|
*
|
||||||
|
* @param toolName - The internal tool name from the backend
|
||||||
|
* @returns A user-friendly display name with an emoji prefix
|
||||||
|
*/
|
||||||
|
export function getToolDisplayName(toolName: string): string {
|
||||||
|
const toolDisplayNames: Record<string, string> = {
|
||||||
|
find_agent: "🔍 Search Marketplace",
|
||||||
|
get_agent_details: "📋 Get Agent Details",
|
||||||
|
check_credentials: "🔑 Check Credentials",
|
||||||
|
setup_agent: "⚙️ Setup Agent",
|
||||||
|
run_agent: "▶️ Run Agent",
|
||||||
|
get_required_setup_info: "📝 Get Setup Requirements",
|
||||||
|
};
|
||||||
|
return toolDisplayNames[toolName] || toolName;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps internal tool names to human-friendly action phrases (present continuous).
|
||||||
|
* Used for tool call messages to indicate what action is currently happening.
|
||||||
|
*
|
||||||
|
* @param toolName - The internal tool name from the backend
|
||||||
|
* @returns A human-friendly action phrase in present continuous tense
|
||||||
|
*/
|
||||||
|
export function getToolActionPhrase(toolName: string): string {
|
||||||
|
const toolActionPhrases: Record<string, string> = {
|
||||||
|
find_agent: "Looking for agents in the marketplace",
|
||||||
|
agent_carousel: "Looking for agents in the marketplace",
|
||||||
|
get_agent_details: "Learning about the agent",
|
||||||
|
check_credentials: "Checking your credentials",
|
||||||
|
setup_agent: "Setting up the agent",
|
||||||
|
execution_started: "Running the agent",
|
||||||
|
run_agent: "Running the agent",
|
||||||
|
get_required_setup_info: "Getting setup requirements",
|
||||||
|
schedule_agent: "Scheduling the agent to run",
|
||||||
|
};
|
||||||
|
|
||||||
|
// Return mapped phrase or generate human-friendly fallback
|
||||||
|
return toolActionPhrases[toolName] || toolName;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps internal tool names to human-friendly completion phrases (past tense).
|
||||||
|
* Used for tool response messages to indicate what action was completed.
|
||||||
|
*
|
||||||
|
* @param toolName - The internal tool name from the backend
|
||||||
|
* @returns A human-friendly completion phrase in past tense
|
||||||
|
*/
|
||||||
|
export function getToolCompletionPhrase(toolName: string): string {
|
||||||
|
const toolCompletionPhrases: Record<string, string> = {
|
||||||
|
find_agent: "Finished searching the marketplace",
|
||||||
|
get_agent_details: "Got agent details",
|
||||||
|
check_credentials: "Checked credentials",
|
||||||
|
setup_agent: "Agent setup complete",
|
||||||
|
run_agent: "Agent execution started",
|
||||||
|
get_required_setup_info: "Got setup requirements",
|
||||||
|
};
|
||||||
|
|
||||||
|
// Return mapped phrase or generate human-friendly fallback
|
||||||
|
return (
|
||||||
|
toolCompletionPhrases[toolName] ||
|
||||||
|
`Finished ${toolName.replace(/_/g, " ").replace("...", "")}`
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,20 +1,17 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { useEffect, useRef, useState } from "react";
|
import { useEffect, useRef } from "react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { useChatSession } from "./useChatSession";
|
import { useChatSession } from "./useChatSession";
|
||||||
import { useChatStream } from "./useChatStream";
|
import { useChatStream } from "./useChatStream";
|
||||||
|
|
||||||
interface UseChatArgs {
|
export function useChat() {
|
||||||
urlSessionId?: string | null;
|
const hasCreatedSessionRef = useRef(false);
|
||||||
}
|
|
||||||
|
|
||||||
export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
|
||||||
const hasClaimedSessionRef = useRef(false);
|
const hasClaimedSessionRef = useRef(false);
|
||||||
const { user } = useSupabase();
|
const { user } = useSupabase();
|
||||||
const { sendMessage: sendStreamMessage } = useChatStream();
|
const { sendMessage: sendStreamMessage } = useChatStream();
|
||||||
const [showLoader, setShowLoader] = useState(false);
|
|
||||||
const {
|
const {
|
||||||
session,
|
session,
|
||||||
sessionId: sessionIdFromHook,
|
sessionId: sessionIdFromHook,
|
||||||
@@ -22,16 +19,27 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
|||||||
isLoading,
|
isLoading,
|
||||||
isCreating,
|
isCreating,
|
||||||
error,
|
error,
|
||||||
isSessionNotFound,
|
|
||||||
createSession,
|
createSession,
|
||||||
claimSession,
|
claimSession,
|
||||||
clearSession: clearSessionBase,
|
clearSession: clearSessionBase,
|
||||||
loadSession,
|
loadSession,
|
||||||
} = useChatSession({
|
} = useChatSession({
|
||||||
urlSessionId,
|
urlSessionId: null,
|
||||||
autoCreate: false,
|
autoCreate: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
useEffect(
|
||||||
|
function autoCreateSession() {
|
||||||
|
if (!hasCreatedSessionRef.current && !isCreating && !sessionIdFromHook) {
|
||||||
|
hasCreatedSessionRef.current = true;
|
||||||
|
createSession().catch((_err) => {
|
||||||
|
hasCreatedSessionRef.current = false;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[isCreating, sessionIdFromHook, createSession],
|
||||||
|
);
|
||||||
|
|
||||||
useEffect(
|
useEffect(
|
||||||
function autoClaimSession() {
|
function autoClaimSession() {
|
||||||
if (
|
if (
|
||||||
@@ -67,17 +75,6 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
|||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (isLoading || isCreating) {
|
|
||||||
const timer = setTimeout(() => {
|
|
||||||
setShowLoader(true);
|
|
||||||
}, 300);
|
|
||||||
return () => clearTimeout(timer);
|
|
||||||
} else {
|
|
||||||
setShowLoader(false);
|
|
||||||
}
|
|
||||||
}, [isLoading, isCreating]);
|
|
||||||
|
|
||||||
useEffect(function monitorNetworkStatus() {
|
useEffect(function monitorNetworkStatus() {
|
||||||
function handleOnline() {
|
function handleOnline() {
|
||||||
toast.success("Connection restored", {
|
toast.success("Connection restored", {
|
||||||
@@ -102,6 +99,7 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
|||||||
|
|
||||||
function clearSession() {
|
function clearSession() {
|
||||||
clearSessionBase();
|
clearSessionBase();
|
||||||
|
hasCreatedSessionRef.current = false;
|
||||||
hasClaimedSessionRef.current = false;
|
hasClaimedSessionRef.current = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,11 +109,9 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) {
|
|||||||
isLoading,
|
isLoading,
|
||||||
isCreating,
|
isCreating,
|
||||||
error,
|
error,
|
||||||
isSessionNotFound,
|
|
||||||
createSession,
|
createSession,
|
||||||
clearSession,
|
clearSession,
|
||||||
loadSession,
|
loadSession,
|
||||||
sessionId: sessionIdFromHook,
|
sessionId: sessionIdFromHook,
|
||||||
showLoader,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,271 @@
|
|||||||
|
import {
|
||||||
|
getGetV2GetSessionQueryKey,
|
||||||
|
getGetV2GetSessionQueryOptions,
|
||||||
|
postV2CreateSession,
|
||||||
|
useGetV2GetSession,
|
||||||
|
usePatchV2SessionAssignUser,
|
||||||
|
usePostV2CreateSession,
|
||||||
|
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
|
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||||
|
import { okData } from "@/app/api/helpers";
|
||||||
|
import { isValidUUID } from "@/lib/utils";
|
||||||
|
import { Key, storage } from "@/services/storage/local-storage";
|
||||||
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
|
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
|
||||||
|
interface UseChatSessionArgs {
|
||||||
|
urlSessionId?: string | null;
|
||||||
|
autoCreate?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useChatSession({
|
||||||
|
urlSessionId,
|
||||||
|
autoCreate = false,
|
||||||
|
}: UseChatSessionArgs = {}) {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
const [sessionId, setSessionId] = useState<string | null>(null);
|
||||||
|
const [error, setError] = useState<Error | null>(null);
|
||||||
|
const justCreatedSessionIdRef = useRef<string | null>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (urlSessionId) {
|
||||||
|
if (!isValidUUID(urlSessionId)) {
|
||||||
|
console.error("Invalid session ID format:", urlSessionId);
|
||||||
|
toast.error("Invalid session ID", {
|
||||||
|
description:
|
||||||
|
"The session ID in the URL is not valid. Starting a new session...",
|
||||||
|
});
|
||||||
|
setSessionId(null);
|
||||||
|
storage.clean(Key.CHAT_SESSION_ID);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setSessionId(urlSessionId);
|
||||||
|
storage.set(Key.CHAT_SESSION_ID, urlSessionId);
|
||||||
|
} else {
|
||||||
|
const storedSessionId = storage.get(Key.CHAT_SESSION_ID);
|
||||||
|
if (storedSessionId) {
|
||||||
|
if (!isValidUUID(storedSessionId)) {
|
||||||
|
console.error("Invalid stored session ID:", storedSessionId);
|
||||||
|
storage.clean(Key.CHAT_SESSION_ID);
|
||||||
|
setSessionId(null);
|
||||||
|
} else {
|
||||||
|
setSessionId(storedSessionId);
|
||||||
|
}
|
||||||
|
} else if (autoCreate) {
|
||||||
|
setSessionId(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [urlSessionId, autoCreate]);
|
||||||
|
|
||||||
|
const {
|
||||||
|
mutateAsync: createSessionMutation,
|
||||||
|
isPending: isCreating,
|
||||||
|
error: createError,
|
||||||
|
} = usePostV2CreateSession();
|
||||||
|
|
||||||
|
const {
|
||||||
|
data: sessionData,
|
||||||
|
isLoading: isLoadingSession,
|
||||||
|
error: loadError,
|
||||||
|
refetch,
|
||||||
|
} = useGetV2GetSession(sessionId || "", {
|
||||||
|
query: {
|
||||||
|
enabled: !!sessionId,
|
||||||
|
select: okData,
|
||||||
|
staleTime: Infinity, // Never mark as stale
|
||||||
|
refetchOnMount: false, // Don't refetch on component mount
|
||||||
|
refetchOnWindowFocus: false, // Don't refetch when window regains focus
|
||||||
|
refetchOnReconnect: false, // Don't refetch when network reconnects
|
||||||
|
retry: 1,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const { mutateAsync: claimSessionMutation } = usePatchV2SessionAssignUser();
|
||||||
|
|
||||||
|
const session = useMemo(() => {
|
||||||
|
if (sessionData) return sessionData;
|
||||||
|
|
||||||
|
if (sessionId && justCreatedSessionIdRef.current === sessionId) {
|
||||||
|
return {
|
||||||
|
id: sessionId,
|
||||||
|
user_id: null,
|
||||||
|
messages: [],
|
||||||
|
created_at: new Date().toISOString(),
|
||||||
|
updated_at: new Date().toISOString(),
|
||||||
|
} as SessionDetailResponse;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}, [sessionData, sessionId]);
|
||||||
|
|
||||||
|
const messages = session?.messages || [];
|
||||||
|
const isLoading = isCreating || isLoadingSession;
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (createError) {
|
||||||
|
setError(
|
||||||
|
createError instanceof Error
|
||||||
|
? createError
|
||||||
|
: new Error("Failed to create session"),
|
||||||
|
);
|
||||||
|
} else if (loadError) {
|
||||||
|
setError(
|
||||||
|
loadError instanceof Error
|
||||||
|
? loadError
|
||||||
|
: new Error("Failed to load session"),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
setError(null);
|
||||||
|
}
|
||||||
|
}, [createError, loadError]);
|
||||||
|
|
||||||
|
const createSession = useCallback(
|
||||||
|
async function createSession() {
|
||||||
|
try {
|
||||||
|
setError(null);
|
||||||
|
const response = await postV2CreateSession({
|
||||||
|
body: JSON.stringify({}),
|
||||||
|
});
|
||||||
|
if (response.status !== 200) {
|
||||||
|
throw new Error("Failed to create session");
|
||||||
|
}
|
||||||
|
const newSessionId = response.data.id;
|
||||||
|
setSessionId(newSessionId);
|
||||||
|
storage.set(Key.CHAT_SESSION_ID, newSessionId);
|
||||||
|
justCreatedSessionIdRef.current = newSessionId;
|
||||||
|
setTimeout(() => {
|
||||||
|
if (justCreatedSessionIdRef.current === newSessionId) {
|
||||||
|
justCreatedSessionIdRef.current = null;
|
||||||
|
}
|
||||||
|
}, 10000);
|
||||||
|
return newSessionId;
|
||||||
|
} catch (err) {
|
||||||
|
const error =
|
||||||
|
err instanceof Error ? err : new Error("Failed to create session");
|
||||||
|
setError(error);
|
||||||
|
toast.error("Failed to create chat session", {
|
||||||
|
description: error.message,
|
||||||
|
});
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[createSessionMutation],
|
||||||
|
);
|
||||||
|
|
||||||
|
const loadSession = useCallback(
|
||||||
|
async function loadSession(id: string) {
|
||||||
|
try {
|
||||||
|
setError(null);
|
||||||
|
// Invalidate the query cache for this session to force a fresh fetch
|
||||||
|
await queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2GetSessionQueryKey(id),
|
||||||
|
});
|
||||||
|
// Set sessionId after invalidation to ensure the hook refetches
|
||||||
|
setSessionId(id);
|
||||||
|
storage.set(Key.CHAT_SESSION_ID, id);
|
||||||
|
// Force fetch with fresh data (bypass cache)
|
||||||
|
const queryOptions = getGetV2GetSessionQueryOptions(id, {
|
||||||
|
query: {
|
||||||
|
staleTime: 0, // Force fresh fetch
|
||||||
|
retry: 1,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
const result = await queryClient.fetchQuery(queryOptions);
|
||||||
|
if (!result || ("status" in result && result.status !== 200)) {
|
||||||
|
console.warn("Session not found on server, clearing local state");
|
||||||
|
storage.clean(Key.CHAT_SESSION_ID);
|
||||||
|
setSessionId(null);
|
||||||
|
throw new Error("Session not found");
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
const error =
|
||||||
|
err instanceof Error ? err : new Error("Failed to load session");
|
||||||
|
setError(error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[queryClient],
|
||||||
|
);
|
||||||
|
|
||||||
|
const refreshSession = useCallback(
|
||||||
|
async function refreshSession() {
|
||||||
|
if (!sessionId) {
|
||||||
|
console.log("[refreshSession] Skipping - no session ID");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
setError(null);
|
||||||
|
await refetch();
|
||||||
|
} catch (err) {
|
||||||
|
const error =
|
||||||
|
err instanceof Error ? err : new Error("Failed to refresh session");
|
||||||
|
setError(error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[sessionId, refetch],
|
||||||
|
);
|
||||||
|
|
||||||
|
const claimSession = useCallback(
|
||||||
|
async function claimSession(id: string) {
|
||||||
|
try {
|
||||||
|
setError(null);
|
||||||
|
await claimSessionMutation({ sessionId: id });
|
||||||
|
if (justCreatedSessionIdRef.current === id) {
|
||||||
|
justCreatedSessionIdRef.current = null;
|
||||||
|
}
|
||||||
|
await queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2GetSessionQueryKey(id),
|
||||||
|
});
|
||||||
|
await refetch();
|
||||||
|
toast.success("Session claimed successfully", {
|
||||||
|
description: "Your chat history has been saved to your account",
|
||||||
|
});
|
||||||
|
} catch (err: unknown) {
|
||||||
|
const error =
|
||||||
|
err instanceof Error ? err : new Error("Failed to claim session");
|
||||||
|
const is404 =
|
||||||
|
(typeof err === "object" &&
|
||||||
|
err !== null &&
|
||||||
|
"status" in err &&
|
||||||
|
err.status === 404) ||
|
||||||
|
(typeof err === "object" &&
|
||||||
|
err !== null &&
|
||||||
|
"response" in err &&
|
||||||
|
typeof err.response === "object" &&
|
||||||
|
err.response !== null &&
|
||||||
|
"status" in err.response &&
|
||||||
|
err.response.status === 404);
|
||||||
|
if (!is404) {
|
||||||
|
setError(error);
|
||||||
|
toast.error("Failed to claim session", {
|
||||||
|
description: error.message || "Unable to claim session",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[claimSessionMutation, queryClient, refetch],
|
||||||
|
);
|
||||||
|
|
||||||
|
const clearSession = useCallback(function clearSession() {
|
||||||
|
setSessionId(null);
|
||||||
|
setError(null);
|
||||||
|
storage.clean(Key.CHAT_SESSION_ID);
|
||||||
|
justCreatedSessionIdRef.current = null;
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return {
|
||||||
|
session,
|
||||||
|
sessionId,
|
||||||
|
messages,
|
||||||
|
isLoading,
|
||||||
|
isCreating,
|
||||||
|
error,
|
||||||
|
createSession,
|
||||||
|
loadSession,
|
||||||
|
refreshSession,
|
||||||
|
claimSession,
|
||||||
|
clearSession,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -21,8 +21,6 @@ export interface StreamChunk {
|
|||||||
timestamp?: string;
|
timestamp?: string;
|
||||||
content?: string;
|
content?: string;
|
||||||
message?: string;
|
message?: string;
|
||||||
code?: string;
|
|
||||||
details?: Record<string, unknown>;
|
|
||||||
tool_id?: string;
|
tool_id?: string;
|
||||||
tool_name?: string;
|
tool_name?: string;
|
||||||
arguments?: ToolArguments;
|
arguments?: ToolArguments;
|
||||||
@@ -140,18 +138,8 @@ function normalizeStreamChunk(
|
|||||||
return { type: "stream_end" };
|
return { type: "stream_end" };
|
||||||
case "start":
|
case "start":
|
||||||
case "text-start":
|
case "text-start":
|
||||||
return null;
|
|
||||||
case "tool-input-start":
|
case "tool-input-start":
|
||||||
const toolInputStart = chunk as Extract<
|
return null;
|
||||||
VercelStreamChunk,
|
|
||||||
{ type: "tool-input-start" }
|
|
||||||
>;
|
|
||||||
return {
|
|
||||||
type: "tool_call_start",
|
|
||||||
tool_id: toolInputStart.toolCallId,
|
|
||||||
tool_name: toolInputStart.toolName,
|
|
||||||
arguments: {},
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,103 +149,22 @@ export function useChatStream() {
|
|||||||
const retryCountRef = useRef<number>(0);
|
const retryCountRef = useRef<number>(0);
|
||||||
const retryTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
const retryTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
||||||
const abortControllerRef = useRef<AbortController | null>(null);
|
const abortControllerRef = useRef<AbortController | null>(null);
|
||||||
const currentSessionIdRef = useRef<string | null>(null);
|
|
||||||
const requestStartTimeRef = useRef<number | null>(null);
|
|
||||||
|
|
||||||
const stopStreaming = useCallback(
|
const stopStreaming = useCallback(() => {
|
||||||
(sessionId?: string, force: boolean = false) => {
|
if (abortControllerRef.current) {
|
||||||
console.log("[useChatStream] stopStreaming called", {
|
abortControllerRef.current.abort();
|
||||||
hasAbortController: !!abortControllerRef.current,
|
abortControllerRef.current = null;
|
||||||
isAborted: abortControllerRef.current?.signal.aborted,
|
}
|
||||||
currentSessionId: currentSessionIdRef.current,
|
if (retryTimeoutRef.current) {
|
||||||
requestedSessionId: sessionId,
|
clearTimeout(retryTimeoutRef.current);
|
||||||
requestStartTime: requestStartTimeRef.current,
|
retryTimeoutRef.current = null;
|
||||||
timeSinceStart: requestStartTimeRef.current
|
}
|
||||||
? Date.now() - requestStartTimeRef.current
|
setIsStreaming(false);
|
||||||
: null,
|
}, []);
|
||||||
force,
|
|
||||||
stack: new Error().stack,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (
|
|
||||||
sessionId &&
|
|
||||||
currentSessionIdRef.current &&
|
|
||||||
currentSessionIdRef.current !== sessionId
|
|
||||||
) {
|
|
||||||
console.log(
|
|
||||||
"[useChatStream] Session changed, aborting previous stream",
|
|
||||||
{
|
|
||||||
oldSessionId: currentSessionIdRef.current,
|
|
||||||
newSessionId: sessionId,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
const controller = abortControllerRef.current;
|
|
||||||
if (controller) {
|
|
||||||
const timeSinceStart = requestStartTimeRef.current
|
|
||||||
? Date.now() - requestStartTimeRef.current
|
|
||||||
: null;
|
|
||||||
|
|
||||||
if (!force && timeSinceStart !== null && timeSinceStart < 100) {
|
|
||||||
console.log(
|
|
||||||
"[useChatStream] Request just started (<100ms), skipping abort to prevent race condition",
|
|
||||||
{
|
|
||||||
timeSinceStart,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const signal = controller.signal;
|
|
||||||
|
|
||||||
if (
|
|
||||||
signal &&
|
|
||||||
typeof signal.aborted === "boolean" &&
|
|
||||||
!signal.aborted
|
|
||||||
) {
|
|
||||||
console.log("[useChatStream] Aborting stream");
|
|
||||||
controller.abort();
|
|
||||||
} else {
|
|
||||||
console.log(
|
|
||||||
"[useChatStream] Stream already aborted or signal invalid",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
if (error instanceof Error && error.name === "AbortError") {
|
|
||||||
console.log(
|
|
||||||
"[useChatStream] AbortError caught (expected during cleanup)",
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
console.warn("[useChatStream] Error aborting stream:", error);
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
abortControllerRef.current = null;
|
|
||||||
requestStartTimeRef.current = null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (retryTimeoutRef.current) {
|
|
||||||
clearTimeout(retryTimeoutRef.current);
|
|
||||||
retryTimeoutRef.current = null;
|
|
||||||
}
|
|
||||||
setIsStreaming(false);
|
|
||||||
},
|
|
||||||
[],
|
|
||||||
);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
console.log("[useChatStream] Component mounted");
|
|
||||||
return () => {
|
return () => {
|
||||||
const sessionIdAtUnmount = currentSessionIdRef.current;
|
stopStreaming();
|
||||||
console.log(
|
|
||||||
"[useChatStream] Component unmounting, calling stopStreaming",
|
|
||||||
{
|
|
||||||
sessionIdAtUnmount,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
stopStreaming(undefined, false);
|
|
||||||
currentSessionIdRef.current = null;
|
|
||||||
};
|
};
|
||||||
}, [stopStreaming]);
|
}, [stopStreaming]);
|
||||||
|
|
||||||
@@ -270,32 +177,12 @@ export function useChatStream() {
|
|||||||
context?: { url: string; content: string },
|
context?: { url: string; content: string },
|
||||||
isRetry: boolean = false,
|
isRetry: boolean = false,
|
||||||
) => {
|
) => {
|
||||||
console.log("[useChatStream] sendMessage called", {
|
stopStreaming();
|
||||||
sessionId,
|
|
||||||
message: message.substring(0, 50),
|
|
||||||
isUserMessage,
|
|
||||||
isRetry,
|
|
||||||
stack: new Error().stack,
|
|
||||||
});
|
|
||||||
|
|
||||||
const previousSessionId = currentSessionIdRef.current;
|
|
||||||
stopStreaming(sessionId, true);
|
|
||||||
currentSessionIdRef.current = sessionId;
|
|
||||||
|
|
||||||
const abortController = new AbortController();
|
const abortController = new AbortController();
|
||||||
abortControllerRef.current = abortController;
|
abortControllerRef.current = abortController;
|
||||||
requestStartTimeRef.current = Date.now();
|
|
||||||
console.log("[useChatStream] Created new AbortController", {
|
|
||||||
sessionId,
|
|
||||||
previousSessionId,
|
|
||||||
requestStartTime: requestStartTimeRef.current,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (abortController.signal.aborted) {
|
if (abortController.signal.aborted) {
|
||||||
console.warn(
|
|
||||||
"[useChatStream] AbortController was aborted before request started",
|
|
||||||
);
|
|
||||||
requestStartTimeRef.current = null;
|
|
||||||
return Promise.reject(new Error("Request aborted"));
|
return Promise.reject(new Error("Request aborted"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,34 +210,18 @@ export function useChatStream() {
|
|||||||
signal: abortController.signal,
|
signal: abortController.signal,
|
||||||
});
|
});
|
||||||
|
|
||||||
console.info("[useChatStream] Stream response", {
|
|
||||||
sessionId,
|
|
||||||
status: response.status,
|
|
||||||
ok: response.ok,
|
|
||||||
contentType: response.headers.get("content-type"),
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
const errorText = await response.text();
|
const errorText = await response.text();
|
||||||
console.warn("[useChatStream] Stream response error", {
|
|
||||||
sessionId,
|
|
||||||
status: response.status,
|
|
||||||
errorText,
|
|
||||||
});
|
|
||||||
throw new Error(errorText || `HTTP ${response.status}`);
|
throw new Error(errorText || `HTTP ${response.status}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!response.body) {
|
if (!response.body) {
|
||||||
console.warn("[useChatStream] Response body is null", { sessionId });
|
|
||||||
throw new Error("Response body is null");
|
throw new Error("Response body is null");
|
||||||
}
|
}
|
||||||
|
|
||||||
const reader = response.body.getReader();
|
const reader = response.body.getReader();
|
||||||
const decoder = new TextDecoder();
|
const decoder = new TextDecoder();
|
||||||
let buffer = "";
|
let buffer = "";
|
||||||
let receivedChunkCount = 0;
|
|
||||||
let firstChunkAt: number | null = null;
|
|
||||||
let loggedLineCount = 0;
|
|
||||||
|
|
||||||
return new Promise<void>((resolve, reject) => {
|
return new Promise<void>((resolve, reject) => {
|
||||||
let didDispatchStreamEnd = false;
|
let didDispatchStreamEnd = false;
|
||||||
@@ -374,13 +245,6 @@ export function useChatStream() {
|
|||||||
|
|
||||||
if (done) {
|
if (done) {
|
||||||
cleanup();
|
cleanup();
|
||||||
console.info("[useChatStream] Stream closed", {
|
|
||||||
sessionId,
|
|
||||||
receivedChunkCount,
|
|
||||||
timeSinceStart: requestStartTimeRef.current
|
|
||||||
? Date.now() - requestStartTimeRef.current
|
|
||||||
: null,
|
|
||||||
});
|
|
||||||
dispatchStreamEnd();
|
dispatchStreamEnd();
|
||||||
retryCountRef.current = 0;
|
retryCountRef.current = 0;
|
||||||
stopStreaming();
|
stopStreaming();
|
||||||
@@ -395,23 +259,8 @@ export function useChatStream() {
|
|||||||
for (const line of lines) {
|
for (const line of lines) {
|
||||||
if (line.startsWith("data: ")) {
|
if (line.startsWith("data: ")) {
|
||||||
const data = line.slice(6);
|
const data = line.slice(6);
|
||||||
if (loggedLineCount < 3) {
|
|
||||||
console.info("[useChatStream] Raw stream line", {
|
|
||||||
sessionId,
|
|
||||||
data:
|
|
||||||
data.length > 300 ? `${data.slice(0, 300)}...` : data,
|
|
||||||
});
|
|
||||||
loggedLineCount += 1;
|
|
||||||
}
|
|
||||||
if (data === "[DONE]") {
|
if (data === "[DONE]") {
|
||||||
cleanup();
|
cleanup();
|
||||||
console.info("[useChatStream] Stream done marker", {
|
|
||||||
sessionId,
|
|
||||||
receivedChunkCount,
|
|
||||||
timeSinceStart: requestStartTimeRef.current
|
|
||||||
? Date.now() - requestStartTimeRef.current
|
|
||||||
: null,
|
|
||||||
});
|
|
||||||
dispatchStreamEnd();
|
dispatchStreamEnd();
|
||||||
retryCountRef.current = 0;
|
retryCountRef.current = 0;
|
||||||
stopStreaming();
|
stopStreaming();
|
||||||
@@ -428,18 +277,6 @@ export function useChatStream() {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!firstChunkAt) {
|
|
||||||
firstChunkAt = Date.now();
|
|
||||||
console.info("[useChatStream] First stream chunk", {
|
|
||||||
sessionId,
|
|
||||||
chunkType: chunk.type,
|
|
||||||
timeSinceStart: requestStartTimeRef.current
|
|
||||||
? firstChunkAt - requestStartTimeRef.current
|
|
||||||
: null,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
receivedChunkCount += 1;
|
|
||||||
|
|
||||||
// Call the chunk handler
|
// Call the chunk handler
|
||||||
onChunk(chunk);
|
onChunk(chunk);
|
||||||
|
|
||||||
@@ -447,13 +284,6 @@ export function useChatStream() {
|
|||||||
if (chunk.type === "stream_end") {
|
if (chunk.type === "stream_end") {
|
||||||
didDispatchStreamEnd = true;
|
didDispatchStreamEnd = true;
|
||||||
cleanup();
|
cleanup();
|
||||||
console.info("[useChatStream] Stream end chunk", {
|
|
||||||
sessionId,
|
|
||||||
receivedChunkCount,
|
|
||||||
timeSinceStart: requestStartTimeRef.current
|
|
||||||
? Date.now() - requestStartTimeRef.current
|
|
||||||
: null,
|
|
||||||
});
|
|
||||||
retryCountRef.current = 0;
|
retryCountRef.current = 0;
|
||||||
stopStreaming();
|
stopStreaming();
|
||||||
resolve();
|
resolve();
|
||||||
@@ -477,9 +307,6 @@ export function useChatStream() {
|
|||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof Error && err.name === "AbortError") {
|
if (err instanceof Error && err.name === "AbortError") {
|
||||||
cleanup();
|
cleanup();
|
||||||
dispatchStreamEnd();
|
|
||||||
stopStreaming();
|
|
||||||
resolve();
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -525,10 +352,6 @@ export function useChatStream() {
|
|||||||
readStream();
|
readStream();
|
||||||
});
|
});
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof Error && err.name === "AbortError") {
|
|
||||||
setIsStreaming(false);
|
|
||||||
return Promise.resolve();
|
|
||||||
}
|
|
||||||
const streamError =
|
const streamError =
|
||||||
err instanceof Error ? err : new Error("Failed to start stream");
|
err instanceof Error ? err : new Error("Failed to start stream");
|
||||||
setError(streamError);
|
setError(streamError);
|
||||||
27
autogpt_platform/frontend/src/app/(platform)/chat/page.tsx
Normal file
27
autogpt_platform/frontend/src/app/(platform)/chat/page.tsx
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
|
import { useRouter } from "next/navigation";
|
||||||
|
import { useEffect } from "react";
|
||||||
|
import { Chat } from "./components/Chat/Chat";
|
||||||
|
|
||||||
|
export default function ChatPage() {
|
||||||
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
|
const router = useRouter();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isChatEnabled === false) {
|
||||||
|
router.push("/marketplace");
|
||||||
|
}
|
||||||
|
}, [isChatEnabled, router]);
|
||||||
|
|
||||||
|
if (isChatEnabled === null || isChatEnabled === false) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex h-full flex-col">
|
||||||
|
<Chat className="flex-1" />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
|
||||||
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
|
||||||
import type { ReactNode } from "react";
|
|
||||||
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
|
|
||||||
import { LoadingState } from "./components/LoadingState/LoadingState";
|
|
||||||
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
|
||||||
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
|
||||||
import { useCopilotShell } from "./useCopilotShell";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
children: ReactNode;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function CopilotShell({ children }: Props) {
|
|
||||||
const {
|
|
||||||
isMobile,
|
|
||||||
isDrawerOpen,
|
|
||||||
isLoading,
|
|
||||||
isLoggedIn,
|
|
||||||
hasActiveSession,
|
|
||||||
sessions,
|
|
||||||
currentSessionId,
|
|
||||||
handleSelectSession,
|
|
||||||
handleOpenDrawer,
|
|
||||||
handleCloseDrawer,
|
|
||||||
handleDrawerOpenChange,
|
|
||||||
handleNewChat,
|
|
||||||
hasNextPage,
|
|
||||||
isFetchingNextPage,
|
|
||||||
fetchNextPage,
|
|
||||||
isReadyToShowContent,
|
|
||||||
} = useCopilotShell();
|
|
||||||
|
|
||||||
if (!isLoggedIn) {
|
|
||||||
return (
|
|
||||||
<div className="flex h-full items-center justify-center">
|
|
||||||
<LoadingSpinner size="large" />
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
className="flex overflow-hidden bg-[#EFEFF0]"
|
|
||||||
style={{ height: `calc(100vh - ${NAVBAR_HEIGHT_PX}px)` }}
|
|
||||||
>
|
|
||||||
{!isMobile && (
|
|
||||||
<DesktopSidebar
|
|
||||||
sessions={sessions}
|
|
||||||
currentSessionId={currentSessionId}
|
|
||||||
isLoading={isLoading}
|
|
||||||
hasNextPage={hasNextPage}
|
|
||||||
isFetchingNextPage={isFetchingNextPage}
|
|
||||||
onSelectSession={handleSelectSession}
|
|
||||||
onFetchNextPage={fetchNextPage}
|
|
||||||
onNewChat={handleNewChat}
|
|
||||||
hasActiveSession={Boolean(hasActiveSession)}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<div className="relative flex min-h-0 flex-1 flex-col">
|
|
||||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
|
||||||
<div className="flex min-h-0 flex-1 flex-col">
|
|
||||||
{isReadyToShowContent ? children : <LoadingState />}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{isMobile && (
|
|
||||||
<MobileDrawer
|
|
||||||
isOpen={isDrawerOpen}
|
|
||||||
sessions={sessions}
|
|
||||||
currentSessionId={currentSessionId}
|
|
||||||
isLoading={isLoading}
|
|
||||||
hasNextPage={hasNextPage}
|
|
||||||
isFetchingNextPage={isFetchingNextPage}
|
|
||||||
onSelectSession={handleSelectSession}
|
|
||||||
onFetchNextPage={fetchNextPage}
|
|
||||||
onNewChat={handleNewChat}
|
|
||||||
onClose={handleCloseDrawer}
|
|
||||||
onOpenChange={handleDrawerOpenChange}
|
|
||||||
hasActiveSession={Boolean(hasActiveSession)}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
import { Plus } from "@phosphor-icons/react";
|
|
||||||
import { SessionsList } from "../SessionsList/SessionsList";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
sessions: SessionSummaryResponse[];
|
|
||||||
currentSessionId: string | null;
|
|
||||||
isLoading: boolean;
|
|
||||||
hasNextPage: boolean;
|
|
||||||
isFetchingNextPage: boolean;
|
|
||||||
onSelectSession: (sessionId: string) => void;
|
|
||||||
onFetchNextPage: () => void;
|
|
||||||
onNewChat: () => void;
|
|
||||||
hasActiveSession: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function DesktopSidebar({
|
|
||||||
sessions,
|
|
||||||
currentSessionId,
|
|
||||||
isLoading,
|
|
||||||
hasNextPage,
|
|
||||||
isFetchingNextPage,
|
|
||||||
onSelectSession,
|
|
||||||
onFetchNextPage,
|
|
||||||
onNewChat,
|
|
||||||
hasActiveSession,
|
|
||||||
}: Props) {
|
|
||||||
return (
|
|
||||||
<aside className="flex h-full w-80 flex-col border-r border-zinc-100 bg-zinc-50">
|
|
||||||
<div className="shrink-0 px-6 py-4">
|
|
||||||
<Text variant="h3" size="body-medium">
|
|
||||||
Your chats
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
<div
|
|
||||||
className={cn(
|
|
||||||
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
|
|
||||||
scrollbarStyles,
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<SessionsList
|
|
||||||
sessions={sessions}
|
|
||||||
currentSessionId={currentSessionId}
|
|
||||||
isLoading={isLoading}
|
|
||||||
hasNextPage={hasNextPage}
|
|
||||||
isFetchingNextPage={isFetchingNextPage}
|
|
||||||
onSelectSession={onSelectSession}
|
|
||||||
onFetchNextPage={onFetchNextPage}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
{hasActiveSession && (
|
|
||||||
<div className="shrink-0 bg-zinc-50 p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
|
||||||
<Button
|
|
||||||
variant="primary"
|
|
||||||
size="small"
|
|
||||||
onClick={onNewChat}
|
|
||||||
className="w-full"
|
|
||||||
leftIcon={<Plus width="1rem" height="1rem" />}
|
|
||||||
>
|
|
||||||
New Chat
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</aside>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
|
||||||
|
|
||||||
export function LoadingState() {
|
|
||||||
return (
|
|
||||||
<div className="flex flex-1 items-center justify-center">
|
|
||||||
<div className="flex flex-col items-center gap-4">
|
|
||||||
<ChatLoader />
|
|
||||||
<Text variant="body" className="text-zinc-500">
|
|
||||||
Loading your chats...
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
import { PlusIcon, X } from "@phosphor-icons/react";
|
|
||||||
import { Drawer } from "vaul";
|
|
||||||
import { SessionsList } from "../SessionsList/SessionsList";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
isOpen: boolean;
|
|
||||||
sessions: SessionSummaryResponse[];
|
|
||||||
currentSessionId: string | null;
|
|
||||||
isLoading: boolean;
|
|
||||||
hasNextPage: boolean;
|
|
||||||
isFetchingNextPage: boolean;
|
|
||||||
onSelectSession: (sessionId: string) => void;
|
|
||||||
onFetchNextPage: () => void;
|
|
||||||
onNewChat: () => void;
|
|
||||||
onClose: () => void;
|
|
||||||
onOpenChange: (open: boolean) => void;
|
|
||||||
hasActiveSession: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function MobileDrawer({
|
|
||||||
isOpen,
|
|
||||||
sessions,
|
|
||||||
currentSessionId,
|
|
||||||
isLoading,
|
|
||||||
hasNextPage,
|
|
||||||
isFetchingNextPage,
|
|
||||||
onSelectSession,
|
|
||||||
onFetchNextPage,
|
|
||||||
onNewChat,
|
|
||||||
onClose,
|
|
||||||
onOpenChange,
|
|
||||||
hasActiveSession,
|
|
||||||
}: Props) {
|
|
||||||
return (
|
|
||||||
<Drawer.Root open={isOpen} onOpenChange={onOpenChange} direction="left">
|
|
||||||
<Drawer.Portal>
|
|
||||||
<Drawer.Overlay className="fixed inset-0 z-[60] bg-black/10 backdrop-blur-sm" />
|
|
||||||
<Drawer.Content className="fixed left-0 top-0 z-[70] flex h-full w-80 flex-col border-r border-zinc-200 bg-zinc-50">
|
|
||||||
<div className="shrink-0 border-b border-zinc-200 p-4">
|
|
||||||
<div className="flex items-center justify-between">
|
|
||||||
<Drawer.Title className="text-lg font-semibold text-zinc-800">
|
|
||||||
Your chats
|
|
||||||
</Drawer.Title>
|
|
||||||
<Button
|
|
||||||
variant="icon"
|
|
||||||
size="icon"
|
|
||||||
aria-label="Close sessions"
|
|
||||||
onClick={onClose}
|
|
||||||
>
|
|
||||||
<X width="1.25rem" height="1.25rem" />
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div
|
|
||||||
className={cn(
|
|
||||||
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
|
|
||||||
scrollbarStyles,
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<SessionsList
|
|
||||||
sessions={sessions}
|
|
||||||
currentSessionId={currentSessionId}
|
|
||||||
isLoading={isLoading}
|
|
||||||
hasNextPage={hasNextPage}
|
|
||||||
isFetchingNextPage={isFetchingNextPage}
|
|
||||||
onSelectSession={onSelectSession}
|
|
||||||
onFetchNextPage={onFetchNextPage}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
{hasActiveSession && (
|
|
||||||
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
|
||||||
<Button
|
|
||||||
variant="primary"
|
|
||||||
size="small"
|
|
||||||
onClick={onNewChat}
|
|
||||||
className="w-full"
|
|
||||||
leftIcon={<PlusIcon width="1rem" height="1rem" />}
|
|
||||||
>
|
|
||||||
New Chat
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</Drawer.Content>
|
|
||||||
</Drawer.Portal>
|
|
||||||
</Drawer.Root>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
import { useState } from "react";
|
|
||||||
|
|
||||||
export function useMobileDrawer() {
|
|
||||||
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
|
||||||
|
|
||||||
function handleOpenDrawer() {
|
|
||||||
setIsDrawerOpen(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleCloseDrawer() {
|
|
||||||
setIsDrawerOpen(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleDrawerOpenChange(open: boolean) {
|
|
||||||
setIsDrawerOpen(open);
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
isDrawerOpen,
|
|
||||||
handleOpenDrawer,
|
|
||||||
handleCloseDrawer,
|
|
||||||
handleDrawerOpenChange,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
|
||||||
import { ListIcon } from "@phosphor-icons/react";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
onOpenDrawer: () => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function MobileHeader({ onOpenDrawer }: Props) {
|
|
||||||
return (
|
|
||||||
<Button
|
|
||||||
variant="icon"
|
|
||||||
size="icon"
|
|
||||||
aria-label="Open sessions"
|
|
||||||
onClick={onOpenDrawer}
|
|
||||||
className="fixed z-50 bg-white shadow-md"
|
|
||||||
style={{ left: "1rem", top: `${NAVBAR_HEIGHT_PX + 20}px` }}
|
|
||||||
>
|
|
||||||
<ListIcon width="1.25rem" height="1.25rem" />
|
|
||||||
</Button>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
|
||||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { InfiniteList } from "@/components/molecules/InfiniteList/InfiniteList";
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
import { getSessionTitle } from "../../helpers";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
sessions: SessionSummaryResponse[];
|
|
||||||
currentSessionId: string | null;
|
|
||||||
isLoading: boolean;
|
|
||||||
hasNextPage: boolean;
|
|
||||||
isFetchingNextPage: boolean;
|
|
||||||
onSelectSession: (sessionId: string) => void;
|
|
||||||
onFetchNextPage: () => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function SessionsList({
|
|
||||||
sessions,
|
|
||||||
currentSessionId,
|
|
||||||
isLoading,
|
|
||||||
hasNextPage,
|
|
||||||
isFetchingNextPage,
|
|
||||||
onSelectSession,
|
|
||||||
onFetchNextPage,
|
|
||||||
}: Props) {
|
|
||||||
if (isLoading) {
|
|
||||||
return (
|
|
||||||
<div className="space-y-1">
|
|
||||||
{Array.from({ length: 5 }).map((_, i) => (
|
|
||||||
<div key={i} className="rounded-lg px-3 py-2.5">
|
|
||||||
<Skeleton className="h-5 w-full" />
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sessions.length === 0) {
|
|
||||||
return (
|
|
||||||
<div className="flex h-full items-center justify-center">
|
|
||||||
<Text variant="body" className="text-zinc-500">
|
|
||||||
You don't have previous chats
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<InfiniteList
|
|
||||||
items={sessions}
|
|
||||||
hasMore={hasNextPage}
|
|
||||||
isFetchingMore={isFetchingNextPage}
|
|
||||||
onEndReached={onFetchNextPage}
|
|
||||||
className="space-y-1"
|
|
||||||
renderItem={(session) => {
|
|
||||||
const isActive = session.id === currentSessionId;
|
|
||||||
return (
|
|
||||||
<button
|
|
||||||
onClick={() => onSelectSession(session.id)}
|
|
||||||
className={cn(
|
|
||||||
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
|
|
||||||
isActive ? "bg-zinc-100" : "hover:bg-zinc-50",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<Text
|
|
||||||
variant="body"
|
|
||||||
className={cn(
|
|
||||||
"font-normal",
|
|
||||||
isActive ? "text-zinc-600" : "text-zinc-800",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
{getSessionTitle(session)}
|
|
||||||
</Text>
|
|
||||||
</button>
|
|
||||||
);
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
|
|
||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
|
||||||
import { okData } from "@/app/api/helpers";
|
|
||||||
import { useEffect, useMemo, useState } from "react";
|
|
||||||
|
|
||||||
const PAGE_SIZE = 50;
|
|
||||||
|
|
||||||
export interface UseSessionsPaginationArgs {
|
|
||||||
enabled: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
|
||||||
const [offset, setOffset] = useState(0);
|
|
||||||
const [accumulatedSessions, setAccumulatedSessions] = useState<
|
|
||||||
SessionSummaryResponse[]
|
|
||||||
>([]);
|
|
||||||
const [totalCount, setTotalCount] = useState<number | null>(null);
|
|
||||||
|
|
||||||
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
|
|
||||||
{ limit: PAGE_SIZE, offset },
|
|
||||||
{
|
|
||||||
query: {
|
|
||||||
enabled: enabled && offset >= 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const responseData = okData(data);
|
|
||||||
if (responseData) {
|
|
||||||
const newSessions = responseData.sessions;
|
|
||||||
const total = responseData.total;
|
|
||||||
setTotalCount(total);
|
|
||||||
|
|
||||||
if (offset === 0) {
|
|
||||||
setAccumulatedSessions(newSessions);
|
|
||||||
} else {
|
|
||||||
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
|
|
||||||
}
|
|
||||||
} else if (!enabled) {
|
|
||||||
setAccumulatedSessions([]);
|
|
||||||
setTotalCount(null);
|
|
||||||
}
|
|
||||||
}, [data, offset, enabled]);
|
|
||||||
|
|
||||||
const hasNextPage = useMemo(() => {
|
|
||||||
if (totalCount === null) return false;
|
|
||||||
return accumulatedSessions.length < totalCount;
|
|
||||||
}, [accumulatedSessions.length, totalCount]);
|
|
||||||
|
|
||||||
const areAllSessionsLoaded = useMemo(() => {
|
|
||||||
if (totalCount === null) return false;
|
|
||||||
return (
|
|
||||||
accumulatedSessions.length >= totalCount && !isFetching && !isLoading
|
|
||||||
);
|
|
||||||
}, [accumulatedSessions.length, totalCount, isFetching, isLoading]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (
|
|
||||||
hasNextPage &&
|
|
||||||
!isFetching &&
|
|
||||||
!isLoading &&
|
|
||||||
!isError &&
|
|
||||||
totalCount !== null
|
|
||||||
) {
|
|
||||||
setOffset((prev) => prev + PAGE_SIZE);
|
|
||||||
}
|
|
||||||
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
|
|
||||||
|
|
||||||
function fetchNextPage() {
|
|
||||||
if (hasNextPage && !isFetching) {
|
|
||||||
setOffset((prev) => prev + PAGE_SIZE);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function reset() {
|
|
||||||
setOffset(0);
|
|
||||||
setAccumulatedSessions([]);
|
|
||||||
setTotalCount(null);
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
sessions: accumulatedSessions,
|
|
||||||
isLoading,
|
|
||||||
isFetching,
|
|
||||||
hasNextPage,
|
|
||||||
areAllSessionsLoaded,
|
|
||||||
totalCount,
|
|
||||||
fetchNextPage,
|
|
||||||
reset,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,165 +0,0 @@
|
|||||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
|
||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
|
||||||
import { format, formatDistanceToNow, isToday } from "date-fns";
|
|
||||||
|
|
||||||
export function convertSessionDetailToSummary(
|
|
||||||
session: SessionDetailResponse,
|
|
||||||
): SessionSummaryResponse {
|
|
||||||
return {
|
|
||||||
id: session.id,
|
|
||||||
created_at: session.created_at,
|
|
||||||
updated_at: session.updated_at,
|
|
||||||
title: undefined,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
export function filterVisibleSessions(
|
|
||||||
sessions: SessionSummaryResponse[],
|
|
||||||
): SessionSummaryResponse[] {
|
|
||||||
return sessions.filter(
|
|
||||||
(session) => session.updated_at !== session.created_at,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getSessionTitle(session: SessionSummaryResponse): string {
|
|
||||||
if (session.title) return session.title;
|
|
||||||
const isNewSession = session.updated_at === session.created_at;
|
|
||||||
if (isNewSession) {
|
|
||||||
const createdDate = new Date(session.created_at);
|
|
||||||
if (isToday(createdDate)) {
|
|
||||||
return "Today";
|
|
||||||
}
|
|
||||||
return format(createdDate, "MMM d, yyyy");
|
|
||||||
}
|
|
||||||
return "Untitled Chat";
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getSessionUpdatedLabel(
|
|
||||||
session: SessionSummaryResponse,
|
|
||||||
): string {
|
|
||||||
if (!session.updated_at) return "";
|
|
||||||
return formatDistanceToNow(new Date(session.updated_at), { addSuffix: true });
|
|
||||||
}
|
|
||||||
|
|
||||||
export function mergeCurrentSessionIntoList(
|
|
||||||
accumulatedSessions: SessionSummaryResponse[],
|
|
||||||
currentSessionId: string | null,
|
|
||||||
currentSessionData: SessionDetailResponse | null | undefined,
|
|
||||||
): SessionSummaryResponse[] {
|
|
||||||
const filteredSessions: SessionSummaryResponse[] = [];
|
|
||||||
|
|
||||||
if (accumulatedSessions.length > 0) {
|
|
||||||
const visibleSessions = filterVisibleSessions(accumulatedSessions);
|
|
||||||
|
|
||||||
if (currentSessionId) {
|
|
||||||
const currentInAll = accumulatedSessions.find(
|
|
||||||
(s) => s.id === currentSessionId,
|
|
||||||
);
|
|
||||||
if (currentInAll) {
|
|
||||||
const isInVisible = visibleSessions.some(
|
|
||||||
(s) => s.id === currentSessionId,
|
|
||||||
);
|
|
||||||
if (!isInVisible) {
|
|
||||||
filteredSessions.push(currentInAll);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
filteredSessions.push(...visibleSessions);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (currentSessionId && currentSessionData) {
|
|
||||||
const isCurrentInList = filteredSessions.some(
|
|
||||||
(s) => s.id === currentSessionId,
|
|
||||||
);
|
|
||||||
if (!isCurrentInList) {
|
|
||||||
const summarySession = convertSessionDetailToSummary(currentSessionData);
|
|
||||||
filteredSessions.unshift(summarySession);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return filteredSessions;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getCurrentSessionId(
|
|
||||||
searchParams: URLSearchParams,
|
|
||||||
): string | null {
|
|
||||||
return searchParams.get("sessionId");
|
|
||||||
}
|
|
||||||
|
|
||||||
export function shouldAutoSelectSession(
|
|
||||||
areAllSessionsLoaded: boolean,
|
|
||||||
hasAutoSelectedSession: boolean,
|
|
||||||
paramSessionId: string | null,
|
|
||||||
visibleSessions: SessionSummaryResponse[],
|
|
||||||
accumulatedSessions: SessionSummaryResponse[],
|
|
||||||
isLoading: boolean,
|
|
||||||
totalCount: number | null,
|
|
||||||
): {
|
|
||||||
shouldSelect: boolean;
|
|
||||||
sessionIdToSelect: string | null;
|
|
||||||
shouldCreate: boolean;
|
|
||||||
} {
|
|
||||||
if (!areAllSessionsLoaded || hasAutoSelectedSession) {
|
|
||||||
return {
|
|
||||||
shouldSelect: false,
|
|
||||||
sessionIdToSelect: null,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (paramSessionId) {
|
|
||||||
return {
|
|
||||||
shouldSelect: false,
|
|
||||||
sessionIdToSelect: null,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (visibleSessions.length > 0) {
|
|
||||||
return {
|
|
||||||
shouldSelect: true,
|
|
||||||
sessionIdToSelect: visibleSessions[0].id,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (accumulatedSessions.length === 0 && !isLoading && totalCount === 0) {
|
|
||||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: true };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (totalCount === 0) {
|
|
||||||
return {
|
|
||||||
shouldSelect: false,
|
|
||||||
sessionIdToSelect: null,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: false };
|
|
||||||
}
|
|
||||||
|
|
||||||
export function checkReadyToShowContent(
|
|
||||||
areAllSessionsLoaded: boolean,
|
|
||||||
paramSessionId: string | null,
|
|
||||||
accumulatedSessions: SessionSummaryResponse[],
|
|
||||||
isCurrentSessionLoading: boolean,
|
|
||||||
currentSessionData: SessionDetailResponse | null | undefined,
|
|
||||||
hasAutoSelectedSession: boolean,
|
|
||||||
): boolean {
|
|
||||||
if (!areAllSessionsLoaded) return false;
|
|
||||||
|
|
||||||
if (paramSessionId) {
|
|
||||||
const sessionFound = accumulatedSessions.some(
|
|
||||||
(s) => s.id === paramSessionId,
|
|
||||||
);
|
|
||||||
return (
|
|
||||||
sessionFound ||
|
|
||||||
(!isCurrentSessionLoading &&
|
|
||||||
currentSessionData !== undefined &&
|
|
||||||
currentSessionData !== null)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return hasAutoSelectedSession;
|
|
||||||
}
|
|
||||||
@@ -1,170 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import {
|
|
||||||
getGetV2ListSessionsQueryKey,
|
|
||||||
useGetV2GetSession,
|
|
||||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
|
||||||
import { okData } from "@/app/api/helpers";
|
|
||||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
|
||||||
import { usePathname, useRouter, useSearchParams } from "next/navigation";
|
|
||||||
import { useEffect, useRef, useState } from "react";
|
|
||||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
|
||||||
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
|
|
||||||
import {
|
|
||||||
checkReadyToShowContent,
|
|
||||||
filterVisibleSessions,
|
|
||||||
getCurrentSessionId,
|
|
||||||
mergeCurrentSessionIntoList,
|
|
||||||
} from "./helpers";
|
|
||||||
|
|
||||||
export function useCopilotShell() {
|
|
||||||
const router = useRouter();
|
|
||||||
const pathname = usePathname();
|
|
||||||
const searchParams = useSearchParams();
|
|
||||||
const queryClient = useQueryClient();
|
|
||||||
const breakpoint = useBreakpoint();
|
|
||||||
const { isLoggedIn } = useSupabase();
|
|
||||||
const isMobile =
|
|
||||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
|
||||||
|
|
||||||
const isOnHomepage = pathname === "/copilot";
|
|
||||||
const paramSessionId = searchParams.get("sessionId");
|
|
||||||
|
|
||||||
const {
|
|
||||||
isDrawerOpen,
|
|
||||||
handleOpenDrawer,
|
|
||||||
handleCloseDrawer,
|
|
||||||
handleDrawerOpenChange,
|
|
||||||
} = useMobileDrawer();
|
|
||||||
|
|
||||||
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
|
|
||||||
|
|
||||||
const {
|
|
||||||
sessions: accumulatedSessions,
|
|
||||||
isLoading: isSessionsLoading,
|
|
||||||
isFetching: isSessionsFetching,
|
|
||||||
hasNextPage,
|
|
||||||
areAllSessionsLoaded,
|
|
||||||
fetchNextPage,
|
|
||||||
reset: resetPagination,
|
|
||||||
} = useSessionsPagination({
|
|
||||||
enabled: paginationEnabled,
|
|
||||||
});
|
|
||||||
|
|
||||||
const currentSessionId = getCurrentSessionId(searchParams);
|
|
||||||
|
|
||||||
const { data: currentSessionData, isLoading: isCurrentSessionLoading } =
|
|
||||||
useGetV2GetSession(currentSessionId || "", {
|
|
||||||
query: {
|
|
||||||
enabled: !!currentSessionId,
|
|
||||||
select: okData,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const [hasAutoSelectedSession, setHasAutoSelectedSession] = useState(false);
|
|
||||||
const hasAutoSelectedRef = useRef(false);
|
|
||||||
|
|
||||||
// Mark as auto-selected when sessionId is in URL
|
|
||||||
useEffect(() => {
|
|
||||||
if (paramSessionId && !hasAutoSelectedRef.current) {
|
|
||||||
hasAutoSelectedRef.current = true;
|
|
||||||
setHasAutoSelectedSession(true);
|
|
||||||
}
|
|
||||||
}, [paramSessionId]);
|
|
||||||
|
|
||||||
// On homepage without sessionId, mark as ready immediately
|
|
||||||
useEffect(() => {
|
|
||||||
if (isOnHomepage && !paramSessionId && !hasAutoSelectedRef.current) {
|
|
||||||
hasAutoSelectedRef.current = true;
|
|
||||||
setHasAutoSelectedSession(true);
|
|
||||||
}
|
|
||||||
}, [isOnHomepage, paramSessionId]);
|
|
||||||
|
|
||||||
// Invalidate sessions list when navigating to homepage (to show newly created sessions)
|
|
||||||
useEffect(() => {
|
|
||||||
if (isOnHomepage && !paramSessionId) {
|
|
||||||
queryClient.invalidateQueries({
|
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [isOnHomepage, paramSessionId, queryClient]);
|
|
||||||
|
|
||||||
// Reset pagination when query becomes disabled
|
|
||||||
const prevPaginationEnabledRef = useRef(paginationEnabled);
|
|
||||||
useEffect(() => {
|
|
||||||
if (prevPaginationEnabledRef.current && !paginationEnabled) {
|
|
||||||
resetPagination();
|
|
||||||
resetAutoSelect();
|
|
||||||
}
|
|
||||||
prevPaginationEnabledRef.current = paginationEnabled;
|
|
||||||
}, [paginationEnabled, resetPagination]);
|
|
||||||
|
|
||||||
const sessions = mergeCurrentSessionIntoList(
|
|
||||||
accumulatedSessions,
|
|
||||||
currentSessionId,
|
|
||||||
currentSessionData,
|
|
||||||
);
|
|
||||||
|
|
||||||
const visibleSessions = filterVisibleSessions(sessions);
|
|
||||||
|
|
||||||
const sidebarSelectedSessionId =
|
|
||||||
isOnHomepage && !paramSessionId ? null : currentSessionId;
|
|
||||||
|
|
||||||
const isReadyToShowContent = isOnHomepage
|
|
||||||
? true
|
|
||||||
: checkReadyToShowContent(
|
|
||||||
areAllSessionsLoaded,
|
|
||||||
paramSessionId,
|
|
||||||
accumulatedSessions,
|
|
||||||
isCurrentSessionLoading,
|
|
||||||
currentSessionData,
|
|
||||||
hasAutoSelectedSession,
|
|
||||||
);
|
|
||||||
|
|
||||||
function handleSelectSession(sessionId: string) {
|
|
||||||
// Navigate using replaceState to avoid full page reload
|
|
||||||
window.history.replaceState(null, "", `/copilot?sessionId=${sessionId}`);
|
|
||||||
// Force a re-render by updating the URL through router
|
|
||||||
router.replace(`/copilot?sessionId=${sessionId}`);
|
|
||||||
if (isMobile) handleCloseDrawer();
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleNewChat() {
|
|
||||||
resetAutoSelect();
|
|
||||||
resetPagination();
|
|
||||||
// Invalidate and refetch sessions list to ensure newly created sessions appear
|
|
||||||
queryClient.invalidateQueries({
|
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
|
||||||
});
|
|
||||||
window.history.replaceState(null, "", "/copilot");
|
|
||||||
router.replace("/copilot");
|
|
||||||
if (isMobile) handleCloseDrawer();
|
|
||||||
}
|
|
||||||
|
|
||||||
function resetAutoSelect() {
|
|
||||||
hasAutoSelectedRef.current = false;
|
|
||||||
setHasAutoSelectedSession(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
isMobile,
|
|
||||||
isDrawerOpen,
|
|
||||||
isLoggedIn,
|
|
||||||
hasActiveSession:
|
|
||||||
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
|
|
||||||
isLoading: isSessionsLoading || !areAllSessionsLoaded,
|
|
||||||
sessions: visibleSessions,
|
|
||||||
currentSessionId: sidebarSelectedSessionId,
|
|
||||||
handleSelectSession,
|
|
||||||
handleOpenDrawer,
|
|
||||||
handleCloseDrawer,
|
|
||||||
handleDrawerOpenChange,
|
|
||||||
handleNewChat,
|
|
||||||
hasNextPage,
|
|
||||||
isFetchingNextPage: isSessionsFetching,
|
|
||||||
fetchNextPage,
|
|
||||||
isReadyToShowContent,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
import type { User } from "@supabase/supabase-js";
|
|
||||||
|
|
||||||
export function getGreetingName(user?: User | null): string {
|
|
||||||
if (!user) return "there";
|
|
||||||
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
|
||||||
const fullName = metadata?.full_name;
|
|
||||||
const name = metadata?.name;
|
|
||||||
if (typeof fullName === "string" && fullName.trim()) {
|
|
||||||
return fullName.split(" ")[0];
|
|
||||||
}
|
|
||||||
if (typeof name === "string" && name.trim()) {
|
|
||||||
return name.split(" ")[0];
|
|
||||||
}
|
|
||||||
if (user.email) {
|
|
||||||
return user.email.split("@")[0];
|
|
||||||
}
|
|
||||||
return "there";
|
|
||||||
}
|
|
||||||
|
|
||||||
export function buildCopilotChatUrl(prompt: string): string {
|
|
||||||
const trimmed = prompt.trim();
|
|
||||||
if (!trimmed) return "/copilot/chat";
|
|
||||||
const encoded = encodeURIComponent(trimmed);
|
|
||||||
return `/copilot/chat?prompt=${encoded}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getQuickActions(): string[] {
|
|
||||||
return [
|
|
||||||
"Show me what I can automate",
|
|
||||||
"Design a custom workflow",
|
|
||||||
"Help me with content creation",
|
|
||||||
];
|
|
||||||
}
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
import type { ReactNode } from "react";
|
|
||||||
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
|
|
||||||
|
|
||||||
export default function CopilotLayout({ children }: { children: ReactNode }) {
|
|
||||||
return <CopilotShell>{children}</CopilotShell>;
|
|
||||||
}
|
|
||||||
@@ -1,228 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { postV2CreateSession } from "@/app/api/__generated__/endpoints/chat/chat";
|
|
||||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
|
||||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
|
||||||
import {
|
|
||||||
Flag,
|
|
||||||
type FlagValues,
|
|
||||||
useGetFlag,
|
|
||||||
} from "@/services/feature-flags/use-get-flag";
|
|
||||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
|
||||||
import { useRouter, useSearchParams } from "next/navigation";
|
|
||||||
import { useEffect, useMemo, useRef, useState } from "react";
|
|
||||||
import { getGreetingName, getQuickActions } from "./helpers";
|
|
||||||
|
|
||||||
type PageState =
|
|
||||||
| { type: "welcome" }
|
|
||||||
| { type: "creating"; prompt: string }
|
|
||||||
| { type: "chat"; sessionId: string; initialPrompt?: string };
|
|
||||||
|
|
||||||
export default function CopilotPage() {
|
|
||||||
const router = useRouter();
|
|
||||||
const searchParams = useSearchParams();
|
|
||||||
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
|
||||||
|
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
|
||||||
const flags = useFlags<FlagValues>();
|
|
||||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
|
||||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
|
||||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
|
||||||
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
|
||||||
const isFlagReady =
|
|
||||||
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
|
||||||
|
|
||||||
const [pageState, setPageState] = useState<PageState>({ type: "welcome" });
|
|
||||||
const initialPromptRef = useRef<Map<string, string>>(new Map());
|
|
||||||
|
|
||||||
const urlSessionId = searchParams.get("sessionId");
|
|
||||||
|
|
||||||
// Sync with URL sessionId (preserve initialPrompt from ref)
|
|
||||||
useEffect(
|
|
||||||
function syncSessionFromUrl() {
|
|
||||||
if (urlSessionId) {
|
|
||||||
// If we're already in chat state with this sessionId, don't overwrite
|
|
||||||
if (pageState.type === "chat" && pageState.sessionId === urlSessionId) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// Get initialPrompt from ref or current state
|
|
||||||
const storedInitialPrompt = initialPromptRef.current.get(urlSessionId);
|
|
||||||
const currentInitialPrompt =
|
|
||||||
storedInitialPrompt ||
|
|
||||||
(pageState.type === "creating"
|
|
||||||
? pageState.prompt
|
|
||||||
: pageState.type === "chat"
|
|
||||||
? pageState.initialPrompt
|
|
||||||
: undefined);
|
|
||||||
if (currentInitialPrompt) {
|
|
||||||
initialPromptRef.current.set(urlSessionId, currentInitialPrompt);
|
|
||||||
}
|
|
||||||
setPageState({
|
|
||||||
type: "chat",
|
|
||||||
sessionId: urlSessionId,
|
|
||||||
initialPrompt: currentInitialPrompt,
|
|
||||||
});
|
|
||||||
} else if (pageState.type === "chat") {
|
|
||||||
setPageState({ type: "welcome" });
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[urlSessionId],
|
|
||||||
);
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function ensureAccess() {
|
|
||||||
if (!isFlagReady) return;
|
|
||||||
if (isChatEnabled === false) {
|
|
||||||
router.replace(homepageRoute);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[homepageRoute, isChatEnabled, isFlagReady, router],
|
|
||||||
);
|
|
||||||
|
|
||||||
const greetingName = useMemo(
|
|
||||||
function getName() {
|
|
||||||
return getGreetingName(user);
|
|
||||||
},
|
|
||||||
[user],
|
|
||||||
);
|
|
||||||
|
|
||||||
const quickActions = useMemo(function getActions() {
|
|
||||||
return getQuickActions();
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
async function startChatWithPrompt(prompt: string) {
|
|
||||||
if (!prompt?.trim()) return;
|
|
||||||
if (pageState.type === "creating") return;
|
|
||||||
|
|
||||||
const trimmedPrompt = prompt.trim();
|
|
||||||
setPageState({ type: "creating", prompt: trimmedPrompt });
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Create session
|
|
||||||
const sessionResponse = await postV2CreateSession({
|
|
||||||
body: JSON.stringify({}),
|
|
||||||
});
|
|
||||||
|
|
||||||
if (sessionResponse.status !== 200 || !sessionResponse.data?.id) {
|
|
||||||
throw new Error("Failed to create session");
|
|
||||||
}
|
|
||||||
|
|
||||||
const sessionId = sessionResponse.data.id;
|
|
||||||
|
|
||||||
// Store initialPrompt in ref so it persists across re-renders
|
|
||||||
initialPromptRef.current.set(sessionId, trimmedPrompt);
|
|
||||||
|
|
||||||
// Update URL and show Chat with initial prompt
|
|
||||||
// Chat will handle sending the message and streaming
|
|
||||||
window.history.replaceState(null, "", `/copilot?sessionId=${sessionId}`);
|
|
||||||
setPageState({ type: "chat", sessionId, initialPrompt: trimmedPrompt });
|
|
||||||
} catch (error) {
|
|
||||||
console.error("[CopilotPage] Failed to start chat:", error);
|
|
||||||
setPageState({ type: "welcome" });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleQuickAction(action: string) {
|
|
||||||
startChatWithPrompt(action);
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleSessionNotFound() {
|
|
||||||
router.replace("/copilot");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isFlagReady || isChatEnabled === false || !isLoggedIn) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Show Chat when we have an active session
|
|
||||||
if (pageState.type === "chat") {
|
|
||||||
return (
|
|
||||||
<div className="flex h-full flex-col">
|
|
||||||
<Chat
|
|
||||||
key={pageState.sessionId ?? "welcome"}
|
|
||||||
className="flex-1"
|
|
||||||
urlSessionId={pageState.sessionId}
|
|
||||||
initialPrompt={pageState.initialPrompt}
|
|
||||||
onSessionNotFound={handleSessionNotFound}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Show loading state while creating session and sending first message
|
|
||||||
if (pageState.type === "creating") {
|
|
||||||
return (
|
|
||||||
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9] px-6 py-10">
|
|
||||||
<LoadingSpinner size="large" />
|
|
||||||
<Text variant="body" className="mt-4 text-zinc-500">
|
|
||||||
Starting your chat...
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Show Welcome screen
|
|
||||||
const isLoading = isUserLoading;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
|
||||||
<div className="w-full text-center">
|
|
||||||
{isLoading ? (
|
|
||||||
<div className="mx-auto max-w-2xl">
|
|
||||||
<Skeleton className="mx-auto mb-3 h-8 w-64" />
|
|
||||||
<Skeleton className="mx-auto mb-8 h-6 w-80" />
|
|
||||||
<div className="mb-8">
|
|
||||||
<Skeleton className="mx-auto h-14 w-full rounded-lg" />
|
|
||||||
</div>
|
|
||||||
<div className="flex flex-wrap items-center justify-center gap-3">
|
|
||||||
{Array.from({ length: 4 }).map((_, i) => (
|
|
||||||
<Skeleton key={i} className="h-9 w-48 rounded-md" />
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
<>
|
|
||||||
<div className="mx-auto max-w-2xl">
|
|
||||||
<Text
|
|
||||||
variant="h3"
|
|
||||||
className="mb-3 !text-[1.375rem] text-zinc-700"
|
|
||||||
>
|
|
||||||
Hey, <span className="text-violet-600">{greetingName}</span>
|
|
||||||
</Text>
|
|
||||||
<Text variant="h3" className="mb-8 !font-normal">
|
|
||||||
What do you want to automate?
|
|
||||||
</Text>
|
|
||||||
|
|
||||||
<div className="mb-6">
|
|
||||||
<ChatInput
|
|
||||||
onSend={startChatWithPrompt}
|
|
||||||
placeholder='You can search or just ask - e.g. "create a blog post outline"'
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div className="flex flex-nowrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
|
||||||
{quickActions.map((action) => (
|
|
||||||
<Button
|
|
||||||
key={action}
|
|
||||||
type="button"
|
|
||||||
variant="outline"
|
|
||||||
size="small"
|
|
||||||
onClick={() => handleQuickAction(action)}
|
|
||||||
className="h-auto shrink-0 border-zinc-600 !px-4 !py-2 text-[1rem] text-zinc-600"
|
|
||||||
>
|
|
||||||
{action}
|
|
||||||
</Button>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
|
||||||
import { useSearchParams } from "next/navigation";
|
import { useSearchParams } from "next/navigation";
|
||||||
import { Suspense } from "react";
|
import { Suspense } from "react";
|
||||||
import { getErrorDetails } from "./helpers";
|
import { getErrorDetails } from "./helpers";
|
||||||
@@ -11,8 +9,6 @@ function ErrorPageContent() {
|
|||||||
const searchParams = useSearchParams();
|
const searchParams = useSearchParams();
|
||||||
const errorMessage = searchParams.get("message");
|
const errorMessage = searchParams.get("message");
|
||||||
const errorDetails = getErrorDetails(errorMessage);
|
const errorDetails = getErrorDetails(errorMessage);
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
|
||||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
|
||||||
|
|
||||||
function handleRetry() {
|
function handleRetry() {
|
||||||
// Auth-related errors should redirect to login
|
// Auth-related errors should redirect to login
|
||||||
@@ -29,8 +25,8 @@ function ErrorPageContent() {
|
|||||||
window.location.reload();
|
window.location.reload();
|
||||||
}, 2000);
|
}, 2000);
|
||||||
} else {
|
} else {
|
||||||
// For server/network errors, go to home
|
// For server/network errors, go to marketplace
|
||||||
window.location.href = homepageRoute;
|
window.location.href = "/marketplace";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ export function RunAgentModal({
|
|||||||
|
|
||||||
{/* Content */}
|
{/* Content */}
|
||||||
{hasAnySetupFields ? (
|
{hasAnySetupFields ? (
|
||||||
<div className="mt-4 pb-10">
|
<div className="mt-10 pb-32">
|
||||||
<RunAgentModalContextProvider
|
<RunAgentModalContextProvider
|
||||||
value={{
|
value={{
|
||||||
agent,
|
agent,
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ export function ModalHeader({ agent }: ModalHeaderProps) {
|
|||||||
<ShowMoreText
|
<ShowMoreText
|
||||||
previewLimit={400}
|
previewLimit={400}
|
||||||
variant="small"
|
variant="small"
|
||||||
className="mb-2 mt-4 !text-zinc-700"
|
className="mt-4 !text-zinc-700"
|
||||||
>
|
>
|
||||||
{agent.description}
|
{agent.description}
|
||||||
</ShowMoreText>
|
</ShowMoreText>
|
||||||
@@ -40,8 +40,6 @@ export function ModalHeader({ agent }: ModalHeaderProps) {
|
|||||||
<Text variant="lead-semibold" className="text-blue-600">
|
<Text variant="lead-semibold" className="text-blue-600">
|
||||||
Tip
|
Tip
|
||||||
</Text>
|
</Text>
|
||||||
<div className="h-px w-full bg-blue-100" />
|
|
||||||
|
|
||||||
<Text variant="body">
|
<Text variant="body">
|
||||||
For best results, run this agent{" "}
|
For best results, run this agent{" "}
|
||||||
{humanizeCronExpression(
|
{humanizeCronExpression(
|
||||||
@@ -52,7 +50,7 @@ export function ModalHeader({ agent }: ModalHeaderProps) {
|
|||||||
) : null}
|
) : null}
|
||||||
|
|
||||||
{agent.instructions ? (
|
{agent.instructions ? (
|
||||||
<div className="mt-4 flex flex-col gap-4 rounded-medium border border-purple-100 bg-[#f1ebfe80] p-4">
|
<div className="flex flex-col gap-4 rounded-medium border border-purple-100 bg-[#F1EBFE/5] p-4">
|
||||||
<Text variant="lead-semibold" className="text-purple-600">
|
<Text variant="lead-semibold" className="text-purple-600">
|
||||||
Instructions
|
Instructions
|
||||||
</Text>
|
</Text>
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ import { useGetV2GetUserProfile } from "@/app/api/__generated__/endpoints/store/
|
|||||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||||
import { okData } from "@/app/api/helpers";
|
import { okData } from "@/app/api/helpers";
|
||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
import { isLogoutInProgress } from "@/lib/autogpt-server-api/helpers";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
|
||||||
import { updateFavoriteInQueries } from "./helpers";
|
import { updateFavoriteInQueries } from "./helpers";
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
@@ -25,14 +23,10 @@ export function useLibraryAgentCard({ agent }: Props) {
|
|||||||
const { toast } = useToast();
|
const { toast } = useToast();
|
||||||
const queryClient = getQueryClient();
|
const queryClient = getQueryClient();
|
||||||
const { mutateAsync: updateLibraryAgent } = usePatchV2UpdateLibraryAgent();
|
const { mutateAsync: updateLibraryAgent } = usePatchV2UpdateLibraryAgent();
|
||||||
const { user, isLoggedIn } = useSupabase();
|
|
||||||
const logoutInProgress = isLogoutInProgress();
|
|
||||||
|
|
||||||
const { data: profile } = useGetV2GetUserProfile({
|
const { data: profile } = useGetV2GetUserProfile({
|
||||||
query: {
|
query: {
|
||||||
select: okData,
|
select: okData,
|
||||||
enabled: isLoggedIn && !!user && !logoutInProgress,
|
|
||||||
queryKey: ["/api/store/profile", user?.id],
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
|
||||||
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
||||||
import { zodResolver } from "@hookform/resolvers/zod";
|
import { zodResolver } from "@hookform/resolvers/zod";
|
||||||
import { useRouter, useSearchParams } from "next/navigation";
|
import { useRouter, useSearchParams } from "next/navigation";
|
||||||
@@ -22,17 +20,15 @@ export function useLoginPage() {
|
|||||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||||
const isCloudEnv = environment.isCloud();
|
const isCloudEnv = environment.isCloud();
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
|
||||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
|
||||||
|
|
||||||
// Get redirect destination from 'next' query parameter
|
// Get redirect destination from 'next' query parameter
|
||||||
const nextUrl = searchParams.get("next");
|
const nextUrl = searchParams.get("next");
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isLoggedIn && !isLoggingIn) {
|
if (isLoggedIn && !isLoggingIn) {
|
||||||
router.push(nextUrl || homepageRoute);
|
router.push(nextUrl || "/marketplace");
|
||||||
}
|
}
|
||||||
}, [homepageRoute, isLoggedIn, isLoggingIn, nextUrl, router]);
|
}, [isLoggedIn, isLoggingIn, nextUrl, router]);
|
||||||
|
|
||||||
const form = useForm<z.infer<typeof loginFormSchema>>({
|
const form = useForm<z.infer<typeof loginFormSchema>>({
|
||||||
resolver: zodResolver(loginFormSchema),
|
resolver: zodResolver(loginFormSchema),
|
||||||
@@ -102,7 +98,7 @@ export function useLoginPage() {
|
|||||||
} else if (result.onboarding) {
|
} else if (result.onboarding) {
|
||||||
router.replace("/onboarding");
|
router.replace("/onboarding");
|
||||||
} else {
|
} else {
|
||||||
router.replace(homepageRoute);
|
router.replace("/marketplace");
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
toast({
|
toast({
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
import { expect, test } from "vitest";
|
|
||||||
import { render, screen } from "@/tests/integrations/test-utils";
|
|
||||||
import { MainMarkeplacePage } from "../MainMarketplacePage";
|
|
||||||
import { server } from "@/mocks/mock-server";
|
|
||||||
import { getDeleteV2DeleteStoreSubmissionMockHandler422 } from "@/app/api/__generated__/endpoints/store/store.msw";
|
|
||||||
|
|
||||||
// Only for CI testing purpose, will remove it in future PR
|
|
||||||
test("MainMarketplacePage", async () => {
|
|
||||||
server.use(getDeleteV2DeleteStoreSubmissionMockHandler422());
|
|
||||||
|
|
||||||
render(<MainMarkeplacePage />);
|
|
||||||
expect(
|
|
||||||
await screen.findByText("Featured agents", { exact: false }),
|
|
||||||
).toBeDefined();
|
|
||||||
});
|
|
||||||
@@ -3,14 +3,12 @@
|
|||||||
import { useGetV2GetUserProfile } from "@/app/api/__generated__/endpoints/store/store";
|
import { useGetV2GetUserProfile } from "@/app/api/__generated__/endpoints/store/store";
|
||||||
import { ProfileInfoForm } from "@/components/__legacy__/ProfileInfoForm";
|
import { ProfileInfoForm } from "@/components/__legacy__/ProfileInfoForm";
|
||||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||||
import { isLogoutInProgress } from "@/lib/autogpt-server-api/helpers";
|
|
||||||
import { ProfileDetails } from "@/lib/autogpt-server-api/types";
|
import { ProfileDetails } from "@/lib/autogpt-server-api/types";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { ProfileLoading } from "./ProfileLoading";
|
import { ProfileLoading } from "./ProfileLoading";
|
||||||
|
|
||||||
export default function UserProfilePage() {
|
export default function UserProfilePage() {
|
||||||
const { user } = useSupabase();
|
const { user } = useSupabase();
|
||||||
const logoutInProgress = isLogoutInProgress();
|
|
||||||
|
|
||||||
const {
|
const {
|
||||||
data: profile,
|
data: profile,
|
||||||
@@ -20,7 +18,7 @@ export default function UserProfilePage() {
|
|||||||
refetch,
|
refetch,
|
||||||
} = useGetV2GetUserProfile<ProfileDetails | null>({
|
} = useGetV2GetUserProfile<ProfileDetails | null>({
|
||||||
query: {
|
query: {
|
||||||
enabled: !!user && !logoutInProgress,
|
enabled: !!user,
|
||||||
select: (res) => {
|
select: (res) => {
|
||||||
if (res.status === 200) {
|
if (res.status === 200) {
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"use server";
|
"use server";
|
||||||
|
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||||
import { signupFormSchema } from "@/types/auth";
|
import { signupFormSchema } from "@/types/auth";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
@@ -12,7 +11,6 @@ export async function signup(
|
|||||||
password: string,
|
password: string,
|
||||||
confirmPassword: string,
|
confirmPassword: string,
|
||||||
agreeToTerms: boolean,
|
agreeToTerms: boolean,
|
||||||
isChatEnabled: boolean,
|
|
||||||
) {
|
) {
|
||||||
try {
|
try {
|
||||||
const parsed = signupFormSchema.safeParse({
|
const parsed = signupFormSchema.safeParse({
|
||||||
@@ -60,9 +58,7 @@ export async function signup(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const isOnboardingEnabled = await shouldShowOnboarding();
|
const isOnboardingEnabled = await shouldShowOnboarding();
|
||||||
const next = isOnboardingEnabled
|
const next = isOnboardingEnabled ? "/onboarding" : "/";
|
||||||
? "/onboarding"
|
|
||||||
: getHomepageRoute(isChatEnabled);
|
|
||||||
|
|
||||||
return { success: true, next };
|
return { success: true, next };
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
|
||||||
import { LoginProvider, signupFormSchema } from "@/types/auth";
|
import { LoginProvider, signupFormSchema } from "@/types/auth";
|
||||||
import { zodResolver } from "@hookform/resolvers/zod";
|
import { zodResolver } from "@hookform/resolvers/zod";
|
||||||
import { useRouter, useSearchParams } from "next/navigation";
|
import { useRouter, useSearchParams } from "next/navigation";
|
||||||
@@ -22,17 +20,15 @@ export function useSignupPage() {
|
|||||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||||
const isCloudEnv = environment.isCloud();
|
const isCloudEnv = environment.isCloud();
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
|
||||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
|
||||||
|
|
||||||
// Get redirect destination from 'next' query parameter
|
// Get redirect destination from 'next' query parameter
|
||||||
const nextUrl = searchParams.get("next");
|
const nextUrl = searchParams.get("next");
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isLoggedIn && !isSigningUp) {
|
if (isLoggedIn && !isSigningUp) {
|
||||||
router.push(nextUrl || homepageRoute);
|
router.push(nextUrl || "/marketplace");
|
||||||
}
|
}
|
||||||
}, [homepageRoute, isLoggedIn, isSigningUp, nextUrl, router]);
|
}, [isLoggedIn, isSigningUp, nextUrl, router]);
|
||||||
|
|
||||||
const form = useForm<z.infer<typeof signupFormSchema>>({
|
const form = useForm<z.infer<typeof signupFormSchema>>({
|
||||||
resolver: zodResolver(signupFormSchema),
|
resolver: zodResolver(signupFormSchema),
|
||||||
@@ -108,7 +104,6 @@ export function useSignupPage() {
|
|||||||
data.password,
|
data.password,
|
||||||
data.confirmPassword,
|
data.confirmPassword,
|
||||||
data.agreeToTerms,
|
data.agreeToTerms,
|
||||||
isChatEnabled === true,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
setIsLoading(false);
|
setIsLoading(false);
|
||||||
@@ -134,7 +129,7 @@ export function useSignupPage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prefer the URL's next parameter, then result.next (for onboarding), then default
|
// Prefer the URL's next parameter, then result.next (for onboarding), then default
|
||||||
const redirectTo = nextUrl || result.next || homepageRoute;
|
const redirectTo = nextUrl || result.next || "/";
|
||||||
router.replace(redirectTo);
|
router.replace(redirectTo);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
setIsLoading(false);
|
setIsLoading(false);
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ import {
|
|||||||
getServerAuthToken,
|
getServerAuthToken,
|
||||||
} from "@/lib/autogpt-server-api/helpers";
|
} from "@/lib/autogpt-server-api/helpers";
|
||||||
|
|
||||||
|
import { transformDates } from "./date-transformer";
|
||||||
|
import { environment } from "@/services/environment";
|
||||||
import {
|
import {
|
||||||
IMPERSONATION_HEADER_NAME,
|
IMPERSONATION_HEADER_NAME,
|
||||||
IMPERSONATION_STORAGE_KEY,
|
IMPERSONATION_STORAGE_KEY,
|
||||||
} from "@/lib/constants";
|
} from "@/lib/constants";
|
||||||
import { environment } from "@/services/environment";
|
|
||||||
import { transformDates } from "./date-transformer";
|
|
||||||
|
|
||||||
const FRONTEND_BASE_URL =
|
const FRONTEND_BASE_URL =
|
||||||
process.env.NEXT_PUBLIC_FRONTEND_BASE_URL || "http://localhost:3000";
|
process.env.NEXT_PUBLIC_FRONTEND_BASE_URL || "http://localhost:3000";
|
||||||
|
|||||||
@@ -1022,7 +1022,7 @@
|
|||||||
"get": {
|
"get": {
|
||||||
"tags": ["v2", "chat", "chat"],
|
"tags": ["v2", "chat", "chat"],
|
||||||
"summary": "Get Session",
|
"summary": "Get Session",
|
||||||
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.",
|
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.",
|
||||||
"operationId": "getV2GetSession",
|
"operationId": "getV2GetSession",
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
"security": [{ "HTTPBearerJWT": [] }],
|
||||||
"parameters": [
|
"parameters": [
|
||||||
|
|||||||
@@ -141,6 +141,52 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@keyframes shimmer {
|
||||||
|
0% {
|
||||||
|
background-position: -200% 0;
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
background-position: 200% 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes l3 {
|
||||||
|
25% {
|
||||||
|
background-position:
|
||||||
|
0 0,
|
||||||
|
100% 100%,
|
||||||
|
100% calc(100% - 5px);
|
||||||
|
}
|
||||||
|
50% {
|
||||||
|
background-position:
|
||||||
|
0 100%,
|
||||||
|
100% 100%,
|
||||||
|
0 calc(100% - 5px);
|
||||||
|
}
|
||||||
|
75% {
|
||||||
|
background-position:
|
||||||
|
0 100%,
|
||||||
|
100% 0,
|
||||||
|
100% 5px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.loader {
|
||||||
|
width: 80px;
|
||||||
|
height: 70px;
|
||||||
|
border: 5px solid rgb(241 245 249);
|
||||||
|
padding: 0 8px;
|
||||||
|
box-sizing: border-box;
|
||||||
|
background:
|
||||||
|
linear-gradient(rgb(15 23 42) 0 0) 0 0/8px 20px,
|
||||||
|
linear-gradient(rgb(15 23 42) 0 0) 100% 0/8px 20px,
|
||||||
|
radial-gradient(farthest-side, rgb(15 23 42) 90%, #0000) 0 5px/8px 8px
|
||||||
|
content-box,
|
||||||
|
transparent;
|
||||||
|
background-repeat: no-repeat;
|
||||||
|
animation: l3 2s infinite linear;
|
||||||
|
}
|
||||||
|
|
||||||
input[type="number"]::-webkit-outer-spin-button,
|
input[type="number"]::-webkit-outer-spin-button,
|
||||||
input[type="number"]::-webkit-inner-spin-button {
|
input[type="number"]::-webkit-inner-spin-button {
|
||||||
-webkit-appearance: none;
|
-webkit-appearance: none;
|
||||||
|
|||||||
@@ -1,27 +1,5 @@
|
|||||||
"use client";
|
import { redirect } from "next/navigation";
|
||||||
|
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
|
||||||
import { useRouter } from "next/navigation";
|
|
||||||
import { useEffect } from "react";
|
|
||||||
|
|
||||||
export default function Page() {
|
export default function Page() {
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
redirect("/marketplace");
|
||||||
const router = useRouter();
|
|
||||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
|
||||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
|
||||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
|
||||||
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
|
||||||
const isFlagReady =
|
|
||||||
!isLaunchDarklyConfigured || typeof isChatEnabled === "boolean";
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function redirectToHomepage() {
|
|
||||||
if (!isFlagReady) return;
|
|
||||||
router.replace(homepageRoute);
|
|
||||||
},
|
|
||||||
[homepageRoute, isFlagReady, router],
|
|
||||||
);
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,81 @@
|
|||||||
|
// import { render, screen } from "@testing-library/react";
|
||||||
|
// import { describe, expect, it } from "vitest";
|
||||||
|
// import { Badge } from "./Badge";
|
||||||
|
|
||||||
|
// describe("Badge Component", () => {
|
||||||
|
// it("renders badge with content", () => {
|
||||||
|
// render(<Badge variant="success">Success</Badge>);
|
||||||
|
|
||||||
|
// expect(screen.getByText("Success")).toBeInTheDocument();
|
||||||
|
// });
|
||||||
|
|
||||||
|
// it("applies correct variant styles", () => {
|
||||||
|
// const { rerender } = render(<Badge variant="success">Success</Badge>);
|
||||||
|
// let badge = screen.getByText("Success");
|
||||||
|
// expect(badge).toHaveClass("bg-green-100", "text-green-800");
|
||||||
|
|
||||||
|
// rerender(<Badge variant="error">Error</Badge>);
|
||||||
|
// badge = screen.getByText("Error");
|
||||||
|
// expect(badge).toHaveClass("bg-red-100", "text-red-800");
|
||||||
|
|
||||||
|
// rerender(<Badge variant="info">Info</Badge>);
|
||||||
|
// badge = screen.getByText("Info");
|
||||||
|
// expect(badge).toHaveClass("bg-slate-100", "text-slate-800");
|
||||||
|
// });
|
||||||
|
|
||||||
|
// it("applies custom className", () => {
|
||||||
|
// render(
|
||||||
|
// <Badge variant="success" className="custom-class">
|
||||||
|
// Success
|
||||||
|
// </Badge>,
|
||||||
|
// );
|
||||||
|
|
||||||
|
// const badge = screen.getByText("Success");
|
||||||
|
// expect(badge).toHaveClass("custom-class");
|
||||||
|
// });
|
||||||
|
|
||||||
|
// it("renders as span element", () => {
|
||||||
|
// render(<Badge variant="success">Success</Badge>);
|
||||||
|
|
||||||
|
// const badge = screen.getByText("Success");
|
||||||
|
// expect(badge.tagName).toBe("SPAN");
|
||||||
|
// });
|
||||||
|
|
||||||
|
// it("renders children correctly", () => {
|
||||||
|
// render(
|
||||||
|
// <Badge variant="success">
|
||||||
|
// <span>Custom</span> Content
|
||||||
|
// </Badge>,
|
||||||
|
// );
|
||||||
|
|
||||||
|
// expect(screen.getByText("Custom")).toBeInTheDocument();
|
||||||
|
// expect(screen.getByText("Content")).toBeInTheDocument();
|
||||||
|
// });
|
||||||
|
|
||||||
|
// it("supports all badge variants", () => {
|
||||||
|
// const variants = ["success", "error", "info"] as const;
|
||||||
|
|
||||||
|
// variants.forEach((variant) => {
|
||||||
|
// const { unmount } = render(
|
||||||
|
// <Badge variant={variant} data-testid={`badge-${variant}`}>
|
||||||
|
// {variant}
|
||||||
|
// </Badge>,
|
||||||
|
// );
|
||||||
|
|
||||||
|
// expect(screen.getByTestId(`badge-${variant}`)).toBeInTheDocument();
|
||||||
|
// unmount();
|
||||||
|
// });
|
||||||
|
// });
|
||||||
|
|
||||||
|
// it("handles long text content", () => {
|
||||||
|
// render(
|
||||||
|
// <Badge variant="info">
|
||||||
|
// Very long text that should be handled properly by the component
|
||||||
|
// </Badge>,
|
||||||
|
// );
|
||||||
|
|
||||||
|
// const badge = screen.getByText(/Very long text/);
|
||||||
|
// expect(badge).toBeInTheDocument();
|
||||||
|
// expect(badge).toHaveClass("overflow-hidden", "text-ellipsis");
|
||||||
|
// });
|
||||||
|
// });
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
import { useEffect, useRef } from "react";
|
|
||||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
|
||||||
import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState";
|
|
||||||
import { ChatLoader } from "./components/ChatLoader/ChatLoader";
|
|
||||||
import { useChat } from "./useChat";
|
|
||||||
|
|
||||||
export interface ChatProps {
|
|
||||||
className?: string;
|
|
||||||
urlSessionId?: string | null;
|
|
||||||
initialPrompt?: string;
|
|
||||||
onSessionNotFound?: () => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function Chat({
|
|
||||||
className,
|
|
||||||
urlSessionId,
|
|
||||||
initialPrompt,
|
|
||||||
onSessionNotFound,
|
|
||||||
}: ChatProps) {
|
|
||||||
const hasHandledNotFoundRef = useRef(false);
|
|
||||||
const {
|
|
||||||
messages,
|
|
||||||
isLoading,
|
|
||||||
isCreating,
|
|
||||||
error,
|
|
||||||
isSessionNotFound,
|
|
||||||
sessionId,
|
|
||||||
createSession,
|
|
||||||
showLoader,
|
|
||||||
} = useChat({ urlSessionId });
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function handleMissingSession() {
|
|
||||||
if (!onSessionNotFound) return;
|
|
||||||
if (!urlSessionId) return;
|
|
||||||
if (!isSessionNotFound || isLoading || isCreating) return;
|
|
||||||
if (hasHandledNotFoundRef.current) return;
|
|
||||||
hasHandledNotFoundRef.current = true;
|
|
||||||
onSessionNotFound();
|
|
||||||
},
|
|
||||||
[onSessionNotFound, urlSessionId, isSessionNotFound, isLoading, isCreating],
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className={cn("flex h-full flex-col", className)}>
|
|
||||||
{/* Main Content */}
|
|
||||||
<main className="flex min-h-0 w-full flex-1 flex-col overflow-hidden bg-[#f8f8f9]">
|
|
||||||
{/* Loading State */}
|
|
||||||
{showLoader && (isLoading || isCreating) && (
|
|
||||||
<div className="flex flex-1 items-center justify-center">
|
|
||||||
<div className="flex flex-col items-center gap-4">
|
|
||||||
<ChatLoader />
|
|
||||||
<Text variant="body" className="text-zinc-500">
|
|
||||||
Loading your chats...
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* Error State */}
|
|
||||||
{error && !isLoading && (
|
|
||||||
<ChatErrorState error={error} onRetry={createSession} />
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* Session Content */}
|
|
||||||
{sessionId && !isLoading && !error && (
|
|
||||||
<ChatContainer
|
|
||||||
sessionId={sessionId}
|
|
||||||
initialMessages={messages}
|
|
||||||
initialPrompt={initialPrompt}
|
|
||||||
className="flex-1"
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</main>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
import { cn } from "@/lib/utils";
|
|
||||||
import { ReactNode } from "react";
|
|
||||||
|
|
||||||
export interface AIChatBubbleProps {
|
|
||||||
children: ReactNode;
|
|
||||||
className?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function AIChatBubble({ children, className }: AIChatBubbleProps) {
|
|
||||||
return (
|
|
||||||
<div className={cn("text-left text-[1rem] leading-relaxed", className)}>
|
|
||||||
{children}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
|
||||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
import { ChatInput } from "../ChatInput/ChatInput";
|
|
||||||
import { MessageList } from "../MessageList/MessageList";
|
|
||||||
import { useChatContainer } from "./useChatContainer";
|
|
||||||
|
|
||||||
export interface ChatContainerProps {
|
|
||||||
sessionId: string | null;
|
|
||||||
initialMessages: SessionDetailResponse["messages"];
|
|
||||||
initialPrompt?: string;
|
|
||||||
className?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function ChatContainer({
|
|
||||||
sessionId,
|
|
||||||
initialMessages,
|
|
||||||
initialPrompt,
|
|
||||||
className,
|
|
||||||
}: ChatContainerProps) {
|
|
||||||
const {
|
|
||||||
messages,
|
|
||||||
streamingChunks,
|
|
||||||
isStreaming,
|
|
||||||
stopStreaming,
|
|
||||||
isRegionBlockedModalOpen,
|
|
||||||
sendMessageWithContext,
|
|
||||||
handleRegionModalOpenChange,
|
|
||||||
handleRegionModalClose,
|
|
||||||
} = useChatContainer({
|
|
||||||
sessionId,
|
|
||||||
initialMessages,
|
|
||||||
initialPrompt,
|
|
||||||
});
|
|
||||||
|
|
||||||
const breakpoint = useBreakpoint();
|
|
||||||
const isMobile =
|
|
||||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
className={cn(
|
|
||||||
"mx-auto flex h-full min-h-0 w-full max-w-3xl flex-col bg-[#f8f8f9]",
|
|
||||||
className,
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<Dialog
|
|
||||||
title="Service unavailable"
|
|
||||||
controlled={{
|
|
||||||
isOpen: isRegionBlockedModalOpen,
|
|
||||||
set: handleRegionModalOpenChange,
|
|
||||||
}}
|
|
||||||
onClose={handleRegionModalClose}
|
|
||||||
>
|
|
||||||
<Dialog.Content>
|
|
||||||
<div className="flex flex-col gap-4">
|
|
||||||
<Text variant="body">
|
|
||||||
This model is not available in your region. Please connect via VPN
|
|
||||||
and try again.
|
|
||||||
</Text>
|
|
||||||
<div className="flex justify-end">
|
|
||||||
<Button
|
|
||||||
type="button"
|
|
||||||
variant="primary"
|
|
||||||
onClick={handleRegionModalClose}
|
|
||||||
>
|
|
||||||
Got it
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</Dialog.Content>
|
|
||||||
</Dialog>
|
|
||||||
{/* Messages - Scrollable */}
|
|
||||||
<div className="relative flex min-h-0 flex-1 flex-col">
|
|
||||||
<div className="flex min-h-full flex-col justify-end">
|
|
||||||
<MessageList
|
|
||||||
messages={messages}
|
|
||||||
streamingChunks={streamingChunks}
|
|
||||||
isStreaming={isStreaming}
|
|
||||||
onSendMessage={sendMessageWithContext}
|
|
||||||
className="flex-1"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Input - Fixed at bottom */}
|
|
||||||
<div className="relative px-3 pb-6 pt-2">
|
|
||||||
<div className="pointer-events-none absolute top-[-18px] z-10 h-6 w-full bg-gradient-to-b from-transparent to-[#f8f8f9]" />
|
|
||||||
<ChatInput
|
|
||||||
onSend={sendMessageWithContext}
|
|
||||||
disabled={isStreaming || !sessionId}
|
|
||||||
isStreaming={isStreaming}
|
|
||||||
onStop={stopStreaming}
|
|
||||||
placeholder={
|
|
||||||
isMobile
|
|
||||||
? "You can search or just ask"
|
|
||||||
: 'You can search or just ask — e.g. "create a blog post outline"'
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
import { ArrowUpIcon, StopIcon } from "@phosphor-icons/react";
|
|
||||||
import { useChatInput } from "./useChatInput";
|
|
||||||
|
|
||||||
export interface Props {
|
|
||||||
onSend: (message: string) => void;
|
|
||||||
disabled?: boolean;
|
|
||||||
isStreaming?: boolean;
|
|
||||||
onStop?: () => void;
|
|
||||||
placeholder?: string;
|
|
||||||
className?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function ChatInput({
|
|
||||||
onSend,
|
|
||||||
disabled = false,
|
|
||||||
isStreaming = false,
|
|
||||||
onStop,
|
|
||||||
placeholder = "Type your message...",
|
|
||||||
className,
|
|
||||||
}: Props) {
|
|
||||||
const inputId = "chat-input";
|
|
||||||
const { value, setValue, handleKeyDown, handleSend, hasMultipleLines } =
|
|
||||||
useChatInput({
|
|
||||||
onSend,
|
|
||||||
disabled: disabled || isStreaming,
|
|
||||||
maxRows: 4,
|
|
||||||
inputId,
|
|
||||||
});
|
|
||||||
|
|
||||||
function handleSubmit(e: React.FormEvent<HTMLFormElement>) {
|
|
||||||
e.preventDefault();
|
|
||||||
handleSend();
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleChange(e: React.ChangeEvent<HTMLTextAreaElement>) {
|
|
||||||
setValue(e.target.value);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<form onSubmit={handleSubmit} className={cn("relative flex-1", className)}>
|
|
||||||
<div className="relative">
|
|
||||||
<div
|
|
||||||
id={`${inputId}-wrapper`}
|
|
||||||
className={cn(
|
|
||||||
"relative overflow-hidden border border-neutral-200 bg-white shadow-sm",
|
|
||||||
"focus-within:border-zinc-400 focus-within:ring-1 focus-within:ring-zinc-400",
|
|
||||||
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<textarea
|
|
||||||
id={inputId}
|
|
||||||
aria-label="Chat message input"
|
|
||||||
value={value}
|
|
||||||
onChange={handleChange}
|
|
||||||
onKeyDown={handleKeyDown}
|
|
||||||
placeholder={placeholder}
|
|
||||||
disabled={disabled || isStreaming}
|
|
||||||
rows={1}
|
|
||||||
className={cn(
|
|
||||||
"w-full resize-none overflow-y-auto border-0 bg-transparent text-[1rem] leading-6 text-black",
|
|
||||||
"placeholder:text-zinc-400",
|
|
||||||
"focus:outline-none focus:ring-0",
|
|
||||||
"disabled:text-zinc-500",
|
|
||||||
hasMultipleLines ? "pb-6 pl-4 pr-4 pt-2" : "pb-4 pl-4 pr-14 pt-4",
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<span id="chat-input-hint" className="sr-only">
|
|
||||||
Press Enter to send, Shift+Enter for new line
|
|
||||||
</span>
|
|
||||||
|
|
||||||
{isStreaming ? (
|
|
||||||
<Button
|
|
||||||
type="button"
|
|
||||||
variant="icon"
|
|
||||||
size="icon"
|
|
||||||
aria-label="Stop generating"
|
|
||||||
onClick={onStop}
|
|
||||||
className="absolute bottom-[7px] right-2 border-red-600 bg-red-600 text-white hover:border-red-800 hover:bg-red-800"
|
|
||||||
>
|
|
||||||
<StopIcon className="h-4 w-4" weight="bold" />
|
|
||||||
</Button>
|
|
||||||
) : (
|
|
||||||
<Button
|
|
||||||
type="submit"
|
|
||||||
variant="icon"
|
|
||||||
size="icon"
|
|
||||||
aria-label="Send message"
|
|
||||||
className={cn(
|
|
||||||
"absolute bottom-[7px] right-2 border-zinc-800 bg-zinc-800 text-white hover:border-zinc-900 hover:bg-zinc-900",
|
|
||||||
(disabled || !value.trim()) && "opacity-20",
|
|
||||||
)}
|
|
||||||
disabled={disabled || !value.trim()}
|
|
||||||
>
|
|
||||||
<ArrowUpIcon className="h-4 w-4" weight="bold" />
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
import { KeyboardEvent, useCallback, useEffect, useState } from "react";
|
|
||||||
|
|
||||||
interface UseChatInputArgs {
|
|
||||||
onSend: (message: string) => void;
|
|
||||||
disabled?: boolean;
|
|
||||||
maxRows?: number;
|
|
||||||
inputId?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useChatInput({
|
|
||||||
onSend,
|
|
||||||
disabled = false,
|
|
||||||
maxRows = 5,
|
|
||||||
inputId = "chat-input",
|
|
||||||
}: UseChatInputArgs) {
|
|
||||||
const [value, setValue] = useState("");
|
|
||||||
const [hasMultipleLines, setHasMultipleLines] = useState(false);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const textarea = document.getElementById(inputId) as HTMLTextAreaElement;
|
|
||||||
const wrapper = document.getElementById(
|
|
||||||
`${inputId}-wrapper`,
|
|
||||||
) as HTMLDivElement;
|
|
||||||
if (!textarea || !wrapper) return;
|
|
||||||
|
|
||||||
const isEmpty = !value.trim();
|
|
||||||
const lines = value.split("\n").length;
|
|
||||||
const hasExplicitNewlines = lines > 1;
|
|
||||||
|
|
||||||
const computedStyle = window.getComputedStyle(textarea);
|
|
||||||
const lineHeight = parseInt(computedStyle.lineHeight, 10);
|
|
||||||
const paddingTop = parseInt(computedStyle.paddingTop, 10);
|
|
||||||
const paddingBottom = parseInt(computedStyle.paddingBottom, 10);
|
|
||||||
|
|
||||||
const singleLinePadding = paddingTop + paddingBottom;
|
|
||||||
|
|
||||||
textarea.style.height = "auto";
|
|
||||||
const scrollHeight = textarea.scrollHeight;
|
|
||||||
|
|
||||||
const singleLineHeight = lineHeight + singleLinePadding;
|
|
||||||
const isMultiLine =
|
|
||||||
hasExplicitNewlines || scrollHeight > singleLineHeight + 2;
|
|
||||||
setHasMultipleLines(isMultiLine);
|
|
||||||
|
|
||||||
if (isEmpty) {
|
|
||||||
wrapper.style.height = `${singleLineHeight}px`;
|
|
||||||
wrapper.style.maxHeight = "";
|
|
||||||
textarea.style.height = `${singleLineHeight}px`;
|
|
||||||
textarea.style.maxHeight = "";
|
|
||||||
textarea.style.overflowY = "hidden";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isMultiLine) {
|
|
||||||
const wrapperMaxHeight = 196;
|
|
||||||
const currentMultilinePadding = paddingTop + paddingBottom;
|
|
||||||
const contentMaxHeight = wrapperMaxHeight - currentMultilinePadding;
|
|
||||||
const minMultiLineHeight = lineHeight * 2 + currentMultilinePadding;
|
|
||||||
const contentHeight = scrollHeight;
|
|
||||||
const targetWrapperHeight = Math.min(
|
|
||||||
Math.max(contentHeight + currentMultilinePadding, minMultiLineHeight),
|
|
||||||
wrapperMaxHeight,
|
|
||||||
);
|
|
||||||
|
|
||||||
wrapper.style.height = `${targetWrapperHeight}px`;
|
|
||||||
wrapper.style.maxHeight = `${wrapperMaxHeight}px`;
|
|
||||||
textarea.style.height = `${contentHeight}px`;
|
|
||||||
textarea.style.maxHeight = `${contentMaxHeight}px`;
|
|
||||||
textarea.style.overflowY =
|
|
||||||
contentHeight > contentMaxHeight ? "auto" : "hidden";
|
|
||||||
} else {
|
|
||||||
wrapper.style.height = `${singleLineHeight}px`;
|
|
||||||
wrapper.style.maxHeight = "";
|
|
||||||
textarea.style.height = `${singleLineHeight}px`;
|
|
||||||
textarea.style.maxHeight = "";
|
|
||||||
textarea.style.overflowY = "hidden";
|
|
||||||
}
|
|
||||||
}, [value, maxRows, inputId]);
|
|
||||||
|
|
||||||
const handleSend = useCallback(() => {
|
|
||||||
if (disabled || !value.trim()) return;
|
|
||||||
onSend(value.trim());
|
|
||||||
setValue("");
|
|
||||||
setHasMultipleLines(false);
|
|
||||||
const textarea = document.getElementById(inputId) as HTMLTextAreaElement;
|
|
||||||
const wrapper = document.getElementById(
|
|
||||||
`${inputId}-wrapper`,
|
|
||||||
) as HTMLDivElement;
|
|
||||||
if (textarea) {
|
|
||||||
textarea.style.height = "auto";
|
|
||||||
}
|
|
||||||
if (wrapper) {
|
|
||||||
wrapper.style.height = "";
|
|
||||||
wrapper.style.maxHeight = "";
|
|
||||||
}
|
|
||||||
}, [value, onSend, disabled, inputId]);
|
|
||||||
|
|
||||||
const handleKeyDown = useCallback(
|
|
||||||
(event: KeyboardEvent<HTMLTextAreaElement>) => {
|
|
||||||
if (event.key === "Enter" && !event.shiftKey) {
|
|
||||||
event.preventDefault();
|
|
||||||
handleSend();
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[handleSend],
|
|
||||||
);
|
|
||||||
|
|
||||||
return {
|
|
||||||
value,
|
|
||||||
setValue,
|
|
||||||
handleKeyDown,
|
|
||||||
handleSend,
|
|
||||||
hasMultipleLines,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
|
|
||||||
export function ChatLoader() {
|
|
||||||
return (
|
|
||||||
<Text
|
|
||||||
variant="small"
|
|
||||||
className="bg-gradient-to-r from-neutral-600 via-neutral-500 to-neutral-600 bg-[length:200%_100%] bg-clip-text text-xs text-transparent [animation:shimmer_2s_ease-in-out_infinite]"
|
|
||||||
>
|
|
||||||
Taking a bit more time...
|
|
||||||
</Text>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
|
||||||
import { StreamingMessage } from "../StreamingMessage/StreamingMessage";
|
|
||||||
import { ThinkingMessage } from "../ThinkingMessage/ThinkingMessage";
|
|
||||||
import { LastToolResponse } from "./components/LastToolResponse/LastToolResponse";
|
|
||||||
import { MessageItem } from "./components/MessageItem/MessageItem";
|
|
||||||
import { findLastMessageIndex, shouldSkipAgentOutput } from "./helpers";
|
|
||||||
import { useMessageList } from "./useMessageList";
|
|
||||||
|
|
||||||
export interface MessageListProps {
|
|
||||||
messages: ChatMessageData[];
|
|
||||||
streamingChunks?: string[];
|
|
||||||
isStreaming?: boolean;
|
|
||||||
className?: string;
|
|
||||||
onStreamComplete?: () => void;
|
|
||||||
onSendMessage?: (content: string) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function MessageList({
|
|
||||||
messages,
|
|
||||||
streamingChunks = [],
|
|
||||||
isStreaming = false,
|
|
||||||
className,
|
|
||||||
onStreamComplete,
|
|
||||||
onSendMessage,
|
|
||||||
}: MessageListProps) {
|
|
||||||
const { messagesEndRef, messagesContainerRef } = useMessageList({
|
|
||||||
messageCount: messages.length,
|
|
||||||
isStreaming,
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Keeps this for debugging purposes 💆🏽
|
|
||||||
*/
|
|
||||||
console.log(messages);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="relative flex min-h-0 flex-1 flex-col">
|
|
||||||
{/* Top fade shadow */}
|
|
||||||
<div className="pointer-events-none absolute top-0 z-10 h-8 w-full bg-gradient-to-b from-[#f8f8f9] to-transparent" />
|
|
||||||
|
|
||||||
<div
|
|
||||||
ref={messagesContainerRef}
|
|
||||||
className={cn(
|
|
||||||
"flex-1 overflow-y-auto overflow-x-hidden",
|
|
||||||
"scrollbar-thin scrollbar-track-transparent scrollbar-thumb-zinc-300",
|
|
||||||
className,
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<div className="mx-auto flex min-w-0 flex-col hyphens-auto break-words py-4">
|
|
||||||
{/* Render all persisted messages */}
|
|
||||||
{(() => {
|
|
||||||
const lastAssistantMessageIndex = findLastMessageIndex(
|
|
||||||
messages,
|
|
||||||
(msg) => msg.type === "message" && msg.role === "assistant",
|
|
||||||
);
|
|
||||||
|
|
||||||
const lastToolResponseIndex = findLastMessageIndex(
|
|
||||||
messages,
|
|
||||||
(msg) => msg.type === "tool_response",
|
|
||||||
);
|
|
||||||
|
|
||||||
return messages.map((message, index) => {
|
|
||||||
// Skip agent_output tool_responses that should be rendered inside assistant messages
|
|
||||||
if (shouldSkipAgentOutput(message, messages[index - 1])) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Render last tool_response as AIChatBubble
|
|
||||||
if (
|
|
||||||
message.type === "tool_response" &&
|
|
||||||
index === lastToolResponseIndex
|
|
||||||
) {
|
|
||||||
return (
|
|
||||||
<LastToolResponse
|
|
||||||
key={index}
|
|
||||||
message={message}
|
|
||||||
prevMessage={messages[index - 1]}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<MessageItem
|
|
||||||
key={index}
|
|
||||||
message={message}
|
|
||||||
messages={messages}
|
|
||||||
index={index}
|
|
||||||
lastAssistantMessageIndex={lastAssistantMessageIndex}
|
|
||||||
isStreaming={isStreaming}
|
|
||||||
onSendMessage={onSendMessage}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
})()}
|
|
||||||
|
|
||||||
{/* Render thinking message when streaming but no chunks yet */}
|
|
||||||
{isStreaming && streamingChunks.length === 0 && <ThinkingMessage />}
|
|
||||||
|
|
||||||
{/* Render streaming message if active */}
|
|
||||||
{isStreaming && streamingChunks.length > 0 && (
|
|
||||||
<StreamingMessage
|
|
||||||
chunks={streamingChunks}
|
|
||||||
onComplete={onStreamComplete}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* Invisible div to scroll to */}
|
|
||||||
<div ref={messagesEndRef} />
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Bottom fade shadow */}
|
|
||||||
<div className="pointer-events-none absolute bottom-0 z-10 h-8 w-full bg-gradient-to-t from-[#f8f8f9] to-transparent" />
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user