diff --git a/.branchlet.json b/.branchlet.json index cc13ff9f74..d02cd60e20 100644 --- a/.branchlet.json +++ b/.branchlet.json @@ -29,8 +29,7 @@ "postCreateCmd": [ "cd autogpt_platform/autogpt_libs && poetry install", "cd autogpt_platform/backend && poetry install && poetry run prisma generate", - "cd autogpt_platform/frontend && pnpm install", - "cd docs && pip install -r requirements.txt" + "cd autogpt_platform/frontend && pnpm install" ], "terminalCommand": "code .", "deleteBranchWithWorktree": false diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 870e6b4b0a..3c72eaae18 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -160,7 +160,7 @@ pnpm storybook # Start component development server **Backend Entry Points:** -- `backend/backend/server/server.py` - FastAPI application setup +- `backend/backend/api/rest_api.py` - FastAPI application setup - `backend/backend/data/` - Database models and user management - `backend/blocks/` - Agent execution blocks and logic @@ -219,7 +219,7 @@ Agents are built using a visual block-based system where each block performs a s ### API Development -1. Update routes in `/backend/backend/server/routers/` +1. Update routes in `/backend/backend/api/features/` 2. Add/update Pydantic models in same directory 3. Write tests alongside route files 4. For `data/*.py` changes, validate user ID checks @@ -285,7 +285,7 @@ Agents are built using a visual block-based system where each block performs a s ### Security Guidelines -**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`): +**Cache Protection Middleware** (`/backend/backend/api/middleware/security.py`): - Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private` - Uses allow list approach for cacheable paths (static assets, health checks, public pages) diff --git a/.gitignore b/.gitignore index dfce8ba810..1a2291b516 100644 --- a/.gitignore +++ b/.gitignore @@ -178,4 +178,5 @@ autogpt_platform/backend/settings.py *.ign.* .test-contents .claude/settings.local.json +CLAUDE.local.md /autogpt_platform/backend/logs diff --git a/AGENTS.md b/AGENTS.md index cd176f8a2d..202c4c6e02 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -16,7 +16,6 @@ See `docs/content/platform/getting-started.md` for setup instructions. - Format Python code with `poetry run format`. - Format frontend code using `pnpm format`. - ## Frontend guidelines: See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference: @@ -33,14 +32,17 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference: 4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only 5. **Testing**: Add Storybook stories for new components, Playwright for E2E 6. **Code conventions**: Function declarations (not arrow functions) for components/handlers + - Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component - Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts) - Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible - Avoid large hooks, abstract logic into `helpers.ts` files when sensible - Use function declarations for components, arrow functions only for callbacks - No barrel files or `index.ts` re-exports -- Do not use `useCallback` or `useMemo` unless strictly needed - Avoid comments at all times unless the code is very complex +- Do not use `useCallback` or `useMemo` unless asked to optimise a given function +- Do not type hook returns, let Typescript infer as much as possible +- Never type with `any`, if not types available use `unknown` ## Testing @@ -49,22 +51,8 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference: Always run the relevant linters and tests before committing. Use conventional commit messages for all commits (e.g. `feat(backend): add API`). - Types: - - feat - - fix - - refactor - - ci - - dx (developer experience) - Scopes: - - platform - - platform/library - - platform/marketplace - - backend - - backend/executor - - frontend - - frontend/library - - frontend/marketplace - - blocks +Types: - feat - fix - refactor - ci - dx (developer experience) +Scopes: - platform - platform/library - platform/marketplace - backend - backend/executor - frontend - frontend/library - frontend/marketplace - blocks ## Pull requests diff --git a/autogpt_platform/CLAUDE.md b/autogpt_platform/CLAUDE.md index 2c76e7db80..62adbdaefa 100644 --- a/autogpt_platform/CLAUDE.md +++ b/autogpt_platform/CLAUDE.md @@ -6,152 +6,30 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co AutoGPT Platform is a monorepo containing: -- **Backend** (`/backend`): Python FastAPI server with async support -- **Frontend** (`/frontend`): Next.js React application -- **Shared Libraries** (`/autogpt_libs`): Common Python utilities +- **Backend** (`backend`): Python FastAPI server with async support +- **Frontend** (`frontend`): Next.js React application +- **Shared Libraries** (`autogpt_libs`): Common Python utilities -## Essential Commands +## Component Documentation -### Backend Development +- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks +- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns -```bash -# Install dependencies -cd backend && poetry install - -# Run database migrations -poetry run prisma migrate dev - -# Start all services (database, redis, rabbitmq, clamav) -docker compose up -d - -# Run the backend server -poetry run serve - -# Run tests -poetry run test - -# Run specific test -poetry run pytest path/to/test_file.py::test_function_name - -# Run block tests (tests that validate all blocks work correctly) -poetry run pytest backend/blocks/test/test_block.py -xvs - -# Run tests for a specific block (e.g., GetCurrentTimeBlock) -poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs - -# Lint and format -# prefer format if you want to just "fix" it and only get the errors that can't be autofixed -poetry run format # Black + isort -poetry run lint # ruff -``` - -More details can be found in TESTING.md - -#### Creating/Updating Snapshots - -When you first write a test or when the expected output changes: - -```bash -poetry run pytest path/to/test.py --snapshot-update -``` - -⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected. - -### Frontend Development - -```bash -# Install dependencies -cd frontend && pnpm i - -# Generate API client from OpenAPI spec -pnpm generate:api - -# Start development server -pnpm dev - -# Run E2E tests -pnpm test - -# Run Storybook for component development -pnpm storybook - -# Build production -pnpm build - -# Format and lint -pnpm format - -# Type checking -pnpm types -``` - -**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns. - -**Key Frontend Conventions:** - -- Separate render logic from data/behavior in components -- Use generated API hooks from `@/app/api/__generated__/endpoints/` -- Use function declarations (not arrow functions) for components/handlers -- Use design system components from `src/components/` (atoms, molecules, organisms) -- Only use Phosphor Icons -- Never use `src/components/__legacy__/*` or deprecated `BackendAPI` - -## Architecture Overview - -### Backend Architecture - -- **API Layer**: FastAPI with REST and WebSocket endpoints -- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings -- **Queue System**: RabbitMQ for async task processing -- **Execution Engine**: Separate executor service processes agent workflows -- **Authentication**: JWT-based with Supabase integration -- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies - -### Frontend Architecture - -- **Framework**: Next.js 15 App Router (client-first approach) -- **Data Fetching**: Type-safe generated API hooks via Orval + React Query -- **State Management**: React Query for server state, co-located UI state in components/hooks -- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks) -- **Workflow Builder**: Visual graph editor using @xyflow/react -- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling -- **Icons**: Phosphor Icons only -- **Feature Flags**: LaunchDarkly integration -- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions -- **Testing**: Playwright for E2E, Storybook for component development - -### Key Concepts +## Key Concepts 1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend -2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks +2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks 3. **Integrations**: OAuth and API connections stored per user 4. **Store**: Marketplace for sharing agent templates 5. **Virus Scanning**: ClamAV integration for file upload security -### Testing Approach - -- Backend uses pytest with snapshot testing for API responses -- Test files are colocated with source files (`*_test.py`) -- Frontend uses Playwright for E2E tests -- Component testing via Storybook - -### Database Schema - -Key models (defined in `/backend/schema.prisma`): - -- `User`: Authentication and profile data -- `AgentGraph`: Workflow definitions with version control -- `AgentGraphExecution`: Execution history and results -- `AgentNode`: Individual nodes in a workflow -- `StoreListing`: Marketplace listings for sharing agents - ### Environment Configuration #### Configuration Files -- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides) -- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides) -- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides) +- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides) +- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides) +- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides) #### Docker Environment Loading Order @@ -167,83 +45,12 @@ Key models (defined in `/backend/schema.prisma`): - Backend/Frontend services use YAML anchors for consistent configuration - Supabase services (`db/docker/docker-compose.yml`) follow the same pattern -### Common Development Tasks - -**Adding a new block:** - -Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers: - -- Provider configuration with `ProviderBuilder` -- Block schema definition -- Authentication (API keys, OAuth, webhooks) -- Testing and validation -- File organization - -Quick steps: - -1. Create new file in `/backend/backend/blocks/` -2. Configure provider using `ProviderBuilder` in `_config.py` -3. Inherit from `Block` base class -4. Define input/output schemas using `BlockSchema` -5. Implement async `run` method -6. Generate unique block ID using `uuid.uuid4()` -7. Test with `poetry run pytest backend/blocks/test/test_block.py` - -Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively? -ex: do the inputs and outputs tie well together? - -If you get any pushback or hit complex block conditions check the new_blocks guide in the docs. - -**Modifying the API:** - -1. Update route in `/backend/backend/server/routers/` -2. Add/update Pydantic models in same directory -3. Write tests alongside the route file -4. Run `poetry run test` to verify - -### Frontend guidelines: - -See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference: - -1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx` - - Add `usePageName.ts` hook for logic - - Put sub-components in local `components/` folder -2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts` - - Use design system components from `src/components/` (atoms, molecules, organisms) - - Never use `src/components/__legacy__/*` -3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/` - - Regenerate with `pnpm generate:api` - - Pattern: `use{Method}{Version}{OperationName}` -4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only -5. **Testing**: Add Storybook stories for new components, Playwright for E2E -6. **Code conventions**: Function declarations (not arrow functions) for components/handlers -- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component -- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts) -- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible -- Avoid large hooks, abstract logic into `helpers.ts` files when sensible -- Use function declarations for components, arrow functions only for callbacks -- No barrel files or `index.ts` re-exports -- Do not use `useCallback` or `useMemo` unless strictly needed -- Avoid comments at all times unless the code is very complex - -### Security Implementation - -**Cache Protection Middleware:** - -- Located in `/backend/backend/server/middleware/security.py` -- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private` -- Uses an allow list approach - only explicitly permitted paths can be cached -- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation -- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies -- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware -- Applied to both main API server and external API applications - ### Creating Pull Requests -- Create the PR aginst the `dev` branch of the repository. -- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/ -- Use conventional commit messages (see below)/ -- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/ +- Create the PR against the `dev` branch of the repository. +- Ensure the branch name is descriptive (e.g., `feature/add-new-block`) +- Use conventional commit messages (see below) +- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description - Run the github pre-commit hooks to ensure code quality. ### Reviewing/Revising Pull Requests diff --git a/autogpt_platform/backend/.env.default b/autogpt_platform/backend/.env.default index 3a98cdbbc7..fa52ba812a 100644 --- a/autogpt_platform/backend/.env.default +++ b/autogpt_platform/backend/.env.default @@ -179,5 +179,10 @@ AYRSHARE_JWT_KEY= SMARTLEAD_API_KEY= ZEROBOUNCE_API_KEY= +# PostHog Analytics +# Get API key from https://posthog.com - Project Settings > Project API Key +POSTHOG_API_KEY= +POSTHOG_HOST=https://eu.i.posthog.com + # Other Services AUTOMOD_API_KEY= diff --git a/autogpt_platform/backend/CLAUDE.md b/autogpt_platform/backend/CLAUDE.md new file mode 100644 index 0000000000..53d52bb4d3 --- /dev/null +++ b/autogpt_platform/backend/CLAUDE.md @@ -0,0 +1,170 @@ +# CLAUDE.md - Backend + +This file provides guidance to Claude Code when working with the backend. + +## Essential Commands + +To run something with Python package dependencies you MUST use `poetry run ...`. + +```bash +# Install dependencies +poetry install + +# Run database migrations +poetry run prisma migrate dev + +# Start all services (database, redis, rabbitmq, clamav) +docker compose up -d + +# Run the backend as a whole +poetry run app + +# Run tests +poetry run test + +# Run specific test +poetry run pytest path/to/test_file.py::test_function_name + +# Run block tests (tests that validate all blocks work correctly) +poetry run pytest backend/blocks/test/test_block.py -xvs + +# Run tests for a specific block (e.g., GetCurrentTimeBlock) +poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs + +# Lint and format +# prefer format if you want to just "fix" it and only get the errors that can't be autofixed +poetry run format # Black + isort +poetry run lint # ruff +``` + +More details can be found in @TESTING.md + +### Creating/Updating Snapshots + +When you first write a test or when the expected output changes: + +```bash +poetry run pytest path/to/test.py --snapshot-update +``` + +⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected. + +## Architecture + +- **API Layer**: FastAPI with REST and WebSocket endpoints +- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings +- **Queue System**: RabbitMQ for async task processing +- **Execution Engine**: Separate executor service processes agent workflows +- **Authentication**: JWT-based with Supabase integration +- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies + +## Testing Approach + +- Uses pytest with snapshot testing for API responses +- Test files are colocated with source files (`*_test.py`) + +## Database Schema + +Key models (defined in `schema.prisma`): + +- `User`: Authentication and profile data +- `AgentGraph`: Workflow definitions with version control +- `AgentGraphExecution`: Execution history and results +- `AgentNode`: Individual nodes in a workflow +- `StoreListing`: Marketplace listings for sharing agents + +## Environment Configuration + +- **Backend**: `.env.default` (defaults) → `.env` (user overrides) + +## Common Development Tasks + +### Adding a new block + +Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers: + +- Provider configuration with `ProviderBuilder` +- Block schema definition +- Authentication (API keys, OAuth, webhooks) +- Testing and validation +- File organization + +Quick steps: + +1. Create new file in `backend/blocks/` +2. Configure provider using `ProviderBuilder` in `_config.py` +3. Inherit from `Block` base class +4. Define input/output schemas using `BlockSchema` +5. Implement async `run` method +6. Generate unique block ID using `uuid.uuid4()` +7. Test with `poetry run pytest backend/blocks/test/test_block.py` + +Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively? +ex: do the inputs and outputs tie well together? + +If you get any pushback or hit complex block conditions check the new_blocks guide in the docs. + +#### Handling files in blocks with `store_media_file()` + +When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back: + +| Format | Use When | Returns | +|--------|----------|---------| +| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) | +| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) | +| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs | + +**Examples:** + +```python +# INPUT: Need to process file locally with ffmpeg +local_path = await store_media_file( + file=input_data.video, + execution_context=execution_context, + return_format="for_local_processing", +) +# local_path = "video.mp4" - use with Path/ffmpeg/etc + +# INPUT: Need to send to external API like Replicate +image_b64 = await store_media_file( + file=input_data.image, + execution_context=execution_context, + return_format="for_external_api", +) +# image_b64 = "data:image/png;base64,iVBORw0..." - send to API + +# OUTPUT: Returning result from block +result_url = await store_media_file( + file=generated_image_url, + execution_context=execution_context, + return_format="for_block_output", +) +yield "image_url", result_url +# In CoPilot: result_url = "workspace://abc123" +# In graphs: result_url = "data:image/png;base64,..." +``` + +**Key points:** + +- `for_block_output` is the ONLY format that auto-adapts to execution context +- Always use `for_block_output` for block outputs unless you have a specific reason not to +- Never hardcode workspace checks - let `for_block_output` handle it + +### Modifying the API + +1. Update route in `backend/api/features/` +2. Add/update Pydantic models in same directory +3. Write tests alongside the route file +4. Run `poetry run test` to verify + +## Security Implementation + +### Cache Protection Middleware + +- Located in `backend/api/middleware/security.py` +- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private` +- Uses an allow list approach - only explicitly permitted paths can be cached +- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation +- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies +- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware +- Applied to both main API server and external API applications diff --git a/autogpt_platform/backend/TESTING.md b/autogpt_platform/backend/TESTING.md index a3a5db68ef..2e09144485 100644 --- a/autogpt_platform/backend/TESTING.md +++ b/autogpt_platform/backend/TESTING.md @@ -138,7 +138,7 @@ If the test doesn't need the `user_id` specifically, mocking is not necessary as #### Using Global Auth Fixtures -Two global auth fixtures are provided by `backend/server/conftest.py`: +Two global auth fixtures are provided by `backend/api/conftest.py`: - `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id") - `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id") diff --git a/autogpt_platform/backend/backend/api/external/v1/routes.py b/autogpt_platform/backend/backend/api/external/v1/routes.py index 58e15dc6a3..00933c1899 100644 --- a/autogpt_platform/backend/backend/api/external/v1/routes.py +++ b/autogpt_platform/backend/backend/api/external/v1/routes.py @@ -86,6 +86,8 @@ async def execute_graph_block( obj = backend.data.block.get_block(block_id) if not obj: raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.") + if obj.disabled: + raise HTTPException(status_code=403, detail=f"Block #{block_id} is disabled.") output = defaultdict(list) async for name, data in obj.execute(data): diff --git a/autogpt_platform/backend/backend/api/features/builder/routes.py b/autogpt_platform/backend/backend/api/features/builder/routes.py index 7fe9cab189..15b922178d 100644 --- a/autogpt_platform/backend/backend/api/features/builder/routes.py +++ b/autogpt_platform/backend/backend/api/features/builder/routes.py @@ -17,7 +17,7 @@ router = fastapi.APIRouter( ) -# Taken from backend/server/v2/store/db.py +# Taken from backend/api/features/store/db.py def sanitize_query(query: str | None) -> str | None: if query is None: return query diff --git a/autogpt_platform/backend/backend/api/features/chat/config.py b/autogpt_platform/backend/backend/api/features/chat/config.py index 95aef7f2ed..dba7934877 100644 --- a/autogpt_platform/backend/backend/api/features/chat/config.py +++ b/autogpt_platform/backend/backend/api/features/chat/config.py @@ -33,9 +33,15 @@ class ChatConfig(BaseSettings): stream_timeout: int = Field(default=300, description="Stream timeout in seconds") max_retries: int = Field(default=3, description="Maximum number of retries") - max_agent_runs: int = Field(default=3, description="Maximum number of agent runs") + max_agent_runs: int = Field(default=30, description="Maximum number of agent runs") max_agent_schedules: int = Field( - default=3, description="Maximum number of agent schedules" + default=30, description="Maximum number of agent schedules" + ) + + # Long-running operation configuration + long_running_operation_ttl: int = Field( + default=600, + description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)", ) # Langfuse Prompt Management Configuration diff --git a/autogpt_platform/backend/backend/api/features/chat/db.py b/autogpt_platform/backend/backend/api/features/chat/db.py index 05a3553cc8..d34b4e5b07 100644 --- a/autogpt_platform/backend/backend/api/features/chat/db.py +++ b/autogpt_platform/backend/backend/api/features/chat/db.py @@ -247,3 +247,45 @@ async def get_chat_session_message_count(session_id: str) -> int: """Get the number of messages in a chat session.""" count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id}) return count + + +async def update_tool_message_content( + session_id: str, + tool_call_id: str, + new_content: str, +) -> bool: + """Update the content of a tool message in chat history. + + Used by background tasks to update pending operation messages with final results. + + Args: + session_id: The chat session ID. + tool_call_id: The tool call ID to find the message. + new_content: The new content to set. + + Returns: + True if a message was updated, False otherwise. + """ + try: + result = await PrismaChatMessage.prisma().update_many( + where={ + "sessionId": session_id, + "toolCallId": tool_call_id, + }, + data={ + "content": new_content, + }, + ) + if result == 0: + logger.warning( + f"No message found to update for session {session_id}, " + f"tool_call_id {tool_call_id}" + ) + return False + return True + except Exception as e: + logger.error( + f"Failed to update tool message for session {session_id}, " + f"tool_call_id {tool_call_id}: {e}" + ) + return False diff --git a/autogpt_platform/backend/backend/api/features/chat/model.py b/autogpt_platform/backend/backend/api/features/chat/model.py index 75bda11127..7318ef88d7 100644 --- a/autogpt_platform/backend/backend/api/features/chat/model.py +++ b/autogpt_platform/backend/backend/api/features/chat/model.py @@ -295,6 +295,21 @@ async def cache_chat_session(session: ChatSession) -> None: await _cache_session(session) +async def invalidate_session_cache(session_id: str) -> None: + """Invalidate a chat session from Redis cache. + + Used by background tasks to ensure fresh data is loaded on next access. + This is best-effort - Redis failures are logged but don't fail the operation. + """ + try: + redis_key = _get_session_cache_key(session_id) + async_redis = await get_redis_async() + await async_redis.delete(redis_key) + except Exception as e: + # Best-effort: log but don't fail - cache will expire naturally + logger.warning(f"Failed to invalidate session cache for {session_id}: {e}") + + async def _get_session_from_db(session_id: str) -> ChatSession | None: """Get a chat session from the database.""" prisma_session = await chat_db.get_chat_session(session_id) diff --git a/autogpt_platform/backend/backend/api/features/chat/response_model.py b/autogpt_platform/backend/backend/api/features/chat/response_model.py index 49a9b38e8f..53a8cf3a1f 100644 --- a/autogpt_platform/backend/backend/api/features/chat/response_model.py +++ b/autogpt_platform/backend/backend/api/features/chat/response_model.py @@ -31,6 +31,7 @@ class ResponseType(str, Enum): # Other ERROR = "error" USAGE = "usage" + HEARTBEAT = "heartbeat" class StreamBaseResponse(BaseModel): @@ -142,3 +143,20 @@ class StreamError(StreamBaseResponse): details: dict[str, Any] | None = Field( default=None, description="Additional error details" ) + + +class StreamHeartbeat(StreamBaseResponse): + """Heartbeat to keep SSE connection alive during long-running operations. + + Uses SSE comment format (: comment) which is ignored by clients but keeps + the connection alive through proxies and load balancers. + """ + + type: ResponseType = ResponseType.HEARTBEAT + toolCallId: str | None = Field( + default=None, description="Tool call ID if heartbeat is for a specific tool" + ) + + def to_sse(self) -> str: + """Convert to SSE comment format to keep connection alive.""" + return ": heartbeat\n\n" diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index 3daf378f65..20216162b5 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -5,9 +5,9 @@ from asyncio import CancelledError from collections.abc import AsyncGenerator from typing import Any +import openai import orjson -from langfuse import get_client, propagate_attributes -from langfuse.openai import openai # type: ignore +from langfuse import get_client from openai import ( APIConnectionError, APIError, @@ -17,6 +17,7 @@ from openai import ( ) from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam +from backend.data.redis_client import get_redis_async from backend.data.understanding import ( format_understanding_for_prompt, get_business_understanding, @@ -24,6 +25,7 @@ from backend.data.understanding import ( from backend.util.exceptions import NotFoundError from backend.util.settings import Settings +from . import db as chat_db from .config import ChatConfig from .model import ( ChatMessage, @@ -31,6 +33,7 @@ from .model import ( Usage, cache_chat_session, get_chat_session, + invalidate_session_cache, update_session_title, upsert_chat_session, ) @@ -38,6 +41,7 @@ from .response_model import ( StreamBaseResponse, StreamError, StreamFinish, + StreamHeartbeat, StreamStart, StreamTextDelta, StreamTextEnd, @@ -47,7 +51,14 @@ from .response_model import ( StreamToolOutputAvailable, StreamUsage, ) -from .tools import execute_tool, tools +from .tools import execute_tool, get_tool, tools +from .tools.models import ( + ErrorResponse, + OperationInProgressResponse, + OperationPendingResponse, + OperationStartedResponse, +) +from .tracking import track_user_message logger = logging.getLogger(__name__) @@ -58,11 +69,126 @@ client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) langfuse = get_client() +# Redis key prefix for tracking running long-running operations +# Used for idempotency across Kubernetes pods - prevents duplicate executions on browser refresh +RUNNING_OPERATION_PREFIX = "chat:running_operation:" -class LangfuseNotConfiguredError(Exception): - """Raised when Langfuse is required but not configured.""" +# Default system prompt used when Langfuse is not configured +# This is a snapshot of the "CoPilot Prompt" from Langfuse (version 11) +DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations. - pass +Here is everything you know about the current user from previous interactions: + + +{users_information} + + +## YOUR CORE MANDATE + +You are action-oriented. Your success is measured by: +- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"? +- **Demonstrable Proof**: Show working automations, not descriptions of what's possible +- **Time Saved**: Focus on tangible efficiency gains +- **Quality Output**: Deliver results that meet or exceed expectations + +## YOUR WORKFLOW + +Adapt flexibly to the conversation context. Not every interaction requires all stages: + +1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations. + +2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task. + +3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.). + +4. **Discover or Create Agents**: + - **Always check the user's library first** with `find_library_agent` (these may be customized to their needs) + - Search the marketplace with `find_agent` for pre-built automations + - Find reusable components with `find_block` + - Create custom solutions with `create_agent` if nothing suitable exists + - Modify existing library agents with `edit_agent` + +5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`. + +6. **Show Results**: Display outputs using `agent_output`. + +## AVAILABLE TOOLS + +**Understanding & Discovery:** +- `add_understanding`: Create a memory about the user's business or use cases for future sessions +- `search_docs`: Search platform documentation for specific technical information +- `get_doc_page`: Retrieve full text of a specific documentation page + +**Agent Discovery:** +- `find_library_agent`: Search the user's existing agents (CHECK HERE FIRST—these may be customized) +- `find_agent`: Search the marketplace for pre-built automations +- `find_block`: Find pre-written code units that perform specific tasks (agents are built from blocks) + +**Agent Creation & Editing:** +- `create_agent`: Create a new automation agent +- `edit_agent`: Modify an agent in the user's library + +**Execution & Output:** +- `run_agent`: Run an agent now, schedule it, or set up a webhook trigger +- `run_block`: Test or run a specific block independently +- `agent_output`: View results from previous agent runs + +## BEHAVIORAL GUIDELINES + +**Be Concise:** +- Target 2-5 short lines maximum +- Make every word count—no repetition or filler +- Use lightweight structure for scannability (bullets, numbered lists, short prompts) +- Avoid jargon (blocks, slugs, cron) unless the user asks + +**Be Proactive:** +- Suggest next steps before being asked +- Anticipate needs based on conversation context and user information +- Look for opportunities to expand scope when relevant +- Reveal capabilities through action, not explanation + +**Use Tools Effectively:** +- Select the right tool for each task +- **Always check `find_library_agent` before searching the marketplace** +- Use `add_understanding` to capture valuable business context +- When tool calls fail, try alternative approaches + +## CRITICAL REMINDER + +You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation.""" + +# Module-level set to hold strong references to background tasks. +# This prevents asyncio from garbage collecting tasks before they complete. +# Tasks are automatically removed on completion via done_callback. +_background_tasks: set[asyncio.Task] = set() + + +async def _mark_operation_started(tool_call_id: str) -> bool: + """Mark a long-running operation as started (Redis-based). + + Returns True if successfully marked (operation was not already running), + False if operation was already running (lost race condition). + Raises exception if Redis is unavailable (fail-closed). + """ + redis = await get_redis_async() + key = f"{RUNNING_OPERATION_PREFIX}{tool_call_id}" + # SETNX with TTL - atomic "set if not exists" + result = await redis.set(key, "1", ex=config.long_running_operation_ttl, nx=True) + return result is not None + + +async def _mark_operation_completed(tool_call_id: str) -> None: + """Mark a long-running operation as completed (remove Redis key). + + This is best-effort - if Redis fails, the TTL will eventually clean up. + """ + try: + redis = await get_redis_async() + key = f"{RUNNING_OPERATION_PREFIX}{tool_call_id}" + await redis.delete(key) + except Exception as e: + # Non-critical: TTL will clean up eventually + logger.warning(f"Failed to delete running operation key {tool_call_id}: {e}") def _is_langfuse_configured() -> bool: @@ -72,6 +198,30 @@ def _is_langfuse_configured() -> bool: ) +async def _get_system_prompt_template(context: str) -> str: + """Get the system prompt, trying Langfuse first with fallback to default. + + Args: + context: The user context/information to compile into the prompt. + + Returns: + The compiled system prompt string. + """ + if _is_langfuse_configured(): + try: + # cache_ttl_seconds=0 disables SDK caching to always get the latest prompt + # Use asyncio.to_thread to avoid blocking the event loop + prompt = await asyncio.to_thread( + langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0 + ) + return prompt.compile(users_information=context) + except Exception as e: + logger.warning(f"Failed to fetch prompt from Langfuse, using default: {e}") + + # Fallback to default prompt + return DEFAULT_SYSTEM_PROMPT.format(users_information=context) + + async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]: """Build the full system prompt including business understanding if available. @@ -80,12 +230,8 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]: If "default" and this is the user's first session, will use "onboarding" instead. Returns: - Tuple of (compiled prompt string, Langfuse prompt object for tracing) + Tuple of (compiled prompt string, business understanding object) """ - - # cache_ttl_seconds=0 disables SDK caching to always get the latest prompt - prompt = langfuse.get_prompt(config.langfuse_prompt_name, cache_ttl_seconds=0) - # If user is authenticated, try to fetch their business understanding understanding = None if user_id: @@ -94,25 +240,43 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]: except Exception as e: logger.warning(f"Failed to fetch business understanding: {e}") understanding = None + if understanding: context = format_understanding_for_prompt(understanding) else: context = "This is the first time you are meeting the user. Greet them and introduce them to the platform" - compiled = prompt.compile(users_information=context) + compiled = await _get_system_prompt_template(context) return compiled, understanding -async def _generate_session_title(message: str) -> str | None: +async def _generate_session_title( + message: str, + user_id: str | None = None, + session_id: str | None = None, +) -> str | None: """Generate a concise title for a chat session based on the first message. Args: message: The first user message in the session + user_id: User ID for OpenRouter tracing (optional) + session_id: Session ID for OpenRouter tracing (optional) Returns: A short title (3-6 words) or None if generation fails """ try: + # Build extra_body for OpenRouter tracing and PostHog analytics + extra_body: dict[str, Any] = {} + if user_id: + extra_body["user"] = user_id[:128] # OpenRouter limit + extra_body["posthogDistinctId"] = user_id + if session_id: + extra_body["session_id"] = session_id[:128] # OpenRouter limit + extra_body["posthogProperties"] = { + "environment": settings.config.app_env.value, + } + response = await client.chat.completions.create( model=config.title_model, messages=[ @@ -127,6 +291,7 @@ async def _generate_session_title(message: str) -> str | None: {"role": "user", "content": message[:500]}, # Limit input length ], max_tokens=20, + extra_body=extra_body, ) title = response.choices[0].message.content if title: @@ -189,16 +354,6 @@ async def stream_chat_completion( f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}" ) - # Check if Langfuse is configured - required for chat functionality - if not _is_langfuse_configured(): - logger.error("Chat request failed: Langfuse is not configured") - yield StreamError( - errorText="Chat service is not available. Langfuse must be configured " - "with LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables." - ) - yield StreamFinish() - return - # Only fetch from Redis if session not provided (initial call) if session is None: session = await get_chat_session(session_id, user_id) @@ -218,18 +373,9 @@ async def stream_chat_completion( ) if message: - # Build message content with context if provided - message_content = message - if context and context.get("url") and context.get("content"): - context_text = f"Page URL: {context['url']}\n\nPage Content:\n{context['content']}\n\n---\n\nUser Message: {message}" - message_content = context_text - logger.info( - f"Including page context: URL={context['url']}, content_length={len(context['content'])}" - ) - session.messages.append( ChatMessage( - role="user" if is_user_message else "assistant", content=message_content + role="user" if is_user_message else "assistant", content=message ) ) logger.info( @@ -237,6 +383,14 @@ async def stream_chat_completion( f"new message_count={len(session.messages)}" ) + # Track user message in PostHog + if is_user_message: + track_user_message( + user_id=user_id, + session_id=session_id, + message_length=len(message), + ) + logger.info( f"Upserting session: {session.session_id} with user id {session.user_id}, " f"message_count={len(session.messages)}" @@ -256,10 +410,15 @@ async def stream_chat_completion( # stale data issues when the main flow modifies the session captured_session_id = session_id captured_message = message + captured_user_id = user_id async def _update_title(): try: - title = await _generate_session_title(captured_message) + title = await _generate_session_title( + captured_message, + user_id=captured_user_id, + session_id=captured_session_id, + ) if title: # Use dedicated title update function that doesn't # touch messages, avoiding race conditions @@ -276,347 +435,332 @@ async def stream_chat_completion( # Build system prompt with business understanding system_prompt, understanding = await _build_system_prompt(user_id) - # Create Langfuse trace for this LLM call (each call gets its own trace, grouped by session_id) - # Using v3 SDK: start_observation creates a root span, update_trace sets trace-level attributes - input = message - if not message and tool_call_response: - input = tool_call_response + # Initialize variables for streaming + assistant_response = ChatMessage( + role="assistant", + content="", + ) + accumulated_tool_calls: list[dict[str, Any]] = [] + has_saved_assistant_message = False + has_appended_streaming_message = False + last_cache_time = 0.0 + last_cache_content_len = 0 - langfuse = get_client() - with langfuse.start_as_current_observation( - as_type="span", - name="user-copilot-request", - input=input, - ) as span: - with propagate_attributes( - session_id=session_id, - user_id=user_id, - tags=["copilot"], - metadata={ - "users_information": format_understanding_for_prompt(understanding)[ - :200 - ] # langfuse only accepts upto to 200 chars - }, + has_yielded_end = False + has_yielded_error = False + has_done_tool_call = False + has_long_running_tool_call = False # Track if we had a long-running tool call + has_received_text = False + text_streaming_ended = False + tool_response_messages: list[ChatMessage] = [] + should_retry = False + + # Generate unique IDs for AI SDK protocol + import uuid as uuid_module + + message_id = str(uuid_module.uuid4()) + text_block_id = str(uuid_module.uuid4()) + + # Yield message start + yield StreamStart(messageId=message_id) + + try: + async for chunk in _stream_chat_chunks( + session=session, + tools=tools, + system_prompt=system_prompt, + text_block_id=text_block_id, ): - - # Initialize variables that will be used in finally block (must be defined before try) - assistant_response = ChatMessage( - role="assistant", - content="", - ) - accumulated_tool_calls: list[dict[str, Any]] = [] - has_saved_assistant_message = False - has_appended_streaming_message = False - last_cache_time = 0.0 - last_cache_content_len = 0 - - # Wrap main logic in try/finally to ensure Langfuse observations are always ended - has_yielded_end = False - has_yielded_error = False - has_done_tool_call = False - has_received_text = False - text_streaming_ended = False - tool_response_messages: list[ChatMessage] = [] - should_retry = False - - # Generate unique IDs for AI SDK protocol - import uuid as uuid_module - - message_id = str(uuid_module.uuid4()) - text_block_id = str(uuid_module.uuid4()) - - # Yield message start - yield StreamStart(messageId=message_id) - - try: - async for chunk in _stream_chat_chunks( - session=session, - tools=tools, - system_prompt=system_prompt, - text_block_id=text_block_id, + if isinstance(chunk, StreamTextStart): + # Emit text-start before first text delta + if not has_received_text: + yield chunk + elif isinstance(chunk, StreamTextDelta): + delta = chunk.delta or "" + assert assistant_response.content is not None + assistant_response.content += delta + has_received_text = True + if not has_appended_streaming_message: + session.messages.append(assistant_response) + has_appended_streaming_message = True + current_time = time.monotonic() + content_len = len(assistant_response.content) + if ( + current_time - last_cache_time >= 1.0 + and content_len > last_cache_content_len ): - - if isinstance(chunk, StreamTextStart): - # Emit text-start before first text delta - if not has_received_text: - yield chunk - elif isinstance(chunk, StreamTextDelta): - delta = chunk.delta or "" - assert assistant_response.content is not None - assistant_response.content += delta - has_received_text = True - if not has_appended_streaming_message: - session.messages.append(assistant_response) - has_appended_streaming_message = True - current_time = time.monotonic() - content_len = len(assistant_response.content) - if ( - current_time - last_cache_time >= 1.0 - and content_len > last_cache_content_len - ): - try: - await cache_chat_session(session) - except Exception as e: - logger.warning( - f"Failed to cache partial session {session.session_id}: {e}" - ) - last_cache_time = current_time - last_cache_content_len = content_len - yield chunk - elif isinstance(chunk, StreamTextEnd): - # Emit text-end after text completes - if has_received_text and not text_streaming_ended: - text_streaming_ended = True - if assistant_response.content: - logger.warn( - f"StreamTextEnd: Attempting to set output {assistant_response.content}" - ) - span.update_trace(output=assistant_response.content) - span.update(output=assistant_response.content) - yield chunk - elif isinstance(chunk, StreamToolInputStart): - # Emit text-end before first tool call, but only if we've received text - if has_received_text and not text_streaming_ended: - yield StreamTextEnd(id=text_block_id) - text_streaming_ended = True - yield chunk - elif isinstance(chunk, StreamToolInputAvailable): - # Accumulate tool calls in OpenAI format - accumulated_tool_calls.append( - { - "id": chunk.toolCallId, - "type": "function", - "function": { - "name": chunk.toolName, - "arguments": orjson.dumps(chunk.input).decode( - "utf-8" - ), - }, - } - ) - elif isinstance(chunk, StreamToolOutputAvailable): - result_content = ( - chunk.output - if isinstance(chunk.output, str) - else orjson.dumps(chunk.output).decode("utf-8") - ) - tool_response_messages.append( - ChatMessage( - role="tool", - content=result_content, - tool_call_id=chunk.toolCallId, - ) - ) - has_done_tool_call = True - # Track if any tool execution failed - if not chunk.success: - logger.warning( - f"Tool {chunk.toolName} (ID: {chunk.toolCallId}) execution failed" - ) - yield chunk - elif isinstance(chunk, StreamFinish): - if not has_done_tool_call: - # Emit text-end before finish if we received text but haven't closed it - if has_received_text and not text_streaming_ended: - yield StreamTextEnd(id=text_block_id) - text_streaming_ended = True - - # Save assistant message before yielding finish to ensure it's persisted - # even if client disconnects immediately after receiving StreamFinish - if not has_saved_assistant_message: - messages_to_save_early: list[ChatMessage] = [] - if accumulated_tool_calls: - assistant_response.tool_calls = ( - accumulated_tool_calls - ) - if not has_appended_streaming_message and ( - assistant_response.content - or assistant_response.tool_calls - ): - messages_to_save_early.append(assistant_response) - messages_to_save_early.extend(tool_response_messages) - - if messages_to_save_early: - session.messages.extend(messages_to_save_early) - logger.info( - f"Saving assistant message before StreamFinish: " - f"content_len={len(assistant_response.content or '')}, " - f"tool_calls={len(assistant_response.tool_calls or [])}, " - f"tool_responses={len(tool_response_messages)}" - ) - if ( - messages_to_save_early - or has_appended_streaming_message - ): - await upsert_chat_session(session) - has_saved_assistant_message = True - - has_yielded_end = True - yield chunk - elif isinstance(chunk, StreamError): - has_yielded_error = True - yield chunk - elif isinstance(chunk, StreamUsage): - session.usage.append( - Usage( - prompt_tokens=chunk.promptTokens, - completion_tokens=chunk.completionTokens, - total_tokens=chunk.totalTokens, - ) - ) - else: - logger.error( - f"Unknown chunk type: {type(chunk)}", exc_info=True - ) - if assistant_response.content: - langfuse.update_current_trace(output=assistant_response.content) - langfuse.update_current_span(output=assistant_response.content) - elif tool_response_messages: - langfuse.update_current_trace(output=str(tool_response_messages)) - langfuse.update_current_span(output=str(tool_response_messages)) - - except CancelledError: - if not has_saved_assistant_message: - if accumulated_tool_calls: - assistant_response.tool_calls = accumulated_tool_calls - if assistant_response.content: - assistant_response.content = ( - f"{assistant_response.content}\n\n[interrupted]" - ) - else: - assistant_response.content = "[interrupted]" - if not has_appended_streaming_message: - session.messages.append(assistant_response) - if tool_response_messages: - session.messages.extend(tool_response_messages) try: - await upsert_chat_session(session) + await cache_chat_session(session) except Exception as e: logger.warning( - f"Failed to save interrupted session {session.session_id}: {e}" + f"Failed to cache partial session {session.session_id}: {e}" ) - raise - except Exception as e: - logger.error(f"Error during stream: {e!s}", exc_info=True) - - # Check if this is a retryable error (JSON parsing, incomplete tool calls, etc.) - is_retryable = isinstance( - e, (orjson.JSONDecodeError, KeyError, TypeError) - ) - - if is_retryable and retry_count < config.max_retries: - logger.info( - f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}" - ) - should_retry = True - else: - # Non-retryable error or max retries exceeded - # Save any partial progress before reporting error - messages_to_save: list[ChatMessage] = [] - - # Add assistant message if it has content or tool calls - if accumulated_tool_calls: - assistant_response.tool_calls = accumulated_tool_calls - if not has_appended_streaming_message and ( - assistant_response.content or assistant_response.tool_calls - ): - messages_to_save.append(assistant_response) - - # Add tool response messages after assistant message - messages_to_save.extend(tool_response_messages) - - if not has_saved_assistant_message: - if messages_to_save: - session.messages.extend(messages_to_save) - if messages_to_save or has_appended_streaming_message: - await upsert_chat_session(session) - - if not has_yielded_error: - error_message = str(e) - if not is_retryable: - error_message = f"Non-retryable error: {error_message}" - elif retry_count >= config.max_retries: - error_message = f"Max retries ({config.max_retries}) exceeded: {error_message}" - - error_response = StreamError(errorText=error_message) - yield error_response - if not has_yielded_end: - yield StreamFinish() - return - - # Handle retry outside of exception handler to avoid nesting - if should_retry and retry_count < config.max_retries: - logger.info( - f"Retrying stream_chat_completion for session {session_id}, attempt {retry_count + 1}" - ) - async for chunk in stream_chat_completion( - session_id=session.session_id, - user_id=user_id, - retry_count=retry_count + 1, - session=session, - context=context, - ): + last_cache_time = current_time + last_cache_content_len = content_len + yield chunk + elif isinstance(chunk, StreamTextEnd): + # Emit text-end after text completes + if has_received_text and not text_streaming_ended: + text_streaming_ended = True yield chunk - return # Exit after retry to avoid double-saving in finally block + elif isinstance(chunk, StreamToolInputStart): + # Emit text-end before first tool call, but only if we've received text + if has_received_text and not text_streaming_ended: + yield StreamTextEnd(id=text_block_id) + text_streaming_ended = True + yield chunk + elif isinstance(chunk, StreamToolInputAvailable): + # Accumulate tool calls in OpenAI format + accumulated_tool_calls.append( + { + "id": chunk.toolCallId, + "type": "function", + "function": { + "name": chunk.toolName, + "arguments": orjson.dumps(chunk.input).decode("utf-8"), + }, + } + ) + yield chunk + elif isinstance(chunk, StreamToolOutputAvailable): + result_content = ( + chunk.output + if isinstance(chunk.output, str) + else orjson.dumps(chunk.output).decode("utf-8") + ) + # Skip saving long-running operation responses - messages already saved in _yield_tool_call + # Use JSON parsing instead of substring matching to avoid false positives + is_long_running_response = False + try: + parsed = orjson.loads(result_content) + if isinstance(parsed, dict) and parsed.get("type") in ( + "operation_started", + "operation_in_progress", + ): + is_long_running_response = True + except (orjson.JSONDecodeError, TypeError): + pass # Not JSON or not a dict - treat as regular response + if is_long_running_response: + # Remove from accumulated_tool_calls since assistant message was already saved + accumulated_tool_calls[:] = [ + tc + for tc in accumulated_tool_calls + if tc["id"] != chunk.toolCallId + ] + has_long_running_tool_call = True + else: + tool_response_messages.append( + ChatMessage( + role="tool", + content=result_content, + tool_call_id=chunk.toolCallId, + ) + ) + has_done_tool_call = True + # Track if any tool execution failed + if not chunk.success: + logger.warning( + f"Tool {chunk.toolName} (ID: {chunk.toolCallId}) execution failed" + ) + yield chunk + elif isinstance(chunk, StreamFinish): + if not has_done_tool_call: + # Emit text-end before finish if we received text but haven't closed it + if has_received_text and not text_streaming_ended: + yield StreamTextEnd(id=text_block_id) + text_streaming_ended = True + + # Save assistant message before yielding finish to ensure it's persisted + # even if client disconnects immediately after receiving StreamFinish + if not has_saved_assistant_message: + messages_to_save_early: list[ChatMessage] = [] + if accumulated_tool_calls: + assistant_response.tool_calls = accumulated_tool_calls + if not has_appended_streaming_message and ( + assistant_response.content or assistant_response.tool_calls + ): + messages_to_save_early.append(assistant_response) + messages_to_save_early.extend(tool_response_messages) + + if messages_to_save_early: + session.messages.extend(messages_to_save_early) + logger.info( + f"Saving assistant message before StreamFinish: " + f"content_len={len(assistant_response.content or '')}, " + f"tool_calls={len(assistant_response.tool_calls or [])}, " + f"tool_responses={len(tool_response_messages)}" + ) + if messages_to_save_early or has_appended_streaming_message: + await upsert_chat_session(session) + has_saved_assistant_message = True + + has_yielded_end = True + yield chunk + elif isinstance(chunk, StreamError): + has_yielded_error = True + yield chunk + elif isinstance(chunk, StreamUsage): + session.usage.append( + Usage( + prompt_tokens=chunk.promptTokens, + completion_tokens=chunk.completionTokens, + total_tokens=chunk.totalTokens, + ) + ) + else: + logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True) + + except CancelledError: + if not has_saved_assistant_message: + if accumulated_tool_calls: + assistant_response.tool_calls = accumulated_tool_calls + if assistant_response.content: + assistant_response.content = ( + f"{assistant_response.content}\n\n[interrupted]" + ) + else: + assistant_response.content = "[interrupted]" + if not has_appended_streaming_message: + session.messages.append(assistant_response) + if tool_response_messages: + session.messages.extend(tool_response_messages) + try: + await upsert_chat_session(session) + except Exception as e: + logger.warning( + f"Failed to save interrupted session {session.session_id}: {e}" + ) + raise + except Exception as e: + logger.error(f"Error during stream: {e!s}", exc_info=True) + + # Check if this is a retryable error (JSON parsing, incomplete tool calls, etc.) + is_retryable = isinstance(e, (orjson.JSONDecodeError, KeyError, TypeError)) + + if is_retryable and retry_count < config.max_retries: + logger.info( + f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}" + ) + should_retry = True + else: + # Non-retryable error or max retries exceeded + # Save any partial progress before reporting error + messages_to_save: list[ChatMessage] = [] + + # Add assistant message if it has content or tool calls + if accumulated_tool_calls: + assistant_response.tool_calls = accumulated_tool_calls + if not has_appended_streaming_message and ( + assistant_response.content or assistant_response.tool_calls + ): + messages_to_save.append(assistant_response) + + # Add tool response messages after assistant message + messages_to_save.extend(tool_response_messages) - # Normal completion path - save session and handle tool call continuation - # Only save if we haven't already saved when StreamFinish was received if not has_saved_assistant_message: - logger.info( - f"Normal completion path: session={session.session_id}, " - f"current message_count={len(session.messages)}" - ) - - # Build the messages list in the correct order - messages_to_save: list[ChatMessage] = [] - - # Add assistant message with tool_calls if any - if accumulated_tool_calls: - assistant_response.tool_calls = accumulated_tool_calls - logger.info( - f"Added {len(accumulated_tool_calls)} tool calls to assistant message" - ) - if not has_appended_streaming_message and ( - assistant_response.content or assistant_response.tool_calls - ): - messages_to_save.append(assistant_response) - logger.info( - f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}" - ) - - # Add tool response messages after assistant message - messages_to_save.extend(tool_response_messages) - logger.info( - f"Saving {len(tool_response_messages)} tool response messages, " - f"total_to_save={len(messages_to_save)}" - ) - if messages_to_save: session.messages.extend(messages_to_save) - logger.info( - f"Extended session messages, new message_count={len(session.messages)}" - ) if messages_to_save or has_appended_streaming_message: await upsert_chat_session(session) - else: - logger.info( - "Assistant message already saved when StreamFinish was received, " - "skipping duplicate save" - ) - # If we did a tool call, stream the chat completion again to get the next response - if has_done_tool_call: - logger.info( - "Tool call executed, streaming chat completion again to get assistant response" - ) - async for chunk in stream_chat_completion( - session_id=session.session_id, - user_id=user_id, - session=session, # Pass session object to avoid Redis refetch - context=context, - tool_call_response=str(tool_response_messages), - ): - yield chunk + if not has_yielded_error: + error_message = str(e) + if not is_retryable: + error_message = f"Non-retryable error: {error_message}" + elif retry_count >= config.max_retries: + error_message = ( + f"Max retries ({config.max_retries}) exceeded: {error_message}" + ) + + error_response = StreamError(errorText=error_message) + yield error_response + if not has_yielded_end: + yield StreamFinish() + return + + # Handle retry outside of exception handler to avoid nesting + if should_retry and retry_count < config.max_retries: + logger.info( + f"Retrying stream_chat_completion for session {session_id}, attempt {retry_count + 1}" + ) + async for chunk in stream_chat_completion( + session_id=session.session_id, + user_id=user_id, + retry_count=retry_count + 1, + session=session, + context=context, + ): + yield chunk + return # Exit after retry to avoid double-saving in finally block + + # Normal completion path - save session and handle tool call continuation + # Only save if we haven't already saved when StreamFinish was received + if not has_saved_assistant_message: + logger.info( + f"Normal completion path: session={session.session_id}, " + f"current message_count={len(session.messages)}" + ) + + # Build the messages list in the correct order + messages_to_save: list[ChatMessage] = [] + + # Add assistant message with tool_calls if any + if accumulated_tool_calls: + assistant_response.tool_calls = accumulated_tool_calls + logger.info( + f"Added {len(accumulated_tool_calls)} tool calls to assistant message" + ) + if not has_appended_streaming_message and ( + assistant_response.content or assistant_response.tool_calls + ): + messages_to_save.append(assistant_response) + logger.info( + f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}" + ) + + # Add tool response messages after assistant message + messages_to_save.extend(tool_response_messages) + logger.info( + f"Saving {len(tool_response_messages)} tool response messages, " + f"total_to_save={len(messages_to_save)}" + ) + + if messages_to_save: + session.messages.extend(messages_to_save) + logger.info( + f"Extended session messages, new message_count={len(session.messages)}" + ) + # Save if there are regular (non-long-running) tool responses or streaming message. + # Long-running tools save their own state, but we still need to save regular tools + # that may be in the same response. + has_regular_tool_responses = len(tool_response_messages) > 0 + if has_regular_tool_responses or ( + not has_long_running_tool_call + and (messages_to_save or has_appended_streaming_message) + ): + await upsert_chat_session(session) + else: + logger.info( + "Assistant message already saved when StreamFinish was received, " + "skipping duplicate save" + ) + + # If we did a tool call, stream the chat completion again to get the next response + # Skip only if ALL tools were long-running (they handle their own completion) + has_regular_tools = len(tool_response_messages) > 0 + if has_done_tool_call and (has_regular_tools or not has_long_running_tool_call): + logger.info( + "Tool call executed, streaming chat completion again to get assistant response" + ) + async for chunk in stream_chat_completion( + session_id=session.session_id, + user_id=user_id, + session=session, # Pass session object to avoid Redis refetch + context=context, + tool_call_response=str(tool_response_messages), + ): + yield chunk # Retry configuration for OpenAI API calls @@ -650,6 +794,209 @@ def _is_region_blocked_error(error: Exception) -> bool: return "not available in your region" in str(error).lower() +async def _summarize_messages( + messages: list, + model: str, + api_key: str | None = None, + base_url: str | None = None, + timeout: float = 30.0, +) -> str: + """Summarize a list of messages into concise context. + + Uses the same model as the chat for higher quality summaries. + + Args: + messages: List of message dicts to summarize + model: Model to use for summarization (same as chat model) + api_key: API key for OpenAI client + base_url: Base URL for OpenAI client + timeout: Request timeout in seconds (default: 30.0) + + Returns: + Summarized text + """ + # Format messages for summarization + conversation = [] + for msg in messages: + role = msg.get("role", "") + content = msg.get("content", "") + # Include user, assistant, and tool messages (tool outputs are important context) + if content and role in ("user", "assistant", "tool"): + conversation.append(f"{role.upper()}: {content}") + + conversation_text = "\n\n".join(conversation) + + # Handle empty conversation + if not conversation_text: + return "No conversation history available." + + # Truncate conversation to fit within summarization model's context + # gpt-4o-mini has 128k context, but we limit to ~25k tokens (~100k chars) for safety + MAX_CHARS = 100_000 + if len(conversation_text) > MAX_CHARS: + conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]" + + # Call LLM to summarize + import openai + + summarization_client = openai.AsyncOpenAI( + api_key=api_key, base_url=base_url, timeout=timeout + ) + + response = await summarization_client.chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": ( + "Create a detailed summary of the conversation so far. " + "This summary will be used as context when continuing the conversation.\n\n" + "Before writing the summary, analyze each message chronologically to identify:\n" + "- User requests and their explicit goals\n" + "- Your approach and key decisions made\n" + "- Technical specifics (file names, tool outputs, function signatures)\n" + "- Errors encountered and resolutions applied\n\n" + "You MUST include ALL of the following sections:\n\n" + "## 1. Primary Request and Intent\n" + "The user's explicit goals and what they are trying to accomplish.\n\n" + "## 2. Key Technical Concepts\n" + "Technologies, frameworks, tools, and patterns being used or discussed.\n\n" + "## 3. Files and Resources Involved\n" + "Specific files examined or modified, with relevant snippets and identifiers.\n\n" + "## 4. Errors and Fixes\n" + "Problems encountered, error messages, and their resolutions. " + "Include any user feedback on fixes.\n\n" + "## 5. Problem Solving\n" + "Issues that have been resolved and how they were addressed.\n\n" + "## 6. All User Messages\n" + "A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n" + "## 7. Pending Tasks\n" + "Work items the user explicitly requested that have not yet been completed.\n\n" + "## 8. Current Work\n" + "Precise description of what was being worked on most recently, including relevant context.\n\n" + "## 9. Next Steps\n" + "What should happen next, aligned with the user's most recent requests. " + "Include verbatim quotes of recent instructions if relevant." + ), + }, + {"role": "user", "content": f"Summarize:\n\n{conversation_text}"}, + ], + max_tokens=1500, + temperature=0.3, + ) + + summary = response.choices[0].message.content + return summary or "No summary available." + + +def _ensure_tool_pairs_intact( + recent_messages: list[dict], + all_messages: list[dict], + start_index: int, +) -> list[dict]: + """ + Ensure tool_call/tool_response pairs stay together after slicing. + + When slicing messages for context compaction, a naive slice can separate + an assistant message containing tool_calls from its corresponding tool + response messages. This causes API validation errors (e.g., Anthropic's + "unexpected tool_use_id found in tool_result blocks"). + + This function checks for orphan tool responses in the slice and extends + backwards to include their corresponding assistant messages. + + Args: + recent_messages: The sliced messages to validate + all_messages: The complete message list (for looking up missing assistants) + start_index: The index in all_messages where recent_messages begins + + Returns: + A potentially extended list of messages with tool pairs intact + """ + if not recent_messages: + return recent_messages + + # Collect all tool_call_ids from assistant messages in the slice + available_tool_call_ids: set[str] = set() + for msg in recent_messages: + if msg.get("role") == "assistant" and msg.get("tool_calls"): + for tc in msg["tool_calls"]: + tc_id = tc.get("id") + if tc_id: + available_tool_call_ids.add(tc_id) + + # Find orphan tool responses (tool messages whose tool_call_id is missing) + orphan_tool_call_ids: set[str] = set() + for msg in recent_messages: + if msg.get("role") == "tool": + tc_id = msg.get("tool_call_id") + if tc_id and tc_id not in available_tool_call_ids: + orphan_tool_call_ids.add(tc_id) + + if not orphan_tool_call_ids: + # No orphans, slice is valid + return recent_messages + + # Find the assistant messages that contain the orphan tool_call_ids + # Search backwards from start_index in all_messages + messages_to_prepend: list[dict] = [] + for i in range(start_index - 1, -1, -1): + msg = all_messages[i] + if msg.get("role") == "assistant" and msg.get("tool_calls"): + msg_tool_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")} + if msg_tool_ids & orphan_tool_call_ids: + # This assistant message has tool_calls we need + # Also collect its contiguous tool responses that follow it + assistant_and_responses: list[dict] = [msg] + + # Scan forward from this assistant to collect tool responses + for j in range(i + 1, start_index): + following_msg = all_messages[j] + if following_msg.get("role") == "tool": + tool_id = following_msg.get("tool_call_id") + if tool_id and tool_id in msg_tool_ids: + assistant_and_responses.append(following_msg) + else: + # Stop at first non-tool message + break + + # Prepend the assistant and its tool responses (maintain order) + messages_to_prepend = assistant_and_responses + messages_to_prepend + # Mark these as found + orphan_tool_call_ids -= msg_tool_ids + # Also add this assistant's tool_call_ids to available set + available_tool_call_ids |= msg_tool_ids + + if not orphan_tool_call_ids: + # Found all missing assistants + break + + if orphan_tool_call_ids: + # Some tool_call_ids couldn't be resolved - remove those tool responses + # This shouldn't happen in normal operation but handles edge cases + logger.warning( + f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. " + "Removing orphan tool responses." + ) + recent_messages = [ + msg + for msg in recent_messages + if not ( + msg.get("role") == "tool" + and msg.get("tool_call_id") in orphan_tool_call_ids + ) + ] + + if messages_to_prepend: + logger.info( + f"Extended recent messages by {len(messages_to_prepend)} to preserve " + f"tool_call/tool_response pairs" + ) + return messages_to_prepend + recent_messages + + return recent_messages + + async def _stream_chat_chunks( session: ChatSession, tools: list[ChatCompletionToolParam], @@ -686,6 +1033,316 @@ async def _stream_chat_chunks( ) messages = [system_message] + messages + # Apply context window management + token_count = 0 # Initialize for exception handler + try: + from backend.util.prompt import estimate_token_count + + # Convert to dict for token counting + # OpenAI message types are TypedDicts, so they're already dict-like + messages_dict = [] + for msg in messages: + # TypedDict objects are already dicts, just filter None values + if isinstance(msg, dict): + msg_dict = {k: v for k, v in msg.items() if v is not None} + else: + # Fallback for unexpected types + msg_dict = dict(msg) + messages_dict.append(msg_dict) + + # Estimate tokens using appropriate tokenizer + # Normalize model name for token counting (tiktoken only supports OpenAI models) + token_count_model = model + if "/" in model: + # Strip provider prefix (e.g., "anthropic/claude-opus-4.5" -> "claude-opus-4.5") + token_count_model = model.split("/")[-1] + + # For Claude and other non-OpenAI models, approximate with gpt-4o tokenizer + # Most modern LLMs have similar tokenization (~1 token per 4 chars) + if "claude" in token_count_model.lower() or not any( + known in token_count_model.lower() + for known in ["gpt", "o1", "chatgpt", "text-"] + ): + token_count_model = "gpt-4o" + + # Attempt token counting with error handling + try: + token_count = estimate_token_count(messages_dict, model=token_count_model) + except Exception as token_error: + # If token counting fails, use gpt-4o as fallback approximation + logger.warning( + f"Token counting failed for model {token_count_model}: {token_error}. " + "Using gpt-4o approximation." + ) + token_count = estimate_token_count(messages_dict, model="gpt-4o") + + # If over threshold, summarize old messages + if token_count > 120_000: + KEEP_RECENT = 15 + + # Check if we have a system prompt at the start + has_system_prompt = ( + len(messages) > 0 and messages[0].get("role") == "system" + ) + + # Always attempt mitigation when over limit, even with few messages + if messages: + # Split messages based on whether system prompt exists + # Calculate start index for the slice + slice_start = max(0, len(messages_dict) - KEEP_RECENT) + recent_messages = messages_dict[-KEEP_RECENT:] + + # Ensure tool_call/tool_response pairs stay together + # This prevents API errors from orphan tool responses + recent_messages = _ensure_tool_pairs_intact( + recent_messages, messages_dict, slice_start + ) + + if has_system_prompt: + # Keep system prompt separate, summarize everything between system and recent + system_msg = messages[0] + old_messages_dict = messages_dict[1:-KEEP_RECENT] + else: + # No system prompt, summarize everything except recent + system_msg = None + old_messages_dict = messages_dict[:-KEEP_RECENT] + + # Summarize any non-empty old messages (no minimum threshold) + # If we're over the token limit, we need to compress whatever we can + if old_messages_dict: + # Summarize old messages using the same model as chat + summary_text = await _summarize_messages( + old_messages_dict, + model=model, + api_key=config.api_key, + base_url=config.base_url, + ) + + # Build new message list + # Use assistant role (not system) to prevent privilege escalation + # of user-influenced content to instruction-level authority + from openai.types.chat import ChatCompletionAssistantMessageParam + + summary_msg = ChatCompletionAssistantMessageParam( + role="assistant", + content=( + "[Previous conversation summary — for context only]: " + f"{summary_text}" + ), + ) + + # Rebuild messages based on whether we have a system prompt + if has_system_prompt: + # system_prompt + summary + recent_messages + messages = [system_msg, summary_msg] + recent_messages + else: + # summary + recent_messages (no original system prompt) + messages = [summary_msg] + recent_messages + + logger.info( + f"Context summarized: {token_count} tokens, " + f"summarized {len(old_messages_dict)} old messages, " + f"kept last {KEEP_RECENT} messages" + ) + + # Fallback: If still over limit after summarization, progressively drop recent messages + # This handles edge cases where recent messages are extremely large + new_messages_dict = [] + for msg in messages: + if isinstance(msg, dict): + msg_dict = {k: v for k, v in msg.items() if v is not None} + else: + msg_dict = dict(msg) + new_messages_dict.append(msg_dict) + + new_token_count = estimate_token_count( + new_messages_dict, model=token_count_model + ) + + if new_token_count > 120_000: + # Still over limit - progressively reduce KEEP_RECENT + logger.warning( + f"Still over limit after summarization: {new_token_count} tokens. " + "Reducing number of recent messages kept." + ) + + for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]: + if keep_count == 0: + # Try with just system prompt + summary (no recent messages) + if has_system_prompt: + messages = [system_msg, summary_msg] + else: + messages = [summary_msg] + logger.info( + "Trying with 0 recent messages (system + summary only)" + ) + else: + # Slice from ORIGINAL recent_messages to avoid duplicating summary + reduced_recent = ( + recent_messages[-keep_count:] + if len(recent_messages) >= keep_count + else recent_messages + ) + # Ensure tool pairs stay intact in the reduced slice + reduced_slice_start = max( + 0, len(recent_messages) - keep_count + ) + reduced_recent = _ensure_tool_pairs_intact( + reduced_recent, recent_messages, reduced_slice_start + ) + if has_system_prompt: + messages = [ + system_msg, + summary_msg, + ] + reduced_recent + else: + messages = [summary_msg] + reduced_recent + + new_messages_dict = [] + for msg in messages: + if isinstance(msg, dict): + msg_dict = { + k: v for k, v in msg.items() if v is not None + } + else: + msg_dict = dict(msg) + new_messages_dict.append(msg_dict) + + new_token_count = estimate_token_count( + new_messages_dict, model=token_count_model + ) + + if new_token_count <= 120_000: + logger.info( + f"Reduced to {keep_count} recent messages, " + f"now {new_token_count} tokens" + ) + break + else: + logger.error( + f"Unable to reduce token count below threshold even with 0 messages. " + f"Final count: {new_token_count} tokens" + ) + # ABSOLUTE LAST RESORT: Drop system prompt + # This should only happen if summary itself is massive + if has_system_prompt and len(messages) > 1: + messages = messages[1:] # Drop system prompt + logger.critical( + "CRITICAL: Dropped system prompt as absolute last resort. " + "Behavioral consistency may be affected." + ) + # Yield error to user + yield StreamError( + errorText=( + "Warning: System prompt dropped due to size constraints. " + "Assistant behavior may be affected." + ) + ) + else: + # No old messages to summarize - all messages are "recent" + # Apply progressive truncation to reduce token count + logger.warning( + f"Token count {token_count} exceeds threshold but no old messages to summarize. " + f"Applying progressive truncation to recent messages." + ) + + # Create a base list excluding system prompt to avoid duplication + # This is the pool of messages we'll slice from in the loop + # Use messages_dict for type consistency with _ensure_tool_pairs_intact + base_msgs = ( + messages_dict[1:] if has_system_prompt else messages_dict + ) + + # Try progressively smaller keep counts + new_token_count = token_count # Initialize with current count + for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]: + if keep_count == 0: + # Try with just system prompt (no recent messages) + if has_system_prompt: + messages = [system_msg] + logger.info( + "Trying with 0 recent messages (system prompt only)" + ) + else: + # No system prompt and no recent messages = empty messages list + # This is invalid, skip this iteration + continue + else: + if len(base_msgs) < keep_count: + continue # Skip if we don't have enough messages + + # Slice from base_msgs to get recent messages (without system prompt) + recent_messages = base_msgs[-keep_count:] + + # Ensure tool pairs stay intact in the reduced slice + reduced_slice_start = max(0, len(base_msgs) - keep_count) + recent_messages = _ensure_tool_pairs_intact( + recent_messages, base_msgs, reduced_slice_start + ) + + if has_system_prompt: + messages = [system_msg] + recent_messages + else: + messages = recent_messages + + new_messages_dict = [] + for msg in messages: + if msg is None: + continue # Skip None messages (type safety) + if isinstance(msg, dict): + msg_dict = { + k: v for k, v in msg.items() if v is not None + } + else: + msg_dict = dict(msg) + new_messages_dict.append(msg_dict) + + new_token_count = estimate_token_count( + new_messages_dict, model=token_count_model + ) + + if new_token_count <= 120_000: + logger.info( + f"Reduced to {keep_count} recent messages, " + f"now {new_token_count} tokens" + ) + break + else: + # Even with 0 messages still over limit + logger.error( + f"Unable to reduce token count below threshold even with 0 messages. " + f"Final count: {new_token_count} tokens. Messages may be extremely large." + ) + # ABSOLUTE LAST RESORT: Drop system prompt + if has_system_prompt and len(messages) > 1: + messages = messages[1:] # Drop system prompt + logger.critical( + "CRITICAL: Dropped system prompt as absolute last resort. " + "Behavioral consistency may be affected." + ) + # Yield error to user + yield StreamError( + errorText=( + "Warning: System prompt dropped due to size constraints. " + "Assistant behavior may be affected." + ) + ) + + except Exception as e: + logger.error(f"Context summarization failed: {e}", exc_info=True) + # If we were over the token limit, yield error to user + # Don't silently continue with oversized messages that will fail + if token_count > 120_000: + yield StreamError( + errorText=( + f"Unable to manage context window (token limit exceeded: {token_count} tokens). " + "Context summarization failed. Please start a new conversation." + ) + ) + yield StreamFinish() + return + # Otherwise, continue with original messages (under limit) + # Loop to handle tool calls and continue conversation while True: retry_count = 0 @@ -698,14 +1355,36 @@ async def _stream_chat_chunks( f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}" ) + # Build extra_body for OpenRouter tracing and PostHog analytics + extra_body: dict[str, Any] = { + "posthogProperties": { + "environment": settings.config.app_env.value, + }, + } + if session.user_id: + extra_body["user"] = session.user_id[:128] # OpenRouter limit + extra_body["posthogDistinctId"] = session.user_id + if session.session_id: + extra_body["session_id"] = session.session_id[ + :128 + ] # OpenRouter limit + # Create the stream with proper types + from typing import cast + + from openai.types.chat import ( + ChatCompletionMessageParam, + ChatCompletionStreamOptionsParam, + ) + stream = await client.chat.completions.create( model=model, - messages=messages, + messages=cast(list[ChatCompletionMessageParam], messages), tools=tools, tool_choice="auto", stream=True, - stream_options={"include_usage": True}, + stream_options=ChatCompletionStreamOptionsParam(include_usage=True), + extra_body=extra_body, ) # Variables to accumulate tool calls @@ -877,14 +1556,19 @@ async def _yield_tool_call( """ Yield a tool call and its execution result. + For tools marked with `is_long_running=True` (like agent generation), spawns a + background task so the operation survives SSE disconnections. For other tools, + yields heartbeat events every 15 seconds to keep the SSE connection alive. + Raises: orjson.JSONDecodeError: If tool call arguments cannot be parsed as JSON KeyError: If expected tool call fields are missing TypeError: If tool call structure is invalid """ + import uuid as uuid_module + tool_name = tool_calls[yield_idx]["function"]["name"] tool_call_id = tool_calls[yield_idx]["id"] - logger.info(f"Yielding tool call: {tool_calls[yield_idx]}") # Parse tool call arguments - handle empty arguments gracefully raw_arguments = tool_calls[yield_idx]["function"]["arguments"] @@ -899,12 +1583,384 @@ async def _yield_tool_call( input=arguments, ) - tool_execution_response: StreamToolOutputAvailable = await execute_tool( - tool_name=tool_name, - parameters=arguments, - tool_call_id=tool_call_id, - user_id=session.user_id, - session=session, + # Check if this tool is long-running (survives SSE disconnection) + tool = get_tool(tool_name) + if tool and tool.is_long_running: + # Atomic check-and-set: returns False if operation already running (lost race) + if not await _mark_operation_started(tool_call_id): + logger.info( + f"Tool call {tool_call_id} already in progress, returning status" + ) + # Build dynamic message based on tool name + if tool_name == "create_agent": + in_progress_msg = "Agent creation already in progress. Please wait..." + elif tool_name == "edit_agent": + in_progress_msg = "Agent edit already in progress. Please wait..." + else: + in_progress_msg = f"{tool_name} already in progress. Please wait..." + + yield StreamToolOutputAvailable( + toolCallId=tool_call_id, + toolName=tool_name, + output=OperationInProgressResponse( + message=in_progress_msg, + tool_call_id=tool_call_id, + ).model_dump_json(), + success=True, + ) + return + + # Generate operation ID + operation_id = str(uuid_module.uuid4()) + + # Build a user-friendly message based on tool and arguments + if tool_name == "create_agent": + agent_desc = arguments.get("description", "") + # Truncate long descriptions for the message + desc_preview = ( + (agent_desc[:100] + "...") if len(agent_desc) > 100 else agent_desc + ) + pending_msg = ( + f"Creating your agent: {desc_preview}" + if desc_preview + else "Creating agent... This may take a few minutes." + ) + started_msg = ( + "Agent creation started. You can close this tab - " + "check your library in a few minutes." + ) + elif tool_name == "edit_agent": + changes = arguments.get("changes", "") + changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes + pending_msg = ( + f"Editing agent: {changes_preview}" + if changes_preview + else "Editing agent... This may take a few minutes." + ) + started_msg = ( + "Agent edit started. You can close this tab - " + "check your library in a few minutes." + ) + else: + pending_msg = f"Running {tool_name}... This may take a few minutes." + started_msg = ( + f"{tool_name} started. You can close this tab - " + "check back in a few minutes." + ) + + # Track appended messages for rollback on failure + assistant_message: ChatMessage | None = None + pending_message: ChatMessage | None = None + + # Wrap session save and task creation in try-except to release lock on failure + try: + # Save assistant message with tool_call FIRST (required by LLM) + assistant_message = ChatMessage( + role="assistant", + content="", + tool_calls=[tool_calls[yield_idx]], + ) + session.messages.append(assistant_message) + + # Then save pending tool result + pending_message = ChatMessage( + role="tool", + content=OperationPendingResponse( + message=pending_msg, + operation_id=operation_id, + tool_name=tool_name, + ).model_dump_json(), + tool_call_id=tool_call_id, + ) + session.messages.append(pending_message) + await upsert_chat_session(session) + logger.info( + f"Saved pending operation {operation_id} for tool {tool_name} " + f"in session {session.session_id}" + ) + + # Store task reference in module-level set to prevent GC before completion + task = asyncio.create_task( + _execute_long_running_tool( + tool_name=tool_name, + parameters=arguments, + tool_call_id=tool_call_id, + operation_id=operation_id, + session_id=session.session_id, + user_id=session.user_id, + ) + ) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + except Exception as e: + # Roll back appended messages to prevent data corruption on subsequent saves + if ( + pending_message + and session.messages + and session.messages[-1] == pending_message + ): + session.messages.pop() + if ( + assistant_message + and session.messages + and session.messages[-1] == assistant_message + ): + session.messages.pop() + + # Release the Redis lock since the background task won't be spawned + await _mark_operation_completed(tool_call_id) + logger.error( + f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True + ) + raise + + # Return immediately - don't wait for completion + yield StreamToolOutputAvailable( + toolCallId=tool_call_id, + toolName=tool_name, + output=OperationStartedResponse( + message=started_msg, + operation_id=operation_id, + tool_name=tool_name, + ).model_dump_json(), + success=True, + ) + return + + # Normal flow: Run tool execution in background task with heartbeats + tool_task = asyncio.create_task( + execute_tool( + tool_name=tool_name, + parameters=arguments, + tool_call_id=tool_call_id, + user_id=session.user_id, + session=session, + ) ) + # Yield heartbeats every 15 seconds while waiting for tool to complete + heartbeat_interval = 15.0 # seconds + while not tool_task.done(): + try: + # Wait for either the task to complete or the heartbeat interval + await asyncio.wait_for( + asyncio.shield(tool_task), timeout=heartbeat_interval + ) + except asyncio.TimeoutError: + # Task still running, send heartbeat to keep connection alive + logger.debug(f"Sending heartbeat for tool {tool_name} ({tool_call_id})") + yield StreamHeartbeat(toolCallId=tool_call_id) + except CancelledError: + # Task was cancelled, clean up and propagate + tool_task.cancel() + logger.warning(f"Tool execution cancelled: {tool_name} ({tool_call_id})") + raise + + # Get the result - handle any exceptions that occurred during execution + try: + tool_execution_response: StreamToolOutputAvailable = await tool_task + except Exception as e: + # Task raised an exception - ensure we send an error response to the frontend + logger.error( + f"Tool execution failed: {tool_name} ({tool_call_id}): {e}", exc_info=True + ) + error_response = ErrorResponse( + message=f"Tool execution failed: {e!s}", + error=type(e).__name__, + session_id=session.session_id, + ) + tool_execution_response = StreamToolOutputAvailable( + toolCallId=tool_call_id, + toolName=tool_name, + output=error_response.model_dump_json(), + success=False, + ) + yield tool_execution_response + + +async def _execute_long_running_tool( + tool_name: str, + parameters: dict[str, Any], + tool_call_id: str, + operation_id: str, + session_id: str, + user_id: str | None, +) -> None: + """Execute a long-running tool in background and update chat history with result. + + This function runs independently of the SSE connection, so the operation + survives if the user closes their browser tab. + """ + try: + # Load fresh session (not stale reference) + session = await get_chat_session(session_id, user_id) + if not session: + logger.error(f"Session {session_id} not found for background tool") + return + + # Execute the actual tool + result = await execute_tool( + tool_name=tool_name, + parameters=parameters, + tool_call_id=tool_call_id, + user_id=user_id, + session=session, + ) + + # Update the pending message with result + await _update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=( + result.output + if isinstance(result.output, str) + else orjson.dumps(result.output).decode("utf-8") + ), + ) + + logger.info(f"Background tool {tool_name} completed for session {session_id}") + + # Generate LLM continuation so user sees response when they poll/refresh + await _generate_llm_continuation(session_id=session_id, user_id=user_id) + + except Exception as e: + logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True) + error_response = ErrorResponse( + message=f"Tool {tool_name} failed: {str(e)}", + ) + await _update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=error_response.model_dump_json(), + ) + finally: + await _mark_operation_completed(tool_call_id) + + +async def _update_pending_operation( + session_id: str, + tool_call_id: str, + result: str, +) -> None: + """Update the pending tool message with final result. + + This is called by background tasks when long-running operations complete. + """ + # Update the message in database + updated = await chat_db.update_tool_message_content( + session_id=session_id, + tool_call_id=tool_call_id, + new_content=result, + ) + + if updated: + # Invalidate Redis cache so next load gets fresh data + # Wrap in try/except to prevent cache failures from triggering error handling + # that would overwrite our successful DB update + try: + await invalidate_session_cache(session_id) + except Exception as e: + # Non-critical: cache will eventually be refreshed on next load + logger.warning(f"Failed to invalidate cache for session {session_id}: {e}") + logger.info( + f"Updated pending operation for tool_call_id {tool_call_id} " + f"in session {session_id}" + ) + else: + logger.warning( + f"Failed to update pending operation for tool_call_id {tool_call_id} " + f"in session {session_id}" + ) + + +async def _generate_llm_continuation( + session_id: str, + user_id: str | None, +) -> None: + """Generate an LLM response after a long-running tool completes. + + This is called by background tasks to continue the conversation + after a tool result is saved. The response is saved to the database + so users see it when they refresh or poll. + """ + try: + # Load fresh session from DB (bypass cache to get the updated tool result) + await invalidate_session_cache(session_id) + session = await get_chat_session(session_id, user_id) + if not session: + logger.error(f"Session {session_id} not found for LLM continuation") + return + + # Build system prompt + system_prompt, _ = await _build_system_prompt(user_id) + + # Build messages in OpenAI format + messages = session.to_openai_messages() + if system_prompt: + from openai.types.chat import ChatCompletionSystemMessageParam + + system_message = ChatCompletionSystemMessageParam( + role="system", + content=system_prompt, + ) + messages = [system_message] + messages + + # Build extra_body for tracing + extra_body: dict[str, Any] = { + "posthogProperties": { + "environment": settings.config.app_env.value, + }, + } + if user_id: + extra_body["user"] = user_id[:128] + extra_body["posthogDistinctId"] = user_id + if session_id: + extra_body["session_id"] = session_id[:128] + + # Make non-streaming LLM call (no tools - just text response) + from typing import cast + + from openai.types.chat import ChatCompletionMessageParam + + # No tools parameter = text-only response (no tool calls) + response = await client.chat.completions.create( + model=config.model, + messages=cast(list[ChatCompletionMessageParam], messages), + extra_body=extra_body, + ) + + if response.choices and response.choices[0].message.content: + assistant_content = response.choices[0].message.content + + # Reload session from DB to avoid race condition with user messages + # that may have been sent while we were generating the LLM response + fresh_session = await get_chat_session(session_id, user_id) + if not fresh_session: + logger.error( + f"Session {session_id} disappeared during LLM continuation" + ) + return + + # Save assistant message to database + assistant_message = ChatMessage( + role="assistant", + content=assistant_content, + ) + fresh_session.messages.append(assistant_message) + + # Save to database (not cache) to persist the response + await upsert_chat_session(fresh_session) + + # Invalidate cache so next poll/refresh gets fresh data + await invalidate_session_cache(session_id) + + logger.info( + f"Generated LLM continuation for session {session_id}, " + f"response length: {len(assistant_content)}" + ) + else: + logger.warning(f"LLM continuation returned empty response for {session_id}") + + except Exception as e: + logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/IDEAS.md b/autogpt_platform/backend/backend/api/features/chat/tools/IDEAS.md new file mode 100644 index 0000000000..656aac61c4 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/IDEAS.md @@ -0,0 +1,79 @@ +# CoPilot Tools - Future Ideas + +## Multimodal Image Support for CoPilot + +**Problem:** CoPilot uses a vision-capable model but can't "see" workspace images. When a block generates an image and returns `workspace://abc123`, CoPilot can't evaluate it (e.g., checking blog thumbnail quality). + +**Backend Solution:** +When preparing messages for the LLM, detect `workspace://` image references and convert them to proper image content blocks: + +```python +# Before sending to LLM, scan for workspace image references +# and inject them as image content parts + +# Example message transformation: +# FROM: {"role": "assistant", "content": "Generated image: workspace://abc123"} +# TO: {"role": "assistant", "content": [ +# {"type": "text", "text": "Generated image: workspace://abc123"}, +# {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}} +# ]} +``` + +**Where to implement:** +- In the chat stream handler before calling the LLM +- Or in a message preprocessing step +- Need to fetch image from workspace, convert to base64, add as image content + +**Considerations:** +- Only do this for image MIME types (image/png, image/jpeg, etc.) +- May want a size limit (don't pass 10MB images) +- Track which images were "shown" to the AI for frontend indicator +- Cost implications - vision API calls are more expensive + +**Frontend Solution:** +Show visual indicator on workspace files in chat: +- If AI saw the image: normal display +- If AI didn't see it: overlay icon saying "AI can't see this image" + +Requires response metadata indicating which `workspace://` refs were passed to the model. + +--- + +## Output Post-Processing Layer for run_block + +**Problem:** Many blocks produce large outputs that: +- Consume massive context (100KB base64 image = ~133KB tokens) +- Can't fit in conversation +- Break things and cause high LLM costs + +**Proposed Solution:** Instead of modifying individual blocks or `store_media_file()`, implement a centralized output processor in `run_block.py` that handles outputs before they're returned to CoPilot. + +**Benefits:** +1. **Centralized** - one place to handle all output processing +2. **Future-proof** - new blocks automatically get output processing +3. **Keeps blocks pure** - they don't need to know about context constraints +4. **Handles all large outputs** - not just images + +**Processing Rules:** +- Detect base64 data URIs → save to workspace, return `workspace://` reference +- Truncate very long strings (>N chars) with truncation note +- Summarize large arrays/lists (e.g., "Array with 1000 items, first 5: [...]") +- Handle nested large outputs in dicts recursively +- Cap total output size + +**Implementation Location:** `run_block.py` after block execution, before returning `BlockOutputResponse` + +**Example:** +```python +def _process_outputs_for_context( + outputs: dict[str, list[Any]], + workspace_manager: WorkspaceManager, + max_string_length: int = 10000, + max_array_preview: int = 5, +) -> dict[str, list[Any]]: + """Process block outputs to prevent context bloat.""" + processed = {} + for name, values in outputs.items(): + processed[name] = [_process_value(v, workspace_manager) for v in values] + return processed +``` diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py index 82ce5cfd6f..d078860c3a 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py @@ -1,8 +1,10 @@ +import logging from typing import TYPE_CHECKING, Any from openai.types.chat import ChatCompletionToolParam from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tracking import track_tool_called from .add_understanding import AddUnderstandingTool from .agent_output import AgentOutputTool @@ -16,10 +18,18 @@ from .get_doc_page import GetDocPageTool from .run_agent import RunAgentTool from .run_block import RunBlockTool from .search_docs import SearchDocsTool +from .workspace_files import ( + DeleteWorkspaceFileTool, + ListWorkspaceFilesTool, + ReadWorkspaceFileTool, + WriteWorkspaceFileTool, +) if TYPE_CHECKING: from backend.api.features.chat.response_model import StreamToolOutputAvailable +logger = logging.getLogger(__name__) + # Single source of truth for all tools TOOL_REGISTRY: dict[str, BaseTool] = { "add_understanding": AddUnderstandingTool(), @@ -33,6 +43,11 @@ TOOL_REGISTRY: dict[str, BaseTool] = { "view_agent_output": AgentOutputTool(), "search_docs": SearchDocsTool(), "get_doc_page": GetDocPageTool(), + # Workspace tools for CoPilot file operations + "list_workspace_files": ListWorkspaceFilesTool(), + "read_workspace_file": ReadWorkspaceFileTool(), + "write_workspace_file": WriteWorkspaceFileTool(), + "delete_workspace_file": DeleteWorkspaceFileTool(), } # Export individual tool instances for backwards compatibility @@ -45,6 +60,11 @@ tools: list[ChatCompletionToolParam] = [ ] +def get_tool(tool_name: str) -> BaseTool | None: + """Get a tool instance by name.""" + return TOOL_REGISTRY.get(tool_name) + + async def execute_tool( tool_name: str, parameters: dict[str, Any], @@ -53,7 +73,20 @@ async def execute_tool( tool_call_id: str, ) -> "StreamToolOutputAvailable": """Execute a tool by name.""" - tool = TOOL_REGISTRY.get(tool_name) + tool = get_tool(tool_name) if not tool: raise ValueError(f"Tool {tool_name} not found") + + # Track tool call in PostHog + logger.info( + f"Tracking tool call: tool={tool_name}, user={user_id}, " + f"session={session.session_id}, call_id={tool_call_id}" + ) + track_tool_called( + user_id=user_id, + session_id=session.session_id, + tool_name=tool_name, + tool_call_id=tool_call_id, + ) + return await tool.execute(user_id, session, tool_call_id, **parameters) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/add_understanding.py b/autogpt_platform/backend/backend/api/features/chat/tools/add_understanding.py index bd93f0e2a6..fe3d5e8984 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/add_understanding.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/add_understanding.py @@ -3,8 +3,6 @@ import logging from typing import Any -from langfuse import observe - from backend.api.features.chat.model import ChatSession from backend.data.understanding import ( BusinessUnderstandingInput, @@ -61,7 +59,6 @@ and automations for the user's specific needs.""" """Requires authentication to store user-specific data.""" return True - @observe(as_type="tool", name="add_understanding") async def _execute( self, user_id: str | None, diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py index d4df2564a8..499025b7dc 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py @@ -1,29 +1,31 @@ """Agent generator package - Creates agents from natural language.""" from .core import ( - apply_agent_patch, + AgentGeneratorNotConfiguredError, decompose_goal, generate_agent, generate_agent_patch, get_agent_as_json, + json_to_graph, save_agent_to_library, ) -from .fixer import apply_all_fixes -from .utils import get_blocks_info -from .validator import validate_agent +from .errors import get_user_message_for_error +from .service import health_check as check_external_service_health +from .service import is_external_service_configured __all__ = [ # Core functions "decompose_goal", "generate_agent", "generate_agent_patch", - "apply_agent_patch", "save_agent_to_library", "get_agent_as_json", - # Fixer - "apply_all_fixes", - # Validator - "validate_agent", - # Utils - "get_blocks_info", + "json_to_graph", + # Exceptions + "AgentGeneratorNotConfiguredError", + # Service + "is_external_service_configured", + "check_external_service_health", + # Error handling + "get_user_message_for_error", ] diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/client.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/client.py deleted file mode 100644 index 4450fa9d75..0000000000 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/client.py +++ /dev/null @@ -1,25 +0,0 @@ -"""OpenRouter client configuration for agent generation.""" - -import os - -from openai import AsyncOpenAI - -# Configuration - use OPEN_ROUTER_API_KEY for consistency with chat/config.py -OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY") -AGENT_GENERATOR_MODEL = os.getenv("AGENT_GENERATOR_MODEL", "anthropic/claude-opus-4.5") - -# OpenRouter client (OpenAI-compatible API) -_client: AsyncOpenAI | None = None - - -def get_client() -> AsyncOpenAI: - """Get or create the OpenRouter client.""" - global _client - if _client is None: - if not OPENROUTER_API_KEY: - raise ValueError("OPENROUTER_API_KEY environment variable is required") - _client = AsyncOpenAI( - base_url="https://openrouter.ai/api/v1", - api_key=OPENROUTER_API_KEY, - ) - return _client diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py index 0f94135a41..d56e33cbb0 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py @@ -1,7 +1,5 @@ """Core agent generation functions.""" -import copy -import json import logging import uuid from typing import Any @@ -9,13 +7,35 @@ from typing import Any from backend.api.features.library import db as library_db from backend.data.graph import Graph, Link, Node, create_graph -from .client import AGENT_GENERATOR_MODEL, get_client -from .prompts import DECOMPOSITION_PROMPT, GENERATION_PROMPT, PATCH_PROMPT -from .utils import get_block_summaries, parse_json_from_llm +from .service import ( + decompose_goal_external, + generate_agent_external, + generate_agent_patch_external, + is_external_service_configured, +) logger = logging.getLogger(__name__) +class AgentGeneratorNotConfiguredError(Exception): + """Raised when the external Agent Generator service is not configured.""" + + pass + + +def _check_service_configured() -> None: + """Check if the external Agent Generator service is configured. + + Raises: + AgentGeneratorNotConfiguredError: If the service is not configured. + """ + if not is_external_service_configured(): + raise AgentGeneratorNotConfiguredError( + "Agent Generator service is not configured. " + "Set AGENTGENERATOR_HOST environment variable to enable agent generation." + ) + + async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None: """Break down a goal into steps or return clarifying questions. @@ -28,40 +48,13 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any] - {"type": "clarifying_questions", "questions": [...]} - {"type": "instructions", "steps": [...]} Or None on error + + Raises: + AgentGeneratorNotConfiguredError: If the external service is not configured. """ - client = get_client() - prompt = DECOMPOSITION_PROMPT.format(block_summaries=get_block_summaries()) - - full_description = description - if context: - full_description = f"{description}\n\nAdditional context:\n{context}" - - try: - response = await client.chat.completions.create( - model=AGENT_GENERATOR_MODEL, - messages=[ - {"role": "system", "content": prompt}, - {"role": "user", "content": full_description}, - ], - temperature=0, - ) - - content = response.choices[0].message.content - if content is None: - logger.error("LLM returned empty content for decomposition") - return None - - result = parse_json_from_llm(content) - - if result is None: - logger.error(f"Failed to parse decomposition response: {content[:200]}") - return None - - return result - - except Exception as e: - logger.error(f"Error decomposing goal: {e}") - return None + _check_service_configured() + logger.info("Calling external Agent Generator service for decompose_goal") + return await decompose_goal_external(description, context) async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None: @@ -71,45 +64,26 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None: instructions: Structured instructions from decompose_goal Returns: - Agent JSON dict or None on error + Agent JSON dict, error dict {"type": "error", ...}, or None on error + + Raises: + AgentGeneratorNotConfiguredError: If the external service is not configured. """ - client = get_client() - prompt = GENERATION_PROMPT.format(block_summaries=get_block_summaries()) - - try: - response = await client.chat.completions.create( - model=AGENT_GENERATOR_MODEL, - messages=[ - {"role": "system", "content": prompt}, - {"role": "user", "content": json.dumps(instructions, indent=2)}, - ], - temperature=0, - ) - - content = response.choices[0].message.content - if content is None: - logger.error("LLM returned empty content for agent generation") - return None - - result = parse_json_from_llm(content) - - if result is None: - logger.error(f"Failed to parse agent JSON: {content[:200]}") - return None - - # Ensure required fields + _check_service_configured() + logger.info("Calling external Agent Generator service for generate_agent") + result = await generate_agent_external(instructions) + if result: + # Check if it's an error response - pass through as-is + if isinstance(result, dict) and result.get("type") == "error": + return result + # Ensure required fields for successful agent generation if "id" not in result: result["id"] = str(uuid.uuid4()) if "version" not in result: result["version"] = 1 if "is_active" not in result: result["is_active"] = True - - return result - - except Exception as e: - logger.error(f"Error generating agent: {e}") - return None + return result def json_to_graph(agent_json: dict[str, Any]) -> Graph: @@ -284,108 +258,24 @@ async def get_agent_as_json( async def generate_agent_patch( update_request: str, current_agent: dict[str, Any] ) -> dict[str, Any] | None: - """Generate a patch to update an existing agent. + """Update an existing agent using natural language. + + The external Agent Generator service handles: + - Generating the patch + - Applying the patch + - Fixing and validating the result Args: update_request: Natural language description of changes current_agent: Current agent JSON Returns: - Patch dict or clarifying questions, or None on error + Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...}, + error dict {"type": "error", ...}, or None on unexpected error + + Raises: + AgentGeneratorNotConfiguredError: If the external service is not configured. """ - client = get_client() - prompt = PATCH_PROMPT.format( - current_agent=json.dumps(current_agent, indent=2), - block_summaries=get_block_summaries(), - ) - - try: - response = await client.chat.completions.create( - model=AGENT_GENERATOR_MODEL, - messages=[ - {"role": "system", "content": prompt}, - {"role": "user", "content": update_request}, - ], - temperature=0, - ) - - content = response.choices[0].message.content - if content is None: - logger.error("LLM returned empty content for patch generation") - return None - - return parse_json_from_llm(content) - - except Exception as e: - logger.error(f"Error generating patch: {e}") - return None - - -def apply_agent_patch( - current_agent: dict[str, Any], patch: dict[str, Any] -) -> dict[str, Any]: - """Apply a patch to an existing agent. - - Args: - current_agent: Current agent JSON - patch: Patch dict with operations - - Returns: - Updated agent JSON - """ - agent = copy.deepcopy(current_agent) - patches = patch.get("patches", []) - - for p in patches: - patch_type = p.get("type") - - if patch_type == "modify": - node_id = p.get("node_id") - changes = p.get("changes", {}) - - for node in agent.get("nodes", []): - if node["id"] == node_id: - _deep_update(node, changes) - logger.debug(f"Modified node {node_id}") - break - - elif patch_type == "add": - new_nodes = p.get("new_nodes", []) - new_links = p.get("new_links", []) - - agent["nodes"] = agent.get("nodes", []) + new_nodes - agent["links"] = agent.get("links", []) + new_links - logger.debug(f"Added {len(new_nodes)} nodes, {len(new_links)} links") - - elif patch_type == "remove": - node_ids_to_remove = set(p.get("node_ids", [])) - link_ids_to_remove = set(p.get("link_ids", [])) - - # Remove nodes - agent["nodes"] = [ - n for n in agent.get("nodes", []) if n["id"] not in node_ids_to_remove - ] - - # Remove links (both explicit and those referencing removed nodes) - agent["links"] = [ - link - for link in agent.get("links", []) - if link["id"] not in link_ids_to_remove - and link["source_id"] not in node_ids_to_remove - and link["sink_id"] not in node_ids_to_remove - ] - - logger.debug( - f"Removed {len(node_ids_to_remove)} nodes, {len(link_ids_to_remove)} links" - ) - - return agent - - -def _deep_update(target: dict, source: dict) -> None: - """Recursively update a dict with another dict.""" - for key, value in source.items(): - if key in target and isinstance(target[key], dict) and isinstance(value, dict): - _deep_update(target[key], value) - else: - target[key] = value + _check_service_configured() + logger.info("Calling external Agent Generator service for generate_agent_patch") + return await generate_agent_patch_external(update_request, current_agent) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/errors.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/errors.py new file mode 100644 index 0000000000..bf71a95df9 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/errors.py @@ -0,0 +1,43 @@ +"""Error handling utilities for agent generator.""" + + +def get_user_message_for_error( + error_type: str, + operation: str = "process the request", + llm_parse_message: str | None = None, + validation_message: str | None = None, +) -> str: + """Get a user-friendly error message based on error type. + + This function maps internal error types to user-friendly messages, + providing a consistent experience across different agent operations. + + Args: + error_type: The error type from the external service + (e.g., "llm_parse_error", "timeout", "rate_limit") + operation: Description of what operation failed, used in the default + message (e.g., "analyze the goal", "generate the agent") + llm_parse_message: Custom message for llm_parse_error type + validation_message: Custom message for validation_error type + + Returns: + User-friendly error message suitable for display to the user + """ + if error_type == "llm_parse_error": + return ( + llm_parse_message + or "The AI had trouble processing this request. Please try again." + ) + elif error_type == "validation_error": + return ( + validation_message + or "The request failed validation. Please try rephrasing." + ) + elif error_type == "patch_error": + return "Failed to apply the changes. Please try a different approach." + elif error_type in ("timeout", "llm_timeout"): + return "The request took too long. Please try again." + elif error_type in ("rate_limit", "llm_rate_limit"): + return "The service is currently busy. Please try again in a moment." + else: + return f"Failed to {operation}. Please try again." diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/fixer.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/fixer.py deleted file mode 100644 index 1e25e0cbed..0000000000 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/fixer.py +++ /dev/null @@ -1,606 +0,0 @@ -"""Agent fixer - Fixes common LLM generation errors.""" - -import logging -import re -import uuid -from typing import Any - -from .utils import ( - ADDTODICTIONARY_BLOCK_ID, - ADDTOLIST_BLOCK_ID, - CODE_EXECUTION_BLOCK_ID, - CONDITION_BLOCK_ID, - CREATEDICT_BLOCK_ID, - CREATELIST_BLOCK_ID, - DATA_SAMPLING_BLOCK_ID, - DOUBLE_CURLY_BRACES_BLOCK_IDS, - GET_CURRENT_DATE_BLOCK_ID, - STORE_VALUE_BLOCK_ID, - UNIVERSAL_TYPE_CONVERTER_BLOCK_ID, - get_blocks_info, - is_valid_uuid, -) - -logger = logging.getLogger(__name__) - - -def fix_agent_ids(agent: dict[str, Any]) -> dict[str, Any]: - """Fix invalid UUIDs in agent and link IDs.""" - # Fix agent ID - if not is_valid_uuid(agent.get("id", "")): - agent["id"] = str(uuid.uuid4()) - logger.debug(f"Fixed agent ID: {agent['id']}") - - # Fix node IDs - id_mapping = {} # Old ID -> New ID - for node in agent.get("nodes", []): - if not is_valid_uuid(node.get("id", "")): - old_id = node.get("id", "") - new_id = str(uuid.uuid4()) - id_mapping[old_id] = new_id - node["id"] = new_id - logger.debug(f"Fixed node ID: {old_id} -> {new_id}") - - # Fix link IDs and update references - for link in agent.get("links", []): - if not is_valid_uuid(link.get("id", "")): - link["id"] = str(uuid.uuid4()) - logger.debug(f"Fixed link ID: {link['id']}") - - # Update source/sink IDs if they were remapped - if link.get("source_id") in id_mapping: - link["source_id"] = id_mapping[link["source_id"]] - if link.get("sink_id") in id_mapping: - link["sink_id"] = id_mapping[link["sink_id"]] - - return agent - - -def fix_double_curly_braces(agent: dict[str, Any]) -> dict[str, Any]: - """Fix single curly braces to double in template blocks.""" - for node in agent.get("nodes", []): - if node.get("block_id") not in DOUBLE_CURLY_BRACES_BLOCK_IDS: - continue - - input_data = node.get("input_default", {}) - for key in ("prompt", "format"): - if key in input_data and isinstance(input_data[key], str): - original = input_data[key] - # Fix simple variable references: {var} -> {{var}} - fixed = re.sub( - r"(? dict[str, Any]: - """Add StoreValueBlock before ConditionBlock if needed for value2.""" - nodes = agent.get("nodes", []) - links = agent.get("links", []) - - # Find all ConditionBlock nodes - condition_node_ids = { - node["id"] for node in nodes if node.get("block_id") == CONDITION_BLOCK_ID - } - - if not condition_node_ids: - return agent - - new_nodes = [] - new_links = [] - processed_conditions = set() - - for link in links: - sink_id = link.get("sink_id") - sink_name = link.get("sink_name") - - # Check if this link goes to a ConditionBlock's value2 - if sink_id in condition_node_ids and sink_name == "value2": - source_node = next( - (n for n in nodes if n["id"] == link.get("source_id")), None - ) - - # Skip if source is already a StoreValueBlock - if source_node and source_node.get("block_id") == STORE_VALUE_BLOCK_ID: - continue - - # Skip if we already processed this condition - if sink_id in processed_conditions: - continue - - processed_conditions.add(sink_id) - - # Create StoreValueBlock - store_node_id = str(uuid.uuid4()) - store_node = { - "id": store_node_id, - "block_id": STORE_VALUE_BLOCK_ID, - "input_default": {"data": None}, - "metadata": {"position": {"x": 0, "y": -100}}, - } - new_nodes.append(store_node) - - # Create link: original source -> StoreValueBlock - new_links.append( - { - "id": str(uuid.uuid4()), - "source_id": link["source_id"], - "source_name": link["source_name"], - "sink_id": store_node_id, - "sink_name": "input", - "is_static": False, - } - ) - - # Update original link: StoreValueBlock -> ConditionBlock - link["source_id"] = store_node_id - link["source_name"] = "output" - - logger.debug(f"Added StoreValueBlock before ConditionBlock {sink_id}") - - if new_nodes: - agent["nodes"] = nodes + new_nodes - - return agent - - -def fix_addtolist_blocks(agent: dict[str, Any]) -> dict[str, Any]: - """Fix AddToList blocks by adding prerequisite empty AddToList block. - - When an AddToList block is found: - 1. Checks if there's a CreateListBlock before it - 2. Removes CreateListBlock if linked directly to AddToList - 3. Adds an empty AddToList block before the original - 4. Ensures the original has a self-referencing link - """ - nodes = agent.get("nodes", []) - links = agent.get("links", []) - new_nodes = [] - original_addtolist_ids = set() - nodes_to_remove = set() - links_to_remove = [] - - # First pass: identify CreateListBlock nodes to remove - for link in links: - source_node = next( - (n for n in nodes if n.get("id") == link.get("source_id")), None - ) - sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None) - - if ( - source_node - and sink_node - and source_node.get("block_id") == CREATELIST_BLOCK_ID - and sink_node.get("block_id") == ADDTOLIST_BLOCK_ID - ): - nodes_to_remove.add(source_node.get("id")) - links_to_remove.append(link) - logger.debug(f"Removing CreateListBlock {source_node.get('id')}") - - # Second pass: process AddToList blocks - filtered_nodes = [] - for node in nodes: - if node.get("id") in nodes_to_remove: - continue - - if node.get("block_id") == ADDTOLIST_BLOCK_ID: - original_addtolist_ids.add(node.get("id")) - node_id = node.get("id") - pos = node.get("metadata", {}).get("position", {"x": 0, "y": 0}) - - # Check if already has prerequisite - has_prereq = any( - link.get("sink_id") == node_id - and link.get("sink_name") == "list" - and link.get("source_name") == "updated_list" - for link in links - ) - - if not has_prereq: - # Remove links to "list" input (except self-reference) - for link in links: - if ( - link.get("sink_id") == node_id - and link.get("sink_name") == "list" - and link.get("source_id") != node_id - and link not in links_to_remove - ): - links_to_remove.append(link) - - # Create prerequisite AddToList block - prereq_id = str(uuid.uuid4()) - prereq_node = { - "id": prereq_id, - "block_id": ADDTOLIST_BLOCK_ID, - "input_default": {"list": [], "entry": None, "entries": []}, - "metadata": { - "position": {"x": pos.get("x", 0) - 800, "y": pos.get("y", 0)} - }, - } - new_nodes.append(prereq_node) - - # Link prerequisite to original - links.append( - { - "id": str(uuid.uuid4()), - "source_id": prereq_id, - "source_name": "updated_list", - "sink_id": node_id, - "sink_name": "list", - "is_static": False, - } - ) - logger.debug(f"Added prerequisite AddToList block for {node_id}") - - filtered_nodes.append(node) - - # Remove marked links - filtered_links = [link for link in links if link not in links_to_remove] - - # Add self-referencing links for original AddToList blocks - for node in filtered_nodes + new_nodes: - if ( - node.get("block_id") == ADDTOLIST_BLOCK_ID - and node.get("id") in original_addtolist_ids - ): - node_id = node.get("id") - has_self_ref = any( - link["source_id"] == node_id - and link["sink_id"] == node_id - and link["source_name"] == "updated_list" - and link["sink_name"] == "list" - for link in filtered_links - ) - if not has_self_ref: - filtered_links.append( - { - "id": str(uuid.uuid4()), - "source_id": node_id, - "source_name": "updated_list", - "sink_id": node_id, - "sink_name": "list", - "is_static": False, - } - ) - logger.debug(f"Added self-reference for AddToList {node_id}") - - agent["nodes"] = filtered_nodes + new_nodes - agent["links"] = filtered_links - return agent - - -def fix_addtodictionary_blocks(agent: dict[str, Any]) -> dict[str, Any]: - """Fix AddToDictionary blocks by removing empty CreateDictionary nodes.""" - nodes = agent.get("nodes", []) - links = agent.get("links", []) - nodes_to_remove = set() - links_to_remove = [] - - for link in links: - source_node = next( - (n for n in nodes if n.get("id") == link.get("source_id")), None - ) - sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None) - - if ( - source_node - and sink_node - and source_node.get("block_id") == CREATEDICT_BLOCK_ID - and sink_node.get("block_id") == ADDTODICTIONARY_BLOCK_ID - ): - nodes_to_remove.add(source_node.get("id")) - links_to_remove.append(link) - logger.debug(f"Removing CreateDictionary {source_node.get('id')}") - - agent["nodes"] = [n for n in nodes if n.get("id") not in nodes_to_remove] - agent["links"] = [link for link in links if link not in links_to_remove] - return agent - - -def fix_code_execution_output(agent: dict[str, Any]) -> dict[str, Any]: - """Fix CodeExecutionBlock output: change 'response' to 'stdout_logs'.""" - nodes = agent.get("nodes", []) - links = agent.get("links", []) - - for link in links: - source_node = next( - (n for n in nodes if n.get("id") == link.get("source_id")), None - ) - if ( - source_node - and source_node.get("block_id") == CODE_EXECUTION_BLOCK_ID - and link.get("source_name") == "response" - ): - link["source_name"] = "stdout_logs" - logger.debug("Fixed CodeExecutionBlock output: response -> stdout_logs") - - return agent - - -def fix_data_sampling_sample_size(agent: dict[str, Any]) -> dict[str, Any]: - """Fix DataSamplingBlock by setting sample_size to 1 as default.""" - nodes = agent.get("nodes", []) - links = agent.get("links", []) - links_to_remove = [] - - for node in nodes: - if node.get("block_id") == DATA_SAMPLING_BLOCK_ID: - node_id = node.get("id") - input_default = node.get("input_default", {}) - - # Remove links to sample_size - for link in links: - if ( - link.get("sink_id") == node_id - and link.get("sink_name") == "sample_size" - ): - links_to_remove.append(link) - - # Set default - input_default["sample_size"] = 1 - node["input_default"] = input_default - logger.debug(f"Fixed DataSamplingBlock {node_id} sample_size to 1") - - if links_to_remove: - agent["links"] = [link for link in links if link not in links_to_remove] - - return agent - - -def fix_node_x_coordinates(agent: dict[str, Any]) -> dict[str, Any]: - """Fix node x-coordinates to ensure 800+ unit spacing between linked nodes.""" - nodes = agent.get("nodes", []) - links = agent.get("links", []) - node_lookup = {n.get("id"): n for n in nodes} - - for link in links: - source_id = link.get("source_id") - sink_id = link.get("sink_id") - - source_node = node_lookup.get(source_id) - sink_node = node_lookup.get(sink_id) - - if not source_node or not sink_node: - continue - - source_pos = source_node.get("metadata", {}).get("position", {}) - sink_pos = sink_node.get("metadata", {}).get("position", {}) - - source_x = source_pos.get("x", 0) - sink_x = sink_pos.get("x", 0) - - if abs(sink_x - source_x) < 800: - new_x = source_x + 800 - if "metadata" not in sink_node: - sink_node["metadata"] = {} - if "position" not in sink_node["metadata"]: - sink_node["metadata"]["position"] = {} - sink_node["metadata"]["position"]["x"] = new_x - logger.debug(f"Fixed node {sink_id} x: {sink_x} -> {new_x}") - - return agent - - -def fix_getcurrentdate_offset(agent: dict[str, Any]) -> dict[str, Any]: - """Fix GetCurrentDateBlock offset to ensure it's positive.""" - for node in agent.get("nodes", []): - if node.get("block_id") == GET_CURRENT_DATE_BLOCK_ID: - input_default = node.get("input_default", {}) - if "offset" in input_default: - offset = input_default["offset"] - if isinstance(offset, (int, float)) and offset < 0: - input_default["offset"] = abs(offset) - logger.debug(f"Fixed offset: {offset} -> {abs(offset)}") - - return agent - - -def fix_ai_model_parameter( - agent: dict[str, Any], - blocks_info: list[dict[str, Any]], - default_model: str = "gpt-4o", -) -> dict[str, Any]: - """Add default model parameter to AI blocks if missing.""" - block_map = {b.get("id"): b for b in blocks_info} - - for node in agent.get("nodes", []): - block_id = node.get("block_id") - block = block_map.get(block_id) - - if not block: - continue - - # Check if block has AI category - categories = block.get("categories", []) - is_ai_block = any( - cat.get("category") == "AI" for cat in categories if isinstance(cat, dict) - ) - - if is_ai_block: - input_default = node.get("input_default", {}) - if "model" not in input_default: - input_default["model"] = default_model - node["input_default"] = input_default - logger.debug( - f"Added model '{default_model}' to AI block {node.get('id')}" - ) - - return agent - - -def fix_link_static_properties( - agent: dict[str, Any], blocks_info: list[dict[str, Any]] -) -> dict[str, Any]: - """Fix is_static property based on source block's staticOutput.""" - block_map = {b.get("id"): b for b in blocks_info} - node_lookup = {n.get("id"): n for n in agent.get("nodes", [])} - - for link in agent.get("links", []): - source_node = node_lookup.get(link.get("source_id")) - if not source_node: - continue - - source_block = block_map.get(source_node.get("block_id")) - if not source_block: - continue - - static_output = source_block.get("staticOutput", False) - if link.get("is_static") != static_output: - link["is_static"] = static_output - logger.debug(f"Fixed link {link.get('id')} is_static to {static_output}") - - return agent - - -def fix_data_type_mismatch( - agent: dict[str, Any], blocks_info: list[dict[str, Any]] -) -> dict[str, Any]: - """Fix data type mismatches by inserting UniversalTypeConverterBlock.""" - nodes = agent.get("nodes", []) - links = agent.get("links", []) - block_map = {b.get("id"): b for b in blocks_info} - node_lookup = {n.get("id"): n for n in nodes} - - def get_property_type(schema: dict, name: str) -> str | None: - if "_#_" in name: - parent, child = name.split("_#_", 1) - parent_schema = schema.get(parent, {}) - if "properties" in parent_schema: - return parent_schema["properties"].get(child, {}).get("type") - return None - return schema.get(name, {}).get("type") - - def are_types_compatible(src: str, sink: str) -> bool: - if {src, sink} <= {"integer", "number"}: - return True - return src == sink - - type_mapping = { - "string": "string", - "text": "string", - "integer": "number", - "number": "number", - "float": "number", - "boolean": "boolean", - "bool": "boolean", - "array": "list", - "list": "list", - "object": "dictionary", - "dict": "dictionary", - "dictionary": "dictionary", - } - - new_links = [] - nodes_to_add = [] - - for link in links: - source_node = node_lookup.get(link.get("source_id")) - sink_node = node_lookup.get(link.get("sink_id")) - - if not source_node or not sink_node: - new_links.append(link) - continue - - source_block = block_map.get(source_node.get("block_id")) - sink_block = block_map.get(sink_node.get("block_id")) - - if not source_block or not sink_block: - new_links.append(link) - continue - - source_outputs = source_block.get("outputSchema", {}).get("properties", {}) - sink_inputs = sink_block.get("inputSchema", {}).get("properties", {}) - - source_type = get_property_type(source_outputs, link.get("source_name", "")) - sink_type = get_property_type(sink_inputs, link.get("sink_name", "")) - - if ( - source_type - and sink_type - and not are_types_compatible(source_type, sink_type) - ): - # Insert type converter - converter_id = str(uuid.uuid4()) - target_type = type_mapping.get(sink_type, sink_type) - - converter_node = { - "id": converter_id, - "block_id": UNIVERSAL_TYPE_CONVERTER_BLOCK_ID, - "input_default": {"type": target_type}, - "metadata": {"position": {"x": 0, "y": 100}}, - } - nodes_to_add.append(converter_node) - - # source -> converter - new_links.append( - { - "id": str(uuid.uuid4()), - "source_id": link["source_id"], - "source_name": link["source_name"], - "sink_id": converter_id, - "sink_name": "value", - "is_static": False, - } - ) - - # converter -> sink - new_links.append( - { - "id": str(uuid.uuid4()), - "source_id": converter_id, - "source_name": "value", - "sink_id": link["sink_id"], - "sink_name": link["sink_name"], - "is_static": False, - } - ) - - logger.debug(f"Inserted type converter: {source_type} -> {target_type}") - else: - new_links.append(link) - - if nodes_to_add: - agent["nodes"] = nodes + nodes_to_add - agent["links"] = new_links - - return agent - - -def apply_all_fixes( - agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None -) -> dict[str, Any]: - """Apply all fixes to an agent JSON. - - Args: - agent: Agent JSON dict - blocks_info: Optional list of block info dicts for advanced fixes - - Returns: - Fixed agent JSON - """ - # Basic fixes (no block info needed) - agent = fix_agent_ids(agent) - agent = fix_double_curly_braces(agent) - agent = fix_storevalue_before_condition(agent) - agent = fix_addtolist_blocks(agent) - agent = fix_addtodictionary_blocks(agent) - agent = fix_code_execution_output(agent) - agent = fix_data_sampling_sample_size(agent) - agent = fix_node_x_coordinates(agent) - agent = fix_getcurrentdate_offset(agent) - - # Advanced fixes (require block info) - if blocks_info is None: - blocks_info = get_blocks_info() - - agent = fix_ai_model_parameter(agent, blocks_info) - agent = fix_link_static_properties(agent, blocks_info) - agent = fix_data_type_mismatch(agent, blocks_info) - - return agent diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/prompts.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/prompts.py deleted file mode 100644 index 228bba8c8a..0000000000 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/prompts.py +++ /dev/null @@ -1,225 +0,0 @@ -"""Prompt templates for agent generation.""" - -DECOMPOSITION_PROMPT = """ -You are an expert AutoGPT Workflow Decomposer. Your task is to analyze a user's high-level goal and break it down into a clear, step-by-step plan using the available blocks. - -Each step should represent a distinct, automatable action suitable for execution by an AI automation system. - ---- - -FIRST: Analyze the user's goal and determine: -1) Design-time configuration (fixed settings that won't change per run) -2) Runtime inputs (values the agent's end-user will provide each time it runs) - -For anything that can vary per run (email addresses, names, dates, search terms, etc.): -- DO NOT ask for the actual value -- Instead, define it as an Agent Input with a clear name, type, and description - -Only ask clarifying questions about design-time config that affects how you build the workflow: -- Which external service to use (e.g., "Gmail vs Outlook", "Notion vs Google Docs") -- Required formats or structures (e.g., "CSV, JSON, or PDF output?") -- Business rules that must be hard-coded - -IMPORTANT CLARIFICATIONS POLICY: -- Ask no more than five essential questions -- Do not ask for concrete values that can be provided at runtime as Agent Inputs -- Do not ask for API keys or credentials; the platform handles those directly -- If there is enough information to infer reasonable defaults, prefer to propose defaults - ---- - -GUIDELINES: -1. List each step as a numbered item -2. Describe the action clearly and specify inputs/outputs -3. Ensure steps are in logical, sequential order -4. Mention block names naturally (e.g., "Use GetWeatherByLocationBlock to...") -5. Help the user reach their goal efficiently - ---- - -RULES: -1. OUTPUT FORMAT: Only output either clarifying questions OR step-by-step instructions, not both -2. USE ONLY THE BLOCKS PROVIDED -3. ALL required_input fields must be provided -4. Data types of linked properties must match -5. Write expert-level prompts for AI-related blocks - ---- - -CRITICAL BLOCK RESTRICTIONS: -1. AddToListBlock: Outputs updated list EVERY addition, not after all additions -2. SendEmailBlock: Draft the email for user review; set SMTP config based on email type -3. ConditionBlock: value2 is reference, value1 is contrast -4. CodeExecutionBlock: DO NOT USE - use AI blocks instead -5. ReadCsvBlock: Only use the 'rows' output, not 'row' - ---- - -OUTPUT FORMAT: - -If more information is needed: -```json -{{ - "type": "clarifying_questions", - "questions": [ - {{ - "question": "Which email provider should be used? (Gmail, Outlook, custom SMTP)", - "keyword": "email_provider", - "example": "Gmail" - }} - ] -}} -``` - -If ready to proceed: -```json -{{ - "type": "instructions", - "steps": [ - {{ - "step_number": 1, - "block_name": "AgentShortTextInputBlock", - "description": "Get the URL of the content to analyze.", - "inputs": [{{"name": "name", "value": "URL"}}], - "outputs": [{{"name": "result", "description": "The URL entered by user"}}] - }} - ] -}} -``` - ---- - -AVAILABLE BLOCKS: -{block_summaries} -""" - -GENERATION_PROMPT = """ -You are an expert AI workflow builder. Generate a valid agent JSON from the given instructions. - ---- - -NODES: -Each node must include: -- `id`: Unique UUID v4 (e.g. `a8f5b1e2-c3d4-4e5f-8a9b-0c1d2e3f4a5b`) -- `block_id`: The block identifier (must match an Allowed Block) -- `input_default`: Dict of inputs (can be empty if no static inputs needed) -- `metadata`: Must contain: - - `position`: {{"x": number, "y": number}} - adjacent nodes should differ by 800+ in X - - `customized_name`: Clear name describing this block's purpose in the workflow - ---- - -LINKS: -Each link connects a source node's output to a sink node's input: -- `id`: MUST be UUID v4 (NOT "link-1", "link-2", etc.) -- `source_id`: ID of the source node -- `source_name`: Output field name from the source block -- `sink_id`: ID of the sink node -- `sink_name`: Input field name on the sink block -- `is_static`: true only if source block has static_output: true - -CRITICAL: All IDs must be valid UUID v4 format! - ---- - -AGENT (GRAPH): -Wrap nodes and links in: -- `id`: UUID of the agent -- `name`: Short, generic name (avoid specific company names, URLs) -- `description`: Short, generic description -- `nodes`: List of all nodes -- `links`: List of all links -- `version`: 1 -- `is_active`: true - ---- - -TIPS: -- All required_input fields must be provided via input_default or a valid link -- Ensure consistent source_id and sink_id references -- Avoid dangling links -- Input/output pins must match block schemas -- Do not invent unknown block_ids - ---- - -ALLOWED BLOCKS: -{block_summaries} - ---- - -Generate the complete agent JSON. Output ONLY valid JSON, no explanation. -""" - -PATCH_PROMPT = """ -You are an expert at modifying AutoGPT agent workflows. Given the current agent and a modification request, generate a JSON patch to update the agent. - -CURRENT AGENT: -{current_agent} - -AVAILABLE BLOCKS: -{block_summaries} - ---- - -PATCH FORMAT: -Return a JSON object with the following structure: - -```json -{{ - "type": "patch", - "intent": "Brief description of what the patch does", - "patches": [ - {{ - "type": "modify", - "node_id": "uuid-of-node-to-modify", - "changes": {{ - "input_default": {{"field": "new_value"}}, - "metadata": {{"customized_name": "New Name"}} - }} - }}, - {{ - "type": "add", - "new_nodes": [ - {{ - "id": "new-uuid", - "block_id": "block-uuid", - "input_default": {{}}, - "metadata": {{"position": {{"x": 0, "y": 0}}, "customized_name": "Name"}} - }} - ], - "new_links": [ - {{ - "id": "link-uuid", - "source_id": "source-node-id", - "source_name": "output_field", - "sink_id": "sink-node-id", - "sink_name": "input_field" - }} - ] - }}, - {{ - "type": "remove", - "node_ids": ["uuid-of-node-to-remove"], - "link_ids": ["uuid-of-link-to-remove"] - }} - ] -}} -``` - -If you need more information, return: -```json -{{ - "type": "clarifying_questions", - "questions": [ - {{ - "question": "What specific change do you want?", - "keyword": "change_type", - "example": "Add error handling" - }} - ] -}} -``` - -Generate the minimal patch needed. Output ONLY valid JSON. -""" diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py new file mode 100644 index 0000000000..1df1faaaef --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py @@ -0,0 +1,374 @@ +"""External Agent Generator service client. + +This module provides a client for communicating with the external Agent Generator +microservice. When AGENTGENERATOR_HOST is configured, the agent generation functions +will delegate to the external service instead of using the built-in LLM-based implementation. +""" + +import logging +from typing import Any + +import httpx + +from backend.util.settings import Settings + +logger = logging.getLogger(__name__) + + +def _create_error_response( + error_message: str, + error_type: str = "unknown", + details: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Create a standardized error response dict. + + Args: + error_message: Human-readable error message + error_type: Machine-readable error type + details: Optional additional error details + + Returns: + Error dict with type="error" and error details + """ + response: dict[str, Any] = { + "type": "error", + "error": error_message, + "error_type": error_type, + } + if details: + response["details"] = details + return response + + +def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]: + """Classify an HTTP error into error_type and message. + + Args: + e: The HTTP status error + + Returns: + Tuple of (error_type, error_message) + """ + status = e.response.status_code + if status == 429: + return "rate_limit", f"Agent Generator rate limited: {e}" + elif status == 503: + return "service_unavailable", f"Agent Generator unavailable: {e}" + elif status == 504 or status == 408: + return "timeout", f"Agent Generator timed out: {e}" + else: + return "http_error", f"HTTP error calling Agent Generator: {e}" + + +def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]: + """Classify a request error into error_type and message. + + Args: + e: The request error + + Returns: + Tuple of (error_type, error_message) + """ + error_str = str(e).lower() + if "timeout" in error_str or "timed out" in error_str: + return "timeout", f"Agent Generator request timed out: {e}" + elif "connect" in error_str: + return "connection_error", f"Could not connect to Agent Generator: {e}" + else: + return "request_error", f"Request error calling Agent Generator: {e}" + + +_client: httpx.AsyncClient | None = None +_settings: Settings | None = None + + +def _get_settings() -> Settings: + """Get or create settings singleton.""" + global _settings + if _settings is None: + _settings = Settings() + return _settings + + +def is_external_service_configured() -> bool: + """Check if external Agent Generator service is configured.""" + settings = _get_settings() + return bool(settings.config.agentgenerator_host) + + +def _get_base_url() -> str: + """Get the base URL for the external service.""" + settings = _get_settings() + host = settings.config.agentgenerator_host + port = settings.config.agentgenerator_port + return f"http://{host}:{port}" + + +def _get_client() -> httpx.AsyncClient: + """Get or create the HTTP client for the external service.""" + global _client + if _client is None: + settings = _get_settings() + _client = httpx.AsyncClient( + base_url=_get_base_url(), + timeout=httpx.Timeout(settings.config.agentgenerator_timeout), + ) + return _client + + +async def decompose_goal_external( + description: str, context: str = "" +) -> dict[str, Any] | None: + """Call the external service to decompose a goal. + + Args: + description: Natural language goal description + context: Additional context (e.g., answers to previous questions) + + Returns: + Dict with either: + - {"type": "clarifying_questions", "questions": [...]} + - {"type": "instructions", "steps": [...]} + - {"type": "unachievable_goal", ...} + - {"type": "vague_goal", ...} + - {"type": "error", "error": "...", "error_type": "..."} on error + Or None on unexpected error + """ + client = _get_client() + + # Build the request payload + payload: dict[str, Any] = {"description": description} + if context: + # The external service uses user_instruction for additional context + payload["user_instruction"] = context + + try: + response = await client.post("/api/decompose-description", json=payload) + response.raise_for_status() + data = response.json() + + if not data.get("success"): + error_msg = data.get("error", "Unknown error from Agent Generator") + error_type = data.get("error_type", "unknown") + logger.error( + f"Agent Generator decomposition failed: {error_msg} " + f"(type: {error_type})" + ) + return _create_error_response(error_msg, error_type) + + # Map the response to the expected format + response_type = data.get("type") + if response_type == "instructions": + return {"type": "instructions", "steps": data.get("steps", [])} + elif response_type == "clarifying_questions": + return { + "type": "clarifying_questions", + "questions": data.get("questions", []), + } + elif response_type == "unachievable_goal": + return { + "type": "unachievable_goal", + "reason": data.get("reason"), + "suggested_goal": data.get("suggested_goal"), + } + elif response_type == "vague_goal": + return { + "type": "vague_goal", + "suggested_goal": data.get("suggested_goal"), + } + elif response_type == "error": + # Pass through error from the service + return _create_error_response( + data.get("error", "Unknown error"), + data.get("error_type", "unknown"), + ) + else: + logger.error( + f"Unknown response type from external service: {response_type}" + ) + return _create_error_response( + f"Unknown response type from Agent Generator: {response_type}", + "invalid_response", + ) + + except httpx.HTTPStatusError as e: + error_type, error_msg = _classify_http_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) + except httpx.RequestError as e: + error_type, error_msg = _classify_request_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) + except Exception as e: + error_msg = f"Unexpected error calling Agent Generator: {e}" + logger.error(error_msg) + return _create_error_response(error_msg, "unexpected_error") + + +async def generate_agent_external( + instructions: dict[str, Any], +) -> dict[str, Any] | None: + """Call the external service to generate an agent from instructions. + + Args: + instructions: Structured instructions from decompose_goal + + Returns: + Agent JSON dict on success, or error dict {"type": "error", ...} on error + """ + client = _get_client() + + try: + response = await client.post( + "/api/generate-agent", json={"instructions": instructions} + ) + response.raise_for_status() + data = response.json() + + if not data.get("success"): + error_msg = data.get("error", "Unknown error from Agent Generator") + error_type = data.get("error_type", "unknown") + logger.error( + f"Agent Generator generation failed: {error_msg} " + f"(type: {error_type})" + ) + return _create_error_response(error_msg, error_type) + + return data.get("agent_json") + + except httpx.HTTPStatusError as e: + error_type, error_msg = _classify_http_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) + except httpx.RequestError as e: + error_type, error_msg = _classify_request_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) + except Exception as e: + error_msg = f"Unexpected error calling Agent Generator: {e}" + logger.error(error_msg) + return _create_error_response(error_msg, "unexpected_error") + + +async def generate_agent_patch_external( + update_request: str, current_agent: dict[str, Any] +) -> dict[str, Any] | None: + """Call the external service to generate a patch for an existing agent. + + Args: + update_request: Natural language description of changes + current_agent: Current agent JSON + + Returns: + Updated agent JSON, clarifying questions dict, or error dict on error + """ + client = _get_client() + + try: + response = await client.post( + "/api/update-agent", + json={ + "update_request": update_request, + "current_agent_json": current_agent, + }, + ) + response.raise_for_status() + data = response.json() + + if not data.get("success"): + error_msg = data.get("error", "Unknown error from Agent Generator") + error_type = data.get("error_type", "unknown") + logger.error( + f"Agent Generator patch generation failed: {error_msg} " + f"(type: {error_type})" + ) + return _create_error_response(error_msg, error_type) + + # Check if it's clarifying questions + if data.get("type") == "clarifying_questions": + return { + "type": "clarifying_questions", + "questions": data.get("questions", []), + } + + # Check if it's an error passed through + if data.get("type") == "error": + return _create_error_response( + data.get("error", "Unknown error"), + data.get("error_type", "unknown"), + ) + + # Otherwise return the updated agent JSON + return data.get("agent_json") + + except httpx.HTTPStatusError as e: + error_type, error_msg = _classify_http_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) + except httpx.RequestError as e: + error_type, error_msg = _classify_request_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) + except Exception as e: + error_msg = f"Unexpected error calling Agent Generator: {e}" + logger.error(error_msg) + return _create_error_response(error_msg, "unexpected_error") + + +async def get_blocks_external() -> list[dict[str, Any]] | None: + """Get available blocks from the external service. + + Returns: + List of block info dicts or None on error + """ + client = _get_client() + + try: + response = await client.get("/api/blocks") + response.raise_for_status() + data = response.json() + + if not data.get("success"): + logger.error("External service returned error getting blocks") + return None + + return data.get("blocks", []) + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error getting blocks from external service: {e}") + return None + except httpx.RequestError as e: + logger.error(f"Request error getting blocks from external service: {e}") + return None + except Exception as e: + logger.error(f"Unexpected error getting blocks from external service: {e}") + return None + + +async def health_check() -> bool: + """Check if the external service is healthy. + + Returns: + True if healthy, False otherwise + """ + if not is_external_service_configured(): + return False + + client = _get_client() + + try: + response = await client.get("/health") + response.raise_for_status() + data = response.json() + return data.get("status") == "healthy" and data.get("blocks_loaded", False) + except Exception as e: + logger.warning(f"External agent generator health check failed: {e}") + return False + + +async def close_client() -> None: + """Close the HTTP client.""" + global _client + if _client is not None: + await _client.aclose() + _client = None diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/utils.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/utils.py deleted file mode 100644 index 9c3c866c7f..0000000000 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/utils.py +++ /dev/null @@ -1,213 +0,0 @@ -"""Utilities for agent generation.""" - -import json -import re -from typing import Any - -from backend.data.block import get_blocks - -# UUID validation regex -UUID_REGEX = re.compile( - r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$" -) - -# Block IDs for various fixes -STORE_VALUE_BLOCK_ID = "1ff065e9-88e8-4358-9d82-8dc91f622ba9" -CONDITION_BLOCK_ID = "715696a0-e1da-45c8-b209-c2fa9c3b0be6" -ADDTOLIST_BLOCK_ID = "aeb08fc1-2fc1-4141-bc8e-f758f183a822" -ADDTODICTIONARY_BLOCK_ID = "31d1064e-7446-4693-a7d4-65e5ca1180d1" -CREATELIST_BLOCK_ID = "a912d5c7-6e00-4542-b2a9-8034136930e4" -CREATEDICT_BLOCK_ID = "b924ddf4-de4f-4b56-9a85-358930dcbc91" -CODE_EXECUTION_BLOCK_ID = "0b02b072-abe7-11ef-8372-fb5d162dd712" -DATA_SAMPLING_BLOCK_ID = "4a448883-71fa-49cf-91cf-70d793bd7d87" -UNIVERSAL_TYPE_CONVERTER_BLOCK_ID = "95d1b990-ce13-4d88-9737-ba5c2070c97b" -GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1" - -DOUBLE_CURLY_BRACES_BLOCK_IDS = [ - "44f6c8ad-d75c-4ae1-8209-aad1c0326928", # FillTextTemplateBlock - "6ab085e2-20b3-4055-bc3e-08036e01eca6", - "90f8c45e-e983-4644-aa0b-b4ebe2f531bc", - "363ae599-353e-4804-937e-b2ee3cef3da4", # AgentOutputBlock - "3b191d9f-356f-482d-8238-ba04b6d18381", - "db7d8f02-2f44-4c55-ab7a-eae0941f0c30", - "3a7c4b8d-6e2f-4a5d-b9c1-f8d23c5a9b0e", - "ed1ae7a0-b770-4089-b520-1f0005fad19a", - "a892b8d9-3e4e-4e9c-9c1e-75f8efcf1bfa", - "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1", - "716a67b3-6760-42e7-86dc-18645c6e00fc", - "530cf046-2ce0-4854-ae2c-659db17c7a46", - "ed55ac19-356e-4243-a6cb-bc599e9b716f", - "1f292d4a-41a4-4977-9684-7c8d560b9f91", # LLM blocks - "32a87eab-381e-4dd4-bdb8-4c47151be35a", -] - - -def is_valid_uuid(value: str) -> bool: - """Check if a string is a valid UUID v4.""" - return isinstance(value, str) and UUID_REGEX.match(value) is not None - - -def _compact_schema(schema: dict) -> dict[str, str]: - """Extract compact type info from a JSON schema properties dict. - - Returns a dict of {field_name: type_string} for essential info only. - """ - props = schema.get("properties", {}) - result = {} - - for name, prop in props.items(): - # Skip internal/complex fields - if name.startswith("_"): - continue - - # Get type string - type_str = prop.get("type", "any") - - # Handle anyOf/oneOf (optional types) - if "anyOf" in prop: - types = [t.get("type", "?") for t in prop["anyOf"] if t.get("type")] - type_str = "|".join(types) if types else "any" - elif "allOf" in prop: - type_str = "object" - - # Add array item type if present - if type_str == "array" and "items" in prop: - items = prop["items"] - if isinstance(items, dict): - item_type = items.get("type", "any") - type_str = f"array[{item_type}]" - - result[name] = type_str - - return result - - -def get_block_summaries(include_schemas: bool = True) -> str: - """Generate compact block summaries for prompts. - - Args: - include_schemas: Whether to include input/output type info - - Returns: - Formatted string of block summaries (compact format) - """ - blocks = get_blocks() - summaries = [] - - for block_id, block_cls in blocks.items(): - block = block_cls() - name = block.name - desc = getattr(block, "description", "") or "" - - # Truncate description - if len(desc) > 150: - desc = desc[:147] + "..." - - if not include_schemas: - summaries.append(f"- {name} (id: {block_id}): {desc}") - else: - # Compact format with type info only - inputs = {} - outputs = {} - required = [] - - if hasattr(block, "input_schema"): - try: - schema = block.input_schema.jsonschema() - inputs = _compact_schema(schema) - required = schema.get("required", []) - except Exception: - pass - - if hasattr(block, "output_schema"): - try: - schema = block.output_schema.jsonschema() - outputs = _compact_schema(schema) - except Exception: - pass - - # Build compact line format - # Format: NAME (id): desc | in: {field:type, ...} [required] | out: {field:type} - in_str = ", ".join(f"{k}:{v}" for k, v in inputs.items()) - out_str = ", ".join(f"{k}:{v}" for k, v in outputs.items()) - req_str = f" req=[{','.join(required)}]" if required else "" - - static = " [static]" if getattr(block, "static_output", False) else "" - - line = f"- {name} (id: {block_id}): {desc}" - if in_str: - line += f"\n in: {{{in_str}}}{req_str}" - if out_str: - line += f"\n out: {{{out_str}}}{static}" - - summaries.append(line) - - return "\n".join(summaries) - - -def get_blocks_info() -> list[dict[str, Any]]: - """Get block information with schemas for validation and fixing.""" - blocks = get_blocks() - blocks_info = [] - for block_id, block_cls in blocks.items(): - block = block_cls() - blocks_info.append( - { - "id": block_id, - "name": block.name, - "description": getattr(block, "description", ""), - "categories": getattr(block, "categories", []), - "staticOutput": getattr(block, "static_output", False), - "inputSchema": ( - block.input_schema.jsonschema() - if hasattr(block, "input_schema") - else {} - ), - "outputSchema": ( - block.output_schema.jsonschema() - if hasattr(block, "output_schema") - else {} - ), - } - ) - return blocks_info - - -def parse_json_from_llm(text: str) -> dict[str, Any] | None: - """Extract JSON from LLM response (handles markdown code blocks).""" - if not text: - return None - - # Try fenced code block - match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE) - if match: - try: - return json.loads(match.group(1).strip()) - except json.JSONDecodeError: - pass - - # Try raw text - try: - return json.loads(text.strip()) - except json.JSONDecodeError: - pass - - # Try finding {...} span - start = text.find("{") - end = text.rfind("}") - if start != -1 and end > start: - try: - return json.loads(text[start : end + 1]) - except json.JSONDecodeError: - pass - - # Try finding [...] span - start = text.find("[") - end = text.rfind("]") - if start != -1 and end > start: - try: - return json.loads(text[start : end + 1]) - except json.JSONDecodeError: - pass - - return None diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/validator.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/validator.py deleted file mode 100644 index c913e92bfd..0000000000 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/validator.py +++ /dev/null @@ -1,279 +0,0 @@ -"""Agent validator - Validates agent structure and connections.""" - -import logging -import re -from typing import Any - -from .utils import get_blocks_info - -logger = logging.getLogger(__name__) - - -class AgentValidator: - """Validator for AutoGPT agents with detailed error reporting.""" - - def __init__(self): - self.errors: list[str] = [] - - def add_error(self, error: str) -> None: - """Add an error message.""" - self.errors.append(error) - - def validate_block_existence( - self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] - ) -> bool: - """Validate all block IDs exist in the blocks library.""" - valid = True - valid_block_ids = {b.get("id") for b in blocks_info if b.get("id")} - - for node in agent.get("nodes", []): - block_id = node.get("block_id") - node_id = node.get("id") - - if not block_id: - self.add_error(f"Node '{node_id}' is missing 'block_id' field.") - valid = False - continue - - if block_id not in valid_block_ids: - self.add_error( - f"Node '{node_id}' references block_id '{block_id}' which does not exist." - ) - valid = False - - return valid - - def validate_link_node_references(self, agent: dict[str, Any]) -> bool: - """Validate all node IDs referenced in links exist.""" - valid = True - valid_node_ids = {n.get("id") for n in agent.get("nodes", []) if n.get("id")} - - for link in agent.get("links", []): - link_id = link.get("id", "Unknown") - source_id = link.get("source_id") - sink_id = link.get("sink_id") - - if not source_id: - self.add_error(f"Link '{link_id}' is missing 'source_id'.") - valid = False - elif source_id not in valid_node_ids: - self.add_error( - f"Link '{link_id}' references non-existent source_id '{source_id}'." - ) - valid = False - - if not sink_id: - self.add_error(f"Link '{link_id}' is missing 'sink_id'.") - valid = False - elif sink_id not in valid_node_ids: - self.add_error( - f"Link '{link_id}' references non-existent sink_id '{sink_id}'." - ) - valid = False - - return valid - - def validate_required_inputs( - self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] - ) -> bool: - """Validate required inputs are provided.""" - valid = True - block_map = {b.get("id"): b for b in blocks_info} - - for node in agent.get("nodes", []): - block_id = node.get("block_id") - block = block_map.get(block_id) - - if not block: - continue - - required_inputs = block.get("inputSchema", {}).get("required", []) - input_defaults = node.get("input_default", {}) - node_id = node.get("id") - - # Get linked inputs - linked_inputs = { - link["sink_name"] - for link in agent.get("links", []) - if link.get("sink_id") == node_id - } - - for req_input in required_inputs: - if ( - req_input not in input_defaults - and req_input not in linked_inputs - and req_input != "credentials" - ): - block_name = block.get("name", "Unknown Block") - self.add_error( - f"Node '{node_id}' ({block_name}) is missing required input '{req_input}'." - ) - valid = False - - return valid - - def validate_data_type_compatibility( - self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] - ) -> bool: - """Validate linked data types are compatible.""" - valid = True - block_map = {b.get("id"): b for b in blocks_info} - node_lookup = {n.get("id"): n for n in agent.get("nodes", [])} - - def get_type(schema: dict, name: str) -> str | None: - if "_#_" in name: - parent, child = name.split("_#_", 1) - parent_schema = schema.get(parent, {}) - if "properties" in parent_schema: - return parent_schema["properties"].get(child, {}).get("type") - return None - return schema.get(name, {}).get("type") - - def are_compatible(src: str, sink: str) -> bool: - if {src, sink} <= {"integer", "number"}: - return True - return src == sink - - for link in agent.get("links", []): - source_node = node_lookup.get(link.get("source_id")) - sink_node = node_lookup.get(link.get("sink_id")) - - if not source_node or not sink_node: - continue - - source_block = block_map.get(source_node.get("block_id")) - sink_block = block_map.get(sink_node.get("block_id")) - - if not source_block or not sink_block: - continue - - source_outputs = source_block.get("outputSchema", {}).get("properties", {}) - sink_inputs = sink_block.get("inputSchema", {}).get("properties", {}) - - source_type = get_type(source_outputs, link.get("source_name", "")) - sink_type = get_type(sink_inputs, link.get("sink_name", "")) - - if source_type and sink_type and not are_compatible(source_type, sink_type): - self.add_error( - f"Type mismatch: {source_block.get('name')} output '{link['source_name']}' " - f"({source_type}) -> {sink_block.get('name')} input '{link['sink_name']}' ({sink_type})." - ) - valid = False - - return valid - - def validate_nested_sink_links( - self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] - ) -> bool: - """Validate nested sink links (with _#_ notation).""" - valid = True - block_map = {b.get("id"): b for b in blocks_info} - node_lookup = {n.get("id"): n for n in agent.get("nodes", [])} - - for link in agent.get("links", []): - sink_name = link.get("sink_name", "") - - if "_#_" in sink_name: - parent, child = sink_name.split("_#_", 1) - - sink_node = node_lookup.get(link.get("sink_id")) - if not sink_node: - continue - - block = block_map.get(sink_node.get("block_id")) - if not block: - continue - - input_props = block.get("inputSchema", {}).get("properties", {}) - parent_schema = input_props.get(parent) - - if not parent_schema: - self.add_error( - f"Invalid nested link '{sink_name}': parent '{parent}' not found." - ) - valid = False - continue - - if not parent_schema.get("additionalProperties"): - if not ( - isinstance(parent_schema, dict) - and "properties" in parent_schema - and child in parent_schema.get("properties", {}) - ): - self.add_error( - f"Invalid nested link '{sink_name}': child '{child}' not found in '{parent}'." - ) - valid = False - - return valid - - def validate_prompt_spaces(self, agent: dict[str, Any]) -> bool: - """Validate prompts don't have spaces in template variables.""" - valid = True - - for node in agent.get("nodes", []): - input_default = node.get("input_default", {}) - prompt = input_default.get("prompt", "") - - if not isinstance(prompt, str): - continue - - # Find {{...}} with spaces - matches = re.finditer(r"\{\{([^}]+)\}\}", prompt) - for match in matches: - content = match.group(1) - if " " in content: - self.add_error( - f"Node '{node.get('id')}' has spaces in template variable: " - f"'{{{{{content}}}}}' should be '{{{{{content.replace(' ', '_')}}}}}'." - ) - valid = False - - return valid - - def validate( - self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None - ) -> tuple[bool, str | None]: - """Run all validations. - - Returns: - Tuple of (is_valid, error_message) - """ - self.errors = [] - - if blocks_info is None: - blocks_info = get_blocks_info() - - checks = [ - self.validate_block_existence(agent, blocks_info), - self.validate_link_node_references(agent), - self.validate_required_inputs(agent, blocks_info), - self.validate_data_type_compatibility(agent, blocks_info), - self.validate_nested_sink_links(agent, blocks_info), - self.validate_prompt_spaces(agent), - ] - - all_passed = all(checks) - - if all_passed: - logger.info("Agent validation successful") - return True, None - - error_message = "Agent validation failed:\n" - for i, error in enumerate(self.errors, 1): - error_message += f"{i}. {error}\n" - - logger.warning(f"Agent validation failed with {len(self.errors)} errors") - return False, error_message - - -def validate_agent( - agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None -) -> tuple[bool, str | None]: - """Convenience function to validate an agent. - - Returns: - Tuple of (is_valid, error_message) - """ - validator = AgentValidator() - return validator.validate(agent, blocks_info) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_output.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_output.py index 00c6d8499b..457e4a4f9b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_output.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_output.py @@ -5,7 +5,6 @@ import re from datetime import datetime, timedelta, timezone from typing import Any -from langfuse import observe from pydantic import BaseModel, field_validator from backend.api.features.chat.model import ChatSession @@ -329,7 +328,6 @@ class AgentOutputTool(BaseTool): total_executions=len(available_executions) if available_executions else 1, ) - @observe(as_type="tool", name="view_agent_output") async def _execute( self, user_id: str | None, diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/base.py b/autogpt_platform/backend/backend/api/features/chat/tools/base.py index 1dc40c18c7..809e06632b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/base.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/base.py @@ -36,6 +36,16 @@ class BaseTool: """Whether this tool requires authentication.""" return False + @property + def is_long_running(self) -> bool: + """Whether this tool is long-running and should execute in background. + + Long-running tools (like agent generation) are executed via background + tasks to survive SSE disconnections. The result is persisted to chat + history and visible when the user refreshes. + """ + return False + def as_openai_tool(self) -> ChatCompletionToolParam: """Convert to OpenAI tool format.""" return ChatCompletionToolParam( diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py index 26c980c6c5..74011c7e95 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py @@ -3,17 +3,14 @@ import logging from typing import Any -from langfuse import observe - from backend.api.features.chat.model import ChatSession from .agent_generator import ( - apply_all_fixes, + AgentGeneratorNotConfiguredError, decompose_goal, generate_agent, - get_blocks_info, + get_user_message_for_error, save_agent_to_library, - validate_agent, ) from .base import BaseTool from .models import ( @@ -27,9 +24,6 @@ from .models import ( logger = logging.getLogger(__name__) -# Maximum retries for agent generation with validation feedback -MAX_GENERATION_RETRIES = 2 - class CreateAgentTool(BaseTool): """Tool for creating agents from natural language descriptions.""" @@ -49,6 +43,10 @@ class CreateAgentTool(BaseTool): def requires_auth(self) -> bool: return True + @property + def is_long_running(self) -> bool: + return True + @property def parameters(self) -> dict[str, Any]: return { @@ -80,7 +78,6 @@ class CreateAgentTool(BaseTool): "required": ["description"], } - @observe(as_type="tool", name="create_agent") async def _execute( self, user_id: str | None, @@ -91,9 +88,8 @@ class CreateAgentTool(BaseTool): Flow: 1. Decompose the description into steps (may return clarifying questions) - 2. Generate agent JSON from the steps - 3. Apply fixes to correct common LLM errors - 4. Preview or save based on the save parameter + 2. Generate agent JSON (external service handles fixing and validation) + 3. Preview or save based on the save parameter """ description = kwargs.get("description", "").strip() context = kwargs.get("context", "") @@ -110,18 +106,41 @@ class CreateAgentTool(BaseTool): # Step 1: Decompose goal into steps try: decomposition_result = await decompose_goal(description, context) - except ValueError as e: - # Handle missing API key or configuration errors + except AgentGeneratorNotConfiguredError: return ErrorResponse( - message=f"Agent generation is not configured: {str(e)}", - error="configuration_error", + message=( + "Agent generation is not available. " + "The Agent Generator service is not configured." + ), + error="service_not_configured", session_id=session_id, ) if decomposition_result is None: return ErrorResponse( - message="Failed to analyze the goal. Please try rephrasing.", - error="Decomposition failed", + message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.", + error="decomposition_failed", + details={"description": description[:100]}, + session_id=session_id, + ) + + # Check if the result is an error from the external service + if decomposition_result.get("type") == "error": + error_msg = decomposition_result.get("error", "Unknown error") + error_type = decomposition_result.get("error_type", "unknown") + user_message = get_user_message_for_error( + error_type, + operation="analyze the goal", + llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.", + ) + return ErrorResponse( + message=user_message, + error=f"decomposition_failed:{error_type}", + details={ + "description": description[:100], + "service_error": error_msg, + "error_type": error_type, + }, session_id=session_id, ) @@ -171,72 +190,54 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) - # Step 2: Generate agent JSON with retry on validation failure - blocks_info = get_blocks_info() - agent_json = None - validation_errors = None - - for attempt in range(MAX_GENERATION_RETRIES + 1): - # Generate agent (include validation errors from previous attempt) - if attempt == 0: - agent_json = await generate_agent(decomposition_result) - else: - # Retry with validation error feedback - logger.info( - f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback" - ) - retry_instructions = { - **decomposition_result, - "previous_errors": validation_errors, - "retry_instructions": ( - "The previous generation had validation errors. " - "Please fix these issues in the new generation:\n" - f"{validation_errors}" - ), - } - agent_json = await generate_agent(retry_instructions) - - if agent_json is None: - if attempt == MAX_GENERATION_RETRIES: - return ErrorResponse( - message="Failed to generate the agent. Please try again.", - error="Generation failed", - session_id=session_id, - ) - continue - - # Step 3: Apply fixes to correct common errors - agent_json = apply_all_fixes(agent_json, blocks_info) - - # Step 4: Validate the agent - is_valid, validation_errors = validate_agent(agent_json, blocks_info) - - if is_valid: - logger.info(f"Agent generated successfully on attempt {attempt + 1}") - break - - logger.warning( - f"Validation failed on attempt {attempt + 1}: {validation_errors}" + # Step 2: Generate agent JSON (external service handles fixing and validation) + try: + agent_json = await generate_agent(decomposition_result) + except AgentGeneratorNotConfiguredError: + return ErrorResponse( + message=( + "Agent generation is not available. " + "The Agent Generator service is not configured." + ), + error="service_not_configured", + session_id=session_id, ) - if attempt == MAX_GENERATION_RETRIES: - # Return error with validation details - return ErrorResponse( - message=( - f"Generated agent has validation errors after {MAX_GENERATION_RETRIES + 1} attempts. " - f"Please try rephrasing your request or simplify the workflow." - ), - error="validation_failed", - details={"validation_errors": validation_errors}, - session_id=session_id, - ) + if agent_json is None: + return ErrorResponse( + message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.", + error="generation_failed", + details={"description": description[:100]}, + session_id=session_id, + ) + + # Check if the result is an error from the external service + if isinstance(agent_json, dict) and agent_json.get("type") == "error": + error_msg = agent_json.get("error", "Unknown error") + error_type = agent_json.get("error_type", "unknown") + user_message = get_user_message_for_error( + error_type, + operation="generate the agent", + llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.", + validation_message="The generated agent failed validation. Please try rephrasing your goal.", + ) + return ErrorResponse( + message=user_message, + error=f"generation_failed:{error_type}", + details={ + "description": description[:100], + "service_error": error_msg, + "error_type": error_type, + }, + session_id=session_id, + ) agent_name = agent_json.get("name", "Generated Agent") agent_description = agent_json.get("description", "") node_count = len(agent_json.get("nodes", [])) link_count = len(agent_json.get("links", [])) - # Step 4: Preview or save + # Step 3: Preview or save if not save: return AgentPreviewResponse( message=( diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py index a50a89c5c7..ee8eee53ce 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py @@ -3,18 +3,14 @@ import logging from typing import Any -from langfuse import observe - from backend.api.features.chat.model import ChatSession from .agent_generator import ( - apply_agent_patch, - apply_all_fixes, + AgentGeneratorNotConfiguredError, generate_agent_patch, get_agent_as_json, - get_blocks_info, + get_user_message_for_error, save_agent_to_library, - validate_agent, ) from .base import BaseTool from .models import ( @@ -28,9 +24,6 @@ from .models import ( logger = logging.getLogger(__name__) -# Maximum retries for patch generation with validation feedback -MAX_GENERATION_RETRIES = 2 - class EditAgentTool(BaseTool): """Tool for editing existing agents using natural language.""" @@ -43,13 +36,17 @@ class EditAgentTool(BaseTool): def description(self) -> str: return ( "Edit an existing agent from the user's library using natural language. " - "Generates a patch to update the agent while preserving unchanged parts." + "Generates updates to the agent while preserving unchanged parts." ) @property def requires_auth(self) -> bool: return True + @property + def is_long_running(self) -> bool: + return True + @property def parameters(self) -> dict[str, Any]: return { @@ -87,7 +84,6 @@ class EditAgentTool(BaseTool): "required": ["agent_id", "changes"], } - @observe(as_type="tool", name="edit_agent") async def _execute( self, user_id: str | None, @@ -98,9 +94,8 @@ class EditAgentTool(BaseTool): Flow: 1. Fetch the current agent - 2. Generate a patch based on the requested changes - 3. Apply the patch to create an updated agent - 4. Preview or save based on the save parameter + 2. Generate updated agent (external service handles fixing and validation) + 3. Preview or save based on the save parameter """ agent_id = kwargs.get("agent_id", "").strip() changes = kwargs.get("changes", "").strip() @@ -137,121 +132,81 @@ class EditAgentTool(BaseTool): if context: update_request = f"{changes}\n\nAdditional context:\n{context}" - # Step 2: Generate patch with retry on validation failure - blocks_info = get_blocks_info() - updated_agent = None - validation_errors = None - intent = "Applied requested changes" - - for attempt in range(MAX_GENERATION_RETRIES + 1): - # Generate patch (include validation errors from previous attempt) - try: - if attempt == 0: - patch_result = await generate_agent_patch( - update_request, current_agent - ) - else: - # Retry with validation error feedback - logger.info( - f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback" - ) - retry_request = ( - f"{update_request}\n\n" - f"IMPORTANT: The previous edit had validation errors. " - f"Please fix these issues:\n{validation_errors}" - ) - patch_result = await generate_agent_patch( - retry_request, current_agent - ) - except ValueError as e: - # Handle missing API key or configuration errors - return ErrorResponse( - message=f"Agent generation is not configured: {str(e)}", - error="configuration_error", - session_id=session_id, - ) - - if patch_result is None: - if attempt == MAX_GENERATION_RETRIES: - return ErrorResponse( - message="Failed to generate changes. Please try rephrasing.", - error="Patch generation failed", - session_id=session_id, - ) - continue - - # Check if LLM returned clarifying questions - if patch_result.get("type") == "clarifying_questions": - questions = patch_result.get("questions", []) - return ClarificationNeededResponse( - message=( - "I need some more information about the changes. " - "Please answer the following questions:" - ), - questions=[ - ClarifyingQuestion( - question=q.get("question", ""), - keyword=q.get("keyword", ""), - example=q.get("example"), - ) - for q in questions - ], - session_id=session_id, - ) - - # Step 3: Apply patch and fixes - try: - updated_agent = apply_agent_patch(current_agent, patch_result) - updated_agent = apply_all_fixes(updated_agent, blocks_info) - except Exception as e: - if attempt == MAX_GENERATION_RETRIES: - return ErrorResponse( - message=f"Failed to apply changes: {str(e)}", - error="patch_apply_failed", - details={"exception": str(e)}, - session_id=session_id, - ) - validation_errors = str(e) - continue - - # Step 4: Validate the updated agent - is_valid, validation_errors = validate_agent(updated_agent, blocks_info) - - if is_valid: - logger.info(f"Agent edited successfully on attempt {attempt + 1}") - intent = patch_result.get("intent", "Applied requested changes") - break - - logger.warning( - f"Validation failed on attempt {attempt + 1}: {validation_errors}" + # Step 2: Generate updated agent (external service handles fixing and validation) + try: + result = await generate_agent_patch(update_request, current_agent) + except AgentGeneratorNotConfiguredError: + return ErrorResponse( + message=( + "Agent editing is not available. " + "The Agent Generator service is not configured." + ), + error="service_not_configured", + session_id=session_id, ) - if attempt == MAX_GENERATION_RETRIES: - # Return error with validation details - return ErrorResponse( - message=( - f"Updated agent has validation errors after " - f"{MAX_GENERATION_RETRIES + 1} attempts. " - f"Please try rephrasing your request or simplify the changes." - ), - error="validation_failed", - details={"validation_errors": validation_errors}, - session_id=session_id, - ) + if result is None: + return ErrorResponse( + message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.", + error="update_generation_failed", + details={"agent_id": agent_id, "changes": changes[:100]}, + session_id=session_id, + ) - # At this point, updated_agent is guaranteed to be set (we return on all failure paths) - assert updated_agent is not None + # Check if the result is an error from the external service + if isinstance(result, dict) and result.get("type") == "error": + error_msg = result.get("error", "Unknown error") + error_type = result.get("error_type", "unknown") + user_message = get_user_message_for_error( + error_type, + operation="generate the changes", + llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.", + validation_message="The generated changes failed validation. Please try rephrasing your request.", + ) + return ErrorResponse( + message=user_message, + error=f"update_generation_failed:{error_type}", + details={ + "agent_id": agent_id, + "changes": changes[:100], + "service_error": error_msg, + "error_type": error_type, + }, + session_id=session_id, + ) + + # Check if LLM returned clarifying questions + if result.get("type") == "clarifying_questions": + questions = result.get("questions", []) + return ClarificationNeededResponse( + message=( + "I need some more information about the changes. " + "Please answer the following questions:" + ), + questions=[ + ClarifyingQuestion( + question=q.get("question", ""), + keyword=q.get("keyword", ""), + example=q.get("example"), + ) + for q in questions + ], + session_id=session_id, + ) + + # Result is the updated agent JSON + updated_agent = result agent_name = updated_agent.get("name", "Updated Agent") agent_description = updated_agent.get("description", "") node_count = len(updated_agent.get("nodes", [])) link_count = len(updated_agent.get("links", [])) - # Step 5: Preview or save + # Step 3: Preview or save if not save: return AgentPreviewResponse( message=( - f"I've updated the agent. Changes: {intent}. " + f"I've updated the agent. " f"The agent now has {node_count} blocks. " f"Review it and call edit_agent with save=true to save the changes." ), @@ -277,10 +232,7 @@ class EditAgentTool(BaseTool): ) return AgentSavedResponse( - message=( - f"Updated agent '{created_graph.name}' has been saved to your library! " - f"Changes: {intent}" - ), + message=f"Updated agent '{created_graph.name}' has been saved to your library!", agent_id=created_graph.id, agent_name=created_graph.name, library_agent_id=library_agent.id, diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/find_agent.py index f231ef4484..477522757d 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/find_agent.py @@ -2,8 +2,6 @@ from typing import Any -from langfuse import observe - from backend.api.features.chat.model import ChatSession from .agent_search import search_agents @@ -37,7 +35,6 @@ class FindAgentTool(BaseTool): "required": ["query"], } - @observe(as_type="tool", name="find_agent") async def _execute( self, user_id: str | None, session: ChatSession, **kwargs ) -> ToolResponseBase: diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py index fc20fdfc4a..7ca85961f9 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py @@ -1,7 +1,6 @@ import logging from typing import Any -from langfuse import observe from prisma.enums import ContentType from backend.api.features.chat.model import ChatSession @@ -56,7 +55,6 @@ class FindBlockTool(BaseTool): def requires_auth(self) -> bool: return True - @observe(as_type="tool", name="find_block") async def _execute( self, user_id: str | None, @@ -109,7 +107,8 @@ class FindBlockTool(BaseTool): block_id = result["content_id"] block = get_block(block_id) - if block: + # Skip disabled blocks + if block and not block.disabled: # Get input/output schemas input_schema = {} output_schema = {} diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_library_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/find_library_agent.py index d9b5edfa9b..108fba75ae 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_library_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/find_library_agent.py @@ -2,8 +2,6 @@ from typing import Any -from langfuse import observe - from backend.api.features.chat.model import ChatSession from .agent_search import search_agents @@ -43,7 +41,6 @@ class FindLibraryAgentTool(BaseTool): def requires_auth(self) -> bool: return True - @observe(as_type="tool", name="find_library_agent") async def _execute( self, user_id: str | None, session: ChatSession, **kwargs ) -> ToolResponseBase: diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/get_doc_page.py b/autogpt_platform/backend/backend/api/features/chat/tools/get_doc_page.py index b2fdcccfcd..7040cd7db5 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/get_doc_page.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/get_doc_page.py @@ -4,8 +4,6 @@ import logging from pathlib import Path from typing import Any -from langfuse import observe - from backend.api.features.chat.model import ChatSession from backend.api.features.chat.tools.base import BaseTool from backend.api.features.chat.tools.models import ( @@ -73,7 +71,6 @@ class GetDocPageTool(BaseTool): url_path = path.rsplit(".", 1)[0] if "." in path else path return f"{DOCS_BASE_URL}/{url_path}" - @observe(as_type="tool", name="get_doc_page") async def _execute( self, user_id: str | None, diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/models.py b/autogpt_platform/backend/backend/api/features/chat/tools/models.py index 1736ddb9a8..49b233784e 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/models.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/models.py @@ -28,6 +28,16 @@ class ResponseType(str, Enum): BLOCK_OUTPUT = "block_output" DOC_SEARCH_RESULTS = "doc_search_results" DOC_PAGE = "doc_page" + # Workspace response types + WORKSPACE_FILE_LIST = "workspace_file_list" + WORKSPACE_FILE_CONTENT = "workspace_file_content" + WORKSPACE_FILE_METADATA = "workspace_file_metadata" + WORKSPACE_FILE_WRITTEN = "workspace_file_written" + WORKSPACE_FILE_DELETED = "workspace_file_deleted" + # Long-running operation types + OPERATION_STARTED = "operation_started" + OPERATION_PENDING = "operation_pending" + OPERATION_IN_PROGRESS = "operation_in_progress" # Base response model @@ -334,3 +344,39 @@ class BlockOutputResponse(ToolResponseBase): block_name: str outputs: dict[str, list[Any]] success: bool = True + + +# Long-running operation models +class OperationStartedResponse(ToolResponseBase): + """Response when a long-running operation has been started in the background. + + This is returned immediately to the client while the operation continues + to execute. The user can close the tab and check back later. + """ + + type: ResponseType = ResponseType.OPERATION_STARTED + operation_id: str + tool_name: str + + +class OperationPendingResponse(ToolResponseBase): + """Response stored in chat history while a long-running operation is executing. + + This is persisted to the database so users see a pending state when they + refresh before the operation completes. + """ + + type: ResponseType = ResponseType.OPERATION_PENDING + operation_id: str + tool_name: str + + +class OperationInProgressResponse(ToolResponseBase): + """Response when an operation is already in progress. + + Returned for idempotency when the same tool_call_id is requested again + while the background task is still running. + """ + + type: ResponseType = ResponseType.OPERATION_IN_PROGRESS + tool_call_id: str diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py index b212c11e8a..a7fa65348a 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py @@ -3,11 +3,14 @@ import logging from typing import Any -from langfuse import observe from pydantic import BaseModel, Field, field_validator from backend.api.features.chat.config import ChatConfig from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tracking import ( + track_agent_run_success, + track_agent_scheduled, +) from backend.api.features.library import db as library_db from backend.data.graph import GraphModel from backend.data.model import CredentialsMetaInput @@ -155,7 +158,6 @@ class RunAgentTool(BaseTool): """All operations require authentication.""" return True - @observe(as_type="tool", name="run_agent") async def _execute( self, user_id: str | None, @@ -453,6 +455,16 @@ class RunAgentTool(BaseTool): session.successful_agent_runs.get(library_agent.graph_id, 0) + 1 ) + # Track in PostHog + track_agent_run_success( + user_id=user_id, + session_id=session_id, + graph_id=library_agent.graph_id, + graph_name=library_agent.name, + execution_id=execution.id, + library_agent_id=library_agent.id, + ) + library_agent_link = f"/library/agents/{library_agent.id}" return ExecutionStartedResponse( message=( @@ -534,6 +546,18 @@ class RunAgentTool(BaseTool): session.successful_agent_schedules.get(library_agent.graph_id, 0) + 1 ) + # Track in PostHog + track_agent_scheduled( + user_id=user_id, + session_id=session_id, + graph_id=library_agent.graph_id, + graph_name=library_agent.name, + schedule_id=result.id, + schedule_name=schedule_name, + cron=cron, + library_agent_id=library_agent.id, + ) + library_agent_link = f"/library/agents/{library_agent.id}" return ExecutionStartedResponse( message=( diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py index 9e10304429..404df2adb6 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py @@ -29,7 +29,7 @@ def mock_embedding_functions(): yield -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio(loop_scope="session") async def test_run_agent(setup_test_data): """Test that the run_agent tool successfully executes an approved agent""" # Use test data from fixture @@ -70,7 +70,7 @@ async def test_run_agent(setup_test_data): assert result_data["graph_name"] == "Test Agent" -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio(loop_scope="session") async def test_run_agent_missing_inputs(setup_test_data): """Test that the run_agent tool returns error when inputs are missing""" # Use test data from fixture @@ -106,7 +106,7 @@ async def test_run_agent_missing_inputs(setup_test_data): assert "message" in result_data -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio(loop_scope="session") async def test_run_agent_invalid_agent_id(setup_test_data): """Test that the run_agent tool returns error for invalid agent ID""" # Use test data from fixture @@ -141,7 +141,7 @@ async def test_run_agent_invalid_agent_id(setup_test_data): ) -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio(loop_scope="session") async def test_run_agent_with_llm_credentials(setup_llm_test_data): """Test that run_agent works with an agent requiring LLM credentials""" # Use test data from fixture @@ -185,7 +185,7 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data): assert result_data["graph_name"] == "LLM Test Agent" -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio(loop_scope="session") async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_data): """Test that run_agent returns available inputs when called without inputs or use_defaults.""" user = setup_test_data["user"] @@ -219,7 +219,7 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da assert "inputs" in result_data["message"].lower() -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio(loop_scope="session") async def test_run_agent_with_use_defaults(setup_test_data): """Test that run_agent executes successfully with use_defaults=True.""" user = setup_test_data["user"] @@ -251,7 +251,7 @@ async def test_run_agent_with_use_defaults(setup_test_data): assert result_data["graph_id"] == graph.id -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio(loop_scope="session") async def test_run_agent_missing_credentials(setup_firecrawl_test_data): """Test that run_agent returns setup_requirements when credentials are missing.""" user = setup_firecrawl_test_data["user"] @@ -285,7 +285,7 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data): assert len(setup_info["user_readiness"]["missing_credentials"]) > 0 -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio(loop_scope="session") async def test_run_agent_invalid_slug_format(setup_test_data): """Test that run_agent returns error for invalid slug format (no slash).""" user = setup_test_data["user"] @@ -313,7 +313,7 @@ async def test_run_agent_invalid_slug_format(setup_test_data): assert "username/agent-name" in result_data["message"] -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio(loop_scope="session") async def test_run_agent_unauthenticated(): """Test that run_agent returns need_login for unauthenticated users.""" tool = RunAgentTool() @@ -340,7 +340,7 @@ async def test_run_agent_unauthenticated(): assert "sign in" in result_data["message"].lower() -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio(loop_scope="session") async def test_run_agent_schedule_without_cron(setup_test_data): """Test that run_agent returns error when scheduling without cron expression.""" user = setup_test_data["user"] @@ -372,7 +372,7 @@ async def test_run_agent_schedule_without_cron(setup_test_data): assert "cron" in result_data["message"].lower() -@pytest.mark.asyncio(scope="session") +@pytest.mark.asyncio(loop_scope="session") async def test_run_agent_schedule_without_name(setup_test_data): """Test that run_agent returns error when scheduling without schedule_name.""" user = setup_test_data["user"] diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py index c29cc92556..a59082b399 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py @@ -1,15 +1,15 @@ """Tool for executing blocks directly.""" import logging +import uuid from collections import defaultdict from typing import Any -from langfuse import observe - from backend.api.features.chat.model import ChatSession from backend.data.block import get_block from backend.data.execution import ExecutionContext from backend.data.model import CredentialsMetaInput +from backend.data.workspace import get_or_create_workspace from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.util.exceptions import BlockError @@ -130,7 +130,6 @@ class RunBlockTool(BaseTool): return matched_credentials, missing_credentials - @observe(as_type="tool", name="run_block") async def _execute( self, user_id: str | None, @@ -179,6 +178,11 @@ class RunBlockTool(BaseTool): message=f"Block '{block_id}' not found", session_id=session_id, ) + if block.disabled: + return ErrorResponse( + message=f"Block '{block_id}' is disabled", + session_id=session_id, + ) logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}") @@ -221,11 +225,48 @@ class RunBlockTool(BaseTool): ) try: - # Fetch actual credentials and prepare kwargs for block execution - # Create execution context with defaults (blocks may require it) + # Get or create user's workspace for CoPilot file operations + workspace = await get_or_create_workspace(user_id) + + # Generate synthetic IDs for CoPilot context + # Each chat session is treated as its own agent with one continuous run + # This means: + # - graph_id (agent) = session (memories scoped to session when limit_to_agent=True) + # - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True) + # - node_exec_id = unique per block execution + synthetic_graph_id = f"copilot-session-{session.session_id}" + synthetic_graph_exec_id = f"copilot-session-{session.session_id}" + synthetic_node_id = f"copilot-node-{block_id}" + synthetic_node_exec_id = ( + f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}" + ) + + # Create unified execution context with all required fields + execution_context = ExecutionContext( + # Execution identity + user_id=user_id, + graph_id=synthetic_graph_id, + graph_exec_id=synthetic_graph_exec_id, + graph_version=1, # Versions are 1-indexed + node_id=synthetic_node_id, + node_exec_id=synthetic_node_exec_id, + # Workspace with session scoping + workspace_id=workspace.id, + session_id=session.session_id, + ) + + # Prepare kwargs for block execution + # Keep individual kwargs for backwards compatibility with existing blocks exec_kwargs: dict[str, Any] = { "user_id": user_id, - "execution_context": ExecutionContext(), + "execution_context": execution_context, + # Legacy: individual kwargs for blocks not yet using execution_context + "workspace_id": workspace.id, + "graph_exec_id": synthetic_graph_exec_id, + "node_exec_id": synthetic_node_exec_id, + "node_id": synthetic_node_id, + "graph_version": 1, # Versions are 1-indexed + "graph_id": synthetic_graph_id, } for field_name, cred_meta in matched_credentials.items(): diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/search_docs.py b/autogpt_platform/backend/backend/api/features/chat/tools/search_docs.py index 4903230b40..edb0c0de1e 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/search_docs.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/search_docs.py @@ -3,7 +3,6 @@ import logging from typing import Any -from langfuse import observe from prisma.enums import ContentType from backend.api.features.chat.model import ChatSession @@ -88,7 +87,6 @@ class SearchDocsTool(BaseTool): url_path = path.rsplit(".", 1)[0] if "." in path else path return f"{DOCS_BASE_URL}/{url_path}" - @observe(as_type="tool", name="search_docs") async def _execute( self, user_id: str | None, diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py b/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py new file mode 100644 index 0000000000..03532c8fee --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py @@ -0,0 +1,620 @@ +"""CoPilot tools for workspace file operations.""" + +import base64 +import logging +from typing import Any, Optional + +from pydantic import BaseModel + +from backend.api.features.chat.model import ChatSession +from backend.data.workspace import get_or_create_workspace +from backend.util.settings import Config +from backend.util.virus_scanner import scan_content_safe +from backend.util.workspace import WorkspaceManager + +from .base import BaseTool +from .models import ErrorResponse, ResponseType, ToolResponseBase + +logger = logging.getLogger(__name__) + + +class WorkspaceFileInfoData(BaseModel): + """Data model for workspace file information (not a response itself).""" + + file_id: str + name: str + path: str + mime_type: str + size_bytes: int + + +class WorkspaceFileListResponse(ToolResponseBase): + """Response containing list of workspace files.""" + + type: ResponseType = ResponseType.WORKSPACE_FILE_LIST + files: list[WorkspaceFileInfoData] + total_count: int + + +class WorkspaceFileContentResponse(ToolResponseBase): + """Response containing workspace file content (legacy, for small text files).""" + + type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT + file_id: str + name: str + path: str + mime_type: str + content_base64: str + + +class WorkspaceFileMetadataResponse(ToolResponseBase): + """Response containing workspace file metadata and download URL (prevents context bloat).""" + + type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA + file_id: str + name: str + path: str + mime_type: str + size_bytes: int + download_url: str + preview: str | None = None # First 500 chars for text files + + +class WorkspaceWriteResponse(ToolResponseBase): + """Response after writing a file to workspace.""" + + type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN + file_id: str + name: str + path: str + size_bytes: int + + +class WorkspaceDeleteResponse(ToolResponseBase): + """Response after deleting a file from workspace.""" + + type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED + file_id: str + success: bool + + +class ListWorkspaceFilesTool(BaseTool): + """Tool for listing files in user's workspace.""" + + @property + def name(self) -> str: + return "list_workspace_files" + + @property + def description(self) -> str: + return ( + "List files in the user's workspace. " + "Returns file names, paths, sizes, and metadata. " + "Optionally filter by path prefix." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "path_prefix": { + "type": "string", + "description": ( + "Optional path prefix to filter files " + "(e.g., '/documents/' to list only files in documents folder). " + "By default, only files from the current session are listed." + ), + }, + "limit": { + "type": "integer", + "description": "Maximum number of files to return (default 50, max 100)", + "minimum": 1, + "maximum": 100, + }, + "include_all_sessions": { + "type": "boolean", + "description": ( + "If true, list files from all sessions. " + "Default is false (only current session's files)." + ), + }, + }, + "required": [], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + session_id = session.session_id + + if not user_id: + return ErrorResponse( + message="Authentication required", + session_id=session_id, + ) + + path_prefix: Optional[str] = kwargs.get("path_prefix") + limit = min(kwargs.get("limit", 50), 100) + include_all_sessions: bool = kwargs.get("include_all_sessions", False) + + try: + workspace = await get_or_create_workspace(user_id) + # Pass session_id for session-scoped file access + manager = WorkspaceManager(user_id, workspace.id, session_id) + + files = await manager.list_files( + path=path_prefix, + limit=limit, + include_all_sessions=include_all_sessions, + ) + total = await manager.get_file_count( + path=path_prefix, + include_all_sessions=include_all_sessions, + ) + + file_infos = [ + WorkspaceFileInfoData( + file_id=f.id, + name=f.name, + path=f.path, + mime_type=f.mimeType, + size_bytes=f.sizeBytes, + ) + for f in files + ] + + scope_msg = "all sessions" if include_all_sessions else "current session" + return WorkspaceFileListResponse( + files=file_infos, + total_count=total, + message=f"Found {len(files)} files in workspace ({scope_msg})", + session_id=session_id, + ) + + except Exception as e: + logger.error(f"Error listing workspace files: {e}", exc_info=True) + return ErrorResponse( + message=f"Failed to list workspace files: {str(e)}", + error=str(e), + session_id=session_id, + ) + + +class ReadWorkspaceFileTool(BaseTool): + """Tool for reading file content from workspace.""" + + # Size threshold for returning full content vs metadata+URL + # Files larger than this return metadata with download URL to prevent context bloat + MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB + # Preview size for text files + PREVIEW_SIZE = 500 + + @property + def name(self) -> str: + return "read_workspace_file" + + @property + def description(self) -> str: + return ( + "Read a file from the user's workspace. " + "Specify either file_id or path to identify the file. " + "For small text files, returns content directly. " + "For large or binary files, returns metadata and a download URL. " + "Paths are scoped to the current session by default. " + "Use /sessions//... for cross-session access." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "file_id": { + "type": "string", + "description": "The file's unique ID (from list_workspace_files)", + }, + "path": { + "type": "string", + "description": ( + "The virtual file path (e.g., '/documents/report.pdf'). " + "Scoped to current session by default." + ), + }, + "force_download_url": { + "type": "boolean", + "description": ( + "If true, always return metadata+URL instead of inline content. " + "Default is false (auto-selects based on file size/type)." + ), + }, + }, + "required": [], # At least one must be provided + } + + @property + def requires_auth(self) -> bool: + return True + + def _is_text_mime_type(self, mime_type: str) -> bool: + """Check if the MIME type is a text-based type.""" + text_types = [ + "text/", + "application/json", + "application/xml", + "application/javascript", + "application/x-python", + "application/x-sh", + ] + return any(mime_type.startswith(t) for t in text_types) + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + session_id = session.session_id + + if not user_id: + return ErrorResponse( + message="Authentication required", + session_id=session_id, + ) + + file_id: Optional[str] = kwargs.get("file_id") + path: Optional[str] = kwargs.get("path") + force_download_url: bool = kwargs.get("force_download_url", False) + + if not file_id and not path: + return ErrorResponse( + message="Please provide either file_id or path", + session_id=session_id, + ) + + try: + workspace = await get_or_create_workspace(user_id) + # Pass session_id for session-scoped file access + manager = WorkspaceManager(user_id, workspace.id, session_id) + + # Get file info + if file_id: + file_info = await manager.get_file_info(file_id) + if file_info is None: + return ErrorResponse( + message=f"File not found: {file_id}", + session_id=session_id, + ) + target_file_id = file_id + else: + # path is guaranteed to be non-None here due to the check above + assert path is not None + file_info = await manager.get_file_info_by_path(path) + if file_info is None: + return ErrorResponse( + message=f"File not found at path: {path}", + session_id=session_id, + ) + target_file_id = file_info.id + + # Decide whether to return inline content or metadata+URL + is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES + is_text_file = self._is_text_mime_type(file_info.mimeType) + + # Return inline content for small text files (unless force_download_url) + if is_small_file and is_text_file and not force_download_url: + content = await manager.read_file_by_id(target_file_id) + content_b64 = base64.b64encode(content).decode("utf-8") + + return WorkspaceFileContentResponse( + file_id=file_info.id, + name=file_info.name, + path=file_info.path, + mime_type=file_info.mimeType, + content_base64=content_b64, + message=f"Successfully read file: {file_info.name}", + session_id=session_id, + ) + + # Return metadata + workspace:// reference for large or binary files + # This prevents context bloat (100KB file = ~133KB as base64) + # Use workspace:// format so frontend urlTransform can add proxy prefix + download_url = f"workspace://{target_file_id}" + + # Generate preview for text files + preview: str | None = None + if is_text_file: + try: + content = await manager.read_file_by_id(target_file_id) + preview_text = content[: self.PREVIEW_SIZE].decode( + "utf-8", errors="replace" + ) + if len(content) > self.PREVIEW_SIZE: + preview_text += "..." + preview = preview_text + except Exception: + pass # Preview is optional + + return WorkspaceFileMetadataResponse( + file_id=file_info.id, + name=file_info.name, + path=file_info.path, + mime_type=file_info.mimeType, + size_bytes=file_info.sizeBytes, + download_url=download_url, + preview=preview, + message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.", + session_id=session_id, + ) + + except FileNotFoundError as e: + return ErrorResponse( + message=str(e), + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error reading workspace file: {e}", exc_info=True) + return ErrorResponse( + message=f"Failed to read workspace file: {str(e)}", + error=str(e), + session_id=session_id, + ) + + +class WriteWorkspaceFileTool(BaseTool): + """Tool for writing files to workspace.""" + + @property + def name(self) -> str: + return "write_workspace_file" + + @property + def description(self) -> str: + return ( + "Write or create a file in the user's workspace. " + "Provide the content as a base64-encoded string. " + f"Maximum file size is {Config().max_file_size_mb}MB. " + "Files are saved to the current session's folder by default. " + "Use /sessions//... for cross-session access." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "Name for the file (e.g., 'report.pdf')", + }, + "content_base64": { + "type": "string", + "description": "Base64-encoded file content", + }, + "path": { + "type": "string", + "description": ( + "Optional virtual path where to save the file " + "(e.g., '/documents/report.pdf'). " + "Defaults to '/{filename}'. Scoped to current session." + ), + }, + "mime_type": { + "type": "string", + "description": ( + "Optional MIME type of the file. " + "Auto-detected from filename if not provided." + ), + }, + "overwrite": { + "type": "boolean", + "description": "Whether to overwrite if file exists at path (default: false)", + }, + }, + "required": ["filename", "content_base64"], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + session_id = session.session_id + + if not user_id: + return ErrorResponse( + message="Authentication required", + session_id=session_id, + ) + + filename: str = kwargs.get("filename", "") + content_b64: str = kwargs.get("content_base64", "") + path: Optional[str] = kwargs.get("path") + mime_type: Optional[str] = kwargs.get("mime_type") + overwrite: bool = kwargs.get("overwrite", False) + + if not filename: + return ErrorResponse( + message="Please provide a filename", + session_id=session_id, + ) + + if not content_b64: + return ErrorResponse( + message="Please provide content_base64", + session_id=session_id, + ) + + # Decode content + try: + content = base64.b64decode(content_b64) + except Exception: + return ErrorResponse( + message="Invalid base64-encoded content", + session_id=session_id, + ) + + # Check size + max_file_size = Config().max_file_size_mb * 1024 * 1024 + if len(content) > max_file_size: + return ErrorResponse( + message=f"File too large. Maximum size is {Config().max_file_size_mb}MB", + session_id=session_id, + ) + + try: + # Virus scan + await scan_content_safe(content, filename=filename) + + workspace = await get_or_create_workspace(user_id) + # Pass session_id for session-scoped file access + manager = WorkspaceManager(user_id, workspace.id, session_id) + + file_record = await manager.write_file( + content=content, + filename=filename, + path=path, + mime_type=mime_type, + overwrite=overwrite, + ) + + return WorkspaceWriteResponse( + file_id=file_record.id, + name=file_record.name, + path=file_record.path, + size_bytes=file_record.sizeBytes, + message=f"Successfully wrote file: {file_record.name}", + session_id=session_id, + ) + + except ValueError as e: + return ErrorResponse( + message=str(e), + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error writing workspace file: {e}", exc_info=True) + return ErrorResponse( + message=f"Failed to write workspace file: {str(e)}", + error=str(e), + session_id=session_id, + ) + + +class DeleteWorkspaceFileTool(BaseTool): + """Tool for deleting files from workspace.""" + + @property + def name(self) -> str: + return "delete_workspace_file" + + @property + def description(self) -> str: + return ( + "Delete a file from the user's workspace. " + "Specify either file_id or path to identify the file. " + "Paths are scoped to the current session by default. " + "Use /sessions//... for cross-session access." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "file_id": { + "type": "string", + "description": "The file's unique ID (from list_workspace_files)", + }, + "path": { + "type": "string", + "description": ( + "The virtual file path (e.g., '/documents/report.pdf'). " + "Scoped to current session by default." + ), + }, + }, + "required": [], # At least one must be provided + } + + @property + def requires_auth(self) -> bool: + return True + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + session_id = session.session_id + + if not user_id: + return ErrorResponse( + message="Authentication required", + session_id=session_id, + ) + + file_id: Optional[str] = kwargs.get("file_id") + path: Optional[str] = kwargs.get("path") + + if not file_id and not path: + return ErrorResponse( + message="Please provide either file_id or path", + session_id=session_id, + ) + + try: + workspace = await get_or_create_workspace(user_id) + # Pass session_id for session-scoped file access + manager = WorkspaceManager(user_id, workspace.id, session_id) + + # Determine the file_id to delete + target_file_id: str + if file_id: + target_file_id = file_id + else: + # path is guaranteed to be non-None here due to the check above + assert path is not None + file_info = await manager.get_file_info_by_path(path) + if file_info is None: + return ErrorResponse( + message=f"File not found at path: {path}", + session_id=session_id, + ) + target_file_id = file_info.id + + success = await manager.delete_file(target_file_id) + + if not success: + return ErrorResponse( + message=f"File not found: {target_file_id}", + session_id=session_id, + ) + + return WorkspaceDeleteResponse( + file_id=target_file_id, + success=True, + message="File deleted successfully", + session_id=session_id, + ) + + except Exception as e: + logger.error(f"Error deleting workspace file: {e}", exc_info=True) + return ErrorResponse( + message=f"Failed to delete workspace file: {str(e)}", + error=str(e), + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tracking.py b/autogpt_platform/backend/backend/api/features/chat/tracking.py new file mode 100644 index 0000000000..b2c0fd032f --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tracking.py @@ -0,0 +1,250 @@ +"""PostHog analytics tracking for the chat system.""" + +import atexit +import logging +from typing import Any + +from posthog import Posthog + +from backend.util.settings import Settings + +logger = logging.getLogger(__name__) +settings = Settings() + +# PostHog client instance (lazily initialized) +_posthog_client: Posthog | None = None + + +def _shutdown_posthog() -> None: + """Flush and shutdown PostHog client on process exit.""" + if _posthog_client is not None: + _posthog_client.flush() + _posthog_client.shutdown() + + +atexit.register(_shutdown_posthog) + + +def _get_posthog_client() -> Posthog | None: + """Get or create the PostHog client instance.""" + global _posthog_client + if _posthog_client is not None: + return _posthog_client + + if not settings.secrets.posthog_api_key: + logger.debug("PostHog API key not configured, analytics disabled") + return None + + _posthog_client = Posthog( + settings.secrets.posthog_api_key, + host=settings.secrets.posthog_host, + ) + logger.info( + f"PostHog client initialized with host: {settings.secrets.posthog_host}" + ) + return _posthog_client + + +def _get_base_properties() -> dict[str, Any]: + """Get base properties included in all events.""" + return { + "environment": settings.config.app_env.value, + "source": "chat_copilot", + } + + +def track_user_message( + user_id: str | None, + session_id: str, + message_length: int, +) -> None: + """Track when a user sends a message in chat. + + Args: + user_id: The user's ID (or None for anonymous) + session_id: The chat session ID + message_length: Length of the user's message + """ + client = _get_posthog_client() + if not client: + return + + try: + properties = { + **_get_base_properties(), + "session_id": session_id, + "message_length": message_length, + } + client.capture( + distinct_id=user_id or f"anonymous_{session_id}", + event="copilot_message_sent", + properties=properties, + ) + except Exception as e: + logger.warning(f"Failed to track user message: {e}") + + +def track_tool_called( + user_id: str | None, + session_id: str, + tool_name: str, + tool_call_id: str, +) -> None: + """Track when a tool is called in chat. + + Args: + user_id: The user's ID (or None for anonymous) + session_id: The chat session ID + tool_name: Name of the tool being called + tool_call_id: Unique ID of the tool call + """ + client = _get_posthog_client() + if not client: + logger.info("PostHog client not available for tool tracking") + return + + try: + properties = { + **_get_base_properties(), + "session_id": session_id, + "tool_name": tool_name, + "tool_call_id": tool_call_id, + } + distinct_id = user_id or f"anonymous_{session_id}" + logger.info( + f"Sending copilot_tool_called event to PostHog: distinct_id={distinct_id}, " + f"tool_name={tool_name}" + ) + client.capture( + distinct_id=distinct_id, + event="copilot_tool_called", + properties=properties, + ) + except Exception as e: + logger.warning(f"Failed to track tool call: {e}") + + +def track_agent_run_success( + user_id: str, + session_id: str, + graph_id: str, + graph_name: str, + execution_id: str, + library_agent_id: str, +) -> None: + """Track when an agent is successfully run. + + Args: + user_id: The user's ID + session_id: The chat session ID + graph_id: ID of the agent graph + graph_name: Name of the agent + execution_id: ID of the execution + library_agent_id: ID of the library agent + """ + client = _get_posthog_client() + if not client: + return + + try: + properties = { + **_get_base_properties(), + "session_id": session_id, + "graph_id": graph_id, + "graph_name": graph_name, + "execution_id": execution_id, + "library_agent_id": library_agent_id, + } + client.capture( + distinct_id=user_id, + event="copilot_agent_run_success", + properties=properties, + ) + except Exception as e: + logger.warning(f"Failed to track agent run: {e}") + + +def track_agent_scheduled( + user_id: str, + session_id: str, + graph_id: str, + graph_name: str, + schedule_id: str, + schedule_name: str, + cron: str, + library_agent_id: str, +) -> None: + """Track when an agent is successfully scheduled. + + Args: + user_id: The user's ID + session_id: The chat session ID + graph_id: ID of the agent graph + graph_name: Name of the agent + schedule_id: ID of the schedule + schedule_name: Name of the schedule + cron: Cron expression for the schedule + library_agent_id: ID of the library agent + """ + client = _get_posthog_client() + if not client: + return + + try: + properties = { + **_get_base_properties(), + "session_id": session_id, + "graph_id": graph_id, + "graph_name": graph_name, + "schedule_id": schedule_id, + "schedule_name": schedule_name, + "cron": cron, + "library_agent_id": library_agent_id, + } + client.capture( + distinct_id=user_id, + event="copilot_agent_scheduled", + properties=properties, + ) + except Exception as e: + logger.warning(f"Failed to track agent schedule: {e}") + + +def track_trigger_setup( + user_id: str, + session_id: str, + graph_id: str, + graph_name: str, + trigger_type: str, + library_agent_id: str, +) -> None: + """Track when a trigger is set up for an agent. + + Args: + user_id: The user's ID + session_id: The chat session ID + graph_id: ID of the agent graph + graph_name: Name of the agent + trigger_type: Type of trigger (e.g., 'webhook') + library_agent_id: ID of the library agent + """ + client = _get_posthog_client() + if not client: + return + + try: + properties = { + **_get_base_properties(), + "session_id": session_id, + "graph_id": graph_id, + "graph_name": graph_name, + "trigger_type": trigger_type, + "library_agent_id": library_agent_id, + } + client.capture( + distinct_id=user_id, + event="copilot_trigger_setup", + properties=properties, + ) + except Exception as e: + logger.warning(f"Failed to track trigger setup: {e}") diff --git a/autogpt_platform/backend/backend/api/features/executions/review/model.py b/autogpt_platform/backend/backend/api/features/executions/review/model.py index 74f72fe1ff..bad8b8d304 100644 --- a/autogpt_platform/backend/backend/api/features/executions/review/model.py +++ b/autogpt_platform/backend/backend/api/features/executions/review/model.py @@ -23,6 +23,7 @@ class PendingHumanReviewModel(BaseModel): id: Unique identifier for the review record user_id: ID of the user who must perform the review node_exec_id: ID of the node execution that created this review + node_id: ID of the node definition (for grouping reviews from same node) graph_exec_id: ID of the graph execution containing the node graph_id: ID of the graph template being executed graph_version: Version number of the graph template @@ -37,6 +38,10 @@ class PendingHumanReviewModel(BaseModel): """ node_exec_id: str = Field(description="Node execution ID (primary key)") + node_id: str = Field( + description="Node definition ID (for grouping)", + default="", # Temporary default for test compatibility + ) user_id: str = Field(description="User ID associated with the review") graph_exec_id: str = Field(description="Graph execution ID") graph_id: str = Field(description="Graph ID") @@ -66,7 +71,9 @@ class PendingHumanReviewModel(BaseModel): ) @classmethod - def from_db(cls, review: "PendingHumanReview") -> "PendingHumanReviewModel": + def from_db( + cls, review: "PendingHumanReview", node_id: str + ) -> "PendingHumanReviewModel": """ Convert a database model to a response model. @@ -74,9 +81,14 @@ class PendingHumanReviewModel(BaseModel): payload, instructions, and editable flag. Handles invalid data gracefully by using safe defaults. + + Args: + review: Database review object + node_id: Node definition ID (fetched from NodeExecution) """ return cls( node_exec_id=review.nodeExecId, + node_id=node_id, user_id=review.userId, graph_exec_id=review.graphExecId, graph_id=review.graphId, @@ -107,6 +119,13 @@ class ReviewItem(BaseModel): reviewed_data: SafeJsonData | None = Field( None, description="Optional edited data (ignored if approved=False)" ) + auto_approve_future: bool = Field( + default=False, + description=( + "If true and this review is approved, future executions of this same " + "block (node) will be automatically approved. This only affects approved reviews." + ), + ) @field_validator("reviewed_data") @classmethod @@ -174,6 +193,9 @@ class ReviewRequest(BaseModel): This request must include ALL pending reviews for a graph execution. Each review will be either approved (with optional data modifications) or rejected (data ignored). The execution will resume only after ALL reviews are processed. + + Each review item can individually specify whether to auto-approve future executions + of the same block via the `auto_approve_future` field on ReviewItem. """ reviews: List[ReviewItem] = Field( diff --git a/autogpt_platform/backend/backend/api/features/executions/review/review_routes_test.py b/autogpt_platform/backend/backend/api/features/executions/review/review_routes_test.py index c4eba0befc..c8bbfe4bb0 100644 --- a/autogpt_platform/backend/backend/api/features/executions/review/review_routes_test.py +++ b/autogpt_platform/backend/backend/api/features/executions/review/review_routes_test.py @@ -1,35 +1,43 @@ import datetime +from typing import AsyncGenerator -import fastapi -import fastapi.testclient +import httpx import pytest +import pytest_asyncio import pytest_mock from prisma.enums import ReviewStatus from pytest_snapshot.plugin import Snapshot -from backend.api.rest_api import handle_internal_http_error +from backend.api.rest_api import app +from backend.data.execution import ( + ExecutionContext, + ExecutionStatus, + NodeExecutionResult, +) +from backend.data.graph import GraphSettings from .model import PendingHumanReviewModel -from .routes import router # Using a fixed timestamp for reproducible tests FIXED_NOW = datetime.datetime(2023, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc) -app = fastapi.FastAPI() -app.include_router(router, prefix="/api/review") -app.add_exception_handler(ValueError, handle_internal_http_error(400)) -client = fastapi.testclient.TestClient(app) - - -@pytest.fixture(autouse=True) -def setup_app_auth(mock_jwt_user): - """Setup auth overrides for all tests in this module""" +@pytest_asyncio.fixture(loop_scope="session") +async def client(server, mock_jwt_user) -> AsyncGenerator[httpx.AsyncClient, None]: + """Create async HTTP client with auth overrides""" from autogpt_libs.auth.jwt_utils import get_jwt_payload + # Override get_jwt_payload dependency to return our test user app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"] - yield - app.dependency_overrides.clear() + + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://test", + ) as http_client: + yield http_client + + # Clean up overrides + app.dependency_overrides.pop(get_jwt_payload, None) @pytest.fixture @@ -37,6 +45,7 @@ def sample_pending_review(test_user_id: str) -> PendingHumanReviewModel: """Create a sample pending review for testing""" return PendingHumanReviewModel( node_exec_id="test_node_123", + node_id="test_node_def_456", user_id=test_user_id, graph_exec_id="test_graph_exec_456", graph_id="test_graph_789", @@ -54,7 +63,9 @@ def sample_pending_review(test_user_id: str) -> PendingHumanReviewModel: ) -def test_get_pending_reviews_empty( +@pytest.mark.asyncio(loop_scope="session") +async def test_get_pending_reviews_empty( + client: httpx.AsyncClient, mocker: pytest_mock.MockerFixture, snapshot: Snapshot, test_user_id: str, @@ -65,14 +76,16 @@ def test_get_pending_reviews_empty( ) mock_get_reviews.return_value = [] - response = client.get("/api/review/pending") + response = await client.get("/api/review/pending") assert response.status_code == 200 assert response.json() == [] mock_get_reviews.assert_called_once_with(test_user_id, 1, 25) -def test_get_pending_reviews_with_data( +@pytest.mark.asyncio(loop_scope="session") +async def test_get_pending_reviews_with_data( + client: httpx.AsyncClient, mocker: pytest_mock.MockerFixture, sample_pending_review: PendingHumanReviewModel, snapshot: Snapshot, @@ -84,7 +97,7 @@ def test_get_pending_reviews_with_data( ) mock_get_reviews.return_value = [sample_pending_review] - response = client.get("/api/review/pending?page=2&page_size=10") + response = await client.get("/api/review/pending?page=2&page_size=10") assert response.status_code == 200 data = response.json() @@ -94,7 +107,9 @@ def test_get_pending_reviews_with_data( mock_get_reviews.assert_called_once_with(test_user_id, 2, 10) -def test_get_pending_reviews_for_execution_success( +@pytest.mark.asyncio(loop_scope="session") +async def test_get_pending_reviews_for_execution_success( + client: httpx.AsyncClient, mocker: pytest_mock.MockerFixture, sample_pending_review: PendingHumanReviewModel, snapshot: Snapshot, @@ -114,7 +129,7 @@ def test_get_pending_reviews_for_execution_success( ) mock_get_reviews.return_value = [sample_pending_review] - response = client.get("/api/review/execution/test_graph_exec_456") + response = await client.get("/api/review/execution/test_graph_exec_456") assert response.status_code == 200 data = response.json() @@ -122,7 +137,9 @@ def test_get_pending_reviews_for_execution_success( assert data[0]["graph_exec_id"] == "test_graph_exec_456" -def test_get_pending_reviews_for_execution_not_available( +@pytest.mark.asyncio(loop_scope="session") +async def test_get_pending_reviews_for_execution_not_available( + client: httpx.AsyncClient, mocker: pytest_mock.MockerFixture, ) -> None: """Test access denied when user doesn't own the execution""" @@ -131,13 +148,15 @@ def test_get_pending_reviews_for_execution_not_available( ) mock_get_graph_execution.return_value = None - response = client.get("/api/review/execution/test_graph_exec_456") + response = await client.get("/api/review/execution/test_graph_exec_456") assert response.status_code == 404 assert "not found" in response.json()["detail"] -def test_process_review_action_approve_success( +@pytest.mark.asyncio(loop_scope="session") +async def test_process_review_action_approve_success( + client: httpx.AsyncClient, mocker: pytest_mock.MockerFixture, sample_pending_review: PendingHumanReviewModel, test_user_id: str, @@ -145,6 +164,12 @@ def test_process_review_action_approve_success( """Test successful review approval""" # Mock the route functions + # Mock get_reviews_by_node_exec_ids (called to find the graph_exec_id) + mock_get_reviews_for_user = mocker.patch( + "backend.api.features.executions.review.routes.get_reviews_by_node_exec_ids" + ) + mock_get_reviews_for_user.return_value = {"test_node_123": sample_pending_review} + mock_get_reviews_for_execution = mocker.patch( "backend.api.features.executions.review.routes.get_pending_reviews_for_execution" ) @@ -173,6 +198,14 @@ def test_process_review_action_approve_success( ) mock_process_all_reviews.return_value = {"test_node_123": approved_review} + # Mock get_graph_execution_meta to return execution in REVIEW status + mock_get_graph_exec = mocker.patch( + "backend.api.features.executions.review.routes.get_graph_execution_meta" + ) + mock_graph_exec_meta = mocker.Mock() + mock_graph_exec_meta.status = ExecutionStatus.REVIEW + mock_get_graph_exec.return_value = mock_graph_exec_meta + mock_has_pending = mocker.patch( "backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec" ) @@ -191,7 +224,7 @@ def test_process_review_action_approve_success( ] } - response = client.post("/api/review/action", json=request_data) + response = await client.post("/api/review/action", json=request_data) assert response.status_code == 200 data = response.json() @@ -201,7 +234,9 @@ def test_process_review_action_approve_success( assert data["error"] is None -def test_process_review_action_reject_success( +@pytest.mark.asyncio(loop_scope="session") +async def test_process_review_action_reject_success( + client: httpx.AsyncClient, mocker: pytest_mock.MockerFixture, sample_pending_review: PendingHumanReviewModel, test_user_id: str, @@ -209,6 +244,20 @@ def test_process_review_action_reject_success( """Test successful review rejection""" # Mock the route functions + # Mock get_reviews_by_node_exec_ids (called to find the graph_exec_id) + mock_get_reviews_for_user = mocker.patch( + "backend.api.features.executions.review.routes.get_reviews_by_node_exec_ids" + ) + mock_get_reviews_for_user.return_value = {"test_node_123": sample_pending_review} + + # Mock get_graph_execution_meta to return execution in REVIEW status + mock_get_graph_exec = mocker.patch( + "backend.api.features.executions.review.routes.get_graph_execution_meta" + ) + mock_graph_exec_meta = mocker.Mock() + mock_graph_exec_meta.status = ExecutionStatus.REVIEW + mock_get_graph_exec.return_value = mock_graph_exec_meta + mock_get_reviews_for_execution = mocker.patch( "backend.api.features.executions.review.routes.get_pending_reviews_for_execution" ) @@ -251,7 +300,7 @@ def test_process_review_action_reject_success( ] } - response = client.post("/api/review/action", json=request_data) + response = await client.post("/api/review/action", json=request_data) assert response.status_code == 200 data = response.json() @@ -261,7 +310,9 @@ def test_process_review_action_reject_success( assert data["error"] is None -def test_process_review_action_mixed_success( +@pytest.mark.asyncio(loop_scope="session") +async def test_process_review_action_mixed_success( + client: httpx.AsyncClient, mocker: pytest_mock.MockerFixture, sample_pending_review: PendingHumanReviewModel, test_user_id: str, @@ -288,6 +339,15 @@ def test_process_review_action_mixed_success( # Mock the route functions + # Mock get_reviews_by_node_exec_ids (called to find the graph_exec_id) + mock_get_reviews_for_user = mocker.patch( + "backend.api.features.executions.review.routes.get_reviews_by_node_exec_ids" + ) + mock_get_reviews_for_user.return_value = { + "test_node_123": sample_pending_review, + "test_node_456": second_review, + } + mock_get_reviews_for_execution = mocker.patch( "backend.api.features.executions.review.routes.get_pending_reviews_for_execution" ) @@ -337,6 +397,14 @@ def test_process_review_action_mixed_success( "test_node_456": rejected_review, } + # Mock get_graph_execution_meta to return execution in REVIEW status + mock_get_graph_exec = mocker.patch( + "backend.api.features.executions.review.routes.get_graph_execution_meta" + ) + mock_graph_exec_meta = mocker.Mock() + mock_graph_exec_meta.status = ExecutionStatus.REVIEW + mock_get_graph_exec.return_value = mock_graph_exec_meta + mock_has_pending = mocker.patch( "backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec" ) @@ -358,7 +426,7 @@ def test_process_review_action_mixed_success( ] } - response = client.post("/api/review/action", json=request_data) + response = await client.post("/api/review/action", json=request_data) assert response.status_code == 200 data = response.json() @@ -368,14 +436,16 @@ def test_process_review_action_mixed_success( assert data["error"] is None -def test_process_review_action_empty_request( +@pytest.mark.asyncio(loop_scope="session") +async def test_process_review_action_empty_request( + client: httpx.AsyncClient, mocker: pytest_mock.MockerFixture, test_user_id: str, ) -> None: """Test error when no reviews provided""" request_data = {"reviews": []} - response = client.post("/api/review/action", json=request_data) + response = await client.post("/api/review/action", json=request_data) assert response.status_code == 422 response_data = response.json() @@ -385,11 +455,29 @@ def test_process_review_action_empty_request( assert "At least one review must be provided" in response_data["detail"][0]["msg"] -def test_process_review_action_review_not_found( +@pytest.mark.asyncio(loop_scope="session") +async def test_process_review_action_review_not_found( + client: httpx.AsyncClient, mocker: pytest_mock.MockerFixture, + sample_pending_review: PendingHumanReviewModel, test_user_id: str, ) -> None: """Test error when review is not found""" + # Mock get_reviews_by_node_exec_ids (called to find the graph_exec_id) + mock_get_reviews_for_user = mocker.patch( + "backend.api.features.executions.review.routes.get_reviews_by_node_exec_ids" + ) + # Return empty dict to simulate review not found + mock_get_reviews_for_user.return_value = {} + + # Mock get_graph_execution_meta to return execution in REVIEW status + mock_get_graph_exec = mocker.patch( + "backend.api.features.executions.review.routes.get_graph_execution_meta" + ) + mock_graph_exec_meta = mocker.Mock() + mock_graph_exec_meta.status = ExecutionStatus.REVIEW + mock_get_graph_exec.return_value = mock_graph_exec_meta + # Mock the functions that extract graph execution ID from the request mock_get_reviews_for_execution = mocker.patch( "backend.api.features.executions.review.routes.get_pending_reviews_for_execution" @@ -415,18 +503,34 @@ def test_process_review_action_review_not_found( ] } - response = client.post("/api/review/action", json=request_data) + response = await client.post("/api/review/action", json=request_data) - assert response.status_code == 400 - assert "Reviews not found" in response.json()["detail"] + assert response.status_code == 404 + assert "Review(s) not found" in response.json()["detail"] -def test_process_review_action_partial_failure( +@pytest.mark.asyncio(loop_scope="session") +async def test_process_review_action_partial_failure( + client: httpx.AsyncClient, mocker: pytest_mock.MockerFixture, sample_pending_review: PendingHumanReviewModel, test_user_id: str, ) -> None: """Test handling of partial failures in review processing""" + # Mock get_reviews_by_node_exec_ids (called to find the graph_exec_id) + mock_get_reviews_for_user = mocker.patch( + "backend.api.features.executions.review.routes.get_reviews_by_node_exec_ids" + ) + mock_get_reviews_for_user.return_value = {"test_node_123": sample_pending_review} + + # Mock get_graph_execution_meta to return execution in REVIEW status + mock_get_graph_exec = mocker.patch( + "backend.api.features.executions.review.routes.get_graph_execution_meta" + ) + mock_graph_exec_meta = mocker.Mock() + mock_graph_exec_meta.status = ExecutionStatus.REVIEW + mock_get_graph_exec.return_value = mock_graph_exec_meta + # Mock the route functions mock_get_reviews_for_execution = mocker.patch( "backend.api.features.executions.review.routes.get_pending_reviews_for_execution" @@ -449,31 +553,34 @@ def test_process_review_action_partial_failure( ] } - response = client.post("/api/review/action", json=request_data) + response = await client.post("/api/review/action", json=request_data) assert response.status_code == 400 assert "Some reviews failed validation" in response.json()["detail"] -def test_process_review_action_invalid_node_exec_id( +@pytest.mark.asyncio(loop_scope="session") +async def test_process_review_action_invalid_node_exec_id( + client: httpx.AsyncClient, mocker: pytest_mock.MockerFixture, sample_pending_review: PendingHumanReviewModel, test_user_id: str, ) -> None: """Test failure when trying to process review with invalid node execution ID""" - # Mock the route functions - mock_get_reviews_for_execution = mocker.patch( - "backend.api.features.executions.review.routes.get_pending_reviews_for_execution" + # Mock get_reviews_by_node_exec_ids (called to find the graph_exec_id) + mock_get_reviews_for_user = mocker.patch( + "backend.api.features.executions.review.routes.get_reviews_by_node_exec_ids" ) - mock_get_reviews_for_execution.return_value = [sample_pending_review] + # Return empty dict to simulate review not found + mock_get_reviews_for_user.return_value = {} - # Mock validation failure - this should return 400, not 500 - mock_process_all_reviews = mocker.patch( - "backend.api.features.executions.review.routes.process_all_reviews_for_execution" - ) - mock_process_all_reviews.side_effect = ValueError( - "Invalid node execution ID format" + # Mock get_graph_execution_meta to return execution in REVIEW status + mock_get_graph_exec = mocker.patch( + "backend.api.features.executions.review.routes.get_graph_execution_meta" ) + mock_graph_exec_meta = mocker.Mock() + mock_graph_exec_meta.status = ExecutionStatus.REVIEW + mock_get_graph_exec.return_value = mock_graph_exec_meta request_data = { "reviews": [ @@ -485,8 +592,638 @@ def test_process_review_action_invalid_node_exec_id( ] } - response = client.post("/api/review/action", json=request_data) + response = await client.post("/api/review/action", json=request_data) - # Should be a 400 Bad Request, not 500 Internal Server Error - assert response.status_code == 400 - assert "Invalid node execution ID format" in response.json()["detail"] + # Returns 404 when review is not found + assert response.status_code == 404 + assert "Review(s) not found" in response.json()["detail"] + + +@pytest.mark.asyncio(loop_scope="session") +async def test_process_review_action_auto_approve_creates_auto_approval_records( + client: httpx.AsyncClient, + mocker: pytest_mock.MockerFixture, + sample_pending_review: PendingHumanReviewModel, + test_user_id: str, +) -> None: + """Test that auto_approve_future_actions flag creates auto-approval records""" + # Mock get_reviews_by_node_exec_ids (called to find the graph_exec_id) + mock_get_reviews_for_user = mocker.patch( + "backend.api.features.executions.review.routes.get_reviews_by_node_exec_ids" + ) + mock_get_reviews_for_user.return_value = {"test_node_123": sample_pending_review} + + # Mock process_all_reviews + mock_process_all_reviews = mocker.patch( + "backend.api.features.executions.review.routes.process_all_reviews_for_execution" + ) + approved_review = PendingHumanReviewModel( + node_exec_id="test_node_123", + user_id=test_user_id, + graph_exec_id="test_graph_exec_456", + graph_id="test_graph_789", + graph_version=1, + payload={"data": "test payload"}, + instructions="Please review", + editable=True, + status=ReviewStatus.APPROVED, + review_message="Approved", + was_edited=False, + processed=False, + created_at=FIXED_NOW, + updated_at=FIXED_NOW, + reviewed_at=FIXED_NOW, + ) + mock_process_all_reviews.return_value = {"test_node_123": approved_review} + + # Mock get_node_executions to return node_id mapping + mock_get_node_executions = mocker.patch( + "backend.data.execution.get_node_executions" + ) + mock_node_exec = mocker.Mock(spec=NodeExecutionResult) + mock_node_exec.node_exec_id = "test_node_123" + mock_node_exec.node_id = "test_node_def_456" + mock_get_node_executions.return_value = [mock_node_exec] + + # Mock create_auto_approval_record + mock_create_auto_approval = mocker.patch( + "backend.api.features.executions.review.routes.create_auto_approval_record" + ) + + # Mock get_graph_execution_meta to return execution in REVIEW status + mock_get_graph_exec = mocker.patch( + "backend.api.features.executions.review.routes.get_graph_execution_meta" + ) + mock_graph_exec_meta = mocker.Mock() + mock_graph_exec_meta.status = ExecutionStatus.REVIEW + mock_get_graph_exec.return_value = mock_graph_exec_meta + + # Mock has_pending_reviews_for_graph_exec + mock_has_pending = mocker.patch( + "backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec" + ) + mock_has_pending.return_value = False + + # Mock get_graph_settings to return custom settings + mock_get_settings = mocker.patch( + "backend.api.features.executions.review.routes.get_graph_settings" + ) + mock_get_settings.return_value = GraphSettings( + human_in_the_loop_safe_mode=True, + sensitive_action_safe_mode=True, + ) + + # Mock get_user_by_id to prevent database access + mock_get_user = mocker.patch( + "backend.api.features.executions.review.routes.get_user_by_id" + ) + mock_user = mocker.Mock() + mock_user.timezone = "UTC" + mock_get_user.return_value = mock_user + + # Mock add_graph_execution + mock_add_execution = mocker.patch( + "backend.api.features.executions.review.routes.add_graph_execution" + ) + + request_data = { + "reviews": [ + { + "node_exec_id": "test_node_123", + "approved": True, + "message": "Approved", + "auto_approve_future": True, + } + ], + } + + response = await client.post("/api/review/action", json=request_data) + + assert response.status_code == 200 + + # Verify process_all_reviews_for_execution was called (without auto_approve param) + mock_process_all_reviews.assert_called_once() + + # Verify create_auto_approval_record was called for the approved review + mock_create_auto_approval.assert_called_once_with( + user_id=test_user_id, + graph_exec_id="test_graph_exec_456", + graph_id="test_graph_789", + graph_version=1, + node_id="test_node_def_456", + payload={"data": "test payload"}, + ) + + # Verify get_graph_settings was called with correct parameters + mock_get_settings.assert_called_once_with( + user_id=test_user_id, graph_id="test_graph_789" + ) + + # Verify add_graph_execution was called with proper ExecutionContext + mock_add_execution.assert_called_once() + call_kwargs = mock_add_execution.call_args.kwargs + execution_context = call_kwargs["execution_context"] + + assert isinstance(execution_context, ExecutionContext) + assert execution_context.human_in_the_loop_safe_mode is True + assert execution_context.sensitive_action_safe_mode is True + + +@pytest.mark.asyncio(loop_scope="session") +async def test_process_review_action_without_auto_approve_still_loads_settings( + client: httpx.AsyncClient, + mocker: pytest_mock.MockerFixture, + sample_pending_review: PendingHumanReviewModel, + test_user_id: str, +) -> None: + """Test that execution context is created with settings even without auto-approve""" + # Mock get_reviews_by_node_exec_ids (called to find the graph_exec_id) + mock_get_reviews_for_user = mocker.patch( + "backend.api.features.executions.review.routes.get_reviews_by_node_exec_ids" + ) + mock_get_reviews_for_user.return_value = {"test_node_123": sample_pending_review} + + # Mock process_all_reviews + mock_process_all_reviews = mocker.patch( + "backend.api.features.executions.review.routes.process_all_reviews_for_execution" + ) + approved_review = PendingHumanReviewModel( + node_exec_id="test_node_123", + user_id=test_user_id, + graph_exec_id="test_graph_exec_456", + graph_id="test_graph_789", + graph_version=1, + payload={"data": "test payload"}, + instructions="Please review", + editable=True, + status=ReviewStatus.APPROVED, + review_message="Approved", + was_edited=False, + processed=False, + created_at=FIXED_NOW, + updated_at=FIXED_NOW, + reviewed_at=FIXED_NOW, + ) + mock_process_all_reviews.return_value = {"test_node_123": approved_review} + + # Mock create_auto_approval_record - should NOT be called when auto_approve is False + mock_create_auto_approval = mocker.patch( + "backend.api.features.executions.review.routes.create_auto_approval_record" + ) + + # Mock get_graph_execution_meta to return execution in REVIEW status + mock_get_graph_exec = mocker.patch( + "backend.api.features.executions.review.routes.get_graph_execution_meta" + ) + mock_graph_exec_meta = mocker.Mock() + mock_graph_exec_meta.status = ExecutionStatus.REVIEW + mock_get_graph_exec.return_value = mock_graph_exec_meta + + # Mock has_pending_reviews_for_graph_exec + mock_has_pending = mocker.patch( + "backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec" + ) + mock_has_pending.return_value = False + + # Mock get_graph_settings with sensitive_action_safe_mode enabled + mock_get_settings = mocker.patch( + "backend.api.features.executions.review.routes.get_graph_settings" + ) + mock_get_settings.return_value = GraphSettings( + human_in_the_loop_safe_mode=False, + sensitive_action_safe_mode=True, + ) + + # Mock get_user_by_id to prevent database access + mock_get_user = mocker.patch( + "backend.api.features.executions.review.routes.get_user_by_id" + ) + mock_user = mocker.Mock() + mock_user.timezone = "UTC" + mock_get_user.return_value = mock_user + + # Mock add_graph_execution + mock_add_execution = mocker.patch( + "backend.api.features.executions.review.routes.add_graph_execution" + ) + + # Request WITHOUT auto_approve_future (defaults to False) + request_data = { + "reviews": [ + { + "node_exec_id": "test_node_123", + "approved": True, + "message": "Approved", + # auto_approve_future defaults to False + } + ], + } + + response = await client.post("/api/review/action", json=request_data) + + assert response.status_code == 200 + + # Verify process_all_reviews_for_execution was called + mock_process_all_reviews.assert_called_once() + + # Verify create_auto_approval_record was NOT called (auto_approve_future=False) + mock_create_auto_approval.assert_not_called() + + # Verify settings were loaded + mock_get_settings.assert_called_once() + + # Verify ExecutionContext has proper settings + mock_add_execution.assert_called_once() + call_kwargs = mock_add_execution.call_args.kwargs + execution_context = call_kwargs["execution_context"] + + assert isinstance(execution_context, ExecutionContext) + assert execution_context.human_in_the_loop_safe_mode is False + assert execution_context.sensitive_action_safe_mode is True + + +@pytest.mark.asyncio(loop_scope="session") +async def test_process_review_action_auto_approve_only_applies_to_approved_reviews( + client: httpx.AsyncClient, + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + """Test that auto_approve record is created only for approved reviews""" + # Create two reviews - one approved, one rejected + approved_review = PendingHumanReviewModel( + node_exec_id="node_exec_approved", + user_id=test_user_id, + graph_exec_id="test_graph_exec_456", + graph_id="test_graph_789", + graph_version=1, + payload={"data": "approved"}, + instructions="Review", + editable=True, + status=ReviewStatus.APPROVED, + review_message=None, + was_edited=False, + processed=False, + created_at=FIXED_NOW, + updated_at=FIXED_NOW, + reviewed_at=FIXED_NOW, + ) + rejected_review = PendingHumanReviewModel( + node_exec_id="node_exec_rejected", + user_id=test_user_id, + graph_exec_id="test_graph_exec_456", + graph_id="test_graph_789", + graph_version=1, + payload={"data": "rejected"}, + instructions="Review", + editable=True, + status=ReviewStatus.REJECTED, + review_message="Rejected", + was_edited=False, + processed=False, + created_at=FIXED_NOW, + updated_at=FIXED_NOW, + reviewed_at=FIXED_NOW, + ) + + # Mock get_reviews_by_node_exec_ids (called to find the graph_exec_id) + mock_get_reviews_for_user = mocker.patch( + "backend.api.features.executions.review.routes.get_reviews_by_node_exec_ids" + ) + # Need to return both reviews in WAITING state (before processing) + approved_review_waiting = PendingHumanReviewModel( + node_exec_id="node_exec_approved", + user_id=test_user_id, + graph_exec_id="test_graph_exec_456", + graph_id="test_graph_789", + graph_version=1, + payload={"data": "approved"}, + instructions="Review", + editable=True, + status=ReviewStatus.WAITING, + review_message=None, + was_edited=False, + processed=False, + created_at=FIXED_NOW, + ) + rejected_review_waiting = PendingHumanReviewModel( + node_exec_id="node_exec_rejected", + user_id=test_user_id, + graph_exec_id="test_graph_exec_456", + graph_id="test_graph_789", + graph_version=1, + payload={"data": "rejected"}, + instructions="Review", + editable=True, + status=ReviewStatus.WAITING, + review_message=None, + was_edited=False, + processed=False, + created_at=FIXED_NOW, + ) + mock_get_reviews_for_user.return_value = { + "node_exec_approved": approved_review_waiting, + "node_exec_rejected": rejected_review_waiting, + } + + # Mock process_all_reviews + mock_process_all_reviews = mocker.patch( + "backend.api.features.executions.review.routes.process_all_reviews_for_execution" + ) + mock_process_all_reviews.return_value = { + "node_exec_approved": approved_review, + "node_exec_rejected": rejected_review, + } + + # Mock get_node_executions to return node_id mapping + mock_get_node_executions = mocker.patch( + "backend.data.execution.get_node_executions" + ) + mock_node_exec = mocker.Mock(spec=NodeExecutionResult) + mock_node_exec.node_exec_id = "node_exec_approved" + mock_node_exec.node_id = "test_node_def_approved" + mock_get_node_executions.return_value = [mock_node_exec] + + # Mock create_auto_approval_record + mock_create_auto_approval = mocker.patch( + "backend.api.features.executions.review.routes.create_auto_approval_record" + ) + + # Mock get_graph_execution_meta to return execution in REVIEW status + mock_get_graph_exec = mocker.patch( + "backend.api.features.executions.review.routes.get_graph_execution_meta" + ) + mock_graph_exec_meta = mocker.Mock() + mock_graph_exec_meta.status = ExecutionStatus.REVIEW + mock_get_graph_exec.return_value = mock_graph_exec_meta + + # Mock has_pending_reviews_for_graph_exec + mock_has_pending = mocker.patch( + "backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec" + ) + mock_has_pending.return_value = False + + # Mock get_graph_settings + mock_get_settings = mocker.patch( + "backend.api.features.executions.review.routes.get_graph_settings" + ) + mock_get_settings.return_value = GraphSettings() + + # Mock get_user_by_id to prevent database access + mock_get_user = mocker.patch( + "backend.api.features.executions.review.routes.get_user_by_id" + ) + mock_user = mocker.Mock() + mock_user.timezone = "UTC" + mock_get_user.return_value = mock_user + + # Mock add_graph_execution + mock_add_execution = mocker.patch( + "backend.api.features.executions.review.routes.add_graph_execution" + ) + + request_data = { + "reviews": [ + { + "node_exec_id": "node_exec_approved", + "approved": True, + "auto_approve_future": True, + }, + { + "node_exec_id": "node_exec_rejected", + "approved": False, + "auto_approve_future": True, # Should be ignored since rejected + }, + ], + } + + response = await client.post("/api/review/action", json=request_data) + + assert response.status_code == 200 + + # Verify process_all_reviews_for_execution was called + mock_process_all_reviews.assert_called_once() + + # Verify create_auto_approval_record was called ONLY for the approved review + # (not for the rejected one) + mock_create_auto_approval.assert_called_once_with( + user_id=test_user_id, + graph_exec_id="test_graph_exec_456", + graph_id="test_graph_789", + graph_version=1, + node_id="test_node_def_approved", + payload={"data": "approved"}, + ) + + # Verify get_node_executions was called to batch-fetch node data + mock_get_node_executions.assert_called_once() + + # Verify ExecutionContext was created (auto-approval is now DB-based) + call_kwargs = mock_add_execution.call_args.kwargs + execution_context = call_kwargs["execution_context"] + assert isinstance(execution_context, ExecutionContext) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_process_review_action_per_review_auto_approve_granularity( + client: httpx.AsyncClient, + mocker: pytest_mock.MockerFixture, + sample_pending_review: PendingHumanReviewModel, + test_user_id: str, +) -> None: + """Test that auto-approval can be set per-review (granular control)""" + # Mock get_reviews_by_node_exec_ids - return different reviews based on node_exec_id + mock_get_reviews_for_user = mocker.patch( + "backend.api.features.executions.review.routes.get_reviews_by_node_exec_ids" + ) + + # Create a mapping of node_exec_id to review + review_map = { + "node_1_auto": PendingHumanReviewModel( + node_exec_id="node_1_auto", + user_id=test_user_id, + graph_exec_id="test_graph_exec", + graph_id="test_graph", + graph_version=1, + payload={"data": "node1"}, + instructions="Review 1", + editable=True, + status=ReviewStatus.WAITING, + review_message=None, + was_edited=False, + processed=False, + created_at=FIXED_NOW, + ), + "node_2_manual": PendingHumanReviewModel( + node_exec_id="node_2_manual", + user_id=test_user_id, + graph_exec_id="test_graph_exec", + graph_id="test_graph", + graph_version=1, + payload={"data": "node2"}, + instructions="Review 2", + editable=True, + status=ReviewStatus.WAITING, + review_message=None, + was_edited=False, + processed=False, + created_at=FIXED_NOW, + ), + "node_3_auto": PendingHumanReviewModel( + node_exec_id="node_3_auto", + user_id=test_user_id, + graph_exec_id="test_graph_exec", + graph_id="test_graph", + graph_version=1, + payload={"data": "node3"}, + instructions="Review 3", + editable=True, + status=ReviewStatus.WAITING, + review_message=None, + was_edited=False, + processed=False, + created_at=FIXED_NOW, + ), + } + + # Return the review map dict (batch function returns all requested reviews) + mock_get_reviews_for_user.return_value = review_map + + # Mock process_all_reviews - return 3 approved reviews + mock_process_all_reviews = mocker.patch( + "backend.api.features.executions.review.routes.process_all_reviews_for_execution" + ) + mock_process_all_reviews.return_value = { + "node_1_auto": PendingHumanReviewModel( + node_exec_id="node_1_auto", + user_id=test_user_id, + graph_exec_id="test_graph_exec", + graph_id="test_graph", + graph_version=1, + payload={"data": "node1"}, + instructions="Review 1", + editable=True, + status=ReviewStatus.APPROVED, + review_message=None, + was_edited=False, + processed=False, + created_at=FIXED_NOW, + updated_at=FIXED_NOW, + reviewed_at=FIXED_NOW, + ), + "node_2_manual": PendingHumanReviewModel( + node_exec_id="node_2_manual", + user_id=test_user_id, + graph_exec_id="test_graph_exec", + graph_id="test_graph", + graph_version=1, + payload={"data": "node2"}, + instructions="Review 2", + editable=True, + status=ReviewStatus.APPROVED, + review_message=None, + was_edited=False, + processed=False, + created_at=FIXED_NOW, + updated_at=FIXED_NOW, + reviewed_at=FIXED_NOW, + ), + "node_3_auto": PendingHumanReviewModel( + node_exec_id="node_3_auto", + user_id=test_user_id, + graph_exec_id="test_graph_exec", + graph_id="test_graph", + graph_version=1, + payload={"data": "node3"}, + instructions="Review 3", + editable=True, + status=ReviewStatus.APPROVED, + review_message=None, + was_edited=False, + processed=False, + created_at=FIXED_NOW, + updated_at=FIXED_NOW, + reviewed_at=FIXED_NOW, + ), + } + + # Mock get_node_executions to return batch node data + mock_get_node_executions = mocker.patch( + "backend.data.execution.get_node_executions" + ) + # Create mock node executions for each review + mock_node_execs = [] + for node_exec_id in ["node_1_auto", "node_2_manual", "node_3_auto"]: + mock_node = mocker.Mock(spec=NodeExecutionResult) + mock_node.node_exec_id = node_exec_id + mock_node.node_id = f"node_def_{node_exec_id}" + mock_node_execs.append(mock_node) + mock_get_node_executions.return_value = mock_node_execs + + # Mock create_auto_approval_record + mock_create_auto_approval = mocker.patch( + "backend.api.features.executions.review.routes.create_auto_approval_record" + ) + + # Mock get_graph_execution_meta + mock_get_graph_exec = mocker.patch( + "backend.api.features.executions.review.routes.get_graph_execution_meta" + ) + mock_graph_exec_meta = mocker.Mock() + mock_graph_exec_meta.status = ExecutionStatus.REVIEW + mock_get_graph_exec.return_value = mock_graph_exec_meta + + # Mock has_pending_reviews_for_graph_exec + mock_has_pending = mocker.patch( + "backend.api.features.executions.review.routes.has_pending_reviews_for_graph_exec" + ) + mock_has_pending.return_value = False + + # Mock settings and execution + mock_get_settings = mocker.patch( + "backend.api.features.executions.review.routes.get_graph_settings" + ) + mock_get_settings.return_value = GraphSettings( + human_in_the_loop_safe_mode=False, sensitive_action_safe_mode=False + ) + + mocker.patch("backend.api.features.executions.review.routes.add_graph_execution") + mocker.patch("backend.api.features.executions.review.routes.get_user_by_id") + + # Request with granular auto-approval: + # - node_1_auto: auto_approve_future=True + # - node_2_manual: auto_approve_future=False (explicit) + # - node_3_auto: auto_approve_future=True + request_data = { + "reviews": [ + { + "node_exec_id": "node_1_auto", + "approved": True, + "auto_approve_future": True, + }, + { + "node_exec_id": "node_2_manual", + "approved": True, + "auto_approve_future": False, # Don't auto-approve this one + }, + { + "node_exec_id": "node_3_auto", + "approved": True, + "auto_approve_future": True, + }, + ], + } + + response = await client.post("/api/review/action", json=request_data) + + assert response.status_code == 200 + + # Verify create_auto_approval_record was called ONLY for reviews with auto_approve_future=True + assert mock_create_auto_approval.call_count == 2 + + # Check that it was called for node_1 and node_3, but NOT node_2 + call_args_list = [call.kwargs for call in mock_create_auto_approval.call_args_list] + node_ids_with_auto_approval = [args["node_id"] for args in call_args_list] + + assert "node_def_node_1_auto" in node_ids_with_auto_approval + assert "node_def_node_3_auto" in node_ids_with_auto_approval + assert "node_def_node_2_manual" not in node_ids_with_auto_approval diff --git a/autogpt_platform/backend/backend/api/features/executions/review/routes.py b/autogpt_platform/backend/backend/api/features/executions/review/routes.py index 88646046da..539c7fd87b 100644 --- a/autogpt_platform/backend/backend/api/features/executions/review/routes.py +++ b/autogpt_platform/backend/backend/api/features/executions/review/routes.py @@ -1,17 +1,27 @@ +import asyncio import logging -from typing import List +from typing import Any, List import autogpt_libs.auth as autogpt_auth_lib from fastapi import APIRouter, HTTPException, Query, Security, status from prisma.enums import ReviewStatus -from backend.data.execution import get_graph_execution_meta +from backend.data.execution import ( + ExecutionContext, + ExecutionStatus, + get_graph_execution_meta, +) +from backend.data.graph import get_graph_settings from backend.data.human_review import ( + create_auto_approval_record, get_pending_reviews_for_execution, get_pending_reviews_for_user, + get_reviews_by_node_exec_ids, has_pending_reviews_for_graph_exec, process_all_reviews_for_execution, ) +from backend.data.model import USER_TIMEZONE_NOT_SET +from backend.data.user import get_user_by_id from backend.executor.utils import add_graph_execution from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse @@ -127,17 +137,70 @@ async def process_review_action( detail="At least one review must be provided", ) - # Build review decisions map + # Batch fetch all requested reviews (regardless of status for idempotent handling) + reviews_map = await get_reviews_by_node_exec_ids( + list(all_request_node_ids), user_id + ) + + # Validate all reviews were found (must exist, any status is OK for now) + missing_ids = all_request_node_ids - set(reviews_map.keys()) + if missing_ids: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Review(s) not found: {', '.join(missing_ids)}", + ) + + # Validate all reviews belong to the same execution + graph_exec_ids = {review.graph_exec_id for review in reviews_map.values()} + if len(graph_exec_ids) > 1: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="All reviews in a single request must belong to the same execution.", + ) + + graph_exec_id = next(iter(graph_exec_ids)) + + # Validate execution status before processing reviews + graph_exec_meta = await get_graph_execution_meta( + user_id=user_id, execution_id=graph_exec_id + ) + + if not graph_exec_meta: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Graph execution #{graph_exec_id} not found", + ) + + # Only allow processing reviews if execution is paused for review + # or incomplete (partial execution with some reviews already processed) + if graph_exec_meta.status not in ( + ExecutionStatus.REVIEW, + ExecutionStatus.INCOMPLETE, + ): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. " + f"Reviews can only be processed when execution is paused (REVIEW status). " + f"Current status: {graph_exec_meta.status}", + ) + + # Build review decisions map and track which reviews requested auto-approval + # Auto-approved reviews use original data (no modifications allowed) review_decisions = {} + auto_approve_requests = {} # Map node_exec_id -> auto_approve_future flag + for review in request.reviews: review_status = ( ReviewStatus.APPROVED if review.approved else ReviewStatus.REJECTED ) + # If this review requested auto-approval, don't allow data modifications + reviewed_data = None if review.auto_approve_future else review.reviewed_data review_decisions[review.node_exec_id] = ( review_status, - review.reviewed_data, + reviewed_data, review.message, ) + auto_approve_requests[review.node_exec_id] = review.auto_approve_future # Process all reviews updated_reviews = await process_all_reviews_for_execution( @@ -145,6 +208,87 @@ async def process_review_action( review_decisions=review_decisions, ) + # Create auto-approval records for approved reviews that requested it + # Deduplicate by node_id to avoid race conditions when multiple reviews + # for the same node are processed in parallel + async def create_auto_approval_for_node( + node_id: str, review_result + ) -> tuple[str, bool]: + """ + Create auto-approval record for a node. + Returns (node_id, success) tuple for tracking failures. + """ + try: + await create_auto_approval_record( + user_id=user_id, + graph_exec_id=review_result.graph_exec_id, + graph_id=review_result.graph_id, + graph_version=review_result.graph_version, + node_id=node_id, + payload=review_result.payload, + ) + return (node_id, True) + except Exception as e: + logger.error( + f"Failed to create auto-approval record for node {node_id}", + exc_info=e, + ) + return (node_id, False) + + # Collect node_exec_ids that need auto-approval + node_exec_ids_needing_auto_approval = [ + node_exec_id + for node_exec_id, review_result in updated_reviews.items() + if review_result.status == ReviewStatus.APPROVED + and auto_approve_requests.get(node_exec_id, False) + ] + + # Batch-fetch node executions to get node_ids + nodes_needing_auto_approval: dict[str, Any] = {} + if node_exec_ids_needing_auto_approval: + from backend.data.execution import get_node_executions + + node_execs = await get_node_executions( + graph_exec_id=graph_exec_id, include_exec_data=False + ) + node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs} + + for node_exec_id in node_exec_ids_needing_auto_approval: + node_exec = node_exec_map.get(node_exec_id) + if node_exec: + review_result = updated_reviews[node_exec_id] + # Use the first approved review for this node (deduplicate by node_id) + if node_exec.node_id not in nodes_needing_auto_approval: + nodes_needing_auto_approval[node_exec.node_id] = review_result + else: + logger.error( + f"Failed to create auto-approval record for {node_exec_id}: " + f"Node execution not found. This may indicate a race condition " + f"or data inconsistency." + ) + + # Execute all auto-approval creations in parallel (deduplicated by node_id) + auto_approval_results = await asyncio.gather( + *[ + create_auto_approval_for_node(node_id, review_result) + for node_id, review_result in nodes_needing_auto_approval.items() + ], + return_exceptions=True, + ) + + # Count auto-approval failures + auto_approval_failed_count = 0 + for result in auto_approval_results: + if isinstance(result, Exception): + # Unexpected exception during auto-approval creation + auto_approval_failed_count += 1 + logger.error( + f"Unexpected exception during auto-approval creation: {result}" + ) + elif isinstance(result, tuple) and len(result) == 2 and not result[1]: + # Auto-approval creation failed (returned False) + auto_approval_failed_count += 1 + # Count results approved_count = sum( 1 @@ -157,30 +301,53 @@ async def process_review_action( if review.status == ReviewStatus.REJECTED ) - # Resume execution if we processed some reviews + # Resume execution only if ALL pending reviews for this execution have been processed if updated_reviews: - # Get graph execution ID from any processed review - first_review = next(iter(updated_reviews.values())) - graph_exec_id = first_review.graph_exec_id - - # Check if any pending reviews remain for this execution still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id) if not still_has_pending: - # Resume execution + # Get the graph_id from any processed review + first_review = next(iter(updated_reviews.values())) + try: + # Fetch user and settings to build complete execution context + user = await get_user_by_id(user_id) + settings = await get_graph_settings( + user_id=user_id, graph_id=first_review.graph_id + ) + + # Preserve user's timezone preference when resuming execution + user_timezone = ( + user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC" + ) + + execution_context = ExecutionContext( + human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode, + sensitive_action_safe_mode=settings.sensitive_action_safe_mode, + user_timezone=user_timezone, + ) + await add_graph_execution( graph_id=first_review.graph_id, user_id=user_id, graph_exec_id=graph_exec_id, + execution_context=execution_context, ) logger.info(f"Resumed execution {graph_exec_id}") except Exception as e: logger.error(f"Failed to resume execution {graph_exec_id}: {str(e)}") + # Build error message if auto-approvals failed + error_message = None + if auto_approval_failed_count > 0: + error_message = ( + f"{auto_approval_failed_count} auto-approval setting(s) could not be saved. " + f"You may need to manually approve these reviews in future executions." + ) + return ReviewResponse( approved_count=approved_count, rejected_count=rejected_count, - failed_count=0, - error=None, + failed_count=auto_approval_failed_count, + error=error_message, ) diff --git a/autogpt_platform/backend/backend/api/features/library/db.py b/autogpt_platform/backend/backend/api/features/library/db.py index 0c775802db..872fe66b28 100644 --- a/autogpt_platform/backend/backend/api/features/library/db.py +++ b/autogpt_platform/backend/backend/api/features/library/db.py @@ -21,7 +21,7 @@ from backend.data.model import CredentialsMetaInput from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate from backend.util.clients import get_scheduler_client -from backend.util.exceptions import DatabaseError, NotFoundError +from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError from backend.util.json import SafeJson from backend.util.models import Pagination from backend.util.settings import Config @@ -64,11 +64,11 @@ async def list_library_agents( if page < 1 or page_size < 1: logger.warning(f"Invalid pagination: page={page}, page_size={page_size}") - raise DatabaseError("Invalid pagination input") + raise InvalidInputError("Invalid pagination input") if search_term and len(search_term.strip()) > 100: logger.warning(f"Search term too long: {repr(search_term)}") - raise DatabaseError("Search term is too long") + raise InvalidInputError("Search term is too long") where_clause: prisma.types.LibraryAgentWhereInput = { "userId": user_id, @@ -175,7 +175,7 @@ async def list_favorite_library_agents( if page < 1 or page_size < 1: logger.warning(f"Invalid pagination: page={page}, page_size={page_size}") - raise DatabaseError("Invalid pagination input") + raise InvalidInputError("Invalid pagination input") where_clause: prisma.types.LibraryAgentWhereInput = { "userId": user_id, @@ -583,7 +583,13 @@ async def update_library_agent( ) update_fields["isDeleted"] = is_deleted if settings is not None: - update_fields["settings"] = SafeJson(settings.model_dump()) + existing_agent = await get_library_agent(id=library_agent_id, user_id=user_id) + current_settings_dict = ( + existing_agent.settings.model_dump() if existing_agent.settings else {} + ) + new_settings = settings.model_dump(exclude_unset=True) + merged_settings = {**current_settings_dict, **new_settings} + update_fields["settings"] = SafeJson(merged_settings) try: # If graph_version is provided, update to that specific version diff --git a/autogpt_platform/backend/backend/api/features/library/routes/agents.py b/autogpt_platform/backend/backend/api/features/library/routes/agents.py index 38c34dd3b8..fa3d1a0f0c 100644 --- a/autogpt_platform/backend/backend/api/features/library/routes/agents.py +++ b/autogpt_platform/backend/backend/api/features/library/routes/agents.py @@ -1,4 +1,3 @@ -import logging from typing import Literal, Optional import autogpt_libs.auth as autogpt_auth_lib @@ -6,15 +5,11 @@ from fastapi import APIRouter, Body, HTTPException, Query, Security, status from fastapi.responses import Response from prisma.enums import OnboardingStep -import backend.api.features.store.exceptions as store_exceptions from backend.data.onboarding import complete_onboarding_step -from backend.util.exceptions import DatabaseError, NotFoundError from .. import db as library_db from .. import model as library_model -logger = logging.getLogger(__name__) - router = APIRouter( prefix="/agents", tags=["library", "private"], @@ -26,10 +21,6 @@ router = APIRouter( "", summary="List Library Agents", response_model=library_model.LibraryAgentResponse, - responses={ - 200: {"description": "List of library agents"}, - 500: {"description": "Server error", "content": {"application/json": {}}}, - }, ) async def list_library_agents( user_id: str = Security(autogpt_auth_lib.get_user_id), @@ -53,43 +44,19 @@ async def list_library_agents( ) -> library_model.LibraryAgentResponse: """ Get all agents in the user's library (both created and saved). - - Args: - user_id: ID of the authenticated user. - search_term: Optional search term to filter agents by name/description. - filter_by: List of filters to apply (favorites, created by user). - sort_by: List of sorting criteria (created date, updated date). - page: Page number to retrieve. - page_size: Number of agents per page. - - Returns: - A LibraryAgentResponse containing agents and pagination metadata. - - Raises: - HTTPException: If a server/database error occurs. """ - try: - return await library_db.list_library_agents( - user_id=user_id, - search_term=search_term, - sort_by=sort_by, - page=page, - page_size=page_size, - ) - except Exception as e: - logger.error(f"Could not list library agents for user #{user_id}: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e), - ) from e + return await library_db.list_library_agents( + user_id=user_id, + search_term=search_term, + sort_by=sort_by, + page=page, + page_size=page_size, + ) @router.get( "/favorites", summary="List Favorite Library Agents", - responses={ - 500: {"description": "Server error", "content": {"application/json": {}}}, - }, ) async def list_favorite_library_agents( user_id: str = Security(autogpt_auth_lib.get_user_id), @@ -106,30 +73,12 @@ async def list_favorite_library_agents( ) -> library_model.LibraryAgentResponse: """ Get all favorite agents in the user's library. - - Args: - user_id: ID of the authenticated user. - page: Page number to retrieve. - page_size: Number of agents per page. - - Returns: - A LibraryAgentResponse containing favorite agents and pagination metadata. - - Raises: - HTTPException: If a server/database error occurs. """ - try: - return await library_db.list_favorite_library_agents( - user_id=user_id, - page=page, - page_size=page_size, - ) - except Exception as e: - logger.error(f"Could not list favorite library agents for user #{user_id}: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e), - ) from e + return await library_db.list_favorite_library_agents( + user_id=user_id, + page=page, + page_size=page_size, + ) @router.get("/{library_agent_id}", summary="Get Library Agent") @@ -162,10 +111,6 @@ async def get_library_agent_by_graph_id( summary="Get Agent By Store ID", tags=["store", "library"], response_model=library_model.LibraryAgent | None, - responses={ - 200: {"description": "Library agent found"}, - 404: {"description": "Agent not found"}, - }, ) async def get_library_agent_by_store_listing_version_id( store_listing_version_id: str, @@ -174,32 +119,15 @@ async def get_library_agent_by_store_listing_version_id( """ Get Library Agent from Store Listing Version ID. """ - try: - return await library_db.get_library_agent_by_store_version_id( - store_listing_version_id, user_id - ) - except NotFoundError as e: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=str(e), - ) - except Exception as e: - logger.error(f"Could not fetch library agent from store version ID: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e), - ) from e + return await library_db.get_library_agent_by_store_version_id( + store_listing_version_id, user_id + ) @router.post( "", summary="Add Marketplace Agent", status_code=status.HTTP_201_CREATED, - responses={ - 201: {"description": "Agent added successfully"}, - 404: {"description": "Store listing version not found"}, - 500: {"description": "Server error"}, - }, ) async def add_marketplace_agent_to_library( store_listing_version_id: str = Body(embed=True), @@ -210,59 +138,19 @@ async def add_marketplace_agent_to_library( ) -> library_model.LibraryAgent: """ Add an agent from the marketplace to the user's library. - - Args: - store_listing_version_id: ID of the store listing version to add. - user_id: ID of the authenticated user. - - Returns: - library_model.LibraryAgent: Agent added to the library - - Raises: - HTTPException(404): If the listing version is not found. - HTTPException(500): If a server/database error occurs. """ - try: - agent = await library_db.add_store_agent_to_library( - store_listing_version_id=store_listing_version_id, - user_id=user_id, - ) - if source != "onboarding": - await complete_onboarding_step( - user_id, OnboardingStep.MARKETPLACE_ADD_AGENT - ) - return agent - - except store_exceptions.AgentNotFoundError as e: - logger.warning( - f"Could not find store listing version {store_listing_version_id} " - "to add to library" - ) - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) - except DatabaseError as e: - logger.error(f"Database error while adding agent to library: {e}", e) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail={"message": str(e), "hint": "Inspect DB logs for details."}, - ) from e - except Exception as e: - logger.error(f"Unexpected error while adding agent to library: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail={ - "message": str(e), - "hint": "Check server logs for more information.", - }, - ) from e + agent = await library_db.add_store_agent_to_library( + store_listing_version_id=store_listing_version_id, + user_id=user_id, + ) + if source != "onboarding": + await complete_onboarding_step(user_id, OnboardingStep.MARKETPLACE_ADD_AGENT) + return agent @router.patch( "/{library_agent_id}", summary="Update Library Agent", - responses={ - 200: {"description": "Agent updated successfully"}, - 500: {"description": "Server error"}, - }, ) async def update_library_agent( library_agent_id: str, @@ -271,52 +159,21 @@ async def update_library_agent( ) -> library_model.LibraryAgent: """ Update the library agent with the given fields. - - Args: - library_agent_id: ID of the library agent to update. - payload: Fields to update (auto_update_version, is_favorite, etc.). - user_id: ID of the authenticated user. - - Raises: - HTTPException(500): If a server/database error occurs. """ - try: - return await library_db.update_library_agent( - library_agent_id=library_agent_id, - user_id=user_id, - auto_update_version=payload.auto_update_version, - graph_version=payload.graph_version, - is_favorite=payload.is_favorite, - is_archived=payload.is_archived, - settings=payload.settings, - ) - except NotFoundError as e: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=str(e), - ) from e - except DatabaseError as e: - logger.error(f"Database error while updating library agent: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail={"message": str(e), "hint": "Verify DB connection."}, - ) from e - except Exception as e: - logger.error(f"Unexpected error while updating library agent: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail={"message": str(e), "hint": "Check server logs."}, - ) from e + return await library_db.update_library_agent( + library_agent_id=library_agent_id, + user_id=user_id, + auto_update_version=payload.auto_update_version, + graph_version=payload.graph_version, + is_favorite=payload.is_favorite, + is_archived=payload.is_archived, + settings=payload.settings, + ) @router.delete( "/{library_agent_id}", summary="Delete Library Agent", - responses={ - 204: {"description": "Agent deleted successfully"}, - 404: {"description": "Agent not found"}, - 500: {"description": "Server error"}, - }, ) async def delete_library_agent( library_agent_id: str, @@ -324,28 +181,11 @@ async def delete_library_agent( ) -> Response: """ Soft-delete the specified library agent. - - Args: - library_agent_id: ID of the library agent to delete. - user_id: ID of the authenticated user. - - Returns: - 204 No Content if successful. - - Raises: - HTTPException(404): If the agent does not exist. - HTTPException(500): If a server/database error occurs. """ - try: - await library_db.delete_library_agent( - library_agent_id=library_agent_id, user_id=user_id - ) - return Response(status_code=status.HTTP_204_NO_CONTENT) - except NotFoundError as e: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=str(e), - ) from e + await library_db.delete_library_agent( + library_agent_id=library_agent_id, user_id=user_id + ) + return Response(status_code=status.HTTP_204_NO_CONTENT) @router.post("/{library_agent_id}/fork", summary="Fork Library Agent") diff --git a/autogpt_platform/backend/backend/api/features/library/routes_test.py b/autogpt_platform/backend/backend/api/features/library/routes_test.py index ca604af760..4d83812891 100644 --- a/autogpt_platform/backend/backend/api/features/library/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/library/routes_test.py @@ -118,21 +118,6 @@ async def test_get_library_agents_success( ) -def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id: str): - mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents") - mock_db_call.side_effect = Exception("Test error") - - response = client.get("/agents?search_term=test") - assert response.status_code == 500 - mock_db_call.assert_called_once_with( - user_id=test_user_id, - search_term="test", - sort_by=library_model.LibraryAgentSort.UPDATED_AT, - page=1, - page_size=15, - ) - - @pytest.mark.asyncio async def test_get_favorite_library_agents_success( mocker: pytest_mock.MockFixture, @@ -190,23 +175,6 @@ async def test_get_favorite_library_agents_success( ) -def test_get_favorite_library_agents_error( - mocker: pytest_mock.MockFixture, test_user_id: str -): - mock_db_call = mocker.patch( - "backend.api.features.library.db.list_favorite_library_agents" - ) - mock_db_call.side_effect = Exception("Test error") - - response = client.get("/agents/favorites") - assert response.status_code == 500 - mock_db_call.assert_called_once_with( - user_id=test_user_id, - page=1, - page_size=15, - ) - - def test_add_agent_to_library_success( mocker: pytest_mock.MockFixture, test_user_id: str ): @@ -258,19 +226,3 @@ def test_add_agent_to_library_success( store_listing_version_id="test-version-id", user_id=test_user_id ) mock_complete_onboarding.assert_awaited_once() - - -def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture, test_user_id: str): - mock_db_call = mocker.patch( - "backend.api.features.library.db.add_store_agent_to_library" - ) - mock_db_call.side_effect = Exception("Test error") - - response = client.post( - "/agents", json={"store_listing_version_id": "test-version-id"} - ) - assert response.status_code == 500 - assert "detail" in response.json() # Verify error response structure - mock_db_call.assert_called_once_with( - store_listing_version_id="test-version-id", user_id=test_user_id - ) diff --git a/autogpt_platform/backend/backend/api/features/oauth_test.py b/autogpt_platform/backend/backend/api/features/oauth_test.py index 5f6b85a88a..5fd35f82e7 100644 --- a/autogpt_platform/backend/backend/api/features/oauth_test.py +++ b/autogpt_platform/backend/backend/api/features/oauth_test.py @@ -20,6 +20,7 @@ from typing import AsyncGenerator import httpx import pytest +import pytest_asyncio from autogpt_libs.api_key.keysmith import APIKeySmith from prisma.enums import APIKeyPermission from prisma.models import OAuthAccessToken as PrismaOAuthAccessToken @@ -38,13 +39,13 @@ keysmith = APIKeySmith() # ============================================================================ -@pytest.fixture +@pytest.fixture(scope="session") def test_user_id() -> str: """Test user ID for OAuth tests.""" return str(uuid.uuid4()) -@pytest.fixture +@pytest_asyncio.fixture(scope="session", loop_scope="session") async def test_user(server, test_user_id: str): """Create a test user in the database.""" await PrismaUser.prisma().create( @@ -67,7 +68,7 @@ async def test_user(server, test_user_id: str): await PrismaUser.prisma().delete(where={"id": test_user_id}) -@pytest.fixture +@pytest_asyncio.fixture async def test_oauth_app(test_user: str): """Create a test OAuth application in the database.""" app_id = str(uuid.uuid4()) @@ -122,7 +123,7 @@ def pkce_credentials() -> tuple[str, str]: return generate_pkce() -@pytest.fixture +@pytest_asyncio.fixture async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, None]: """ Create an async HTTP client that talks directly to the FastAPI app. @@ -287,7 +288,7 @@ async def test_authorize_invalid_client_returns_error( assert query_params["error"][0] == "invalid_client" -@pytest.fixture +@pytest_asyncio.fixture async def inactive_oauth_app(test_user: str): """Create an inactive test OAuth application in the database.""" app_id = str(uuid.uuid4()) @@ -1004,7 +1005,7 @@ async def test_token_refresh_revoked( assert "revoked" in response.json()["detail"].lower() -@pytest.fixture +@pytest_asyncio.fixture async def other_oauth_app(test_user: str): """Create a second OAuth application for cross-app tests.""" app_id = str(uuid.uuid4()) diff --git a/autogpt_platform/backend/backend/api/features/store/content_handlers.py b/autogpt_platform/backend/backend/api/features/store/content_handlers.py index 1560db421c..cbbdcfbebf 100644 --- a/autogpt_platform/backend/backend/api/features/store/content_handlers.py +++ b/autogpt_platform/backend/backend/api/features/store/content_handlers.py @@ -188,6 +188,10 @@ class BlockHandler(ContentHandler): try: block_instance = block_cls() + # Skip disabled blocks - they shouldn't be indexed + if block_instance.disabled: + continue + # Build searchable text from block metadata parts = [] if hasattr(block_instance, "name") and block_instance.name: @@ -248,12 +252,19 @@ class BlockHandler(ContentHandler): from backend.data.block import get_blocks all_blocks = get_blocks() - total_blocks = len(all_blocks) + + # Filter out disabled blocks - they're not indexed + enabled_block_ids = [ + block_id + for block_id, block_cls in all_blocks.items() + if not block_cls().disabled + ] + total_blocks = len(enabled_block_ids) if total_blocks == 0: return {"total": 0, "with_embeddings": 0, "without_embeddings": 0} - block_ids = list(all_blocks.keys()) + block_ids = enabled_block_ids placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))]) embedded_result = await query_raw_with_schema( diff --git a/autogpt_platform/backend/backend/api/features/store/content_handlers_test.py b/autogpt_platform/backend/backend/api/features/store/content_handlers_test.py index 28bc88e270..fee879fae0 100644 --- a/autogpt_platform/backend/backend/api/features/store/content_handlers_test.py +++ b/autogpt_platform/backend/backend/api/features/store/content_handlers_test.py @@ -81,6 +81,7 @@ async def test_block_handler_get_missing_items(mocker): mock_block_instance.name = "Calculator Block" mock_block_instance.description = "Performs calculations" mock_block_instance.categories = [MagicMock(value="MATH")] + mock_block_instance.disabled = False mock_block_instance.input_schema.model_json_schema.return_value = { "properties": {"expression": {"description": "Math expression to evaluate"}} } @@ -116,11 +117,18 @@ async def test_block_handler_get_stats(mocker): """Test BlockHandler returns correct stats.""" handler = BlockHandler() - # Mock get_blocks + # Mock get_blocks - each block class returns an instance with disabled=False + def make_mock_block_class(): + mock_class = MagicMock() + mock_instance = MagicMock() + mock_instance.disabled = False + mock_class.return_value = mock_instance + return mock_class + mock_blocks = { - "block-1": MagicMock(), - "block-2": MagicMock(), - "block-3": MagicMock(), + "block-1": make_mock_block_class(), + "block-2": make_mock_block_class(), + "block-3": make_mock_block_class(), } # Mock embedded count query (2 blocks have embeddings) @@ -309,6 +317,7 @@ async def test_block_handler_handles_missing_attributes(): mock_block_class = MagicMock() mock_block_instance = MagicMock() mock_block_instance.name = "Minimal Block" + mock_block_instance.disabled = False # No description, categories, or schema del mock_block_instance.description del mock_block_instance.categories @@ -342,6 +351,7 @@ async def test_block_handler_skips_failed_blocks(): good_instance.name = "Good Block" good_instance.description = "Works fine" good_instance.categories = [] + good_instance.disabled = False good_block.return_value = good_instance bad_block = MagicMock() diff --git a/autogpt_platform/backend/backend/api/features/store/db.py b/autogpt_platform/backend/backend/api/features/store/db.py index e6aa3853f6..956fdfa7da 100644 --- a/autogpt_platform/backend/backend/api/features/store/db.py +++ b/autogpt_platform/backend/backend/api/features/store/db.py @@ -1552,7 +1552,7 @@ async def review_store_submission( # Generate embedding for approved listing (blocking - admin operation) # Inside transaction: if embedding fails, entire transaction rolls back - embedding_success = await ensure_embedding( + await ensure_embedding( version_id=store_listing_version_id, name=store_listing_version.name, description=store_listing_version.description, @@ -1560,12 +1560,6 @@ async def review_store_submission( categories=store_listing_version.categories or [], tx=tx, ) - if not embedding_success: - raise ValueError( - f"Failed to generate embedding for listing {store_listing_version_id}. " - "This is likely due to OpenAI API being unavailable. " - "Please try again later or contact support if the issue persists." - ) await prisma.models.StoreListing.prisma(tx).update( where={"id": store_listing_version.StoreListing.id}, diff --git a/autogpt_platform/backend/backend/api/features/store/embeddings.py b/autogpt_platform/backend/backend/api/features/store/embeddings.py index efe896f665..434f2fe2ce 100644 --- a/autogpt_platform/backend/backend/api/features/store/embeddings.py +++ b/autogpt_platform/backend/backend/api/features/store/embeddings.py @@ -21,7 +21,6 @@ from backend.util.json import dumps logger = logging.getLogger(__name__) - # OpenAI embedding model configuration EMBEDDING_MODEL = "text-embedding-3-small" # Embedding dimension for the model above @@ -63,49 +62,42 @@ def build_searchable_text( return " ".join(parts) -async def generate_embedding(text: str) -> list[float] | None: +async def generate_embedding(text: str) -> list[float]: """ Generate embedding for text using OpenAI API. - Returns None if embedding generation fails. - Fail-fast: no retries to maintain consistency with approval flow. + Raises exceptions on failure - caller should handle. """ - try: - client = get_openai_client() - if not client: - logger.error("openai_internal_api_key not set, cannot generate embedding") - return None + client = get_openai_client() + if not client: + raise RuntimeError("openai_internal_api_key not set, cannot generate embedding") - # Truncate text to token limit using tiktoken - # Character-based truncation is insufficient because token ratios vary by content type - enc = encoding_for_model(EMBEDDING_MODEL) - tokens = enc.encode(text) - if len(tokens) > EMBEDDING_MAX_TOKENS: - tokens = tokens[:EMBEDDING_MAX_TOKENS] - truncated_text = enc.decode(tokens) - logger.info( - f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens" - ) - else: - truncated_text = text - - start_time = time.time() - response = await client.embeddings.create( - model=EMBEDDING_MODEL, - input=truncated_text, - ) - latency_ms = (time.time() - start_time) * 1000 - - embedding = response.data[0].embedding + # Truncate text to token limit using tiktoken + # Character-based truncation is insufficient because token ratios vary by content type + enc = encoding_for_model(EMBEDDING_MODEL) + tokens = enc.encode(text) + if len(tokens) > EMBEDDING_MAX_TOKENS: + tokens = tokens[:EMBEDDING_MAX_TOKENS] + truncated_text = enc.decode(tokens) logger.info( - f"Generated embedding: {len(embedding)} dims, " - f"{len(tokens)} tokens, {latency_ms:.0f}ms" + f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens" ) - return embedding + else: + truncated_text = text - except Exception as e: - logger.error(f"Failed to generate embedding: {e}") - return None + start_time = time.time() + response = await client.embeddings.create( + model=EMBEDDING_MODEL, + input=truncated_text, + ) + latency_ms = (time.time() - start_time) * 1000 + + embedding = response.data[0].embedding + logger.info( + f"Generated embedding: {len(embedding)} dims, " + f"{len(tokens)} tokens, {latency_ms:.0f}ms" + ) + return embedding async def store_embedding( @@ -144,48 +136,45 @@ async def store_content_embedding( New function for unified content embedding storage. Uses raw SQL since Prisma doesn't natively support pgvector. + + Raises exceptions on failure - caller should handle. """ - try: - client = tx if tx else prisma.get_client() + client = tx if tx else prisma.get_client() - # Convert embedding to PostgreSQL vector format - embedding_str = embedding_to_vector_string(embedding) - metadata_json = dumps(metadata or {}) + # Convert embedding to PostgreSQL vector format + embedding_str = embedding_to_vector_string(embedding) + metadata_json = dumps(metadata or {}) - # Upsert the embedding - # WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT - # Use unqualified ::vector - pgvector is in search_path on all environments - await execute_raw_with_schema( - """ - INSERT INTO {schema_prefix}"UnifiedContentEmbedding" ( - "id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt" - ) - VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW()) - ON CONFLICT ("contentType", "contentId", "userId") - DO UPDATE SET - "embedding" = $4::vector, - "searchableText" = $5, - "metadata" = $6::jsonb, - "updatedAt" = NOW() - WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType" - AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2 - AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL)) - """, - content_type, - content_id, - user_id, - embedding_str, - searchable_text, - metadata_json, - client=client, + # Upsert the embedding + # WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT + # Use unqualified ::vector - pgvector is in search_path on all environments + await execute_raw_with_schema( + """ + INSERT INTO {schema_prefix}"UnifiedContentEmbedding" ( + "id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt" ) + VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW()) + ON CONFLICT ("contentType", "contentId", "userId") + DO UPDATE SET + "embedding" = $4::vector, + "searchableText" = $5, + "metadata" = $6::jsonb, + "updatedAt" = NOW() + WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType" + AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2 + AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL)) + """, + content_type, + content_id, + user_id, + embedding_str, + searchable_text, + metadata_json, + client=client, + ) - logger.info(f"Stored embedding for {content_type}:{content_id}") - return True - - except Exception as e: - logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}") - return False + logger.info(f"Stored embedding for {content_type}:{content_id}") + return True async def get_embedding(version_id: str) -> dict[str, Any] | None: @@ -217,34 +206,31 @@ async def get_content_embedding( New function for unified content embedding retrieval. Returns dict with contentType, contentId, embedding, timestamps or None if not found. + + Raises exceptions on failure - caller should handle. """ - try: - result = await query_raw_with_schema( - """ - SELECT - "contentType", - "contentId", - "userId", - "embedding"::text as "embedding", - "searchableText", - "metadata", - "createdAt", - "updatedAt" - FROM {schema_prefix}"UnifiedContentEmbedding" - WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL)) - """, - content_type, - content_id, - user_id, - ) + result = await query_raw_with_schema( + """ + SELECT + "contentType", + "contentId", + "userId", + "embedding"::text as "embedding", + "searchableText", + "metadata", + "createdAt", + "updatedAt" + FROM {schema_prefix}"UnifiedContentEmbedding" + WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL)) + """, + content_type, + content_id, + user_id, + ) - if result and len(result) > 0: - return result[0] - return None - - except Exception as e: - logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}") - return None + if result and len(result) > 0: + return result[0] + return None async def ensure_embedding( @@ -272,46 +258,38 @@ async def ensure_embedding( tx: Optional transaction client Returns: - True if embedding exists/was created, False on failure + True if embedding exists/was created + + Raises exceptions on failure - caller should handle. """ - try: - # Check if embedding already exists - if not force: - existing = await get_embedding(version_id) - if existing and existing.get("embedding"): - logger.debug(f"Embedding for version {version_id} already exists") - return True + # Check if embedding already exists + if not force: + existing = await get_embedding(version_id) + if existing and existing.get("embedding"): + logger.debug(f"Embedding for version {version_id} already exists") + return True - # Build searchable text for embedding - searchable_text = build_searchable_text( - name, description, sub_heading, categories - ) + # Build searchable text for embedding + searchable_text = build_searchable_text(name, description, sub_heading, categories) - # Generate new embedding - embedding = await generate_embedding(searchable_text) - if embedding is None: - logger.warning(f"Could not generate embedding for version {version_id}") - return False + # Generate new embedding + embedding = await generate_embedding(searchable_text) - # Store the embedding with metadata using new function - metadata = { - "name": name, - "subHeading": sub_heading, - "categories": categories, - } - return await store_content_embedding( - content_type=ContentType.STORE_AGENT, - content_id=version_id, - embedding=embedding, - searchable_text=searchable_text, - metadata=metadata, - user_id=None, # Store agents are public - tx=tx, - ) - - except Exception as e: - logger.error(f"Failed to ensure embedding for version {version_id}: {e}") - return False + # Store the embedding with metadata using new function + metadata = { + "name": name, + "subHeading": sub_heading, + "categories": categories, + } + return await store_content_embedding( + content_type=ContentType.STORE_AGENT, + content_id=version_id, + embedding=embedding, + searchable_text=searchable_text, + metadata=metadata, + user_id=None, # Store agents are public + tx=tx, + ) async def delete_embedding(version_id: str) -> bool: @@ -476,6 +454,7 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]: total_processed = 0 total_success = 0 total_failed = 0 + all_errors: dict[str, int] = {} # Aggregate errors across all content types # Process content types in explicit order processing_order = [ @@ -521,6 +500,13 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]: success = sum(1 for result in results if result is True) failed = len(results) - success + # Aggregate errors across all content types + if failed > 0: + for result in results: + if isinstance(result, Exception): + error_key = f"{type(result).__name__}: {str(result)}" + all_errors[error_key] = all_errors.get(error_key, 0) + 1 + results_by_type[content_type.value] = { "processed": len(missing_items), "success": success, @@ -546,6 +532,13 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]: "error": str(e), } + # Log aggregated errors once at the end + if all_errors: + error_details = ", ".join( + f"{error} ({count}x)" for error, count in all_errors.items() + ) + logger.error(f"Embedding backfill errors: {error_details}") + return { "by_type": results_by_type, "totals": { @@ -557,11 +550,12 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]: } -async def embed_query(query: str) -> list[float] | None: +async def embed_query(query: str) -> list[float]: """ Generate embedding for a search query. Same as generate_embedding but with clearer intent. + Raises exceptions on failure - caller should handle. """ return await generate_embedding(query) @@ -594,40 +588,30 @@ async def ensure_content_embedding( tx: Optional transaction client Returns: - True if embedding exists/was created, False on failure + True if embedding exists/was created + + Raises exceptions on failure - caller should handle. """ - try: - # Check if embedding already exists - if not force: - existing = await get_content_embedding(content_type, content_id, user_id) - if existing and existing.get("embedding"): - logger.debug( - f"Embedding for {content_type}:{content_id} already exists" - ) - return True + # Check if embedding already exists + if not force: + existing = await get_content_embedding(content_type, content_id, user_id) + if existing and existing.get("embedding"): + logger.debug(f"Embedding for {content_type}:{content_id} already exists") + return True - # Generate new embedding - embedding = await generate_embedding(searchable_text) - if embedding is None: - logger.warning( - f"Could not generate embedding for {content_type}:{content_id}" - ) - return False + # Generate new embedding + embedding = await generate_embedding(searchable_text) - # Store the embedding - return await store_content_embedding( - content_type=content_type, - content_id=content_id, - embedding=embedding, - searchable_text=searchable_text, - metadata=metadata or {}, - user_id=user_id, - tx=tx, - ) - - except Exception as e: - logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}") - return False + # Store the embedding + return await store_content_embedding( + content_type=content_type, + content_id=content_id, + embedding=embedding, + searchable_text=searchable_text, + metadata=metadata or {}, + user_id=user_id, + tx=tx, + ) async def cleanup_orphaned_embeddings() -> dict[str, Any]: @@ -854,9 +838,8 @@ async def semantic_search( limit = 100 # Generate query embedding - query_embedding = await embed_query(query) - - if query_embedding is not None: + try: + query_embedding = await embed_query(query) # Semantic search with embeddings embedding_str = embedding_to_vector_string(query_embedding) @@ -907,24 +890,21 @@ async def semantic_search( """ ) - try: - results = await query_raw_with_schema(sql, *params) - return [ - { - "content_id": row["content_id"], - "content_type": row["content_type"], - "searchable_text": row["searchable_text"], - "metadata": row["metadata"], - "similarity": float(row["similarity"]), - } - for row in results - ] - except Exception as e: - logger.error(f"Semantic search failed: {e}") - # Fall through to lexical search below + results = await query_raw_with_schema(sql, *params) + return [ + { + "content_id": row["content_id"], + "content_type": row["content_type"], + "searchable_text": row["searchable_text"], + "metadata": row["metadata"], + "similarity": float(row["similarity"]), + } + for row in results + ] + except Exception as e: + logger.warning(f"Semantic search failed, falling back to lexical search: {e}") # Fallback to lexical search if embeddings unavailable - logger.warning("Falling back to lexical search (embeddings unavailable)") params_lexical: list[Any] = [limit] user_filter = "" diff --git a/autogpt_platform/backend/backend/api/features/store/embeddings_schema_test.py b/autogpt_platform/backend/backend/api/features/store/embeddings_schema_test.py index 7ba200fda0..5aa13b4d23 100644 --- a/autogpt_platform/backend/backend/api/features/store/embeddings_schema_test.py +++ b/autogpt_platform/backend/backend/api/features/store/embeddings_schema_test.py @@ -298,17 +298,16 @@ async def test_schema_handling_error_cases(): mock_client.execute_raw.side_effect = Exception("Database error") mock_get_client.return_value = mock_client - result = await embeddings.store_content_embedding( - content_type=ContentType.STORE_AGENT, - content_id="test-id", - embedding=[0.1] * EMBEDDING_DIM, - searchable_text="test", - metadata=None, - user_id=None, - ) - - # Should return False on error, not raise - assert result is False + # Should raise exception on error + with pytest.raises(Exception, match="Database error"): + await embeddings.store_content_embedding( + content_type=ContentType.STORE_AGENT, + content_id="test-id", + embedding=[0.1] * EMBEDDING_DIM, + searchable_text="test", + metadata=None, + user_id=None, + ) if __name__ == "__main__": diff --git a/autogpt_platform/backend/backend/api/features/store/embeddings_test.py b/autogpt_platform/backend/backend/api/features/store/embeddings_test.py index 8cb471379b..0d5e5ce4a2 100644 --- a/autogpt_platform/backend/backend/api/features/store/embeddings_test.py +++ b/autogpt_platform/backend/backend/api/features/store/embeddings_test.py @@ -80,9 +80,8 @@ async def test_generate_embedding_no_api_key(): ) as mock_get_client: mock_get_client.return_value = None - result = await embeddings.generate_embedding("test text") - - assert result is None + with pytest.raises(RuntimeError, match="openai_internal_api_key not set"): + await embeddings.generate_embedding("test text") @pytest.mark.asyncio(loop_scope="session") @@ -97,9 +96,8 @@ async def test_generate_embedding_api_error(): ) as mock_get_client: mock_get_client.return_value = mock_client - result = await embeddings.generate_embedding("test text") - - assert result is None + with pytest.raises(Exception, match="API Error"): + await embeddings.generate_embedding("test text") @pytest.mark.asyncio(loop_scope="session") @@ -173,11 +171,10 @@ async def test_store_embedding_database_error(mocker): embedding = [0.1, 0.2, 0.3] - result = await embeddings.store_embedding( - version_id="test-version-id", embedding=embedding, tx=mock_client - ) - - assert result is False + with pytest.raises(Exception, match="Database error"): + await embeddings.store_embedding( + version_id="test-version-id", embedding=embedding, tx=mock_client + ) @pytest.mark.asyncio(loop_scope="session") @@ -277,17 +274,16 @@ async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate): async def test_ensure_embedding_generation_fails(mock_get, mock_generate): """Test ensure_embedding when generation fails.""" mock_get.return_value = None - mock_generate.return_value = None + mock_generate.side_effect = Exception("Generation failed") - result = await embeddings.ensure_embedding( - version_id="test-id", - name="Test", - description="Test description", - sub_heading="Test heading", - categories=["test"], - ) - - assert result is False + with pytest.raises(Exception, match="Generation failed"): + await embeddings.ensure_embedding( + version_id="test-id", + name="Test", + description="Test description", + sub_heading="Test heading", + categories=["test"], + ) @pytest.mark.asyncio(loop_scope="session") diff --git a/autogpt_platform/backend/backend/api/features/store/hybrid_search.py b/autogpt_platform/backend/backend/api/features/store/hybrid_search.py index 95ec3f4ff9..8b0884bb24 100644 --- a/autogpt_platform/backend/backend/api/features/store/hybrid_search.py +++ b/autogpt_platform/backend/backend/api/features/store/hybrid_search.py @@ -186,13 +186,12 @@ async def unified_hybrid_search( offset = (page - 1) * page_size - # Generate query embedding - query_embedding = await embed_query(query) - - # Graceful degradation if embedding unavailable - if query_embedding is None or not query_embedding: + # Generate query embedding with graceful degradation + try: + query_embedding = await embed_query(query) + except Exception as e: logger.warning( - "Failed to generate query embedding - falling back to lexical-only search. " + f"Failed to generate query embedding - falling back to lexical-only search: {e}. " "Check that openai_internal_api_key is configured and OpenAI API is accessible." ) query_embedding = [0.0] * EMBEDDING_DIM @@ -464,13 +463,12 @@ async def hybrid_search( offset = (page - 1) * page_size - # Generate query embedding - query_embedding = await embed_query(query) - - # Graceful degradation - if query_embedding is None or not query_embedding: + # Generate query embedding with graceful degradation + try: + query_embedding = await embed_query(query) + except Exception as e: logger.warning( - "Failed to generate query embedding - falling back to lexical-only search." + f"Failed to generate query embedding - falling back to lexical-only search: {e}" ) query_embedding = [0.0] * EMBEDDING_DIM total_non_semantic = ( diff --git a/autogpt_platform/backend/backend/api/features/store/hybrid_search_test.py b/autogpt_platform/backend/backend/api/features/store/hybrid_search_test.py index 7f942927a5..58989fbb41 100644 --- a/autogpt_platform/backend/backend/api/features/store/hybrid_search_test.py +++ b/autogpt_platform/backend/backend/api/features/store/hybrid_search_test.py @@ -172,8 +172,8 @@ async def test_hybrid_search_without_embeddings(): with patch( "backend.api.features.store.hybrid_search.query_raw_with_schema" ) as mock_query: - # Simulate embedding failure - mock_embed.return_value = None + # Simulate embedding failure by raising exception + mock_embed.side_effect = Exception("Embedding generation failed") mock_query.return_value = mock_results # Should NOT raise - graceful degradation @@ -613,7 +613,9 @@ async def test_unified_hybrid_search_graceful_degradation(): "backend.api.features.store.hybrid_search.embed_query" ) as mock_embed: mock_query.return_value = mock_results - mock_embed.return_value = None # Embedding failure + mock_embed.side_effect = Exception( + "Embedding generation failed" + ) # Embedding failure # Should NOT raise - graceful degradation results, total = await unified_hybrid_search( diff --git a/autogpt_platform/backend/backend/api/features/v1.py b/autogpt_platform/backend/backend/api/features/v1.py index 3a5dd3ec12..09d3759a65 100644 --- a/autogpt_platform/backend/backend/api/features/v1.py +++ b/autogpt_platform/backend/backend/api/features/v1.py @@ -261,14 +261,36 @@ async def get_onboarding_agents( return await get_recommended_agents(user_id) +class OnboardingStatusResponse(pydantic.BaseModel): + """Response for onboarding status check.""" + + is_onboarding_enabled: bool + is_chat_enabled: bool + + @v1_router.get( "/onboarding/enabled", summary="Is onboarding enabled", tags=["onboarding", "public"], - dependencies=[Security(requires_user)], + response_model=OnboardingStatusResponse, ) -async def is_onboarding_enabled() -> bool: - return await onboarding_enabled() +async def is_onboarding_enabled( + user_id: Annotated[str, Security(get_user_id)], +) -> OnboardingStatusResponse: + # Check if chat is enabled for user + is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False) + + # If chat is enabled, skip legacy onboarding + if is_chat_enabled: + return OnboardingStatusResponse( + is_onboarding_enabled=False, + is_chat_enabled=True, + ) + + return OnboardingStatusResponse( + is_onboarding_enabled=await onboarding_enabled(), + is_chat_enabled=False, + ) @v1_router.post( @@ -364,6 +386,8 @@ async def execute_graph_block( obj = get_block(block_id) if not obj: raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.") + if obj.disabled: + raise HTTPException(status_code=403, detail=f"Block #{block_id} is disabled.") user = await get_user_by_id(user_id) if not user: diff --git a/autogpt_platform/backend/backend/api/features/v1_test.py b/autogpt_platform/backend/backend/api/features/v1_test.py index a186d38810..d57ad49949 100644 --- a/autogpt_platform/backend/backend/api/features/v1_test.py +++ b/autogpt_platform/backend/backend/api/features/v1_test.py @@ -138,6 +138,7 @@ def test_execute_graph_block( """Test execute block endpoint""" # Mock block mock_block = Mock() + mock_block.disabled = False async def mock_execute(*args, **kwargs): yield "output1", {"data": "result1"} diff --git a/autogpt_platform/backend/backend/api/features/workspace/__init__.py b/autogpt_platform/backend/backend/api/features/workspace/__init__.py new file mode 100644 index 0000000000..688ada9937 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/workspace/__init__.py @@ -0,0 +1 @@ +# Workspace API feature module diff --git a/autogpt_platform/backend/backend/api/features/workspace/routes.py b/autogpt_platform/backend/backend/api/features/workspace/routes.py new file mode 100644 index 0000000000..b6d0c84572 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/workspace/routes.py @@ -0,0 +1,122 @@ +""" +Workspace API routes for managing user file storage. +""" + +import logging +import re +from typing import Annotated +from urllib.parse import quote + +import fastapi +from autogpt_libs.auth.dependencies import get_user_id, requires_user +from fastapi.responses import Response + +from backend.data.workspace import get_workspace, get_workspace_file +from backend.util.workspace_storage import get_workspace_storage + + +def _sanitize_filename_for_header(filename: str) -> str: + """ + Sanitize filename for Content-Disposition header to prevent header injection. + + Removes/replaces characters that could break the header or inject new headers. + Uses RFC5987 encoding for non-ASCII characters. + """ + # Remove CR, LF, and null bytes (header injection prevention) + sanitized = re.sub(r"[\r\n\x00]", "", filename) + # Escape quotes + sanitized = sanitized.replace('"', '\\"') + # For non-ASCII, use RFC5987 filename* parameter + # Check if filename has non-ASCII characters + try: + sanitized.encode("ascii") + return f'attachment; filename="{sanitized}"' + except UnicodeEncodeError: + # Use RFC5987 encoding for UTF-8 filenames + encoded = quote(sanitized, safe="") + return f"attachment; filename*=UTF-8''{encoded}" + + +logger = logging.getLogger(__name__) + +router = fastapi.APIRouter( + dependencies=[fastapi.Security(requires_user)], +) + + +def _create_streaming_response(content: bytes, file) -> Response: + """Create a streaming response for file content.""" + return Response( + content=content, + media_type=file.mimeType, + headers={ + "Content-Disposition": _sanitize_filename_for_header(file.name), + "Content-Length": str(len(content)), + }, + ) + + +async def _create_file_download_response(file) -> Response: + """ + Create a download response for a workspace file. + + Handles both local storage (direct streaming) and GCS (signed URL redirect + with fallback to streaming). + """ + storage = await get_workspace_storage() + + # For local storage, stream the file directly + if file.storagePath.startswith("local://"): + content = await storage.retrieve(file.storagePath) + return _create_streaming_response(content, file) + + # For GCS, try to redirect to signed URL, fall back to streaming + try: + url = await storage.get_download_url(file.storagePath, expires_in=300) + # If we got back an API path (fallback), stream directly instead + if url.startswith("/api/"): + content = await storage.retrieve(file.storagePath) + return _create_streaming_response(content, file) + return fastapi.responses.RedirectResponse(url=url, status_code=302) + except Exception as e: + # Log the signed URL failure with context + logger.error( + f"Failed to get signed URL for file {file.id} " + f"(storagePath={file.storagePath}): {e}", + exc_info=True, + ) + # Fall back to streaming directly from GCS + try: + content = await storage.retrieve(file.storagePath) + return _create_streaming_response(content, file) + except Exception as fallback_error: + logger.error( + f"Fallback streaming also failed for file {file.id} " + f"(storagePath={file.storagePath}): {fallback_error}", + exc_info=True, + ) + raise + + +@router.get( + "/files/{file_id}/download", + summary="Download file by ID", +) +async def download_file( + user_id: Annotated[str, fastapi.Security(get_user_id)], + file_id: str, +) -> Response: + """ + Download a file by its ID. + + Returns the file content directly or redirects to a signed URL for GCS. + """ + workspace = await get_workspace(user_id) + if workspace is None: + raise fastapi.HTTPException(status_code=404, detail="Workspace not found") + + file = await get_workspace_file(file_id, workspace.id) + if file is None: + raise fastapi.HTTPException(status_code=404, detail="File not found") + + return await _create_file_download_response(file) diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index e9556e992f..b936312ce1 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -32,6 +32,7 @@ import backend.api.features.postmark.postmark import backend.api.features.store.model import backend.api.features.store.routes import backend.api.features.v1 +import backend.api.features.workspace.routes as workspace_routes import backend.data.block import backend.data.db import backend.data.graph @@ -52,6 +53,7 @@ from backend.util.exceptions import ( ) from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly from backend.util.service import UnhealthyServiceError +from backend.util.workspace_storage import shutdown_workspace_storage from .external.fastapi_app import external_api from .features.analytics import router as analytics_router @@ -124,6 +126,11 @@ async def lifespan_context(app: fastapi.FastAPI): except Exception as e: logger.warning(f"Error shutting down cloud storage handler: {e}") + try: + await shutdown_workspace_storage() + except Exception as e: + logger.warning(f"Error shutting down workspace storage: {e}") + await backend.data.db.disconnect() @@ -315,6 +322,11 @@ app.include_router( tags=["v2", "chat"], prefix="/api/chat", ) +app.include_router( + workspace_routes.router, + tags=["workspace"], + prefix="/api/workspace", +) app.include_router( backend.api.features.oauth.router, tags=["oauth"], diff --git a/autogpt_platform/backend/backend/blocks/ai_image_customizer.py b/autogpt_platform/backend/backend/blocks/ai_image_customizer.py index 83178e924d..91be33a60e 100644 --- a/autogpt_platform/backend/backend/blocks/ai_image_customizer.py +++ b/autogpt_platform/backend/backend/blocks/ai_image_customizer.py @@ -13,6 +13,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import ( APIKeyCredentials, CredentialsField, @@ -117,11 +118,13 @@ class AIImageCustomizerBlock(Block): "credentials": TEST_CREDENTIALS_INPUT, }, test_output=[ - ("image_url", "https://replicate.delivery/generated-image.jpg"), + # Output will be a workspace ref or data URI depending on context + ("image_url", lambda x: x.startswith(("workspace://", "data:"))), ], test_mock={ + # Use data URI to avoid HTTP requests during tests "run_model": lambda *args, **kwargs: MediaFileType( - "https://replicate.delivery/generated-image.jpg" + "data:image/jpeg;base64,/9j/4AAQSkZJRgABAgAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q==" ), }, test_credentials=TEST_CREDENTIALS, @@ -132,8 +135,7 @@ class AIImageCustomizerBlock(Block): input_data: Input, *, credentials: APIKeyCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: try: @@ -141,10 +143,9 @@ class AIImageCustomizerBlock(Block): processed_images = await asyncio.gather( *( store_media_file( - graph_exec_id=graph_exec_id, file=img, - user_id=user_id, - return_content=True, + execution_context=execution_context, + return_format="for_external_api", # Get content for Replicate API ) for img in input_data.images ) @@ -158,7 +159,14 @@ class AIImageCustomizerBlock(Block): aspect_ratio=input_data.aspect_ratio.value, output_format=input_data.output_format.value, ) - yield "image_url", result + + # Store the generated image to the user's workspace for persistence + stored_url = await store_media_file( + file=result, + execution_context=execution_context, + return_format="for_block_output", + ) + yield "image_url", stored_url except Exception as e: yield "error", str(e) diff --git a/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py b/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py index 8c7b6e6102..e40731cd97 100644 --- a/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py +++ b/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py @@ -6,6 +6,7 @@ from replicate.client import Client as ReplicateClient from replicate.helpers import FileOutput from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput +from backend.data.execution import ExecutionContext from backend.data.model import ( APIKeyCredentials, CredentialsField, @@ -13,6 +14,8 @@ from backend.data.model import ( SchemaField, ) from backend.integrations.providers import ProviderName +from backend.util.file import store_media_file +from backend.util.type import MediaFileType class ImageSize(str, Enum): @@ -165,11 +168,13 @@ class AIImageGeneratorBlock(Block): test_output=[ ( "image_url", - "https://replicate.delivery/generated-image.webp", + # Test output is a data URI since we now store images + lambda x: x.startswith("data:image/"), ), ], test_mock={ - "_run_client": lambda *args, **kwargs: "https://replicate.delivery/generated-image.webp" + # Return a data URI directly so store_media_file doesn't need to download + "_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO" }, ) @@ -318,11 +323,24 @@ class AIImageGeneratorBlock(Block): style_text = style_map.get(style, "") return f"{style_text} of" if style_text else "" - async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs): + async def run( + self, + input_data: Input, + *, + credentials: APIKeyCredentials, + execution_context: ExecutionContext, + **kwargs, + ): try: url = await self.generate_image(input_data, credentials) if url: - yield "image_url", url + # Store the generated image to the user's workspace/execution folder + stored_url = await store_media_file( + file=MediaFileType(url), + execution_context=execution_context, + return_format="for_block_output", + ) + yield "image_url", stored_url else: yield "error", "Image generation returned an empty result." except Exception as e: diff --git a/autogpt_platform/backend/backend/blocks/ai_shortform_video_block.py b/autogpt_platform/backend/backend/blocks/ai_shortform_video_block.py index a9e96890d3..eb60843185 100644 --- a/autogpt_platform/backend/backend/blocks/ai_shortform_video_block.py +++ b/autogpt_platform/backend/backend/blocks/ai_shortform_video_block.py @@ -13,6 +13,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import ( APIKeyCredentials, CredentialsField, @@ -21,7 +22,9 @@ from backend.data.model import ( ) from backend.integrations.providers import ProviderName from backend.util.exceptions import BlockExecutionError +from backend.util.file import store_media_file from backend.util.request import Requests +from backend.util.type import MediaFileType TEST_CREDENTIALS = APIKeyCredentials( id="01234567-89ab-cdef-0123-456789abcdef", @@ -271,7 +274,10 @@ class AIShortformVideoCreatorBlock(Block): "voice": Voice.LILY, "video_style": VisualMediaType.STOCK_VIDEOS, }, - test_output=("video_url", "https://example.com/video.mp4"), + test_output=( + "video_url", + lambda x: x.startswith(("workspace://", "data:")), + ), test_mock={ "create_webhook": lambda *args, **kwargs: ( "test_uuid", @@ -280,15 +286,21 @@ class AIShortformVideoCreatorBlock(Block): "create_video": lambda *args, **kwargs: {"pid": "test_pid"}, "check_video_status": lambda *args, **kwargs: { "status": "ready", - "videoUrl": "https://example.com/video.mp4", + "videoUrl": "data:video/mp4;base64,AAAA", }, - "wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4", + # Use data URI to avoid HTTP requests during tests + "wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA", }, test_credentials=TEST_CREDENTIALS, ) async def run( - self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs + self, + input_data: Input, + *, + credentials: APIKeyCredentials, + execution_context: ExecutionContext, + **kwargs, ) -> BlockOutput: # Create a new Webhook.site URL webhook_token, webhook_url = await self.create_webhook() @@ -340,7 +352,13 @@ class AIShortformVideoCreatorBlock(Block): ) video_url = await self.wait_for_video(credentials.api_key, pid) logger.debug(f"Video ready: {video_url}") - yield "video_url", video_url + # Store the generated video to the user's workspace for persistence + stored_url = await store_media_file( + file=MediaFileType(video_url), + execution_context=execution_context, + return_format="for_block_output", + ) + yield "video_url", stored_url class AIAdMakerVideoCreatorBlock(Block): @@ -447,7 +465,10 @@ class AIAdMakerVideoCreatorBlock(Block): "https://cdn.revid.ai/uploads/1747076315114-image.png", ], }, - test_output=("video_url", "https://example.com/ad.mp4"), + test_output=( + "video_url", + lambda x: x.startswith(("workspace://", "data:")), + ), test_mock={ "create_webhook": lambda *args, **kwargs: ( "test_uuid", @@ -456,14 +477,21 @@ class AIAdMakerVideoCreatorBlock(Block): "create_video": lambda *args, **kwargs: {"pid": "test_pid"}, "check_video_status": lambda *args, **kwargs: { "status": "ready", - "videoUrl": "https://example.com/ad.mp4", + "videoUrl": "data:video/mp4;base64,AAAA", }, - "wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4", + "wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA", }, test_credentials=TEST_CREDENTIALS, ) - async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs): + async def run( + self, + input_data: Input, + *, + credentials: APIKeyCredentials, + execution_context: ExecutionContext, + **kwargs, + ): webhook_token, webhook_url = await self.create_webhook() payload = { @@ -531,7 +559,13 @@ class AIAdMakerVideoCreatorBlock(Block): raise RuntimeError("Failed to create video: No project ID returned") video_url = await self.wait_for_video(credentials.api_key, pid) - yield "video_url", video_url + # Store the generated video to the user's workspace for persistence + stored_url = await store_media_file( + file=MediaFileType(video_url), + execution_context=execution_context, + return_format="for_block_output", + ) + yield "video_url", stored_url class AIScreenshotToVideoAdBlock(Block): @@ -626,7 +660,10 @@ class AIScreenshotToVideoAdBlock(Block): "script": "Amazing numbers!", "screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png", }, - test_output=("video_url", "https://example.com/screenshot.mp4"), + test_output=( + "video_url", + lambda x: x.startswith(("workspace://", "data:")), + ), test_mock={ "create_webhook": lambda *args, **kwargs: ( "test_uuid", @@ -635,14 +672,21 @@ class AIScreenshotToVideoAdBlock(Block): "create_video": lambda *args, **kwargs: {"pid": "test_pid"}, "check_video_status": lambda *args, **kwargs: { "status": "ready", - "videoUrl": "https://example.com/screenshot.mp4", + "videoUrl": "data:video/mp4;base64,AAAA", }, - "wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4", + "wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA", }, test_credentials=TEST_CREDENTIALS, ) - async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs): + async def run( + self, + input_data: Input, + *, + credentials: APIKeyCredentials, + execution_context: ExecutionContext, + **kwargs, + ): webhook_token, webhook_url = await self.create_webhook() payload = { @@ -710,4 +754,10 @@ class AIScreenshotToVideoAdBlock(Block): raise RuntimeError("Failed to create video: No project ID returned") video_url = await self.wait_for_video(credentials.api_key, pid) - yield "video_url", video_url + # Store the generated video to the user's workspace for persistence + stored_url = await store_media_file( + file=MediaFileType(video_url), + execution_context=execution_context, + return_format="for_block_output", + ) + yield "video_url", stored_url diff --git a/autogpt_platform/backend/backend/blocks/bannerbear/text_overlay.py b/autogpt_platform/backend/backend/blocks/bannerbear/text_overlay.py index 16d46c0d99..62aaf63d88 100644 --- a/autogpt_platform/backend/backend/blocks/bannerbear/text_overlay.py +++ b/autogpt_platform/backend/backend/blocks/bannerbear/text_overlay.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from pydantic import SecretStr +from backend.data.execution import ExecutionContext from backend.sdk import ( APIKeyCredentials, Block, @@ -17,6 +18,8 @@ from backend.sdk import ( Requests, SchemaField, ) +from backend.util.file import store_media_file +from backend.util.type import MediaFileType from ._config import bannerbear @@ -135,15 +138,17 @@ class BannerbearTextOverlayBlock(Block): }, test_output=[ ("success", True), - ("image_url", "https://cdn.bannerbear.com/test-image.jpg"), + # Output will be a workspace ref or data URI depending on context + ("image_url", lambda x: x.startswith(("workspace://", "data:"))), ("uid", "test-uid-123"), ("status", "completed"), ], test_mock={ + # Use data URI to avoid HTTP requests during tests "_make_api_request": lambda *args, **kwargs: { "uid": "test-uid-123", "status": "completed", - "image_url": "https://cdn.bannerbear.com/test-image.jpg", + "image_url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAABAAEBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+v//Z", } }, test_credentials=TEST_CREDENTIALS, @@ -177,7 +182,12 @@ class BannerbearTextOverlayBlock(Block): raise Exception(error_msg) async def run( - self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs + self, + input_data: Input, + *, + credentials: APIKeyCredentials, + execution_context: ExecutionContext, + **kwargs, ) -> BlockOutput: # Build the modifications array modifications = [] @@ -234,6 +244,18 @@ class BannerbearTextOverlayBlock(Block): # Synchronous request - image should be ready yield "success", True - yield "image_url", data.get("image_url", "") + + # Store the generated image to workspace for persistence + image_url = data.get("image_url", "") + if image_url: + stored_url = await store_media_file( + file=MediaFileType(image_url), + execution_context=execution_context, + return_format="for_block_output", + ) + yield "image_url", stored_url + else: + yield "image_url", "" + yield "uid", data.get("uid", "") yield "status", data.get("status", "completed") diff --git a/autogpt_platform/backend/backend/blocks/basic.py b/autogpt_platform/backend/backend/blocks/basic.py index 4d452f3b34..95193b3feb 100644 --- a/autogpt_platform/backend/backend/blocks/basic.py +++ b/autogpt_platform/backend/backend/blocks/basic.py @@ -9,6 +9,7 @@ from backend.data.block import ( BlockSchemaOutput, BlockType, ) +from backend.data.execution import ExecutionContext from backend.data.model import SchemaField from backend.util.file import store_media_file from backend.util.type import MediaFileType, convert @@ -17,10 +18,10 @@ from backend.util.type import MediaFileType, convert class FileStoreBlock(Block): class Input(BlockSchemaInput): file_in: MediaFileType = SchemaField( - description="The file to store in the temporary directory, it can be a URL, data URI, or local path." + description="The file to download and store. Can be a URL (https://...), data URI, or local path." ) base_64: bool = SchemaField( - description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).", + description="Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks).", default=False, advanced=True, title="Produce Base64 Output", @@ -28,13 +29,18 @@ class FileStoreBlock(Block): class Output(BlockSchemaOutput): file_out: MediaFileType = SchemaField( - description="The relative path to the stored file in the temporary directory." + description="Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks." ) def __init__(self): super().__init__( id="cbb50872-625b-42f0-8203-a2ae78242d8a", - description="Stores the input file in the temporary directory.", + description=( + "Downloads and stores a file from a URL, data URI, or local path. " + "Use this to fetch images, documents, or other files for processing. " + "In CoPilot: saves to workspace (use list_workspace_files to see it). " + "In graphs: outputs a data URI to pass to other blocks." + ), categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA}, input_schema=FileStoreBlock.Input, output_schema=FileStoreBlock.Output, @@ -45,15 +51,18 @@ class FileStoreBlock(Block): self, input_data: Input, *, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: + # Determine return format based on user preference + # for_external_api: always returns data URI (base64) - honors "Produce Base64 Output" + # for_block_output: smart format - workspace:// in CoPilot, data URI in graphs + return_format = "for_external_api" if input_data.base_64 else "for_block_output" + yield "file_out", await store_media_file( - graph_exec_id=graph_exec_id, file=input_data.file_in, - user_id=user_id, - return_content=input_data.base_64, + execution_context=execution_context, + return_format=return_format, ) @@ -116,6 +125,7 @@ class PrintToConsoleBlock(Block): input_schema=PrintToConsoleBlock.Input, output_schema=PrintToConsoleBlock.Output, test_input={"text": "Hello, World!"}, + is_sensitive_action=True, test_output=[ ("output", "Hello, World!"), ("status", "printed"), diff --git a/autogpt_platform/backend/backend/blocks/discord/bot_blocks.py b/autogpt_platform/backend/backend/blocks/discord/bot_blocks.py index 5ecd730f47..4438af1955 100644 --- a/autogpt_platform/backend/backend/blocks/discord/bot_blocks.py +++ b/autogpt_platform/backend/backend/blocks/discord/bot_blocks.py @@ -15,6 +15,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import APIKeyCredentials, SchemaField from backend.util.file import store_media_file from backend.util.request import Requests @@ -666,8 +667,7 @@ class SendDiscordFileBlock(Block): file: MediaFileType, filename: str, message_content: str, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, ) -> dict: intents = discord.Intents.default() intents.guilds = True @@ -731,10 +731,9 @@ class SendDiscordFileBlock(Block): # Local file path - read from stored media file # This would be a path from a previous block's output stored_file = await store_media_file( - graph_exec_id=graph_exec_id, file=file, - user_id=user_id, - return_content=True, # Get as data URI + execution_context=execution_context, + return_format="for_external_api", # Get content to send to Discord ) # Now process as data URI header, encoded = stored_file.split(",", 1) @@ -781,8 +780,7 @@ class SendDiscordFileBlock(Block): input_data: Input, *, credentials: APIKeyCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: try: @@ -793,8 +791,7 @@ class SendDiscordFileBlock(Block): file=input_data.file, filename=input_data.filename, message_content=input_data.message_content, - graph_exec_id=graph_exec_id, - user_id=user_id, + execution_context=execution_context, ) yield "status", result.get("status", "Unknown error") diff --git a/autogpt_platform/backend/backend/blocks/fal/ai_video_generator.py b/autogpt_platform/backend/backend/blocks/fal/ai_video_generator.py index 2a71548dcc..c2079ef159 100644 --- a/autogpt_platform/backend/backend/blocks/fal/ai_video_generator.py +++ b/autogpt_platform/backend/backend/blocks/fal/ai_video_generator.py @@ -17,8 +17,11 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import SchemaField +from backend.util.file import store_media_file from backend.util.request import ClientResponseError, Requests +from backend.util.type import MediaFileType logger = logging.getLogger(__name__) @@ -64,9 +67,13 @@ class AIVideoGeneratorBlock(Block): "credentials": TEST_CREDENTIALS_INPUT, }, test_credentials=TEST_CREDENTIALS, - test_output=[("video_url", "https://fal.media/files/example/video.mp4")], + test_output=[ + # Output will be a workspace ref or data URI depending on context + ("video_url", lambda x: x.startswith(("workspace://", "data:"))), + ], test_mock={ - "generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4" + # Use data URI to avoid HTTP requests during tests + "generate_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA" }, ) @@ -208,11 +215,22 @@ class AIVideoGeneratorBlock(Block): raise RuntimeError(f"API request failed: {str(e)}") async def run( - self, input_data: Input, *, credentials: FalCredentials, **kwargs + self, + input_data: Input, + *, + credentials: FalCredentials, + execution_context: ExecutionContext, + **kwargs, ) -> BlockOutput: try: video_url = await self.generate_video(input_data, credentials) - yield "video_url", video_url + # Store the generated video to the user's workspace for persistence + stored_url = await store_media_file( + file=MediaFileType(video_url), + execution_context=execution_context, + return_format="for_block_output", + ) + yield "video_url", stored_url except Exception as e: error_message = str(e) yield "error", error_message diff --git a/autogpt_platform/backend/backend/blocks/flux_kontext.py b/autogpt_platform/backend/backend/blocks/flux_kontext.py index dd8375c4ce..d56baa6d92 100644 --- a/autogpt_platform/backend/backend/blocks/flux_kontext.py +++ b/autogpt_platform/backend/backend/blocks/flux_kontext.py @@ -12,6 +12,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import ( APIKeyCredentials, CredentialsField, @@ -121,10 +122,12 @@ class AIImageEditorBlock(Block): "credentials": TEST_CREDENTIALS_INPUT, }, test_output=[ - ("output_image", "https://replicate.com/output/edited-image.png"), + # Output will be a workspace ref or data URI depending on context + ("output_image", lambda x: x.startswith(("workspace://", "data:"))), ], test_mock={ - "run_model": lambda *args, **kwargs: "https://replicate.com/output/edited-image.png", + # Use data URI to avoid HTTP requests during tests + "run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==", }, test_credentials=TEST_CREDENTIALS, ) @@ -134,8 +137,7 @@ class AIImageEditorBlock(Block): input_data: Input, *, credentials: APIKeyCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: result = await self.run_model( @@ -144,20 +146,25 @@ class AIImageEditorBlock(Block): prompt=input_data.prompt, input_image_b64=( await store_media_file( - graph_exec_id=graph_exec_id, file=input_data.input_image, - user_id=user_id, - return_content=True, + execution_context=execution_context, + return_format="for_external_api", # Get content for Replicate API ) if input_data.input_image else None ), aspect_ratio=input_data.aspect_ratio.value, seed=input_data.seed, - user_id=user_id, - graph_exec_id=graph_exec_id, + user_id=execution_context.user_id or "", + graph_exec_id=execution_context.graph_exec_id or "", ) - yield "output_image", result + # Store the generated image to the user's workspace for persistence + stored_url = await store_media_file( + file=result, + execution_context=execution_context, + return_format="for_block_output", + ) + yield "output_image", stored_url async def run_model( self, diff --git a/autogpt_platform/backend/backend/blocks/google/gmail.py b/autogpt_platform/backend/backend/blocks/google/gmail.py index d1b3ecd4bf..2040cabe3f 100644 --- a/autogpt_platform/backend/backend/blocks/google/gmail.py +++ b/autogpt_platform/backend/backend/blocks/google/gmail.py @@ -21,6 +21,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import SchemaField from backend.util.file import MediaFileType, get_exec_file_path, store_media_file from backend.util.settings import Settings @@ -95,8 +96,7 @@ def _make_mime_text( async def create_mime_message( input_data, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, ) -> str: """Create a MIME message with attachments and return base64-encoded raw message.""" @@ -117,12 +117,12 @@ async def create_mime_message( if input_data.attachments: for attach in input_data.attachments: local_path = await store_media_file( - user_id=user_id, - graph_exec_id=graph_exec_id, file=attach, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) - abs_path = get_exec_file_path(graph_exec_id, local_path) + assert execution_context.graph_exec_id # Validated by store_media_file + abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path) part = MIMEBase("application", "octet-stream") with open(abs_path, "rb") as f: part.set_payload(f.read()) @@ -582,27 +582,25 @@ class GmailSendBlock(GmailBase): input_data: Input, *, credentials: GoogleCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: service = self._build_service(credentials, **kwargs) result = await self._send_email( service, input_data, - graph_exec_id, - user_id, + execution_context, ) yield "result", result async def _send_email( - self, service, input_data: Input, graph_exec_id: str, user_id: str + self, service, input_data: Input, execution_context: ExecutionContext ) -> dict: if not input_data.to or not input_data.subject or not input_data.body: raise ValueError( "At least one recipient, subject, and body are required for sending an email" ) - raw_message = await create_mime_message(input_data, graph_exec_id, user_id) + raw_message = await create_mime_message(input_data, execution_context) sent_message = await asyncio.to_thread( lambda: service.users() .messages() @@ -692,30 +690,28 @@ class GmailCreateDraftBlock(GmailBase): input_data: Input, *, credentials: GoogleCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: service = self._build_service(credentials, **kwargs) result = await self._create_draft( service, input_data, - graph_exec_id, - user_id, + execution_context, ) yield "result", GmailDraftResult( id=result["id"], message_id=result["message"]["id"], status="draft_created" ) async def _create_draft( - self, service, input_data: Input, graph_exec_id: str, user_id: str + self, service, input_data: Input, execution_context: ExecutionContext ) -> dict: if not input_data.to or not input_data.subject: raise ValueError( "At least one recipient and subject are required for creating a draft" ) - raw_message = await create_mime_message(input_data, graph_exec_id, user_id) + raw_message = await create_mime_message(input_data, execution_context) draft = await asyncio.to_thread( lambda: service.users() .drafts() @@ -1100,7 +1096,7 @@ class GmailGetThreadBlock(GmailBase): async def _build_reply_message( - service, input_data, graph_exec_id: str, user_id: str + service, input_data, execution_context: ExecutionContext ) -> tuple[str, str]: """ Builds a reply MIME message for Gmail threads. @@ -1190,12 +1186,12 @@ async def _build_reply_message( # Handle attachments for attach in input_data.attachments: local_path = await store_media_file( - user_id=user_id, - graph_exec_id=graph_exec_id, file=attach, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) - abs_path = get_exec_file_path(graph_exec_id, local_path) + assert execution_context.graph_exec_id # Validated by store_media_file + abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path) part = MIMEBase("application", "octet-stream") with open(abs_path, "rb") as f: part.set_payload(f.read()) @@ -1311,16 +1307,14 @@ class GmailReplyBlock(GmailBase): input_data: Input, *, credentials: GoogleCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: service = self._build_service(credentials, **kwargs) message = await self._reply( service, input_data, - graph_exec_id, - user_id, + execution_context, ) yield "messageId", message["id"] yield "threadId", message.get("threadId", input_data.threadId) @@ -1343,11 +1337,11 @@ class GmailReplyBlock(GmailBase): yield "email", email async def _reply( - self, service, input_data: Input, graph_exec_id: str, user_id: str + self, service, input_data: Input, execution_context: ExecutionContext ) -> dict: # Build the reply message using the shared helper raw, thread_id = await _build_reply_message( - service, input_data, graph_exec_id, user_id + service, input_data, execution_context ) # Send the message @@ -1441,16 +1435,14 @@ class GmailDraftReplyBlock(GmailBase): input_data: Input, *, credentials: GoogleCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: service = self._build_service(credentials, **kwargs) draft = await self._create_draft_reply( service, input_data, - graph_exec_id, - user_id, + execution_context, ) yield "draftId", draft["id"] yield "messageId", draft["message"]["id"] @@ -1458,11 +1450,11 @@ class GmailDraftReplyBlock(GmailBase): yield "status", "draft_created" async def _create_draft_reply( - self, service, input_data: Input, graph_exec_id: str, user_id: str + self, service, input_data: Input, execution_context: ExecutionContext ) -> dict: # Build the reply message using the shared helper raw, thread_id = await _build_reply_message( - service, input_data, graph_exec_id, user_id + service, input_data, execution_context ) # Create draft with proper thread association @@ -1629,23 +1621,21 @@ class GmailForwardBlock(GmailBase): input_data: Input, *, credentials: GoogleCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: service = self._build_service(credentials, **kwargs) result = await self._forward_message( service, input_data, - graph_exec_id, - user_id, + execution_context, ) yield "messageId", result["id"] yield "threadId", result.get("threadId", "") yield "status", "forwarded" async def _forward_message( - self, service, input_data: Input, graph_exec_id: str, user_id: str + self, service, input_data: Input, execution_context: ExecutionContext ) -> dict: if not input_data.to: raise ValueError("At least one recipient is required for forwarding") @@ -1727,12 +1717,12 @@ To: {original_to} # Add any additional attachments for attach in input_data.additionalAttachments: local_path = await store_media_file( - user_id=user_id, - graph_exec_id=graph_exec_id, file=attach, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) - abs_path = get_exec_file_path(graph_exec_id, local_path) + assert execution_context.graph_exec_id # Validated by store_media_file + abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path) part = MIMEBase("application", "octet-stream") with open(abs_path, "rb") as f: part.set_payload(f.read()) diff --git a/autogpt_platform/backend/backend/blocks/helpers/review.py b/autogpt_platform/backend/backend/blocks/helpers/review.py index 80c28cfd14..4bd85e424b 100644 --- a/autogpt_platform/backend/backend/blocks/helpers/review.py +++ b/autogpt_platform/backend/backend/blocks/helpers/review.py @@ -9,7 +9,7 @@ from typing import Any, Optional from prisma.enums import ReviewStatus from pydantic import BaseModel -from backend.data.execution import ExecutionContext, ExecutionStatus +from backend.data.execution import ExecutionStatus from backend.data.human_review import ReviewResult from backend.executor.manager import async_update_node_execution_status from backend.util.clients import get_database_manager_async_client @@ -28,6 +28,11 @@ class ReviewDecision(BaseModel): class HITLReviewHelper: """Helper class for Human-In-The-Loop review operations.""" + @staticmethod + async def check_approval(**kwargs) -> Optional[ReviewResult]: + """Check if there's an existing approval for this node execution.""" + return await get_database_manager_async_client().check_approval(**kwargs) + @staticmethod async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]: """Create or retrieve a human review from the database.""" @@ -55,11 +60,11 @@ class HITLReviewHelper: async def _handle_review_request( input_data: Any, user_id: str, + node_id: str, node_exec_id: str, graph_exec_id: str, graph_id: str, graph_version: int, - execution_context: ExecutionContext, block_name: str = "Block", editable: bool = False, ) -> Optional[ReviewResult]: @@ -69,11 +74,11 @@ class HITLReviewHelper: Args: input_data: The input data to be reviewed user_id: ID of the user requesting the review + node_id: ID of the node in the graph definition node_exec_id: ID of the node execution graph_exec_id: ID of the graph execution graph_id: ID of the graph graph_version: Version of the graph - execution_context: Current execution context block_name: Name of the block requesting review editable: Whether the reviewer can edit the data @@ -83,15 +88,41 @@ class HITLReviewHelper: Raises: Exception: If review creation or status update fails """ - # Skip review if safe mode is disabled - return auto-approved result - if not execution_context.human_in_the_loop_safe_mode: + # Note: Safe mode checks (human_in_the_loop_safe_mode, sensitive_action_safe_mode) + # are handled by the caller: + # - HITL blocks check human_in_the_loop_safe_mode in their run() method + # - Sensitive action blocks check sensitive_action_safe_mode in is_block_exec_need_review() + # This function only handles checking for existing approvals. + + # Check if this node has already been approved (normal or auto-approval) + if approval_result := await HITLReviewHelper.check_approval( + node_exec_id=node_exec_id, + graph_exec_id=graph_exec_id, + node_id=node_id, + user_id=user_id, + input_data=input_data, + ): logger.info( - f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled" + f"Block {block_name} skipping review for node {node_exec_id} - " + f"found existing approval" + ) + # Return a new ReviewResult with the current node_exec_id but approved status + # For auto-approvals, always use current input_data + # For normal approvals, use approval_result.data unless it's None + is_auto_approval = approval_result.node_exec_id != node_exec_id + approved_data = ( + input_data + if is_auto_approval + else ( + approval_result.data + if approval_result.data is not None + else input_data + ) ) return ReviewResult( - data=input_data, + data=approved_data, status=ReviewStatus.APPROVED, - message="Auto-approved (safe mode disabled)", + message=approval_result.message, processed=True, node_exec_id=node_exec_id, ) @@ -103,7 +134,7 @@ class HITLReviewHelper: graph_id=graph_id, graph_version=graph_version, input_data=input_data, - message=f"Review required for {block_name} execution", + message=block_name, # Use block_name directly as the message editable=editable, ) @@ -129,11 +160,11 @@ class HITLReviewHelper: async def handle_review_decision( input_data: Any, user_id: str, + node_id: str, node_exec_id: str, graph_exec_id: str, graph_id: str, graph_version: int, - execution_context: ExecutionContext, block_name: str = "Block", editable: bool = False, ) -> Optional[ReviewDecision]: @@ -143,11 +174,11 @@ class HITLReviewHelper: Args: input_data: The input data to be reviewed user_id: ID of the user requesting the review + node_id: ID of the node in the graph definition node_exec_id: ID of the node execution graph_exec_id: ID of the graph execution graph_id: ID of the graph graph_version: Version of the graph - execution_context: Current execution context block_name: Name of the block requesting review editable: Whether the reviewer can edit the data @@ -158,11 +189,11 @@ class HITLReviewHelper: review_result = await HITLReviewHelper._handle_review_request( input_data=input_data, user_id=user_id, + node_id=node_id, node_exec_id=node_exec_id, graph_exec_id=graph_exec_id, graph_id=graph_id, graph_version=graph_version, - execution_context=execution_context, block_name=block_name, editable=editable, ) diff --git a/autogpt_platform/backend/backend/blocks/http.py b/autogpt_platform/backend/backend/blocks/http.py index 9b27a3b129..77e7fe243f 100644 --- a/autogpt_platform/backend/backend/blocks/http.py +++ b/autogpt_platform/backend/backend/blocks/http.py @@ -15,6 +15,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import ( CredentialsField, CredentialsMetaInput, @@ -116,10 +117,9 @@ class SendWebRequestBlock(Block): @staticmethod async def _prepare_files( - graph_exec_id: str, + execution_context: ExecutionContext, files_name: str, files: list[MediaFileType], - user_id: str, ) -> list[tuple[str, tuple[str, BytesIO, str]]]: """ Prepare files for the request by storing them and reading their content. @@ -127,11 +127,16 @@ class SendWebRequestBlock(Block): (files_name, (filename, BytesIO, mime_type)) """ files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = [] + graph_exec_id = execution_context.graph_exec_id + if graph_exec_id is None: + raise ValueError("graph_exec_id is required for file operations") for media in files: # Normalise to a list so we can repeat the same key rel_path = await store_media_file( - graph_exec_id, media, user_id, return_content=False + file=media, + execution_context=execution_context, + return_format="for_local_processing", ) abs_path = get_exec_file_path(graph_exec_id, rel_path) async with aiofiles.open(abs_path, "rb") as f: @@ -143,7 +148,7 @@ class SendWebRequestBlock(Block): return files_payload async def run( - self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs + self, input_data: Input, *, execution_context: ExecutionContext, **kwargs ) -> BlockOutput: # ─── Parse/normalise body ──────────────────────────────────── body = input_data.body @@ -174,7 +179,7 @@ class SendWebRequestBlock(Block): files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = [] if use_files: files_payload = await self._prepare_files( - graph_exec_id, input_data.files_name, input_data.files, user_id + execution_context, input_data.files_name, input_data.files ) # Enforce body format rules @@ -238,9 +243,8 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock): self, input_data: Input, *, - graph_exec_id: str, + execution_context: ExecutionContext, credentials: HostScopedCredentials, - user_id: str, **kwargs, ) -> BlockOutput: # Create SendWebRequestBlock.Input from our input (removing credentials field) @@ -271,6 +275,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock): # Use parent class run method async for output_name, output_data in super().run( - base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs + base_input, execution_context=execution_context, **kwargs ): yield output_name, output_data diff --git a/autogpt_platform/backend/backend/blocks/human_in_the_loop.py b/autogpt_platform/backend/backend/blocks/human_in_the_loop.py index b6106843bd..568ac4b33f 100644 --- a/autogpt_platform/backend/backend/blocks/human_in_the_loop.py +++ b/autogpt_platform/backend/backend/blocks/human_in_the_loop.py @@ -97,6 +97,7 @@ class HumanInTheLoopBlock(Block): input_data: Input, *, user_id: str, + node_id: str, node_exec_id: str, graph_exec_id: str, graph_id: str, @@ -115,12 +116,12 @@ class HumanInTheLoopBlock(Block): decision = await self.handle_review_decision( input_data=input_data.data, user_id=user_id, + node_id=node_id, node_exec_id=node_exec_id, graph_exec_id=graph_exec_id, graph_id=graph_id, graph_version=graph_version, - execution_context=execution_context, - block_name=self.name, + block_name=input_data.name, # Use user-provided name instead of block type editable=input_data.editable, ) diff --git a/autogpt_platform/backend/backend/blocks/io.py b/autogpt_platform/backend/backend/blocks/io.py index 6f8e62e339..a9c3859490 100644 --- a/autogpt_platform/backend/backend/blocks/io.py +++ b/autogpt_platform/backend/backend/blocks/io.py @@ -12,6 +12,7 @@ from backend.data.block import ( BlockSchemaInput, BlockType, ) +from backend.data.execution import ExecutionContext from backend.data.model import SchemaField from backend.util.file import store_media_file from backend.util.mock import MockObject @@ -462,18 +463,21 @@ class AgentFileInputBlock(AgentInputBlock): self, input_data: Input, *, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: if not input_data.value: return + # Determine return format based on user preference + # for_external_api: always returns data URI (base64) - honors "Produce Base64 Output" + # for_block_output: smart format - workspace:// in CoPilot, data URI in graphs + return_format = "for_external_api" if input_data.base_64 else "for_block_output" + yield "result", await store_media_file( - graph_exec_id=graph_exec_id, file=input_data.value, - user_id=user_id, - return_content=input_data.base_64, + execution_context=execution_context, + return_format=return_format, ) diff --git a/autogpt_platform/backend/backend/blocks/media.py b/autogpt_platform/backend/backend/blocks/media.py new file mode 100644 index 0000000000..a8d145bc64 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/media.py @@ -0,0 +1,246 @@ +import os +import tempfile +from typing import Optional + +from moviepy.audio.io.AudioFileClip import AudioFileClip +from moviepy.video.fx.Loop import Loop +from moviepy.video.io.VideoFileClip import VideoFileClip + +from backend.data.block import ( + Block, + BlockCategory, + BlockOutput, + BlockSchemaInput, + BlockSchemaOutput, +) +from backend.data.execution import ExecutionContext +from backend.data.model import SchemaField +from backend.util.file import MediaFileType, get_exec_file_path, store_media_file + + +class MediaDurationBlock(Block): + + class Input(BlockSchemaInput): + media_in: MediaFileType = SchemaField( + description="Media input (URL, data URI, or local path)." + ) + is_video: bool = SchemaField( + description="Whether the media is a video (True) or audio (False).", + default=True, + ) + + class Output(BlockSchemaOutput): + duration: float = SchemaField( + description="Duration of the media file (in seconds)." + ) + + def __init__(self): + super().__init__( + id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6", + description="Block to get the duration of a media file.", + categories={BlockCategory.MULTIMEDIA}, + input_schema=MediaDurationBlock.Input, + output_schema=MediaDurationBlock.Output, + ) + + async def run( + self, + input_data: Input, + *, + execution_context: ExecutionContext, + **kwargs, + ) -> BlockOutput: + # 1) Store the input media locally + local_media_path = await store_media_file( + file=input_data.media_in, + execution_context=execution_context, + return_format="for_local_processing", + ) + assert execution_context.graph_exec_id is not None + media_abspath = get_exec_file_path( + execution_context.graph_exec_id, local_media_path + ) + + # 2) Load the clip + if input_data.is_video: + clip = VideoFileClip(media_abspath) + else: + clip = AudioFileClip(media_abspath) + + yield "duration", clip.duration + + +class LoopVideoBlock(Block): + """ + Block for looping (repeating) a video clip until a given duration or number of loops. + """ + + class Input(BlockSchemaInput): + video_in: MediaFileType = SchemaField( + description="The input video (can be a URL, data URI, or local path)." + ) + # Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`. + duration: Optional[float] = SchemaField( + description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.", + default=None, + ge=0.0, + ) + n_loops: Optional[int] = SchemaField( + description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).", + default=None, + ge=1, + ) + + class Output(BlockSchemaOutput): + video_out: str = SchemaField( + description="Looped video returned either as a relative path or a data URI." + ) + + def __init__(self): + super().__init__( + id="8bf9eef6-5451-4213-b265-25306446e94b", + description="Block to loop a video to a given duration or number of repeats.", + categories={BlockCategory.MULTIMEDIA}, + input_schema=LoopVideoBlock.Input, + output_schema=LoopVideoBlock.Output, + ) + + async def run( + self, + input_data: Input, + *, + execution_context: ExecutionContext, + **kwargs, + ) -> BlockOutput: + assert execution_context.graph_exec_id is not None + assert execution_context.node_exec_id is not None + graph_exec_id = execution_context.graph_exec_id + node_exec_id = execution_context.node_exec_id + + # 1) Store the input video locally + local_video_path = await store_media_file( + file=input_data.video_in, + execution_context=execution_context, + return_format="for_local_processing", + ) + input_abspath = get_exec_file_path(graph_exec_id, local_video_path) + + # 2) Load the clip + clip = VideoFileClip(input_abspath) + + # 3) Apply the loop effect + looped_clip = clip + if input_data.duration: + # Loop until we reach the specified duration + looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)]) + elif input_data.n_loops: + looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)]) + else: + raise ValueError("Either 'duration' or 'n_loops' must be provided.") + + assert isinstance(looped_clip, VideoFileClip) + + # 4) Save the looped output + output_filename = MediaFileType( + f"{node_exec_id}_looped_{os.path.basename(local_video_path)}" + ) + output_abspath = get_exec_file_path(graph_exec_id, output_filename) + + looped_clip = looped_clip.with_audio(clip.audio) + looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac") + + # Return output - for_block_output returns workspace:// if available, else data URI + video_out = await store_media_file( + file=output_filename, + execution_context=execution_context, + return_format="for_block_output", + ) + + yield "video_out", video_out + + +class AddAudioToVideoBlock(Block): + """ + Block that adds (attaches) an audio track to an existing video. + Optionally scale the volume of the new track. + """ + + class Input(BlockSchemaInput): + video_in: MediaFileType = SchemaField( + description="Video input (URL, data URI, or local path)." + ) + audio_in: MediaFileType = SchemaField( + description="Audio input (URL, data URI, or local path)." + ) + volume: float = SchemaField( + description="Volume scale for the newly attached audio track (1.0 = original).", + default=1.0, + ) + + class Output(BlockSchemaOutput): + video_out: MediaFileType = SchemaField( + description="Final video (with attached audio), as a path or data URI." + ) + + def __init__(self): + super().__init__( + id="3503748d-62b6-4425-91d6-725b064af509", + description="Block to attach an audio file to a video file using moviepy.", + categories={BlockCategory.MULTIMEDIA}, + input_schema=AddAudioToVideoBlock.Input, + output_schema=AddAudioToVideoBlock.Output, + ) + + async def run( + self, + input_data: Input, + *, + execution_context: ExecutionContext, + **kwargs, + ) -> BlockOutput: + assert execution_context.graph_exec_id is not None + assert execution_context.node_exec_id is not None + graph_exec_id = execution_context.graph_exec_id + node_exec_id = execution_context.node_exec_id + + # 1) Store the inputs locally + local_video_path = await store_media_file( + file=input_data.video_in, + execution_context=execution_context, + return_format="for_local_processing", + ) + local_audio_path = await store_media_file( + file=input_data.audio_in, + execution_context=execution_context, + return_format="for_local_processing", + ) + + abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id) + video_abspath = os.path.join(abs_temp_dir, local_video_path) + audio_abspath = os.path.join(abs_temp_dir, local_audio_path) + + # 2) Load video + audio with moviepy + video_clip = VideoFileClip(video_abspath) + audio_clip = AudioFileClip(audio_abspath) + # Optionally scale volume + if input_data.volume != 1.0: + audio_clip = audio_clip.with_volume_scaled(input_data.volume) + + # 3) Attach the new audio track + final_clip = video_clip.with_audio(audio_clip) + + # 4) Write to output file + output_filename = MediaFileType( + f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}" + ) + output_abspath = os.path.join(abs_temp_dir, output_filename) + final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac") + + # 5) Return output - for_block_output returns workspace:// if available, else data URI + video_out = await store_media_file( + file=output_filename, + execution_context=execution_context, + return_format="for_block_output", + ) + + yield "video_out", video_out diff --git a/autogpt_platform/backend/backend/blocks/screenshotone.py b/autogpt_platform/backend/backend/blocks/screenshotone.py index 1f8947376b..ee998f8da2 100644 --- a/autogpt_platform/backend/backend/blocks/screenshotone.py +++ b/autogpt_platform/backend/backend/blocks/screenshotone.py @@ -11,6 +11,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import ( APIKeyCredentials, CredentialsField, @@ -112,8 +113,7 @@ class ScreenshotWebPageBlock(Block): @staticmethod async def take_screenshot( credentials: APIKeyCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, url: str, viewport_width: int, viewport_height: int, @@ -155,12 +155,11 @@ class ScreenshotWebPageBlock(Block): return { "image": await store_media_file( - graph_exec_id=graph_exec_id, file=MediaFileType( f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}" ), - user_id=user_id, - return_content=True, + execution_context=execution_context, + return_format="for_block_output", ) } @@ -169,15 +168,13 @@ class ScreenshotWebPageBlock(Block): input_data: Input, *, credentials: APIKeyCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: try: screenshot_data = await self.take_screenshot( credentials=credentials, - graph_exec_id=graph_exec_id, - user_id=user_id, + execution_context=execution_context, url=input_data.url, viewport_width=input_data.viewport_width, viewport_height=input_data.viewport_height, diff --git a/autogpt_platform/backend/backend/blocks/spreadsheet.py b/autogpt_platform/backend/backend/blocks/spreadsheet.py index 211aac23f4..a13f9e2f6d 100644 --- a/autogpt_platform/backend/backend/blocks/spreadsheet.py +++ b/autogpt_platform/backend/backend/blocks/spreadsheet.py @@ -7,6 +7,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import ContributorDetails, SchemaField from backend.util.file import get_exec_file_path, store_media_file from backend.util.type import MediaFileType @@ -98,7 +99,7 @@ class ReadSpreadsheetBlock(Block): ) async def run( - self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs + self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs ) -> BlockOutput: import csv from io import StringIO @@ -106,14 +107,16 @@ class ReadSpreadsheetBlock(Block): # Determine data source - prefer file_input if provided, otherwise use contents if input_data.file_input: stored_file_path = await store_media_file( - user_id=user_id, - graph_exec_id=graph_exec_id, file=input_data.file_input, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) # Get full file path - file_path = get_exec_file_path(graph_exec_id, stored_file_path) + assert execution_context.graph_exec_id # Validated by store_media_file + file_path = get_exec_file_path( + execution_context.graph_exec_id, stored_file_path + ) if not Path(file_path).exists(): raise ValueError(f"File does not exist: {file_path}") diff --git a/autogpt_platform/backend/backend/blocks/talking_head.py b/autogpt_platform/backend/backend/blocks/talking_head.py index 7a466bec7e..e01e3d4023 100644 --- a/autogpt_platform/backend/backend/blocks/talking_head.py +++ b/autogpt_platform/backend/backend/blocks/talking_head.py @@ -10,6 +10,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import ( APIKeyCredentials, CredentialsField, @@ -17,7 +18,9 @@ from backend.data.model import ( SchemaField, ) from backend.integrations.providers import ProviderName +from backend.util.file import store_media_file from backend.util.request import Requests +from backend.util.type import MediaFileType TEST_CREDENTIALS = APIKeyCredentials( id="01234567-89ab-cdef-0123-456789abcdef", @@ -102,7 +105,7 @@ class CreateTalkingAvatarVideoBlock(Block): test_output=[ ( "video_url", - "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video", + lambda x: x.startswith(("workspace://", "data:")), ), ], test_mock={ @@ -110,9 +113,10 @@ class CreateTalkingAvatarVideoBlock(Block): "id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx", "status": "created", }, + # Use data URI to avoid HTTP requests during tests "get_clip_status": lambda *args, **kwargs: { "status": "done", - "result_url": "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video", + "result_url": "data:video/mp4;base64,AAAA", }, }, test_credentials=TEST_CREDENTIALS, @@ -138,7 +142,12 @@ class CreateTalkingAvatarVideoBlock(Block): return response.json() async def run( - self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs + self, + input_data: Input, + *, + credentials: APIKeyCredentials, + execution_context: ExecutionContext, + **kwargs, ) -> BlockOutput: # Create the clip payload = { @@ -165,7 +174,14 @@ class CreateTalkingAvatarVideoBlock(Block): for _ in range(input_data.max_polling_attempts): status_response = await self.get_clip_status(credentials.api_key, clip_id) if status_response["status"] == "done": - yield "video_url", status_response["result_url"] + # Store the generated video to the user's workspace for persistence + video_url = status_response["result_url"] + stored_url = await store_media_file( + file=MediaFileType(video_url), + execution_context=execution_context, + return_format="for_block_output", + ) + yield "video_url", stored_url return elif status_response["status"] == "error": raise RuntimeError( diff --git a/autogpt_platform/backend/backend/blocks/test/test_blocks_dos_vulnerability.py b/autogpt_platform/backend/backend/blocks/test/test_blocks_dos_vulnerability.py index 389bb5c636..e2e44b194c 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_blocks_dos_vulnerability.py +++ b/autogpt_platform/backend/backend/blocks/test/test_blocks_dos_vulnerability.py @@ -12,6 +12,7 @@ from backend.blocks.iteration import StepThroughItemsBlock from backend.blocks.llm import AITextSummarizerBlock from backend.blocks.text import ExtractTextInformationBlock from backend.blocks.xml_parser import XMLParserBlock +from backend.data.execution import ExecutionContext from backend.util.file import store_media_file from backend.util.type import MediaFileType @@ -233,9 +234,12 @@ class TestStoreMediaFileSecurity: with pytest.raises(ValueError, match="File too large"): await store_media_file( - graph_exec_id="test", file=MediaFileType(large_data_uri), - user_id="test_user", + execution_context=ExecutionContext( + user_id="test_user", + graph_exec_id="test", + ), + return_format="for_local_processing", ) @patch("backend.util.file.Path") @@ -270,9 +274,12 @@ class TestStoreMediaFileSecurity: # Should raise an error when directory size exceeds limit with pytest.raises(ValueError, match="Disk usage limit exceeded"): await store_media_file( - graph_exec_id="test", file=MediaFileType( "data:text/plain;base64,dGVzdA==" ), # Small test file - user_id="test_user", + execution_context=ExecutionContext( + user_id="test_user", + graph_exec_id="test", + ), + return_format="for_local_processing", ) diff --git a/autogpt_platform/backend/backend/blocks/test/test_http.py b/autogpt_platform/backend/backend/blocks/test/test_http.py index bdc30f3ecf..e01b8e2c5b 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_http.py +++ b/autogpt_platform/backend/backend/blocks/test/test_http.py @@ -11,10 +11,22 @@ from backend.blocks.http import ( HttpMethod, SendAuthenticatedWebRequestBlock, ) +from backend.data.execution import ExecutionContext from backend.data.model import HostScopedCredentials from backend.util.request import Response +def make_test_context( + graph_exec_id: str = "test-exec-id", + user_id: str = "test-user-id", +) -> ExecutionContext: + """Helper to create test ExecutionContext.""" + return ExecutionContext( + user_id=user_id, + graph_exec_id=graph_exec_id, + ) + + class TestHttpBlockWithHostScopedCredentials: """Test suite for HTTP block integration with HostScopedCredentials.""" @@ -105,8 +117,7 @@ class TestHttpBlockWithHostScopedCredentials: async for output_name, output_data in http_block.run( input_data, credentials=exact_match_credentials, - graph_exec_id="test-exec-id", - user_id="test-user-id", + execution_context=make_test_context(), ): result.append((output_name, output_data)) @@ -161,8 +172,7 @@ class TestHttpBlockWithHostScopedCredentials: async for output_name, output_data in http_block.run( input_data, credentials=wildcard_credentials, - graph_exec_id="test-exec-id", - user_id="test-user-id", + execution_context=make_test_context(), ): result.append((output_name, output_data)) @@ -208,8 +218,7 @@ class TestHttpBlockWithHostScopedCredentials: async for output_name, output_data in http_block.run( input_data, credentials=non_matching_credentials, - graph_exec_id="test-exec-id", - user_id="test-user-id", + execution_context=make_test_context(), ): result.append((output_name, output_data)) @@ -258,8 +267,7 @@ class TestHttpBlockWithHostScopedCredentials: async for output_name, output_data in http_block.run( input_data, credentials=exact_match_credentials, - graph_exec_id="test-exec-id", - user_id="test-user-id", + execution_context=make_test_context(), ): result.append((output_name, output_data)) @@ -318,8 +326,7 @@ class TestHttpBlockWithHostScopedCredentials: async for output_name, output_data in http_block.run( input_data, credentials=auto_discovered_creds, # Execution manager found these - graph_exec_id="test-exec-id", - user_id="test-user-id", + execution_context=make_test_context(), ): result.append((output_name, output_data)) @@ -382,8 +389,7 @@ class TestHttpBlockWithHostScopedCredentials: async for output_name, output_data in http_block.run( input_data, credentials=multi_header_creds, - graph_exec_id="test-exec-id", - user_id="test-user-id", + execution_context=make_test_context(), ): result.append((output_name, output_data)) @@ -471,8 +477,7 @@ class TestHttpBlockWithHostScopedCredentials: async for output_name, output_data in http_block.run( input_data, credentials=test_creds, - graph_exec_id="test-exec-id", - user_id="test-user-id", + execution_context=make_test_context(), ): result.append((output_name, output_data)) diff --git a/autogpt_platform/backend/backend/blocks/text.py b/autogpt_platform/backend/backend/blocks/text.py index 5e58e27101..359e22a84f 100644 --- a/autogpt_platform/backend/backend/blocks/text.py +++ b/autogpt_platform/backend/backend/blocks/text.py @@ -11,6 +11,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import SchemaField from backend.util import json, text from backend.util.file import get_exec_file_path, store_media_file @@ -444,18 +445,21 @@ class FileReadBlock(Block): ) async def run( - self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs + self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs ) -> BlockOutput: # Store the media file properly (handles URLs, data URIs, etc.) stored_file_path = await store_media_file( - user_id=user_id, - graph_exec_id=graph_exec_id, file=input_data.file_input, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) - # Get full file path - file_path = get_exec_file_path(graph_exec_id, stored_file_path) + # Get full file path (graph_exec_id validated by store_media_file above) + if not execution_context.graph_exec_id: + raise ValueError("execution_context.graph_exec_id is required") + file_path = get_exec_file_path( + execution_context.graph_exec_id, stored_file_path + ) if not Path(file_path).exists(): raise ValueError(f"File does not exist: {file_path}") diff --git a/autogpt_platform/backend/backend/blocks/video/__init__.py b/autogpt_platform/backend/backend/blocks/video/__init__.py index 11afc26443..417903a409 100644 --- a/autogpt_platform/backend/backend/blocks/video/__init__.py +++ b/autogpt_platform/backend/backend/blocks/video/__init__.py @@ -6,9 +6,9 @@ This module provides blocks for: - Concatenating multiple videos - Adding text overlays - Adding AI-generated narration -- Getting media duration -- Looping videos -- Adding audio to videos + +Note: MediaDurationBlock, LoopVideoBlock, and AddAudioToVideoBlock are +provided by backend/blocks/media.py. Dependencies: - yt-dlp: For video downloading @@ -16,19 +16,13 @@ Dependencies: - requests: For API calls (narration block) """ -from backend.blocks.video.add_audio import AddAudioToVideoBlock from backend.blocks.video.clip import VideoClipBlock from backend.blocks.video.concat import VideoConcatBlock from backend.blocks.video.download import VideoDownloadBlock -from backend.blocks.video.duration import MediaDurationBlock -from backend.blocks.video.loop import LoopVideoBlock from backend.blocks.video.narration import VideoNarrationBlock from backend.blocks.video.text_overlay import VideoTextOverlayBlock __all__ = [ - "AddAudioToVideoBlock", - "LoopVideoBlock", - "MediaDurationBlock", "VideoClipBlock", "VideoConcatBlock", "VideoDownloadBlock", diff --git a/autogpt_platform/backend/backend/blocks/video/add_audio.py b/autogpt_platform/backend/backend/blocks/video/add_audio.py deleted file mode 100644 index 02334e3234..0000000000 --- a/autogpt_platform/backend/backend/blocks/video/add_audio.py +++ /dev/null @@ -1,127 +0,0 @@ -"""AddAudioToVideoBlock - Attach an audio track to a video.""" - -import os -from typing import Literal - -from moviepy.audio.io.AudioFileClip import AudioFileClip -from moviepy.video.io.VideoFileClip import VideoFileClip - -from backend.blocks.video._utils import get_video_codecs -from backend.data.block import ( - Block, - BlockCategory, - BlockOutput, - BlockSchemaInput, - BlockSchemaOutput, -) -from backend.data.model import SchemaField -from backend.util.file import MediaFileType, get_exec_file_path, store_media_file - - -class AddAudioToVideoBlock(Block): - """Attach an audio track to an existing video.""" - - class Input(BlockSchemaInput): - video_in: MediaFileType = SchemaField( - description="Video input (URL, data URI, or local path)." - ) - audio_in: MediaFileType = SchemaField( - description="Audio input (URL, data URI, or local path)." - ) - volume: float = SchemaField( - description="Volume scale for the newly attached audio track (1.0 = original).", - default=1.0, - ) - output_return_type: Literal["file_path", "data_uri"] = SchemaField( - description="Return the final output as a relative path or base64 data URI.", - default="file_path", - ) - - class Output(BlockSchemaOutput): - video_out: MediaFileType = SchemaField( - description="Final video (with attached audio), as a path or data URI." - ) - - def __init__(self): - super().__init__( - id="3503748d-62b6-4425-91d6-725b064af509", - description="Block to attach an audio file to a video file using moviepy.", - categories={BlockCategory.MULTIMEDIA}, - input_schema=AddAudioToVideoBlock.Input, - output_schema=AddAudioToVideoBlock.Output, - ) - - async def run( - self, - input_data: Input, - *, - node_exec_id: str, - graph_exec_id: str, - user_id: str, - **kwargs, - ) -> BlockOutput: - # 1) Store the inputs locally - local_video_path = await store_media_file( - graph_exec_id=graph_exec_id, - file=input_data.video_in, - user_id=user_id, - return_content=False, - ) - local_audio_path = await store_media_file( - graph_exec_id=graph_exec_id, - file=input_data.audio_in, - user_id=user_id, - return_content=False, - ) - - video_abspath = get_exec_file_path(graph_exec_id, local_video_path) - audio_abspath = get_exec_file_path(graph_exec_id, local_audio_path) - - video_clip = None - audio_clip_original = None - audio_clip_scaled = None - final_clip = None - try: - # 2) Load video + audio with moviepy - video_clip = VideoFileClip(video_abspath) - audio_clip_original = AudioFileClip(audio_abspath) - - # Optionally scale volume - audio_to_use = audio_clip_original - if input_data.volume != 1.0: - audio_clip_scaled = audio_clip_original.with_volume_scaled( - input_data.volume - ) - audio_to_use = audio_clip_scaled - - # 3) Attach the new audio track - final_clip = video_clip.with_audio(audio_to_use) - - # 4) Write to output file - output_filename = MediaFileType( - f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}" - ) - output_abspath = get_exec_file_path(graph_exec_id, output_filename) - video_codec, audio_codec = get_video_codecs(output_abspath) - final_clip.write_videofile( - output_abspath, codec=video_codec, audio_codec=audio_codec - ) - - # 5) Return either path or data URI - video_out = await store_media_file( - graph_exec_id=graph_exec_id, - file=output_filename, - user_id=user_id, - return_content=input_data.output_return_type == "data_uri", - ) - - yield "video_out", video_out - finally: - if final_clip: - final_clip.close() - if audio_clip_scaled: - audio_clip_scaled.close() - if audio_clip_original: - audio_clip_original.close() - if video_clip: - video_clip.close() diff --git a/autogpt_platform/backend/backend/blocks/video/clip.py b/autogpt_platform/backend/backend/blocks/video/clip.py index aee7e83a68..01b6fff34d 100644 --- a/autogpt_platform/backend/backend/blocks/video/clip.py +++ b/autogpt_platform/backend/backend/blocks/video/clip.py @@ -13,6 +13,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import SchemaField from backend.util.exceptions import BlockExecutionError from backend.util.file import MediaFileType, get_exec_file_path, store_media_file @@ -30,10 +31,6 @@ class VideoClipBlock(Block): output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField( description="Output format", default="mp4", advanced=True ) - output_return_type: Literal["file_path", "data_uri"] = SchemaField( - description="Return the output as a relative path or base64 data URI.", - default="file_path", - ) class Output(BlockSchemaOutput): video_out: MediaFileType = SchemaField( @@ -62,29 +59,23 @@ class VideoClipBlock(Block): ) async def _store_input_video( - self, graph_exec_id: str, file: MediaFileType, user_id: str + self, execution_context: ExecutionContext, file: MediaFileType ) -> MediaFileType: """Store input video. Extracted for testability.""" return await store_media_file( - graph_exec_id=graph_exec_id, file=file, - user_id=user_id, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) async def _store_output_video( - self, - graph_exec_id: str, - file: MediaFileType, - user_id: str, - return_content: bool, + self, execution_context: ExecutionContext, file: MediaFileType ) -> MediaFileType: """Store output video. Extracted for testability.""" return await store_media_file( - graph_exec_id=graph_exec_id, file=file, - user_id=user_id, - return_content=return_content, + execution_context=execution_context, + return_format="for_block_output", ) def _clip_video( @@ -115,9 +106,8 @@ class VideoClipBlock(Block): self, input_data: Input, *, + execution_context: ExecutionContext, node_exec_id: str, - graph_exec_id: str, - user_id: str, **kwargs, ) -> BlockOutput: # Validate time range @@ -129,11 +119,15 @@ class VideoClipBlock(Block): ) try: + assert execution_context.graph_exec_id is not None + # Store the input video locally local_video_path = await self._store_input_video( - graph_exec_id, input_data.video_in, user_id + execution_context, input_data.video_in + ) + video_abspath = get_exec_file_path( + execution_context.graph_exec_id, local_video_path ) - video_abspath = get_exec_file_path(graph_exec_id, local_video_path) # Build output path output_filename = MediaFileType( @@ -142,7 +136,9 @@ class VideoClipBlock(Block): # Ensure correct extension base, _ = os.path.splitext(output_filename) output_filename = MediaFileType(f"{base}.{input_data.output_format}") - output_abspath = get_exec_file_path(graph_exec_id, output_filename) + output_abspath = get_exec_file_path( + execution_context.graph_exec_id, output_filename + ) duration = self._clip_video( video_abspath, @@ -151,13 +147,8 @@ class VideoClipBlock(Block): input_data.end_time, ) - # Return as data URI or path - video_out = await self._store_output_video( - graph_exec_id, - output_filename, - user_id, - input_data.output_return_type == "data_uri", - ) + # Return as workspace path or data URI based on context + video_out = await self._store_output_video(execution_context, output_filename) yield "video_out", video_out yield "duration", duration diff --git a/autogpt_platform/backend/backend/blocks/video/concat.py b/autogpt_platform/backend/backend/blocks/video/concat.py index 1e74b1e820..298227625e 100644 --- a/autogpt_platform/backend/backend/blocks/video/concat.py +++ b/autogpt_platform/backend/backend/blocks/video/concat.py @@ -14,6 +14,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import SchemaField from backend.util.exceptions import BlockExecutionError from backend.util.file import MediaFileType, get_exec_file_path, store_media_file @@ -38,10 +39,6 @@ class VideoConcatBlock(Block): output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField( description="Output format", default="mp4", advanced=True ) - output_return_type: Literal["file_path", "data_uri"] = SchemaField( - description="Return the output as a relative path or base64 data URI.", - default="file_path", - ) class Output(BlockSchemaOutput): video_out: MediaFileType = SchemaField( @@ -66,29 +63,23 @@ class VideoConcatBlock(Block): ) async def _store_input_video( - self, graph_exec_id: str, file: MediaFileType, user_id: str + self, execution_context: ExecutionContext, file: MediaFileType ) -> MediaFileType: """Store input video. Extracted for testability.""" return await store_media_file( - graph_exec_id=graph_exec_id, file=file, - user_id=user_id, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) async def _store_output_video( - self, - graph_exec_id: str, - file: MediaFileType, - user_id: str, - return_content: bool, + self, execution_context: ExecutionContext, file: MediaFileType ) -> MediaFileType: """Store output video. Extracted for testability.""" return await store_media_file( - graph_exec_id=graph_exec_id, file=file, - user_id=user_id, - return_content=return_content, + execution_context=execution_context, + return_format="for_block_output", ) def _concat_videos( @@ -150,9 +141,8 @@ class VideoConcatBlock(Block): self, input_data: Input, *, + execution_context: ExecutionContext, node_exec_id: str, - graph_exec_id: str, - user_id: str, **kwargs, ) -> BlockOutput: # Validate minimum clips @@ -164,19 +154,23 @@ class VideoConcatBlock(Block): ) try: + assert execution_context.graph_exec_id is not None + # Store all input videos locally video_abspaths = [] for video in input_data.videos: - local_path = await self._store_input_video( - graph_exec_id, video, user_id + local_path = await self._store_input_video(execution_context, video) + video_abspaths.append( + get_exec_file_path(execution_context.graph_exec_id, local_path) ) - video_abspaths.append(get_exec_file_path(graph_exec_id, local_path)) # Build output path output_filename = MediaFileType( f"{node_exec_id}_concat.{input_data.output_format}" ) - output_abspath = get_exec_file_path(graph_exec_id, output_filename) + output_abspath = get_exec_file_path( + execution_context.graph_exec_id, output_filename + ) total_duration = self._concat_videos( video_abspaths, @@ -185,13 +179,8 @@ class VideoConcatBlock(Block): input_data.transition_duration, ) - # Return as data URI or path - video_out = await self._store_output_video( - graph_exec_id, - output_filename, - user_id, - input_data.output_return_type == "data_uri", - ) + # Return as workspace path or data URI based on context + video_out = await self._store_output_video(execution_context, output_filename) yield "video_out", video_out yield "total_duration", total_duration diff --git a/autogpt_platform/backend/backend/blocks/video/download.py b/autogpt_platform/backend/backend/blocks/video/download.py index c45ed3dac3..d9c5fc4afb 100644 --- a/autogpt_platform/backend/backend/blocks/video/download.py +++ b/autogpt_platform/backend/backend/blocks/video/download.py @@ -16,6 +16,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import SchemaField from backend.util.exceptions import BlockExecutionError from backend.util.file import MediaFileType, get_exec_file_path, store_media_file @@ -35,10 +36,6 @@ class VideoDownloadBlock(Block): output_format: Literal["mp4", "webm", "mkv"] = SchemaField( description="Output video format", default="mp4", advanced=True ) - output_return_type: Literal["file_path", "data_uri"] = SchemaField( - description="Return the output as a relative path or base64 data URI.", - default="file_path", - ) class Output(BlockSchemaOutput): video_file: MediaFileType = SchemaField( @@ -72,18 +69,13 @@ class VideoDownloadBlock(Block): ) async def _store_output_video( - self, - graph_exec_id: str, - file: MediaFileType, - user_id: str, - return_content: bool, + self, execution_context: ExecutionContext, file: MediaFileType ) -> MediaFileType: """Store output video. Extracted for testability.""" return await store_media_file( - graph_exec_id=graph_exec_id, file=file, - user_id=user_id, - return_content=return_content, + execution_context=execution_context, + return_format="for_block_output", ) def _get_format_string(self, quality: str) -> str: @@ -138,14 +130,15 @@ class VideoDownloadBlock(Block): self, input_data: Input, *, + execution_context: ExecutionContext, node_exec_id: str, - graph_exec_id: str, - user_id: str, **kwargs, ) -> BlockOutput: try: + assert execution_context.graph_exec_id is not None + # Get the exec file directory - output_dir = get_exec_file_path(graph_exec_id, "") + output_dir = get_exec_file_path(execution_context.graph_exec_id, "") os.makedirs(output_dir, exist_ok=True) filename, duration, title = self._download_video( @@ -156,12 +149,9 @@ class VideoDownloadBlock(Block): node_exec_id, ) - # Return as data URI or path + # Return as workspace path or data URI based on context video_out = await self._store_output_video( - graph_exec_id, - MediaFileType(filename), - user_id, - input_data.output_return_type == "data_uri", + execution_context, MediaFileType(filename) ) yield "video_file", video_out diff --git a/autogpt_platform/backend/backend/blocks/video/duration.py b/autogpt_platform/backend/backend/blocks/video/duration.py deleted file mode 100644 index f4182c9784..0000000000 --- a/autogpt_platform/backend/backend/blocks/video/duration.py +++ /dev/null @@ -1,71 +0,0 @@ -"""MediaDurationBlock - Get the duration of a media file.""" - -from moviepy.audio.io.AudioFileClip import AudioFileClip -from moviepy.video.io.VideoFileClip import VideoFileClip - -from backend.data.block import ( - Block, - BlockCategory, - BlockOutput, - BlockSchemaInput, - BlockSchemaOutput, -) -from backend.data.model import SchemaField -from backend.util.file import MediaFileType, get_exec_file_path, store_media_file - - -class MediaDurationBlock(Block): - """Get the duration of a media file.""" - - class Input(BlockSchemaInput): - media_in: MediaFileType = SchemaField( - description="Media input (URL, data URI, or local path)." - ) - is_video: bool = SchemaField( - description="Whether the media is a video (True) or audio (False).", - default=True, - ) - - class Output(BlockSchemaOutput): - duration: float = SchemaField( - description="Duration of the media file (in seconds)." - ) - - def __init__(self): - super().__init__( - id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6", - description="Block to get the duration of a media file.", - categories={BlockCategory.MULTIMEDIA}, - input_schema=MediaDurationBlock.Input, - output_schema=MediaDurationBlock.Output, - ) - - async def run( - self, - input_data: Input, - *, - graph_exec_id: str, - user_id: str, - **kwargs, - ) -> BlockOutput: - # 1) Store the input media locally - local_media_path = await store_media_file( - graph_exec_id=graph_exec_id, - file=input_data.media_in, - user_id=user_id, - return_content=False, - ) - media_abspath = get_exec_file_path(graph_exec_id, local_media_path) - - # 2) Load the clip - clip = None - try: - if input_data.is_video: - clip = VideoFileClip(media_abspath) - else: - clip = AudioFileClip(media_abspath) - - yield "duration", clip.duration - finally: - if clip: - clip.close() diff --git a/autogpt_platform/backend/backend/blocks/video/loop.py b/autogpt_platform/backend/backend/blocks/video/loop.py deleted file mode 100644 index 4dd2ed8444..0000000000 --- a/autogpt_platform/backend/backend/blocks/video/loop.py +++ /dev/null @@ -1,116 +0,0 @@ -"""LoopVideoBlock - Loop a video to a given duration or number of repeats.""" - -import os -from typing import Literal, Optional - -from moviepy.video.fx.Loop import Loop -from moviepy.video.io.VideoFileClip import VideoFileClip - -from backend.blocks.video._utils import get_video_codecs -from backend.data.block import ( - Block, - BlockCategory, - BlockOutput, - BlockSchemaInput, - BlockSchemaOutput, -) -from backend.data.model import SchemaField -from backend.util.file import MediaFileType, get_exec_file_path, store_media_file - - -class LoopVideoBlock(Block): - """Loop (repeat) a video clip until a given duration or number of loops.""" - - class Input(BlockSchemaInput): - video_in: MediaFileType = SchemaField( - description="The input video (can be a URL, data URI, or local path)." - ) - duration: Optional[float] = SchemaField( - description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.", - default=None, - ge=0.0, - ) - n_loops: Optional[int] = SchemaField( - description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).", - default=None, - ge=1, - ) - output_return_type: Literal["file_path", "data_uri"] = SchemaField( - description="How to return the output video. Either a relative path or base64 data URI.", - default="file_path", - ) - - class Output(BlockSchemaOutput): - video_out: str = SchemaField( - description="Looped video returned either as a relative path or a data URI." - ) - - def __init__(self): - super().__init__( - id="8bf9eef6-5451-4213-b265-25306446e94b", - description="Block to loop a video to a given duration or number of repeats.", - categories={BlockCategory.MULTIMEDIA}, - input_schema=LoopVideoBlock.Input, - output_schema=LoopVideoBlock.Output, - ) - - async def run( - self, - input_data: Input, - *, - node_exec_id: str, - graph_exec_id: str, - user_id: str, - **kwargs, - ) -> BlockOutput: - # 1) Store the input video locally - local_video_path = await store_media_file( - graph_exec_id=graph_exec_id, - file=input_data.video_in, - user_id=user_id, - return_content=False, - ) - input_abspath = get_exec_file_path(graph_exec_id, local_video_path) - - clip: VideoFileClip | None = None - looped_clip: VideoFileClip | None = None - try: - # 2) Load the clip - clip = VideoFileClip(input_abspath) - - # 3) Apply the loop effect - # Note: Loop effect handles both video and audio looping automatically - if input_data.duration: - looped_clip = clip.with_effects([Loop(duration=input_data.duration)]) # type: ignore[arg-type] Clip implements shallow copy that loses type info - elif input_data.n_loops: - looped_clip = clip.with_effects([Loop(n=input_data.n_loops)]) # type: ignore[arg-type] Clip implements shallow copy that loses type info - else: - raise ValueError("Either 'duration' or 'n_loops' must be provided.") - - # 4) Save the looped output - output_filename = MediaFileType( - f"{node_exec_id}_looped_{os.path.basename(local_video_path)}" - ) - output_abspath = get_exec_file_path(graph_exec_id, output_filename) - - assert looped_clip is not None - - video_codec, audio_codec = get_video_codecs(output_abspath) - looped_clip.write_videofile( - output_abspath, codec=video_codec, audio_codec=audio_codec - ) - - # Return as data URI or path - video_out = await store_media_file( - graph_exec_id=graph_exec_id, - file=output_filename, - user_id=user_id, - return_content=input_data.output_return_type == "data_uri", - ) - - yield "video_out", video_out - finally: - if looped_clip is not None: - looped_clip.close() - if clip is not None: - clip.close() diff --git a/autogpt_platform/backend/backend/blocks/video/narration.py b/autogpt_platform/backend/backend/blocks/video/narration.py index c569ea6ca5..aebf1d89cc 100644 --- a/autogpt_platform/backend/backend/blocks/video/narration.py +++ b/autogpt_platform/backend/backend/blocks/video/narration.py @@ -22,6 +22,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import CredentialsField, SchemaField from backend.util.exceptions import BlockExecutionError from backend.util.file import MediaFileType, get_exec_file_path, store_media_file @@ -68,10 +69,6 @@ class VideoNarrationBlock(Block): le=1.0, advanced=True, ) - output_return_type: Literal["file_path", "data_uri"] = SchemaField( - description="Return the output as a relative path or base64 data URI.", - default="file_path", - ) class Output(BlockSchemaOutput): video_out: MediaFileType = SchemaField( @@ -104,29 +101,23 @@ class VideoNarrationBlock(Block): ) async def _store_input_video( - self, graph_exec_id: str, file: MediaFileType, user_id: str + self, execution_context: ExecutionContext, file: MediaFileType ) -> MediaFileType: """Store input video. Extracted for testability.""" return await store_media_file( - graph_exec_id=graph_exec_id, file=file, - user_id=user_id, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) async def _store_output_video( - self, - graph_exec_id: str, - file: MediaFileType, - user_id: str, - return_content: bool, + self, execution_context: ExecutionContext, file: MediaFileType ) -> MediaFileType: """Store output video. Extracted for testability.""" return await store_media_file( - graph_exec_id=graph_exec_id, file=file, - user_id=user_id, - return_content=return_content, + execution_context=execution_context, + return_format="for_block_output", ) def _generate_narration_audio( @@ -204,17 +195,20 @@ class VideoNarrationBlock(Block): input_data: Input, *, credentials: ElevenLabsCredentials, + execution_context: ExecutionContext, node_exec_id: str, - graph_exec_id: str, - user_id: str, **kwargs, ) -> BlockOutput: try: + assert execution_context.graph_exec_id is not None + # Store the input video locally local_video_path = await self._store_input_video( - graph_exec_id, input_data.video_in, user_id + execution_context, input_data.video_in + ) + video_abspath = get_exec_file_path( + execution_context.graph_exec_id, local_video_path ) - video_abspath = get_exec_file_path(graph_exec_id, local_video_path) # Generate narration audio via ElevenLabs audio_content = self._generate_narration_audio( @@ -226,7 +220,9 @@ class VideoNarrationBlock(Block): # Save audio to exec file path audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3") - audio_abspath = get_exec_file_path(graph_exec_id, audio_filename) + audio_abspath = get_exec_file_path( + execution_context.graph_exec_id, audio_filename + ) os.makedirs(os.path.dirname(audio_abspath), exist_ok=True) with open(audio_abspath, "wb") as f: f.write(audio_content) @@ -235,7 +231,9 @@ class VideoNarrationBlock(Block): output_filename = MediaFileType( f"{node_exec_id}_narrated_{os.path.basename(local_video_path)}" ) - output_abspath = get_exec_file_path(graph_exec_id, output_filename) + output_abspath = get_exec_file_path( + execution_context.graph_exec_id, output_filename + ) self._add_narration_to_video( video_abspath, @@ -246,16 +244,9 @@ class VideoNarrationBlock(Block): input_data.original_volume, ) - # Return as data URI or path - return_as_data_uri = input_data.output_return_type == "data_uri" - - video_out = await self._store_output_video( - graph_exec_id, output_filename, user_id, return_as_data_uri - ) - - audio_out = await self._store_output_video( - graph_exec_id, audio_filename, user_id, return_as_data_uri - ) + # Return as workspace path or data URI based on context + video_out = await self._store_output_video(execution_context, output_filename) + audio_out = await self._store_output_video(execution_context, audio_filename) yield "video_out", video_out yield "audio_file", audio_out diff --git a/autogpt_platform/backend/backend/blocks/video/text_overlay.py b/autogpt_platform/backend/backend/blocks/video/text_overlay.py index 20e9737807..50e54641e9 100644 --- a/autogpt_platform/backend/backend/blocks/video/text_overlay.py +++ b/autogpt_platform/backend/backend/blocks/video/text_overlay.py @@ -14,6 +14,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import SchemaField from backend.util.exceptions import BlockExecutionError from backend.util.file import MediaFileType, get_exec_file_path, store_media_file @@ -57,10 +58,6 @@ class VideoTextOverlayBlock(Block): default=None, advanced=True, ) - output_return_type: Literal["file_path", "data_uri"] = SchemaField( - description="Return the output as a relative path or base64 data URI.", - default="file_path", - ) class Output(BlockSchemaOutput): video_out: MediaFileType = SchemaField( @@ -84,29 +81,23 @@ class VideoTextOverlayBlock(Block): ) async def _store_input_video( - self, graph_exec_id: str, file: MediaFileType, user_id: str + self, execution_context: ExecutionContext, file: MediaFileType ) -> MediaFileType: """Store input video. Extracted for testability.""" return await store_media_file( - graph_exec_id=graph_exec_id, file=file, - user_id=user_id, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) async def _store_output_video( - self, - graph_exec_id: str, - file: MediaFileType, - user_id: str, - return_content: bool, + self, execution_context: ExecutionContext, file: MediaFileType ) -> MediaFileType: """Store output video. Extracted for testability.""" return await store_media_file( - graph_exec_id=graph_exec_id, file=file, - user_id=user_id, - return_content=return_content, + execution_context=execution_context, + return_format="for_block_output", ) def _add_text_overlay( @@ -172,9 +163,8 @@ class VideoTextOverlayBlock(Block): self, input_data: Input, *, + execution_context: ExecutionContext, node_exec_id: str, - graph_exec_id: str, - user_id: str, **kwargs, ) -> BlockOutput: # Validate time range if both are provided @@ -190,17 +180,23 @@ class VideoTextOverlayBlock(Block): ) try: + assert execution_context.graph_exec_id is not None + # Store the input video locally local_video_path = await self._store_input_video( - graph_exec_id, input_data.video_in, user_id + execution_context, input_data.video_in + ) + video_abspath = get_exec_file_path( + execution_context.graph_exec_id, local_video_path ) - video_abspath = get_exec_file_path(graph_exec_id, local_video_path) # Build output path output_filename = MediaFileType( f"{node_exec_id}_overlay_{os.path.basename(local_video_path)}" ) - output_abspath = get_exec_file_path(graph_exec_id, output_filename) + output_abspath = get_exec_file_path( + execution_context.graph_exec_id, output_filename + ) self._add_text_overlay( video_abspath, @@ -214,13 +210,8 @@ class VideoTextOverlayBlock(Block): input_data.bg_color, ) - # Return as data URI or path - video_out = await self._store_output_video( - graph_exec_id, - output_filename, - user_id, - input_data.output_return_type == "data_uri", - ) + # Return as workspace path or data URI based on context + video_out = await self._store_output_video(execution_context, output_filename) yield "video_out", video_out diff --git a/autogpt_platform/backend/backend/conftest.py b/autogpt_platform/backend/backend/conftest.py index b0b7f0cc67..57481e4b85 100644 --- a/autogpt_platform/backend/backend/conftest.py +++ b/autogpt_platform/backend/backend/conftest.py @@ -1,7 +1,7 @@ import logging import os -import pytest +import pytest_asyncio from dotenv import load_dotenv from backend.util.logging import configure_logging @@ -19,7 +19,7 @@ if not os.getenv("PRISMA_DEBUG"): prisma_logger.setLevel(logging.INFO) -@pytest.fixture(scope="session") +@pytest_asyncio.fixture(scope="session", loop_scope="session") async def server(): from backend.util.test import SpinTestServer @@ -27,7 +27,7 @@ async def server(): yield server -@pytest.fixture(scope="session", autouse=True) +@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True) async def graph_cleanup(server): created_graph_ids = [] original_create_graph = server.agent_server.test_create_graph diff --git a/autogpt_platform/backend/backend/data/block.py b/autogpt_platform/backend/backend/data/block.py index 4bfa3892e2..8d9ecfff4c 100644 --- a/autogpt_platform/backend/backend/data/block.py +++ b/autogpt_platform/backend/backend/data/block.py @@ -441,6 +441,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): static_output: bool = False, block_type: BlockType = BlockType.STANDARD, webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None, + is_sensitive_action: bool = False, ): """ Initialize the block with the given schema. @@ -473,8 +474,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): self.static_output = static_output self.block_type = block_type self.webhook_config = webhook_config + self.is_sensitive_action = is_sensitive_action self.execution_stats: NodeExecutionStats = NodeExecutionStats() - self.is_sensitive_action: bool = False if self.webhook_config: if isinstance(self.webhook_config, BlockWebhookConfig): @@ -622,6 +623,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): input_data: BlockInput, *, user_id: str, + node_id: str, node_exec_id: str, graph_exec_id: str, graph_id: str, @@ -648,11 +650,11 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): decision = await HITLReviewHelper.handle_review_decision( input_data=input_data, user_id=user_id, + node_id=node_id, node_exec_id=node_exec_id, graph_exec_id=graph_exec_id, graph_id=graph_id, graph_version=graph_version, - execution_context=execution_context, block_name=self.name, editable=True, ) diff --git a/autogpt_platform/backend/backend/data/execution.py b/autogpt_platform/backend/backend/data/execution.py index 3c1fd25c51..afb8c70538 100644 --- a/autogpt_platform/backend/backend/data/execution.py +++ b/autogpt_platform/backend/backend/data/execution.py @@ -83,12 +83,29 @@ class ExecutionContext(BaseModel): model_config = {"extra": "ignore"} + # Execution identity + user_id: Optional[str] = None + graph_id: Optional[str] = None + graph_exec_id: Optional[str] = None + graph_version: Optional[int] = None + node_id: Optional[str] = None + node_exec_id: Optional[str] = None + + # Safety settings human_in_the_loop_safe_mode: bool = True sensitive_action_safe_mode: bool = False + + # User settings user_timezone: str = "UTC" + + # Execution hierarchy root_execution_id: Optional[str] = None parent_execution_id: Optional[str] = None + # Workspace + workspace_id: Optional[str] = None + session_id: Optional[str] = None + # -------------------------- Models -------------------------- # diff --git a/autogpt_platform/backend/backend/data/human_review.py b/autogpt_platform/backend/backend/data/human_review.py index de7a30759e..f198043a38 100644 --- a/autogpt_platform/backend/backend/data/human_review.py +++ b/autogpt_platform/backend/backend/data/human_review.py @@ -6,10 +6,10 @@ Handles all database operations for pending human reviews. import asyncio import logging from datetime import datetime, timezone -from typing import Optional +from typing import TYPE_CHECKING, Optional from prisma.enums import ReviewStatus -from prisma.models import PendingHumanReview +from prisma.models import AgentNodeExecution, PendingHumanReview from prisma.types import PendingHumanReviewUpdateInput from pydantic import BaseModel @@ -17,8 +17,12 @@ from backend.api.features.executions.review.model import ( PendingHumanReviewModel, SafeJsonData, ) +from backend.data.execution import get_graph_execution_meta from backend.util.json import SafeJson +if TYPE_CHECKING: + pass + logger = logging.getLogger(__name__) @@ -32,6 +36,125 @@ class ReviewResult(BaseModel): node_exec_id: str +def get_auto_approve_key(graph_exec_id: str, node_id: str) -> str: + """Generate the special nodeExecId key for auto-approval records.""" + return f"auto_approve_{graph_exec_id}_{node_id}" + + +async def check_approval( + node_exec_id: str, + graph_exec_id: str, + node_id: str, + user_id: str, + input_data: SafeJsonData | None = None, +) -> Optional[ReviewResult]: + """ + Check if there's an existing approval for this node execution. + + Checks both: + 1. Normal approval by node_exec_id (previous run of the same node execution) + 2. Auto-approval by special key pattern "auto_approve_{graph_exec_id}_{node_id}" + + Args: + node_exec_id: ID of the node execution + graph_exec_id: ID of the graph execution + node_id: ID of the node definition (not execution) + user_id: ID of the user (for data isolation) + input_data: Current input data (used for auto-approvals to avoid stale data) + + Returns: + ReviewResult if approval found (either normal or auto), None otherwise + """ + auto_approve_key = get_auto_approve_key(graph_exec_id, node_id) + + # Check for either normal approval or auto-approval in a single query + existing_review = await PendingHumanReview.prisma().find_first( + where={ + "OR": [ + {"nodeExecId": node_exec_id}, + {"nodeExecId": auto_approve_key}, + ], + "status": ReviewStatus.APPROVED, + "userId": user_id, + }, + ) + + if existing_review: + is_auto_approval = existing_review.nodeExecId == auto_approve_key + logger.info( + f"Found {'auto-' if is_auto_approval else ''}approval for node {node_id} " + f"(exec: {node_exec_id}) in execution {graph_exec_id}" + ) + # For auto-approvals, use current input_data to avoid replaying stale payload + # For normal approvals, use the stored payload (which may have been edited) + return ReviewResult( + data=( + input_data + if is_auto_approval and input_data is not None + else existing_review.payload + ), + status=ReviewStatus.APPROVED, + message=( + "Auto-approved (user approved all future actions for this node)" + if is_auto_approval + else existing_review.reviewMessage or "" + ), + processed=True, + node_exec_id=existing_review.nodeExecId, + ) + + return None + + +async def create_auto_approval_record( + user_id: str, + graph_exec_id: str, + graph_id: str, + graph_version: int, + node_id: str, + payload: SafeJsonData, +) -> None: + """ + Create an auto-approval record for a node in this execution. + + This is stored as a PendingHumanReview with a special nodeExecId pattern + and status=APPROVED, so future executions of the same node can skip review. + + Raises: + ValueError: If the graph execution doesn't belong to the user + """ + # Validate that the graph execution belongs to this user (defense in depth) + graph_exec = await get_graph_execution_meta( + user_id=user_id, execution_id=graph_exec_id + ) + if not graph_exec: + raise ValueError( + f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}" + ) + + auto_approve_key = get_auto_approve_key(graph_exec_id, node_id) + + await PendingHumanReview.prisma().upsert( + where={"nodeExecId": auto_approve_key}, + data={ + "create": { + "nodeExecId": auto_approve_key, + "userId": user_id, + "graphExecId": graph_exec_id, + "graphId": graph_id, + "graphVersion": graph_version, + "payload": SafeJson(payload), + "instructions": "Auto-approval record", + "editable": False, + "status": ReviewStatus.APPROVED, + "processed": True, + "reviewedAt": datetime.now(timezone.utc), + }, + "update": {}, # Already exists, no update needed + }, + ) + + async def get_or_create_human_review( user_id: str, node_exec_id: str, @@ -108,6 +231,89 @@ async def get_or_create_human_review( ) +async def get_pending_review_by_node_exec_id( + node_exec_id: str, user_id: str +) -> Optional["PendingHumanReviewModel"]: + """ + Get a pending review by its node execution ID. + + Args: + node_exec_id: The node execution ID to look up + user_id: User ID for authorization (only returns if review belongs to this user) + + Returns: + The pending review if found and belongs to user, None otherwise + """ + review = await PendingHumanReview.prisma().find_first( + where={ + "nodeExecId": node_exec_id, + "userId": user_id, + "status": ReviewStatus.WAITING, + } + ) + + if not review: + return None + + # Local import to avoid event loop conflicts in tests + from backend.data.execution import get_node_execution + + node_exec = await get_node_execution(review.nodeExecId) + node_id = node_exec.node_id if node_exec else review.nodeExecId + return PendingHumanReviewModel.from_db(review, node_id=node_id) + + +async def get_reviews_by_node_exec_ids( + node_exec_ids: list[str], user_id: str +) -> dict[str, "PendingHumanReviewModel"]: + """ + Get multiple reviews by their node execution IDs regardless of status. + + Unlike get_pending_reviews_by_node_exec_ids, this returns reviews in any status + (WAITING, APPROVED, REJECTED). Used for validation in idempotent operations. + + Args: + node_exec_ids: List of node execution IDs to look up + user_id: User ID for authorization (only returns reviews belonging to this user) + + Returns: + Dictionary mapping node_exec_id -> PendingHumanReviewModel for found reviews + """ + if not node_exec_ids: + return {} + + reviews = await PendingHumanReview.prisma().find_many( + where={ + "nodeExecId": {"in": node_exec_ids}, + "userId": user_id, + } + ) + + if not reviews: + return {} + + # Batch fetch all node executions to avoid N+1 queries + node_exec_ids_to_fetch = [review.nodeExecId for review in reviews] + node_execs = await AgentNodeExecution.prisma().find_many( + where={"id": {"in": node_exec_ids_to_fetch}}, + include={"Node": True}, + ) + + # Create mapping from node_exec_id to node_id + node_exec_id_to_node_id = { + node_exec.id: node_exec.agentNodeId for node_exec in node_execs + } + + result = {} + for review in reviews: + node_id = node_exec_id_to_node_id.get(review.nodeExecId, review.nodeExecId) + result[review.nodeExecId] = PendingHumanReviewModel.from_db( + review, node_id=node_id + ) + + return result + + async def has_pending_reviews_for_graph_exec(graph_exec_id: str) -> bool: """ Check if a graph execution has any pending reviews. @@ -137,8 +343,11 @@ async def get_pending_reviews_for_user( page_size: Number of reviews per page Returns: - List of pending review models + List of pending review models with node_id included """ + # Local import to avoid event loop conflicts in tests + from backend.data.execution import get_node_execution + # Calculate offset for pagination offset = (page - 1) * page_size @@ -149,7 +358,14 @@ async def get_pending_reviews_for_user( take=page_size, ) - return [PendingHumanReviewModel.from_db(review) for review in reviews] + # Fetch node_id for each review from NodeExecution + result = [] + for review in reviews: + node_exec = await get_node_execution(review.nodeExecId) + node_id = node_exec.node_id if node_exec else review.nodeExecId + result.append(PendingHumanReviewModel.from_db(review, node_id=node_id)) + + return result async def get_pending_reviews_for_execution( @@ -163,8 +379,11 @@ async def get_pending_reviews_for_execution( user_id: User ID for security validation Returns: - List of pending review models + List of pending review models with node_id included """ + # Local import to avoid event loop conflicts in tests + from backend.data.execution import get_node_execution + reviews = await PendingHumanReview.prisma().find_many( where={ "userId": user_id, @@ -174,7 +393,14 @@ async def get_pending_reviews_for_execution( order={"createdAt": "asc"}, ) - return [PendingHumanReviewModel.from_db(review) for review in reviews] + # Fetch node_id for each review from NodeExecution + result = [] + for review in reviews: + node_exec = await get_node_execution(review.nodeExecId) + node_id = node_exec.node_id if node_exec else review.nodeExecId + result.append(PendingHumanReviewModel.from_db(review, node_id=node_id)) + + return result async def process_all_reviews_for_execution( @@ -183,38 +409,68 @@ async def process_all_reviews_for_execution( ) -> dict[str, PendingHumanReviewModel]: """Process all pending reviews for an execution with approve/reject decisions. + Handles race conditions gracefully: if a review was already processed with the + same decision by a concurrent request, it's treated as success rather than error. + Args: user_id: User ID for ownership validation review_decisions: Map of node_exec_id -> (status, reviewed_data, message) Returns: - Dict of node_exec_id -> updated review model + Dict of node_exec_id -> updated review model (includes already-processed reviews) """ if not review_decisions: return {} node_exec_ids = list(review_decisions.keys()) - # Get all reviews for validation - reviews = await PendingHumanReview.prisma().find_many( + # Get all reviews (both WAITING and already processed) for the user + all_reviews = await PendingHumanReview.prisma().find_many( where={ "nodeExecId": {"in": node_exec_ids}, "userId": user_id, - "status": ReviewStatus.WAITING, }, ) - # Validate all reviews can be processed - if len(reviews) != len(node_exec_ids): - missing_ids = set(node_exec_ids) - {review.nodeExecId for review in reviews} + # Separate into pending and already-processed reviews + reviews_to_process = [] + already_processed = [] + for review in all_reviews: + if review.status == ReviewStatus.WAITING: + reviews_to_process.append(review) + else: + already_processed.append(review) + + # Check for truly missing reviews (not found at all) + found_ids = {review.nodeExecId for review in all_reviews} + missing_ids = set(node_exec_ids) - found_ids + if missing_ids: raise ValueError( - f"Reviews not found, access denied, or not in WAITING status: {', '.join(missing_ids)}" + f"Reviews not found or access denied: {', '.join(missing_ids)}" ) - # Create parallel update tasks + # Validate already-processed reviews have compatible status (same decision) + # This handles race conditions where another request processed the same reviews + for review in already_processed: + requested_status = review_decisions[review.nodeExecId][0] + if review.status != requested_status: + raise ValueError( + f"Review {review.nodeExecId} was already processed with status " + f"{review.status}, cannot change to {requested_status}" + ) + + # Log if we're handling a race condition (some reviews already processed) + if already_processed: + already_processed_ids = [r.nodeExecId for r in already_processed] + logger.info( + f"Race condition handled: {len(already_processed)} review(s) already " + f"processed by concurrent request: {already_processed_ids}" + ) + + # Create parallel update tasks for reviews that still need processing update_tasks = [] - for review in reviews: + for review in reviews_to_process: new_status, reviewed_data, message = review_decisions[review.nodeExecId] has_data_changes = reviewed_data is not None and reviewed_data != review.payload @@ -239,16 +495,27 @@ async def process_all_reviews_for_execution( update_tasks.append(task) # Execute all updates in parallel and get updated reviews - updated_reviews = await asyncio.gather(*update_tasks) + updated_reviews = await asyncio.gather(*update_tasks) if update_tasks else [] # Note: Execution resumption is now handled at the API layer after ALL reviews # for an execution are processed (both approved and rejected) - # Return as dict for easy access - return { - review.nodeExecId: PendingHumanReviewModel.from_db(review) - for review in updated_reviews - } + # Fetch node_id for each review and return as dict for easy access + # Local import to avoid event loop conflicts in tests + from backend.data.execution import get_node_execution + + # Combine updated reviews with already-processed ones (for idempotent response) + all_result_reviews = list(updated_reviews) + already_processed + + result = {} + for review in all_result_reviews: + node_exec = await get_node_execution(review.nodeExecId) + node_id = node_exec.node_id if node_exec else review.nodeExecId + result[review.nodeExecId] = PendingHumanReviewModel.from_db( + review, node_id=node_id + ) + + return result async def update_review_processed_status(node_exec_id: str, processed: bool) -> None: @@ -256,3 +523,44 @@ async def update_review_processed_status(node_exec_id: str, processed: bool) -> await PendingHumanReview.prisma().update( where={"nodeExecId": node_exec_id}, data={"processed": processed} ) + + +async def cancel_pending_reviews_for_execution(graph_exec_id: str, user_id: str) -> int: + """ + Cancel all pending reviews for a graph execution (e.g., when execution is stopped). + + Marks all WAITING reviews as REJECTED with a message indicating the execution was stopped. + + Args: + graph_exec_id: The graph execution ID + user_id: User ID who owns the execution (for security validation) + + Returns: + Number of reviews cancelled + + Raises: + ValueError: If the graph execution doesn't belong to the user + """ + # Validate user ownership before cancelling reviews + graph_exec = await get_graph_execution_meta( + user_id=user_id, execution_id=graph_exec_id + ) + if not graph_exec: + raise ValueError( + f"Graph execution {graph_exec_id} not found or doesn't belong to user {user_id}" + ) + + result = await PendingHumanReview.prisma().update_many( + where={ + "graphExecId": graph_exec_id, + "userId": user_id, + "status": ReviewStatus.WAITING, + }, + data={ + "status": ReviewStatus.REJECTED, + "reviewMessage": "Execution was stopped by user", + "processed": True, + "reviewedAt": datetime.now(timezone.utc), + }, + ) + return result diff --git a/autogpt_platform/backend/backend/data/human_review_test.py b/autogpt_platform/backend/backend/data/human_review_test.py index c349fdde46..baa5c0c0c4 100644 --- a/autogpt_platform/backend/backend/data/human_review_test.py +++ b/autogpt_platform/backend/backend/data/human_review_test.py @@ -36,7 +36,7 @@ def sample_db_review(): return mock_review -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="function") async def test_get_or_create_human_review_new( mocker: pytest_mock.MockFixture, sample_db_review, @@ -46,8 +46,8 @@ async def test_get_or_create_human_review_new( sample_db_review.status = ReviewStatus.WAITING sample_db_review.processed = False - mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma") - mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review) + mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma") + mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review) result = await get_or_create_human_review( user_id="test-user-123", @@ -64,7 +64,7 @@ async def test_get_or_create_human_review_new( assert result is None -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="function") async def test_get_or_create_human_review_approved( mocker: pytest_mock.MockFixture, sample_db_review, @@ -75,8 +75,8 @@ async def test_get_or_create_human_review_approved( sample_db_review.processed = False sample_db_review.reviewMessage = "Looks good" - mock_upsert = mocker.patch("backend.data.human_review.PendingHumanReview.prisma") - mock_upsert.return_value.upsert = AsyncMock(return_value=sample_db_review) + mock_prisma = mocker.patch("backend.data.human_review.PendingHumanReview.prisma") + mock_prisma.return_value.upsert = AsyncMock(return_value=sample_db_review) result = await get_or_create_human_review( user_id="test-user-123", @@ -96,7 +96,7 @@ async def test_get_or_create_human_review_approved( assert result.message == "Looks good" -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="function") async def test_has_pending_reviews_for_graph_exec_true( mocker: pytest_mock.MockFixture, ): @@ -109,7 +109,7 @@ async def test_has_pending_reviews_for_graph_exec_true( assert result is True -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="function") async def test_has_pending_reviews_for_graph_exec_false( mocker: pytest_mock.MockFixture, ): @@ -122,7 +122,7 @@ async def test_has_pending_reviews_for_graph_exec_false( assert result is False -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="function") async def test_get_pending_reviews_for_user( mocker: pytest_mock.MockFixture, sample_db_review, @@ -131,10 +131,19 @@ async def test_get_pending_reviews_for_user( mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma") mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review]) + # Mock get_node_execution to return node with node_id (async function) + mock_node_exec = Mock() + mock_node_exec.node_id = "test_node_def_789" + mocker.patch( + "backend.data.execution.get_node_execution", + new=AsyncMock(return_value=mock_node_exec), + ) + result = await get_pending_reviews_for_user("test_user", page=2, page_size=10) assert len(result) == 1 assert result[0].node_exec_id == "test_node_123" + assert result[0].node_id == "test_node_def_789" # Verify pagination parameters call_args = mock_find_many.return_value.find_many.call_args @@ -142,7 +151,7 @@ async def test_get_pending_reviews_for_user( assert call_args.kwargs["take"] == 10 -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="function") async def test_get_pending_reviews_for_execution( mocker: pytest_mock.MockFixture, sample_db_review, @@ -151,12 +160,21 @@ async def test_get_pending_reviews_for_execution( mock_find_many = mocker.patch("backend.data.human_review.PendingHumanReview.prisma") mock_find_many.return_value.find_many = AsyncMock(return_value=[sample_db_review]) + # Mock get_node_execution to return node with node_id (async function) + mock_node_exec = Mock() + mock_node_exec.node_id = "test_node_def_789" + mocker.patch( + "backend.data.execution.get_node_execution", + new=AsyncMock(return_value=mock_node_exec), + ) + result = await get_pending_reviews_for_execution( "test_graph_exec_456", "test-user-123" ) assert len(result) == 1 assert result[0].graph_exec_id == "test_graph_exec_456" + assert result[0].node_id == "test_node_def_789" # Verify it filters by execution and user call_args = mock_find_many.return_value.find_many.call_args @@ -166,7 +184,7 @@ async def test_get_pending_reviews_for_execution( assert where_clause["status"] == ReviewStatus.WAITING -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="function") async def test_process_all_reviews_for_execution_success( mocker: pytest_mock.MockFixture, sample_db_review, @@ -201,6 +219,14 @@ async def test_process_all_reviews_for_execution_success( new=AsyncMock(return_value=[updated_review]), ) + # Mock get_node_execution to return node with node_id (async function) + mock_node_exec = Mock() + mock_node_exec.node_id = "test_node_def_789" + mocker.patch( + "backend.data.execution.get_node_execution", + new=AsyncMock(return_value=mock_node_exec), + ) + result = await process_all_reviews_for_execution( user_id="test-user-123", review_decisions={ @@ -211,9 +237,10 @@ async def test_process_all_reviews_for_execution_success( assert len(result) == 1 assert "test_node_123" in result assert result["test_node_123"].status == ReviewStatus.APPROVED + assert result["test_node_123"].node_id == "test_node_def_789" -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="function") async def test_process_all_reviews_for_execution_validation_errors( mocker: pytest_mock.MockFixture, ): @@ -233,7 +260,7 @@ async def test_process_all_reviews_for_execution_validation_errors( ) -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="function") async def test_process_all_reviews_edit_permission_error( mocker: pytest_mock.MockFixture, sample_db_review, @@ -259,7 +286,7 @@ async def test_process_all_reviews_edit_permission_error( ) -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="function") async def test_process_all_reviews_mixed_approval_rejection( mocker: pytest_mock.MockFixture, sample_db_review, @@ -329,6 +356,14 @@ async def test_process_all_reviews_mixed_approval_rejection( new=AsyncMock(return_value=[approved_review, rejected_review]), ) + # Mock get_node_execution to return node with node_id (async function) + mock_node_exec = Mock() + mock_node_exec.node_id = "test_node_def_789" + mocker.patch( + "backend.data.execution.get_node_execution", + new=AsyncMock(return_value=mock_node_exec), + ) + result = await process_all_reviews_for_execution( user_id="test-user-123", review_decisions={ @@ -340,3 +375,5 @@ async def test_process_all_reviews_mixed_approval_rejection( assert len(result) == 2 assert "test_node_123" in result assert "test_node_456" in result + assert result["test_node_123"].node_id == "test_node_def_789" + assert result["test_node_456"].node_id == "test_node_def_789" diff --git a/autogpt_platform/backend/backend/data/onboarding.py b/autogpt_platform/backend/backend/data/onboarding.py index 6a842d1022..4af8e8dffd 100644 --- a/autogpt_platform/backend/backend/data/onboarding.py +++ b/autogpt_platform/backend/backend/data/onboarding.py @@ -41,6 +41,7 @@ FrontendOnboardingStep = Literal[ OnboardingStep.AGENT_NEW_RUN, OnboardingStep.AGENT_INPUT, OnboardingStep.CONGRATS, + OnboardingStep.VISIT_COPILOT, OnboardingStep.MARKETPLACE_VISIT, OnboardingStep.BUILDER_OPEN, ] @@ -122,6 +123,9 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate): async def _reward_user(user_id: str, onboarding: UserOnboarding, step: OnboardingStep): reward = 0 match step: + # Welcome bonus for visiting copilot ($5 = 500 credits) + case OnboardingStep.VISIT_COPILOT: + reward = 500 # Reward user when they clicked New Run during onboarding # This is because they need credits before scheduling a run (next step) # This is seen as a reward for the GET_RESULTS step in the wallet diff --git a/autogpt_platform/backend/backend/data/workspace.py b/autogpt_platform/backend/backend/data/workspace.py new file mode 100644 index 0000000000..f3dba0a294 --- /dev/null +++ b/autogpt_platform/backend/backend/data/workspace.py @@ -0,0 +1,276 @@ +""" +Database CRUD operations for User Workspace. + +This module provides functions for managing user workspaces and workspace files. +""" + +import logging +from datetime import datetime, timezone +from typing import Optional + +from prisma.models import UserWorkspace, UserWorkspaceFile +from prisma.types import UserWorkspaceFileWhereInput + +from backend.util.json import SafeJson + +logger = logging.getLogger(__name__) + + +async def get_or_create_workspace(user_id: str) -> UserWorkspace: + """ + Get user's workspace, creating one if it doesn't exist. + + Uses upsert to handle race conditions when multiple concurrent requests + attempt to create a workspace for the same user. + + Args: + user_id: The user's ID + + Returns: + UserWorkspace instance + """ + workspace = await UserWorkspace.prisma().upsert( + where={"userId": user_id}, + data={ + "create": {"userId": user_id}, + "update": {}, # No updates needed if exists + }, + ) + + return workspace + + +async def get_workspace(user_id: str) -> Optional[UserWorkspace]: + """ + Get user's workspace if it exists. + + Args: + user_id: The user's ID + + Returns: + UserWorkspace instance or None + """ + return await UserWorkspace.prisma().find_unique(where={"userId": user_id}) + + +async def create_workspace_file( + workspace_id: str, + file_id: str, + name: str, + path: str, + storage_path: str, + mime_type: str, + size_bytes: int, + checksum: Optional[str] = None, + metadata: Optional[dict] = None, +) -> UserWorkspaceFile: + """ + Create a new workspace file record. + + Args: + workspace_id: The workspace ID + file_id: The file ID (same as used in storage path for consistency) + name: User-visible filename + path: Virtual path (e.g., "/documents/report.pdf") + storage_path: Actual storage path (GCS or local) + mime_type: MIME type of the file + size_bytes: File size in bytes + checksum: Optional SHA256 checksum + metadata: Optional additional metadata + + Returns: + Created UserWorkspaceFile instance + """ + # Normalize path to start with / + if not path.startswith("/"): + path = f"/{path}" + + file = await UserWorkspaceFile.prisma().create( + data={ + "id": file_id, + "workspaceId": workspace_id, + "name": name, + "path": path, + "storagePath": storage_path, + "mimeType": mime_type, + "sizeBytes": size_bytes, + "checksum": checksum, + "metadata": SafeJson(metadata or {}), + } + ) + + logger.info( + f"Created workspace file {file.id} at path {path} " + f"in workspace {workspace_id}" + ) + return file + + +async def get_workspace_file( + file_id: str, + workspace_id: Optional[str] = None, +) -> Optional[UserWorkspaceFile]: + """ + Get a workspace file by ID. + + Args: + file_id: The file ID + workspace_id: Optional workspace ID for validation + + Returns: + UserWorkspaceFile instance or None + """ + where_clause: dict = {"id": file_id, "isDeleted": False} + if workspace_id: + where_clause["workspaceId"] = workspace_id + + return await UserWorkspaceFile.prisma().find_first(where=where_clause) + + +async def get_workspace_file_by_path( + workspace_id: str, + path: str, +) -> Optional[UserWorkspaceFile]: + """ + Get a workspace file by its virtual path. + + Args: + workspace_id: The workspace ID + path: Virtual path + + Returns: + UserWorkspaceFile instance or None + """ + # Normalize path + if not path.startswith("/"): + path = f"/{path}" + + return await UserWorkspaceFile.prisma().find_first( + where={ + "workspaceId": workspace_id, + "path": path, + "isDeleted": False, + } + ) + + +async def list_workspace_files( + workspace_id: str, + path_prefix: Optional[str] = None, + include_deleted: bool = False, + limit: Optional[int] = None, + offset: int = 0, +) -> list[UserWorkspaceFile]: + """ + List files in a workspace. + + Args: + workspace_id: The workspace ID + path_prefix: Optional path prefix to filter (e.g., "/documents/") + include_deleted: Whether to include soft-deleted files + limit: Maximum number of files to return + offset: Number of files to skip + + Returns: + List of UserWorkspaceFile instances + """ + where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id} + + if not include_deleted: + where_clause["isDeleted"] = False + + if path_prefix: + # Normalize prefix + if not path_prefix.startswith("/"): + path_prefix = f"/{path_prefix}" + where_clause["path"] = {"startswith": path_prefix} + + return await UserWorkspaceFile.prisma().find_many( + where=where_clause, + order={"createdAt": "desc"}, + take=limit, + skip=offset, + ) + + +async def count_workspace_files( + workspace_id: str, + path_prefix: Optional[str] = None, + include_deleted: bool = False, +) -> int: + """ + Count files in a workspace. + + Args: + workspace_id: The workspace ID + path_prefix: Optional path prefix to filter (e.g., "/sessions/abc123/") + include_deleted: Whether to include soft-deleted files + + Returns: + Number of files + """ + where_clause: dict = {"workspaceId": workspace_id} + if not include_deleted: + where_clause["isDeleted"] = False + + if path_prefix: + # Normalize prefix + if not path_prefix.startswith("/"): + path_prefix = f"/{path_prefix}" + where_clause["path"] = {"startswith": path_prefix} + + return await UserWorkspaceFile.prisma().count(where=where_clause) + + +async def soft_delete_workspace_file( + file_id: str, + workspace_id: Optional[str] = None, +) -> Optional[UserWorkspaceFile]: + """ + Soft-delete a workspace file. + + The path is modified to include a deletion timestamp to free up the original + path for new files while preserving the record for potential recovery. + + Args: + file_id: The file ID + workspace_id: Optional workspace ID for validation + + Returns: + Updated UserWorkspaceFile instance or None if not found + """ + # First verify the file exists and belongs to workspace + file = await get_workspace_file(file_id, workspace_id) + if file is None: + return None + + deleted_at = datetime.now(timezone.utc) + # Modify path to free up the unique constraint for new files at original path + # Format: {original_path}__deleted__{timestamp} + deleted_path = f"{file.path}__deleted__{int(deleted_at.timestamp())}" + + updated = await UserWorkspaceFile.prisma().update( + where={"id": file_id}, + data={ + "isDeleted": True, + "deletedAt": deleted_at, + "path": deleted_path, + }, + ) + + logger.info(f"Soft-deleted workspace file {file_id}") + return updated + + +async def get_workspace_total_size(workspace_id: str) -> int: + """ + Get the total size of all files in a workspace. + + Args: + workspace_id: The workspace ID + + Returns: + Total size in bytes + """ + files = await list_workspace_files(workspace_id) + return sum(file.sizeBytes for file in files) diff --git a/autogpt_platform/backend/backend/executor/database.py b/autogpt_platform/backend/backend/executor/database.py index ac381bbd67..ae7474fc1d 100644 --- a/autogpt_platform/backend/backend/executor/database.py +++ b/autogpt_platform/backend/backend/executor/database.py @@ -50,6 +50,8 @@ from backend.data.graph import ( validate_graph_execution_permissions, ) from backend.data.human_review import ( + cancel_pending_reviews_for_execution, + check_approval, get_or_create_human_review, has_pending_reviews_for_graph_exec, update_review_processed_status, @@ -190,6 +192,8 @@ class DatabaseManager(AppService): get_user_notification_preference = _(get_user_notification_preference) # Human In The Loop + cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution) + check_approval = _(check_approval) get_or_create_human_review = _(get_or_create_human_review) has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec) update_review_processed_status = _(update_review_processed_status) @@ -313,6 +317,8 @@ class DatabaseManagerAsyncClient(AppServiceClient): set_execution_kv_data = d.set_execution_kv_data # Human In The Loop + cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution + check_approval = d.check_approval get_or_create_human_review = d.get_or_create_human_review update_review_processed_status = d.update_review_processed_status diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index 39d4f984eb..8362dae828 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -236,7 +236,14 @@ async def execute_node( input_size = len(input_data_str) log_metadata.debug("Executed node with input", input=input_data_str) + # Create node-specific execution context to avoid race conditions + # (multiple nodes can execute concurrently and would otherwise mutate shared state) + execution_context = execution_context.model_copy( + update={"node_id": node_id, "node_exec_id": node_exec_id} + ) + # Inject extra execution arguments for the blocks via kwargs + # Keep individual kwargs for backwards compatibility with existing blocks extra_exec_kwargs: dict = { "graph_id": graph_id, "graph_version": graph_version, diff --git a/autogpt_platform/backend/backend/executor/utils.py b/autogpt_platform/backend/backend/executor/utils.py index 7771c3751c..fa264c30a7 100644 --- a/autogpt_platform/backend/backend/executor/utils.py +++ b/autogpt_platform/backend/backend/executor/utils.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, JsonValue, ValidationError from backend.data import execution as execution_db from backend.data import graph as graph_db +from backend.data import human_review as human_review_db from backend.data import onboarding as onboarding_db from backend.data import user as user_db from backend.data.block import ( @@ -749,9 +750,27 @@ async def stop_graph_execution( if graph_exec.status in [ ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE, + ExecutionStatus.REVIEW, ]: - # If the graph is still on the queue, we can prevent them from being executed - # by setting the status to TERMINATED. + # If the graph is queued/incomplete/paused for review, terminate immediately + # No need to wait for executor since it's not actively running + + # If graph is in REVIEW status, clean up pending reviews before terminating + if graph_exec.status == ExecutionStatus.REVIEW: + # Use human_review_db if Prisma connected, else database manager + review_db = ( + human_review_db + if prisma.is_connected() + else get_database_manager_async_client() + ) + # Mark all pending reviews as rejected/cancelled + cancelled_count = await review_db.cancel_pending_reviews_for_execution( + graph_exec_id, user_id + ) + logger.info( + f"Cancelled {cancelled_count} pending review(s) for stopped execution {graph_exec_id}" + ) + graph_exec.status = ExecutionStatus.TERMINATED await asyncio.gather( @@ -873,11 +892,19 @@ async def add_graph_execution( settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id) execution_context = ExecutionContext( + # Execution identity + user_id=user_id, + graph_id=graph_id, + graph_exec_id=graph_exec.id, + graph_version=graph_exec.graph_version, + # Safety settings human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode, sensitive_action_safe_mode=settings.sensitive_action_safe_mode, + # User settings user_timezone=( user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC" ), + # Execution hierarchy root_execution_id=graph_exec.id, ) @@ -887,9 +914,28 @@ async def add_graph_execution( nodes_to_skip=nodes_to_skip, execution_context=execution_context, ) - logger.info(f"Publishing execution {graph_exec.id} to execution queue") + logger.info(f"Queueing execution {graph_exec.id}") + + # Update execution status to QUEUED BEFORE publishing to prevent race condition + # where two concurrent requests could both publish the same execution + updated_exec = await edb.update_graph_execution_stats( + graph_exec_id=graph_exec.id, + status=ExecutionStatus.QUEUED, + ) + + # Verify the status update succeeded (prevents duplicate queueing in race conditions) + # If another request already updated the status, this execution will not be QUEUED + if not updated_exec or updated_exec.status != ExecutionStatus.QUEUED: + logger.warning( + f"Skipping queue publish for execution {graph_exec.id} - " + f"status update failed or execution already queued by another request" + ) + return graph_exec + + graph_exec.status = ExecutionStatus.QUEUED # Publish to execution queue for executor to pick up + # This happens AFTER status update to ensure only one request publishes exec_queue = await get_async_execution_queue() await exec_queue.publish_message( routing_key=GRAPH_EXECUTION_ROUTING_KEY, @@ -897,13 +943,6 @@ async def add_graph_execution( exchange=GRAPH_EXECUTION_EXCHANGE, ) logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue") - - # Update execution status to QUEUED - graph_exec.status = ExecutionStatus.QUEUED - await edb.update_graph_execution_stats( - graph_exec_id=graph_exec.id, - status=graph_exec.status, - ) except BaseException as e: err = str(e) or type(e).__name__ if not graph_exec: diff --git a/autogpt_platform/backend/backend/executor/utils_test.py b/autogpt_platform/backend/backend/executor/utils_test.py index e6e8fcbf60..db33249583 100644 --- a/autogpt_platform/backend/backend/executor/utils_test.py +++ b/autogpt_platform/backend/backend/executor/utils_test.py @@ -4,6 +4,7 @@ import pytest from pytest_mock import MockerFixture from backend.data.dynamic_fields import merge_execution_input, parse_execution_output +from backend.data.execution import ExecutionStatus from backend.util.mock import MockObject @@ -346,6 +347,8 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture): mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes) mock_graph_exec.id = "execution-id-123" mock_graph_exec.node_executions = [] # Add this to avoid AttributeError + mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check + mock_graph_exec.graph_version = graph_version mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock() # Mock the queue and event bus @@ -432,6 +435,9 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture): # Create a second mock execution for the sanity check mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes) mock_graph_exec_2.id = "execution-id-456" + mock_graph_exec_2.node_executions = [] + mock_graph_exec_2.status = ExecutionStatus.QUEUED + mock_graph_exec_2.graph_version = graph_version mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock() # Reset mocks and set up for second call @@ -611,6 +617,8 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture): mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes) mock_graph_exec.id = "execution-id-123" mock_graph_exec.node_executions = [] + mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check + mock_graph_exec.graph_version = graph_version # Track what's passed to to_graph_execution_entry captured_kwargs = {} @@ -670,3 +678,232 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture): # Verify nodes_to_skip was passed to to_graph_execution_entry assert "nodes_to_skip" in captured_kwargs assert captured_kwargs["nodes_to_skip"] == nodes_to_skip + + +@pytest.mark.asyncio +async def test_stop_graph_execution_in_review_status_cancels_pending_reviews( + mocker: MockerFixture, +): + """Test that stopping an execution in REVIEW status cancels pending reviews.""" + from backend.data.execution import ExecutionStatus, GraphExecutionMeta + from backend.executor.utils import stop_graph_execution + + user_id = "test-user" + graph_exec_id = "test-exec-123" + + # Mock graph execution in REVIEW status + mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta) + mock_graph_exec.id = graph_exec_id + mock_graph_exec.status = ExecutionStatus.REVIEW + + # Mock dependencies + mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue") + mock_queue_client = mocker.AsyncMock() + mock_get_queue.return_value = mock_queue_client + + mock_prisma = mocker.patch("backend.executor.utils.prisma") + mock_prisma.is_connected.return_value = True + + mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db") + mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock( + return_value=2 # 2 reviews cancelled + ) + + mock_execution_db = mocker.patch("backend.executor.utils.execution_db") + mock_execution_db.get_graph_execution_meta = mocker.AsyncMock( + return_value=mock_graph_exec + ) + mock_execution_db.update_graph_execution_stats = mocker.AsyncMock() + + mock_get_event_bus = mocker.patch( + "backend.executor.utils.get_async_execution_event_bus" + ) + mock_event_bus = mocker.MagicMock() + mock_event_bus.publish = mocker.AsyncMock() + mock_get_event_bus.return_value = mock_event_bus + + mock_get_child_executions = mocker.patch( + "backend.executor.utils._get_child_executions" + ) + mock_get_child_executions.return_value = [] # No children + + # Call stop_graph_execution with timeout to allow status check + await stop_graph_execution( + user_id=user_id, + graph_exec_id=graph_exec_id, + wait_timeout=1.0, # Wait to allow status check + cascade=True, + ) + + # Verify pending reviews were cancelled + mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with( + graph_exec_id, user_id + ) + + # Verify execution status was updated to TERMINATED + mock_execution_db.update_graph_execution_stats.assert_called_once() + call_kwargs = mock_execution_db.update_graph_execution_stats.call_args[1] + assert call_kwargs["graph_exec_id"] == graph_exec_id + assert call_kwargs["status"] == ExecutionStatus.TERMINATED + + +@pytest.mark.asyncio +async def test_stop_graph_execution_with_database_manager_when_prisma_disconnected( + mocker: MockerFixture, +): + """Test that stop uses database manager when Prisma is not connected.""" + from backend.data.execution import ExecutionStatus, GraphExecutionMeta + from backend.executor.utils import stop_graph_execution + + user_id = "test-user" + graph_exec_id = "test-exec-456" + + # Mock graph execution in REVIEW status + mock_graph_exec = mocker.MagicMock(spec=GraphExecutionMeta) + mock_graph_exec.id = graph_exec_id + mock_graph_exec.status = ExecutionStatus.REVIEW + + # Mock dependencies + mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue") + mock_queue_client = mocker.AsyncMock() + mock_get_queue.return_value = mock_queue_client + + # Prisma is NOT connected + mock_prisma = mocker.patch("backend.executor.utils.prisma") + mock_prisma.is_connected.return_value = False + + # Mock database manager client + mock_get_db_manager = mocker.patch( + "backend.executor.utils.get_database_manager_async_client" + ) + mock_db_manager = mocker.AsyncMock() + mock_db_manager.get_graph_execution_meta = mocker.AsyncMock( + return_value=mock_graph_exec + ) + mock_db_manager.cancel_pending_reviews_for_execution = mocker.AsyncMock( + return_value=3 # 3 reviews cancelled + ) + mock_db_manager.update_graph_execution_stats = mocker.AsyncMock() + mock_get_db_manager.return_value = mock_db_manager + + mock_get_event_bus = mocker.patch( + "backend.executor.utils.get_async_execution_event_bus" + ) + mock_event_bus = mocker.MagicMock() + mock_event_bus.publish = mocker.AsyncMock() + mock_get_event_bus.return_value = mock_event_bus + + mock_get_child_executions = mocker.patch( + "backend.executor.utils._get_child_executions" + ) + mock_get_child_executions.return_value = [] # No children + + # Call stop_graph_execution with timeout + await stop_graph_execution( + user_id=user_id, + graph_exec_id=graph_exec_id, + wait_timeout=1.0, + cascade=True, + ) + + # Verify database manager was used for cancel_pending_reviews + mock_db_manager.cancel_pending_reviews_for_execution.assert_called_once_with( + graph_exec_id, user_id + ) + + # Verify execution status was updated via database manager + mock_db_manager.update_graph_execution_stats.assert_called_once() + + +@pytest.mark.asyncio +async def test_stop_graph_execution_cascades_to_child_with_reviews( + mocker: MockerFixture, +): + """Test that stopping parent execution cascades to children and cancels their reviews.""" + from backend.data.execution import ExecutionStatus, GraphExecutionMeta + from backend.executor.utils import stop_graph_execution + + user_id = "test-user" + parent_exec_id = "parent-exec" + child_exec_id = "child-exec" + + # Mock parent execution in RUNNING status + mock_parent_exec = mocker.MagicMock(spec=GraphExecutionMeta) + mock_parent_exec.id = parent_exec_id + mock_parent_exec.status = ExecutionStatus.RUNNING + + # Mock child execution in REVIEW status + mock_child_exec = mocker.MagicMock(spec=GraphExecutionMeta) + mock_child_exec.id = child_exec_id + mock_child_exec.status = ExecutionStatus.REVIEW + + # Mock dependencies + mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue") + mock_queue_client = mocker.AsyncMock() + mock_get_queue.return_value = mock_queue_client + + mock_prisma = mocker.patch("backend.executor.utils.prisma") + mock_prisma.is_connected.return_value = True + + mock_human_review_db = mocker.patch("backend.executor.utils.human_review_db") + mock_human_review_db.cancel_pending_reviews_for_execution = mocker.AsyncMock( + return_value=1 # 1 child review cancelled + ) + + # Mock execution_db to return different status based on which execution is queried + mock_execution_db = mocker.patch("backend.executor.utils.execution_db") + + # Track call count to simulate status transition + call_count = {"count": 0} + + async def get_exec_meta_side_effect(execution_id, user_id): + call_count["count"] += 1 + if execution_id == parent_exec_id: + # After a few calls (child processing happens), transition parent to TERMINATED + # This simulates the executor service processing the stop request + if call_count["count"] > 3: + mock_parent_exec.status = ExecutionStatus.TERMINATED + return mock_parent_exec + elif execution_id == child_exec_id: + return mock_child_exec + return None + + mock_execution_db.get_graph_execution_meta = mocker.AsyncMock( + side_effect=get_exec_meta_side_effect + ) + mock_execution_db.update_graph_execution_stats = mocker.AsyncMock() + + mock_get_event_bus = mocker.patch( + "backend.executor.utils.get_async_execution_event_bus" + ) + mock_event_bus = mocker.MagicMock() + mock_event_bus.publish = mocker.AsyncMock() + mock_get_event_bus.return_value = mock_event_bus + + # Mock _get_child_executions to return the child + mock_get_child_executions = mocker.patch( + "backend.executor.utils._get_child_executions" + ) + + def get_children_side_effect(parent_id): + if parent_id == parent_exec_id: + return [mock_child_exec] + return [] + + mock_get_child_executions.side_effect = get_children_side_effect + + # Call stop_graph_execution on parent with cascade=True + await stop_graph_execution( + user_id=user_id, + graph_exec_id=parent_exec_id, + wait_timeout=1.0, + cascade=True, + ) + + # Verify child reviews were cancelled + mock_human_review_db.cancel_pending_reviews_for_execution.assert_called_once_with( + child_exec_id, user_id + ) + + # Verify both parent and child status updates + assert mock_execution_db.update_graph_execution_stats.call_count >= 1 diff --git a/autogpt_platform/backend/backend/util/cloud_storage.py b/autogpt_platform/backend/backend/util/cloud_storage.py index 93fb9039ec..28423d003d 100644 --- a/autogpt_platform/backend/backend/util/cloud_storage.py +++ b/autogpt_platform/backend/backend/util/cloud_storage.py @@ -13,6 +13,7 @@ import aiohttp from gcloud.aio import storage as async_gcs_storage from google.cloud import storage as gcs_storage +from backend.util.gcs_utils import download_with_fresh_session, generate_signed_url from backend.util.settings import Config logger = logging.getLogger(__name__) @@ -251,7 +252,7 @@ class CloudStorageHandler: f"in_task: {current_task is not None}" ) - # Parse bucket and blob name from path + # Parse bucket and blob name from path (path already has gcs:// prefix removed) parts = path.split("/", 1) if len(parts) != 2: raise ValueError(f"Invalid GCS path: {path}") @@ -261,50 +262,19 @@ class CloudStorageHandler: # Authorization check self._validate_file_access(blob_name, user_id, graph_exec_id) - # Use a fresh client for each download to avoid session issues - # This is less efficient but more reliable with the executor's event loop - logger.info("[CloudStorage] Creating fresh GCS client for download") - - # Create a new session specifically for this download - session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(limit=10, force_close=True) + logger.info( + f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}" ) - async_client = None try: - # Create a new GCS client with the fresh session - async_client = async_gcs_storage.Storage(session=session) - - logger.info( - f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}" - ) - - # Download content using the fresh client - content = await async_client.download(bucket_name, blob_name) + content = await download_with_fresh_session(bucket_name, blob_name) logger.info( f"[CloudStorage] GCS download successful - size: {len(content)} bytes" ) - - # Clean up - await async_client.close() - await session.close() - return content - + except FileNotFoundError: + raise except Exception as e: - # Always try to clean up - if async_client is not None: - try: - await async_client.close() - except Exception as cleanup_error: - logger.warning( - f"[CloudStorage] Error closing GCS client: {cleanup_error}" - ) - try: - await session.close() - except Exception as cleanup_error: - logger.warning(f"[CloudStorage] Error closing session: {cleanup_error}") - # Log the specific error for debugging logger.error( f"[CloudStorage] GCS download failed - error: {str(e)}, " @@ -319,10 +289,6 @@ class CloudStorageHandler: f"current_task: {current_task}, " f"bucket: {bucket_name}, blob: redacted for privacy" ) - - # Convert gcloud-aio exceptions to standard ones - if "404" in str(e) or "Not Found" in str(e): - raise FileNotFoundError(f"File not found: gcs://{path}") raise def _validate_file_access( @@ -445,8 +411,7 @@ class CloudStorageHandler: graph_exec_id: str | None = None, ) -> str: """Generate signed URL for GCS with authorization.""" - - # Parse bucket and blob name from path + # Parse bucket and blob name from path (path already has gcs:// prefix removed) parts = path.split("/", 1) if len(parts) != 2: raise ValueError(f"Invalid GCS path: {path}") @@ -456,21 +421,11 @@ class CloudStorageHandler: # Authorization check self._validate_file_access(blob_name, user_id, graph_exec_id) - # Use sync client for signed URLs since gcloud-aio doesn't support them sync_client = self._get_sync_gcs_client() - bucket = sync_client.bucket(bucket_name) - blob = bucket.blob(blob_name) - - # Generate signed URL asynchronously using sync client - url = await asyncio.to_thread( - blob.generate_signed_url, - version="v4", - expiration=datetime.now(timezone.utc) + timedelta(hours=expiration_hours), - method="GET", + return await generate_signed_url( + sync_client, bucket_name, blob_name, expiration_hours * 3600 ) - return url - async def delete_expired_files(self, provider: str = "gcs") -> int: """ Delete files that have passed their expiration time. diff --git a/autogpt_platform/backend/backend/util/exceptions.py b/autogpt_platform/backend/backend/util/exceptions.py index 6d0192c0e5..ffda783873 100644 --- a/autogpt_platform/backend/backend/util/exceptions.py +++ b/autogpt_platform/backend/backend/util/exceptions.py @@ -135,6 +135,12 @@ class GraphValidationError(ValueError): ) +class InvalidInputError(ValueError): + """Raised when user input validation fails (e.g., search term too long)""" + + pass + + class DatabaseError(Exception): """Raised when there is an error interacting with the database""" diff --git a/autogpt_platform/backend/backend/util/file.py b/autogpt_platform/backend/backend/util/file.py index dc8f86ea41..baa9225629 100644 --- a/autogpt_platform/backend/backend/util/file.py +++ b/autogpt_platform/backend/backend/util/file.py @@ -5,13 +5,26 @@ import shutil import tempfile import uuid from pathlib import Path +from typing import TYPE_CHECKING, Literal from urllib.parse import urlparse from backend.util.cloud_storage import get_cloud_storage_handler from backend.util.request import Requests +from backend.util.settings import Config from backend.util.type import MediaFileType from backend.util.virus_scanner import scan_content_safe +if TYPE_CHECKING: + from backend.data.execution import ExecutionContext + +# Return format options for store_media_file +# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc. +# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs +# - "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs +MediaReturnFormat = Literal[ + "for_local_processing", "for_external_api", "for_block_output" +] + TEMP_DIR = Path(tempfile.gettempdir()).resolve() # Maximum filename length (conservative limit for most filesystems) @@ -67,42 +80,56 @@ def clean_exec_files(graph_exec_id: str, file: str = "") -> None: async def store_media_file( - graph_exec_id: str, file: MediaFileType, - user_id: str, - return_content: bool = False, + execution_context: "ExecutionContext", + *, + return_format: MediaReturnFormat, ) -> MediaFileType: """ - Safely handle 'file' (a data URI, a URL, or a local path relative to {temp}/exec_file/{exec_id}), - placing or verifying it under: + Safely handle 'file' (a data URI, a URL, a workspace:// reference, or a local path + relative to {temp}/exec_file/{exec_id}), placing or verifying it under: {tempdir}/exec_file/{exec_id}/... - If 'return_content=True', return a data URI (data:;base64,). - Otherwise, returns the file media path relative to the exec_id folder. + For each MediaFileType input: + - Data URI: decode and store locally + - URL: download and store locally + - workspace:// reference: read from workspace, store locally + - Local path: verify it exists in exec_file directory - For each MediaFileType type: - - Data URI: - -> decode and store in a new random file in that folder - - URL: - -> download and store in that folder - - Local path: - -> interpret as relative to that folder; verify it exists - (no copying, as it's presumably already there). - We realpath-check so no symlink or '..' can escape the folder. + Return format options: + - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc. + - "for_external_api": Returns data URI (base64) - use when sending to external APIs + - "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs - - :param graph_exec_id: The unique ID of the graph execution. - :param file: Data URI, URL, or local (relative) path. - :param return_content: If True, return a data URI of the file content. - If False, return the *relative* path inside the exec_id folder. - :return: The requested result: data URI or relative path of the media. + :param file: Data URI, URL, workspace://, or local (relative) path. + :param execution_context: ExecutionContext with user_id, graph_exec_id, workspace_id. + :param return_format: What to return: "for_local_processing", "for_external_api", or "for_block_output". + :return: The requested result based on return_format. """ + # Extract values from execution_context + graph_exec_id = execution_context.graph_exec_id + user_id = execution_context.user_id + + if not graph_exec_id: + raise ValueError("execution_context.graph_exec_id is required") + if not user_id: + raise ValueError("execution_context.user_id is required") + + # Create workspace_manager if we have workspace_id (with session scoping) + # Import here to avoid circular import (file.py → workspace.py → data → blocks → file.py) + from backend.util.workspace import WorkspaceManager + + workspace_manager: WorkspaceManager | None = None + if execution_context.workspace_id: + workspace_manager = WorkspaceManager( + user_id, execution_context.workspace_id, execution_context.session_id + ) # Build base path base_path = Path(get_exec_file_path(graph_exec_id, "")) base_path.mkdir(parents=True, exist_ok=True) # Security fix: Add disk space limits to prevent DoS - MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB per file + MAX_FILE_SIZE_BYTES = Config().max_file_size_mb * 1024 * 1024 MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory # Check total disk usage in base_path @@ -142,9 +169,57 @@ async def store_media_file( """ return str(absolute_path.relative_to(base)) - # Check if this is a cloud storage path + # Get cloud storage handler for checking cloud paths cloud_storage = await get_cloud_storage_handler() - if cloud_storage.is_cloud_path(file): + + # Track if the input came from workspace (don't re-save it) + is_from_workspace = file.startswith("workspace://") + + # Check if this is a workspace file reference + if is_from_workspace: + if workspace_manager is None: + raise ValueError( + "Workspace file reference requires workspace context. " + "This file type is only available in CoPilot sessions." + ) + + # Parse workspace reference + # workspace://abc123 - by file ID + # workspace:///path/to/file.txt - by virtual path + file_ref = file[12:] # Remove "workspace://" + + if file_ref.startswith("/"): + # Path reference + workspace_content = await workspace_manager.read_file(file_ref) + file_info = await workspace_manager.get_file_info_by_path(file_ref) + filename = sanitize_filename( + file_info.name if file_info else f"{uuid.uuid4()}.bin" + ) + else: + # ID reference + workspace_content = await workspace_manager.read_file_by_id(file_ref) + file_info = await workspace_manager.get_file_info(file_ref) + filename = sanitize_filename( + file_info.name if file_info else f"{uuid.uuid4()}.bin" + ) + + try: + target_path = _ensure_inside_base(base_path / filename, base_path) + except OSError as e: + raise ValueError(f"Invalid file path '{filename}': {e}") from e + + # Check file size limit + if len(workspace_content) > MAX_FILE_SIZE_BYTES: + raise ValueError( + f"File too large: {len(workspace_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes" + ) + + # Virus scan the workspace content before writing locally + await scan_content_safe(workspace_content, filename=filename) + target_path.write_bytes(workspace_content) + + # Check if this is a cloud storage path + elif cloud_storage.is_cloud_path(file): # Download from cloud storage and store locally cloud_content = await cloud_storage.retrieve_file( file, user_id=user_id, graph_exec_id=graph_exec_id @@ -159,9 +234,9 @@ async def store_media_file( raise ValueError(f"Invalid file path '{filename}': {e}") from e # Check file size limit - if len(cloud_content) > MAX_FILE_SIZE: + if len(cloud_content) > MAX_FILE_SIZE_BYTES: raise ValueError( - f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE} bytes" + f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes" ) # Virus scan the cloud content before writing locally @@ -189,9 +264,9 @@ async def store_media_file( content = base64.b64decode(b64_content) # Check file size limit - if len(content) > MAX_FILE_SIZE: + if len(content) > MAX_FILE_SIZE_BYTES: raise ValueError( - f"File too large: {len(content)} bytes > {MAX_FILE_SIZE} bytes" + f"File too large: {len(content)} bytes > {MAX_FILE_SIZE_BYTES} bytes" ) # Virus scan the base64 content before writing @@ -199,23 +274,31 @@ async def store_media_file( target_path.write_bytes(content) elif file.startswith(("http://", "https://")): - # URL + # URL - download first to get Content-Type header + resp = await Requests().get(file) + + # Check file size limit + if len(resp.content) > MAX_FILE_SIZE_BYTES: + raise ValueError( + f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE_BYTES} bytes" + ) + + # Extract filename from URL path parsed_url = urlparse(file) filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}") + + # If filename lacks extension, add one from Content-Type header + if "." not in filename: + content_type = resp.headers.get("Content-Type", "").split(";")[0].strip() + if content_type: + ext = _extension_from_mime(content_type) + filename = f"{filename}{ext}" + try: target_path = _ensure_inside_base(base_path / filename, base_path) except OSError as e: raise ValueError(f"Invalid file path '{filename}': {e}") from e - # Download and save - resp = await Requests().get(file) - - # Check file size limit - if len(resp.content) > MAX_FILE_SIZE: - raise ValueError( - f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE} bytes" - ) - # Virus scan the downloaded content before writing await scan_content_safe(resp.content, filename=filename) target_path.write_bytes(resp.content) @@ -230,12 +313,44 @@ async def store_media_file( if not target_path.is_file(): raise ValueError(f"Local file does not exist: {target_path}") - # Return result - if return_content: - return MediaFileType(_file_to_data_uri(target_path)) - else: + # Return based on requested format + if return_format == "for_local_processing": + # Use when processing files locally with tools like ffmpeg, MoviePy, PIL + # Returns: relative path in exec_file directory (e.g., "image.png") return MediaFileType(_strip_base_prefix(target_path, base_path)) + elif return_format == "for_external_api": + # Use when sending content to external APIs that need base64 + # Returns: data URI (e.g., "data:image/png;base64,iVBORw0...") + return MediaFileType(_file_to_data_uri(target_path)) + + elif return_format == "for_block_output": + # Use when returning output from a block to user/next block + # Returns: workspace:// ref (CoPilot) or data URI (graph execution) + if workspace_manager is None: + # No workspace available (graph execution without CoPilot) + # Fallback to data URI so the content can still be used/displayed + return MediaFileType(_file_to_data_uri(target_path)) + + # Don't re-save if input was already from workspace + if is_from_workspace: + # Return original workspace reference + return MediaFileType(file) + + # Save new content to workspace + content = target_path.read_bytes() + filename = target_path.name + + file_record = await workspace_manager.write_file( + content=content, + filename=filename, + overwrite=True, + ) + return MediaFileType(f"workspace://{file_record.id}") + + else: + raise ValueError(f"Invalid return_format: {return_format}") + def get_dir_size(path: Path) -> int: """Get total size of directory.""" diff --git a/autogpt_platform/backend/backend/util/file_test.py b/autogpt_platform/backend/backend/util/file_test.py index cd4fc69706..9fe672d155 100644 --- a/autogpt_platform/backend/backend/util/file_test.py +++ b/autogpt_platform/backend/backend/util/file_test.py @@ -7,10 +7,22 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from backend.data.execution import ExecutionContext from backend.util.file import store_media_file from backend.util.type import MediaFileType +def make_test_context( + graph_exec_id: str = "test-exec-123", + user_id: str = "test-user-123", +) -> ExecutionContext: + """Helper to create test ExecutionContext.""" + return ExecutionContext( + user_id=user_id, + graph_exec_id=graph_exec_id, + ) + + class TestFileCloudIntegration: """Test cases for cloud storage integration in file utilities.""" @@ -70,10 +82,9 @@ class TestFileCloudIntegration: mock_path_class.side_effect = path_constructor result = await store_media_file( - graph_exec_id, - MediaFileType(cloud_path), - "test-user-123", - return_content=False, + file=MediaFileType(cloud_path), + execution_context=make_test_context(graph_exec_id=graph_exec_id), + return_format="for_local_processing", ) # Verify cloud storage operations @@ -144,10 +155,9 @@ class TestFileCloudIntegration: mock_path_obj.name = "image.png" with patch("backend.util.file.Path", return_value=mock_path_obj): result = await store_media_file( - graph_exec_id, - MediaFileType(cloud_path), - "test-user-123", - return_content=True, + file=MediaFileType(cloud_path), + execution_context=make_test_context(graph_exec_id=graph_exec_id), + return_format="for_external_api", ) # Verify result is a data URI @@ -198,10 +208,9 @@ class TestFileCloudIntegration: mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt") await store_media_file( - graph_exec_id, - MediaFileType(data_uri), - "test-user-123", - return_content=False, + file=MediaFileType(data_uri), + execution_context=make_test_context(graph_exec_id=graph_exec_id), + return_format="for_local_processing", ) # Verify cloud handler was checked but not used for retrieval @@ -234,5 +243,7 @@ class TestFileCloudIntegration: FileNotFoundError, match="File not found in cloud storage" ): await store_media_file( - graph_exec_id, MediaFileType(cloud_path), "test-user-123" + file=MediaFileType(cloud_path), + execution_context=make_test_context(graph_exec_id=graph_exec_id), + return_format="for_local_processing", ) diff --git a/autogpt_platform/backend/backend/util/gcs_utils.py b/autogpt_platform/backend/backend/util/gcs_utils.py new file mode 100644 index 0000000000..3f91f21897 --- /dev/null +++ b/autogpt_platform/backend/backend/util/gcs_utils.py @@ -0,0 +1,108 @@ +""" +Shared GCS utilities for workspace and cloud storage backends. + +This module provides common functionality for working with Google Cloud Storage, +including path parsing, client management, and signed URL generation. +""" + +import asyncio +import logging +from datetime import datetime, timedelta, timezone + +import aiohttp +from gcloud.aio import storage as async_gcs_storage +from google.cloud import storage as gcs_storage + +logger = logging.getLogger(__name__) + + +def parse_gcs_path(path: str) -> tuple[str, str]: + """ + Parse a GCS path in the format 'gcs://bucket/blob' to (bucket, blob). + + Args: + path: GCS path string (e.g., "gcs://my-bucket/path/to/file") + + Returns: + Tuple of (bucket_name, blob_name) + + Raises: + ValueError: If the path format is invalid + """ + if not path.startswith("gcs://"): + raise ValueError(f"Invalid GCS path: {path}") + + path_without_prefix = path[6:] # Remove "gcs://" + parts = path_without_prefix.split("/", 1) + if len(parts) != 2: + raise ValueError(f"Invalid GCS path format: {path}") + + return parts[0], parts[1] + + +async def download_with_fresh_session(bucket: str, blob: str) -> bytes: + """ + Download file content using a fresh session. + + This approach avoids event loop issues that can occur when reusing + sessions across different async contexts (e.g., in executors). + + Args: + bucket: GCS bucket name + blob: Blob path within the bucket + + Returns: + File content as bytes + + Raises: + FileNotFoundError: If the file doesn't exist + """ + session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=10, force_close=True) + ) + client: async_gcs_storage.Storage | None = None + try: + client = async_gcs_storage.Storage(session=session) + content = await client.download(bucket, blob) + return content + except Exception as e: + if "404" in str(e) or "Not Found" in str(e): + raise FileNotFoundError(f"File not found: gcs://{bucket}/{blob}") + raise + finally: + if client: + try: + await client.close() + except Exception: + pass # Best-effort cleanup + await session.close() + + +async def generate_signed_url( + sync_client: gcs_storage.Client, + bucket_name: str, + blob_name: str, + expires_in: int, +) -> str: + """ + Generate a signed URL for temporary access to a GCS file. + + Uses asyncio.to_thread() to run the sync operation without blocking. + + Args: + sync_client: Sync GCS client with service account credentials + bucket_name: GCS bucket name + blob_name: Blob path within the bucket + expires_in: URL expiration time in seconds + + Returns: + Signed URL string + """ + bucket = sync_client.bucket(bucket_name) + blob = bucket.blob(blob_name) + return await asyncio.to_thread( + blob.generate_signed_url, + version="v4", + expiration=datetime.now(timezone.utc) + timedelta(seconds=expires_in), + method="GET", + ) diff --git a/autogpt_platform/backend/backend/util/settings.py b/autogpt_platform/backend/backend/util/settings.py index d3c3d041d0..50b7428160 100644 --- a/autogpt_platform/backend/backend/util/settings.py +++ b/autogpt_platform/backend/backend/util/settings.py @@ -263,6 +263,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings): description="The name of the Google Cloud Storage bucket for media files", ) + workspace_storage_dir: str = Field( + default="", + description="Local directory for workspace file storage when GCS is not configured. " + "If empty, defaults to {app_data}/workspaces. Used for self-hosted deployments.", + ) + reddit_user_agent: str = Field( default="web:AutoGPT:v0.6.0 (by /u/autogpt)", description="The user agent for the Reddit API", @@ -350,6 +356,19 @@ class Config(UpdateTrackingModel["Config"], BaseSettings): description="Whether to mark failed scans as clean or not", ) + agentgenerator_host: str = Field( + default="", + description="The host for the Agent Generator service (empty to use built-in)", + ) + agentgenerator_port: int = Field( + default=8000, + description="The port for the Agent Generator service", + ) + agentgenerator_timeout: int = Field( + default=600, + description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)", + ) + enable_example_blocks: bool = Field( default=False, description="Whether to enable example blocks in production", @@ -376,6 +395,13 @@ class Config(UpdateTrackingModel["Config"], BaseSettings): description="Maximum file size in MB for file uploads (1-1024 MB)", ) + max_file_size_mb: int = Field( + default=100, + ge=1, + le=1024, + description="Maximum file size in MB for workspace files (1-1024 MB)", + ) + # AutoMod configuration automod_enabled: bool = Field( default=False, @@ -667,6 +693,12 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings): default="https://cloud.langfuse.com", description="Langfuse host URL" ) + # PostHog analytics + posthog_api_key: str = Field(default="", description="PostHog API key") + posthog_host: str = Field( + default="https://eu.i.posthog.com", description="PostHog host URL" + ) + # Add more secret fields as needed model_config = SettingsConfigDict( env_file=".env", diff --git a/autogpt_platform/backend/backend/util/test.py b/autogpt_platform/backend/backend/util/test.py index 1e8244ff8e..23d7c24147 100644 --- a/autogpt_platform/backend/backend/util/test.py +++ b/autogpt_platform/backend/backend/util/test.py @@ -1,3 +1,4 @@ +import asyncio import inspect import logging import time @@ -58,6 +59,11 @@ class SpinTestServer: self.db_api.__exit__(exc_type, exc_val, exc_tb) self.notif_manager.__exit__(exc_type, exc_val, exc_tb) + # Give services time to fully shut down + # This prevents event loop issues where services haven't fully cleaned up + # before the next test starts + await asyncio.sleep(0.5) + def setup_dependency_overrides(self): # Override get_user_id for testing self.agent_server.set_test_dependency_overrides( @@ -134,14 +140,29 @@ async def execute_block_test(block: Block): setattr(block, mock_name, mock_obj) # Populate credentials argument(s) + # Generate IDs for execution context + graph_id = str(uuid.uuid4()) + node_id = str(uuid.uuid4()) + graph_exec_id = str(uuid.uuid4()) + node_exec_id = str(uuid.uuid4()) + user_id = str(uuid.uuid4()) + graph_version = 1 # Default version for tests + extra_exec_kwargs: dict = { - "graph_id": str(uuid.uuid4()), - "node_id": str(uuid.uuid4()), - "graph_exec_id": str(uuid.uuid4()), - "node_exec_id": str(uuid.uuid4()), - "user_id": str(uuid.uuid4()), - "graph_version": 1, # Default version for tests - "execution_context": ExecutionContext(), + "graph_id": graph_id, + "node_id": node_id, + "graph_exec_id": graph_exec_id, + "node_exec_id": node_exec_id, + "user_id": user_id, + "graph_version": graph_version, + "execution_context": ExecutionContext( + user_id=user_id, + graph_id=graph_id, + graph_exec_id=graph_exec_id, + graph_version=graph_version, + node_id=node_id, + node_exec_id=node_exec_id, + ), } input_model = cast(type[BlockSchema], block.input_schema) diff --git a/autogpt_platform/backend/backend/util/workspace.py b/autogpt_platform/backend/backend/util/workspace.py new file mode 100644 index 0000000000..a2f1a61b9e --- /dev/null +++ b/autogpt_platform/backend/backend/util/workspace.py @@ -0,0 +1,419 @@ +""" +WorkspaceManager for managing user workspace file operations. + +This module provides a high-level interface for workspace file operations, +combining the storage backend and database layer. +""" + +import logging +import mimetypes +import uuid +from typing import Optional + +from prisma.errors import UniqueViolationError +from prisma.models import UserWorkspaceFile + +from backend.data.workspace import ( + count_workspace_files, + create_workspace_file, + get_workspace_file, + get_workspace_file_by_path, + list_workspace_files, + soft_delete_workspace_file, +) +from backend.util.settings import Config +from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage + +logger = logging.getLogger(__name__) + + +class WorkspaceManager: + """ + Manages workspace file operations. + + Combines storage backend operations with database record management. + Supports session-scoped file segmentation where files are stored in + session-specific virtual paths: /sessions/{session_id}/{filename} + """ + + def __init__( + self, user_id: str, workspace_id: str, session_id: Optional[str] = None + ): + """ + Initialize WorkspaceManager. + + Args: + user_id: The user's ID + workspace_id: The workspace ID + session_id: Optional session ID for session-scoped file access + """ + self.user_id = user_id + self.workspace_id = workspace_id + self.session_id = session_id + # Session path prefix for file isolation + self.session_path = f"/sessions/{session_id}" if session_id else "" + + def _resolve_path(self, path: str) -> str: + """ + Resolve a path, defaulting to session folder if session_id is set. + + Cross-session access is allowed by explicitly using /sessions/other-session-id/... + + Args: + path: Virtual path (e.g., "/file.txt" or "/sessions/abc123/file.txt") + + Returns: + Resolved path with session prefix if applicable + """ + # If path explicitly references a session folder, use it as-is + if path.startswith("/sessions/"): + return path + + # If we have a session context, prepend session path + if self.session_path: + # Normalize the path + if not path.startswith("/"): + path = f"/{path}" + return f"{self.session_path}{path}" + + # No session context, use path as-is + return path if path.startswith("/") else f"/{path}" + + def _get_effective_path( + self, path: Optional[str], include_all_sessions: bool + ) -> Optional[str]: + """ + Get effective path for list/count operations based on session context. + + Args: + path: Optional path prefix to filter + include_all_sessions: If True, don't apply session scoping + + Returns: + Effective path prefix for database query + """ + if include_all_sessions: + # Normalize path to ensure leading slash (stored paths are normalized) + if path is not None and not path.startswith("/"): + return f"/{path}" + return path + elif path is not None: + # Resolve the provided path with session scoping + return self._resolve_path(path) + elif self.session_path: + # Default to session folder with trailing slash to prevent prefix collisions + # e.g., "/sessions/abc" should not match "/sessions/abc123" + return self.session_path.rstrip("/") + "/" + else: + # No session context, use path as-is + return path + + async def read_file(self, path: str) -> bytes: + """ + Read file from workspace by virtual path. + + When session_id is set, paths are resolved relative to the session folder + unless they explicitly reference /sessions/... + + Args: + path: Virtual path (e.g., "/documents/report.pdf") + + Returns: + File content as bytes + + Raises: + FileNotFoundError: If file doesn't exist + """ + resolved_path = self._resolve_path(path) + file = await get_workspace_file_by_path(self.workspace_id, resolved_path) + if file is None: + raise FileNotFoundError(f"File not found at path: {resolved_path}") + + storage = await get_workspace_storage() + return await storage.retrieve(file.storagePath) + + async def read_file_by_id(self, file_id: str) -> bytes: + """ + Read file from workspace by file ID. + + Args: + file_id: The file's ID + + Returns: + File content as bytes + + Raises: + FileNotFoundError: If file doesn't exist + """ + file = await get_workspace_file(file_id, self.workspace_id) + if file is None: + raise FileNotFoundError(f"File not found: {file_id}") + + storage = await get_workspace_storage() + return await storage.retrieve(file.storagePath) + + async def write_file( + self, + content: bytes, + filename: str, + path: Optional[str] = None, + mime_type: Optional[str] = None, + overwrite: bool = False, + ) -> UserWorkspaceFile: + """ + Write file to workspace. + + When session_id is set, files are written to /sessions/{session_id}/... + by default. Use explicit /sessions/... paths for cross-session access. + + Args: + content: File content as bytes + filename: Filename for the file + path: Virtual path (defaults to "/{filename}", session-scoped if session_id set) + mime_type: MIME type (auto-detected if not provided) + overwrite: Whether to overwrite existing file at path + + Returns: + Created UserWorkspaceFile instance + + Raises: + ValueError: If file exceeds size limit or path already exists + """ + # Enforce file size limit + max_file_size = Config().max_file_size_mb * 1024 * 1024 + if len(content) > max_file_size: + raise ValueError( + f"File too large: {len(content)} bytes exceeds " + f"{Config().max_file_size_mb}MB limit" + ) + + # Determine path with session scoping + if path is None: + path = f"/{filename}" + elif not path.startswith("/"): + path = f"/{path}" + + # Resolve path with session prefix + path = self._resolve_path(path) + + # Check if file exists at path (only error for non-overwrite case) + # For overwrite=True, we let the write proceed and handle via UniqueViolationError + # This ensures the new file is written to storage BEFORE the old one is deleted, + # preventing data loss if the new write fails + if not overwrite: + existing = await get_workspace_file_by_path(self.workspace_id, path) + if existing is not None: + raise ValueError(f"File already exists at path: {path}") + + # Auto-detect MIME type if not provided + if mime_type is None: + mime_type, _ = mimetypes.guess_type(filename) + mime_type = mime_type or "application/octet-stream" + + # Compute checksum + checksum = compute_file_checksum(content) + + # Generate unique file ID for storage + file_id = str(uuid.uuid4()) + + # Store file in storage backend + storage = await get_workspace_storage() + storage_path = await storage.store( + workspace_id=self.workspace_id, + file_id=file_id, + filename=filename, + content=content, + ) + + # Create database record - handle race condition where another request + # created a file at the same path between our check and create + try: + file = await create_workspace_file( + workspace_id=self.workspace_id, + file_id=file_id, + name=filename, + path=path, + storage_path=storage_path, + mime_type=mime_type, + size_bytes=len(content), + checksum=checksum, + ) + except UniqueViolationError: + # Race condition: another request created a file at this path + if overwrite: + # Re-fetch and delete the conflicting file, then retry + existing = await get_workspace_file_by_path(self.workspace_id, path) + if existing: + await self.delete_file(existing.id) + # Retry the create - if this also fails, clean up storage file + try: + file = await create_workspace_file( + workspace_id=self.workspace_id, + file_id=file_id, + name=filename, + path=path, + storage_path=storage_path, + mime_type=mime_type, + size_bytes=len(content), + checksum=checksum, + ) + except Exception: + # Clean up orphaned storage file on retry failure + try: + await storage.delete(storage_path) + except Exception as e: + logger.warning(f"Failed to clean up orphaned storage file: {e}") + raise + else: + # Clean up the orphaned storage file before raising + try: + await storage.delete(storage_path) + except Exception as e: + logger.warning(f"Failed to clean up orphaned storage file: {e}") + raise ValueError(f"File already exists at path: {path}") + except Exception: + # Any other database error (connection, validation, etc.) - clean up storage + try: + await storage.delete(storage_path) + except Exception as e: + logger.warning(f"Failed to clean up orphaned storage file: {e}") + raise + + logger.info( + f"Wrote file {file.id} ({filename}) to workspace {self.workspace_id} " + f"at path {path}, size={len(content)} bytes" + ) + + return file + + async def list_files( + self, + path: Optional[str] = None, + limit: Optional[int] = None, + offset: int = 0, + include_all_sessions: bool = False, + ) -> list[UserWorkspaceFile]: + """ + List files in workspace. + + When session_id is set and include_all_sessions is False (default), + only files in the current session's folder are listed. + + Args: + path: Optional path prefix to filter (e.g., "/documents/") + limit: Maximum number of files to return + offset: Number of files to skip + include_all_sessions: If True, list files from all sessions. + If False (default), only list current session's files. + + Returns: + List of UserWorkspaceFile instances + """ + effective_path = self._get_effective_path(path, include_all_sessions) + + return await list_workspace_files( + workspace_id=self.workspace_id, + path_prefix=effective_path, + limit=limit, + offset=offset, + ) + + async def delete_file(self, file_id: str) -> bool: + """ + Delete a file (soft-delete). + + Args: + file_id: The file's ID + + Returns: + True if deleted, False if not found + """ + file = await get_workspace_file(file_id, self.workspace_id) + if file is None: + return False + + # Delete from storage + storage = await get_workspace_storage() + try: + await storage.delete(file.storagePath) + except Exception as e: + logger.warning(f"Failed to delete file from storage: {e}") + # Continue with database soft-delete even if storage delete fails + + # Soft-delete database record + result = await soft_delete_workspace_file(file_id, self.workspace_id) + return result is not None + + async def get_download_url(self, file_id: str, expires_in: int = 3600) -> str: + """ + Get download URL for a file. + + Args: + file_id: The file's ID + expires_in: URL expiration in seconds (default 1 hour) + + Returns: + Download URL (signed URL for GCS, API endpoint for local) + + Raises: + FileNotFoundError: If file doesn't exist + """ + file = await get_workspace_file(file_id, self.workspace_id) + if file is None: + raise FileNotFoundError(f"File not found: {file_id}") + + storage = await get_workspace_storage() + return await storage.get_download_url(file.storagePath, expires_in) + + async def get_file_info(self, file_id: str) -> Optional[UserWorkspaceFile]: + """ + Get file metadata. + + Args: + file_id: The file's ID + + Returns: + UserWorkspaceFile instance or None + """ + return await get_workspace_file(file_id, self.workspace_id) + + async def get_file_info_by_path(self, path: str) -> Optional[UserWorkspaceFile]: + """ + Get file metadata by path. + + When session_id is set, paths are resolved relative to the session folder + unless they explicitly reference /sessions/... + + Args: + path: Virtual path + + Returns: + UserWorkspaceFile instance or None + """ + resolved_path = self._resolve_path(path) + return await get_workspace_file_by_path(self.workspace_id, resolved_path) + + async def get_file_count( + self, + path: Optional[str] = None, + include_all_sessions: bool = False, + ) -> int: + """ + Get number of files in workspace. + + When session_id is set and include_all_sessions is False (default), + only counts files in the current session's folder. + + Args: + path: Optional path prefix to filter (e.g., "/documents/") + include_all_sessions: If True, count all files in workspace. + If False (default), only count current session's files. + + Returns: + Number of files + """ + effective_path = self._get_effective_path(path, include_all_sessions) + + return await count_workspace_files( + self.workspace_id, path_prefix=effective_path + ) diff --git a/autogpt_platform/backend/backend/util/workspace_storage.py b/autogpt_platform/backend/backend/util/workspace_storage.py new file mode 100644 index 0000000000..2f4c8ae2b5 --- /dev/null +++ b/autogpt_platform/backend/backend/util/workspace_storage.py @@ -0,0 +1,398 @@ +""" +Workspace storage backend abstraction for supporting both cloud and local deployments. + +This module provides a unified interface for storing workspace files, with implementations +for Google Cloud Storage (cloud deployments) and local filesystem (self-hosted deployments). +""" + +import asyncio +import hashlib +import logging +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional + +import aiofiles +import aiohttp +from gcloud.aio import storage as async_gcs_storage +from google.cloud import storage as gcs_storage + +from backend.util.data import get_data_path +from backend.util.gcs_utils import ( + download_with_fresh_session, + generate_signed_url, + parse_gcs_path, +) +from backend.util.settings import Config + +logger = logging.getLogger(__name__) + + +class WorkspaceStorageBackend(ABC): + """Abstract interface for workspace file storage.""" + + @abstractmethod + async def store( + self, + workspace_id: str, + file_id: str, + filename: str, + content: bytes, + ) -> str: + """ + Store file content, return storage path. + + Args: + workspace_id: The workspace ID + file_id: Unique file ID for storage + filename: Original filename + content: File content as bytes + + Returns: + Storage path string (cloud path or local path) + """ + pass + + @abstractmethod + async def retrieve(self, storage_path: str) -> bytes: + """ + Retrieve file content from storage. + + Args: + storage_path: The storage path returned from store() + + Returns: + File content as bytes + """ + pass + + @abstractmethod + async def delete(self, storage_path: str) -> None: + """ + Delete file from storage. + + Args: + storage_path: The storage path to delete + """ + pass + + @abstractmethod + async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str: + """ + Get URL for downloading the file. + + Args: + storage_path: The storage path + expires_in: URL expiration time in seconds (default 1 hour) + + Returns: + Download URL (signed URL for GCS, direct API path for local) + """ + pass + + +class GCSWorkspaceStorage(WorkspaceStorageBackend): + """Google Cloud Storage implementation for workspace storage.""" + + def __init__(self, bucket_name: str): + self.bucket_name = bucket_name + self._async_client: Optional[async_gcs_storage.Storage] = None + self._sync_client: Optional[gcs_storage.Client] = None + self._session: Optional[aiohttp.ClientSession] = None + + async def _get_async_client(self) -> async_gcs_storage.Storage: + """Get or create async GCS client.""" + if self._async_client is None: + self._session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=100, force_close=False) + ) + self._async_client = async_gcs_storage.Storage(session=self._session) + return self._async_client + + def _get_sync_client(self) -> gcs_storage.Client: + """Get or create sync GCS client (for signed URLs).""" + if self._sync_client is None: + self._sync_client = gcs_storage.Client() + return self._sync_client + + async def close(self) -> None: + """Close all client connections.""" + if self._async_client is not None: + try: + await self._async_client.close() + except Exception as e: + logger.warning(f"Error closing GCS client: {e}") + self._async_client = None + + if self._session is not None: + try: + await self._session.close() + except Exception as e: + logger.warning(f"Error closing session: {e}") + self._session = None + + def _build_blob_name(self, workspace_id: str, file_id: str, filename: str) -> str: + """Build the blob path for workspace files.""" + return f"workspaces/{workspace_id}/{file_id}/{filename}" + + async def store( + self, + workspace_id: str, + file_id: str, + filename: str, + content: bytes, + ) -> str: + """Store file in GCS.""" + client = await self._get_async_client() + blob_name = self._build_blob_name(workspace_id, file_id, filename) + + # Upload with metadata + upload_time = datetime.now(timezone.utc) + await client.upload( + self.bucket_name, + blob_name, + content, + metadata={ + "uploaded_at": upload_time.isoformat(), + "workspace_id": workspace_id, + "file_id": file_id, + }, + ) + + return f"gcs://{self.bucket_name}/{blob_name}" + + async def retrieve(self, storage_path: str) -> bytes: + """Retrieve file from GCS.""" + bucket_name, blob_name = parse_gcs_path(storage_path) + return await download_with_fresh_session(bucket_name, blob_name) + + async def delete(self, storage_path: str) -> None: + """Delete file from GCS.""" + bucket_name, blob_name = parse_gcs_path(storage_path) + client = await self._get_async_client() + + try: + await client.delete(bucket_name, blob_name) + except Exception as e: + if "404" not in str(e) and "Not Found" not in str(e): + raise + # File already deleted, that's fine + + async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str: + """ + Generate download URL for GCS file. + + Attempts to generate a signed URL if running with service account credentials. + Falls back to an API proxy endpoint if signed URL generation fails + (e.g., when running locally with user OAuth credentials). + """ + bucket_name, blob_name = parse_gcs_path(storage_path) + + # Extract file_id from blob_name for fallback: workspaces/{workspace_id}/{file_id}/{filename} + blob_parts = blob_name.split("/") + file_id = blob_parts[2] if len(blob_parts) >= 3 else None + + # Try to generate signed URL (requires service account credentials) + try: + sync_client = self._get_sync_client() + return await generate_signed_url( + sync_client, bucket_name, blob_name, expires_in + ) + except AttributeError as e: + # Signed URL generation requires service account with private key. + # When running with user OAuth credentials, fall back to API proxy. + if "private key" in str(e) and file_id: + logger.debug( + "Cannot generate signed URL (no service account credentials), " + "falling back to API proxy endpoint" + ) + return f"/api/workspace/files/{file_id}/download" + raise + + +class LocalWorkspaceStorage(WorkspaceStorageBackend): + """Local filesystem implementation for workspace storage (self-hosted deployments).""" + + def __init__(self, base_dir: Optional[str] = None): + """ + Initialize local storage backend. + + Args: + base_dir: Base directory for workspace storage. + If None, defaults to {app_data}/workspaces + """ + if base_dir: + self.base_dir = Path(base_dir) + else: + self.base_dir = Path(get_data_path()) / "workspaces" + + # Ensure base directory exists + self.base_dir.mkdir(parents=True, exist_ok=True) + + def _build_file_path(self, workspace_id: str, file_id: str, filename: str) -> Path: + """Build the local file path with path traversal protection.""" + # Import here to avoid circular import + # (file.py imports workspace.py which imports workspace_storage.py) + from backend.util.file import sanitize_filename + + # Sanitize filename to prevent path traversal (removes / and \ among others) + safe_filename = sanitize_filename(filename) + file_path = (self.base_dir / workspace_id / file_id / safe_filename).resolve() + + # Verify the resolved path is still under base_dir + if not file_path.is_relative_to(self.base_dir.resolve()): + raise ValueError("Invalid filename: path traversal detected") + + return file_path + + def _parse_storage_path(self, storage_path: str) -> Path: + """Parse local storage path to filesystem path.""" + if storage_path.startswith("local://"): + relative_path = storage_path[8:] # Remove "local://" + else: + relative_path = storage_path + + full_path = (self.base_dir / relative_path).resolve() + + # Security check: ensure path is under base_dir + # Use is_relative_to() for robust path containment check + # (handles case-insensitive filesystems and edge cases) + if not full_path.is_relative_to(self.base_dir.resolve()): + raise ValueError("Invalid storage path: path traversal detected") + + return full_path + + async def store( + self, + workspace_id: str, + file_id: str, + filename: str, + content: bytes, + ) -> str: + """Store file locally.""" + file_path = self._build_file_path(workspace_id, file_id, filename) + + # Create parent directories + file_path.parent.mkdir(parents=True, exist_ok=True) + + # Write file asynchronously + async with aiofiles.open(file_path, "wb") as f: + await f.write(content) + + # Return relative path as storage path + relative_path = file_path.relative_to(self.base_dir) + return f"local://{relative_path}" + + async def retrieve(self, storage_path: str) -> bytes: + """Retrieve file from local storage.""" + file_path = self._parse_storage_path(storage_path) + + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {storage_path}") + + async with aiofiles.open(file_path, "rb") as f: + return await f.read() + + async def delete(self, storage_path: str) -> None: + """Delete file from local storage.""" + file_path = self._parse_storage_path(storage_path) + + if file_path.exists(): + # Remove file + file_path.unlink() + + # Clean up empty parent directories + parent = file_path.parent + while parent != self.base_dir: + try: + if parent.exists() and not any(parent.iterdir()): + parent.rmdir() + else: + break + except OSError: + break + parent = parent.parent + + async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str: + """ + Get download URL for local file. + + For local storage, this returns an API endpoint path. + The actual serving is handled by the API layer. + """ + # Parse the storage path to get the components + if storage_path.startswith("local://"): + relative_path = storage_path[8:] + else: + relative_path = storage_path + + # Return the API endpoint for downloading + # The file_id is extracted from the path: {workspace_id}/{file_id}/{filename} + parts = relative_path.split("/") + if len(parts) >= 2: + file_id = parts[1] # Second component is file_id + return f"/api/workspace/files/{file_id}/download" + else: + raise ValueError(f"Invalid storage path format: {storage_path}") + + +# Global storage backend instance +_workspace_storage: Optional[WorkspaceStorageBackend] = None +_storage_lock = asyncio.Lock() + + +async def get_workspace_storage() -> WorkspaceStorageBackend: + """ + Get the workspace storage backend instance. + + Uses GCS if media_gcs_bucket_name is configured, otherwise uses local storage. + """ + global _workspace_storage + + if _workspace_storage is None: + async with _storage_lock: + if _workspace_storage is None: + config = Config() + + if config.media_gcs_bucket_name: + logger.info( + f"Using GCS workspace storage: {config.media_gcs_bucket_name}" + ) + _workspace_storage = GCSWorkspaceStorage( + config.media_gcs_bucket_name + ) + else: + storage_dir = ( + config.workspace_storage_dir + if config.workspace_storage_dir + else None + ) + logger.info( + f"Using local workspace storage: {storage_dir or 'default'}" + ) + _workspace_storage = LocalWorkspaceStorage(storage_dir) + + return _workspace_storage + + +async def shutdown_workspace_storage() -> None: + """ + Properly shutdown the global workspace storage backend. + + Closes aiohttp sessions and other resources for GCS backend. + Should be called during application shutdown. + """ + global _workspace_storage + + if _workspace_storage is not None: + async with _storage_lock: + if _workspace_storage is not None: + if isinstance(_workspace_storage, GCSWorkspaceStorage): + await _workspace_storage.close() + _workspace_storage = None + + +def compute_file_checksum(content: bytes) -> str: + """Compute SHA256 checksum of file content.""" + return hashlib.sha256(content).hexdigest() diff --git a/autogpt_platform/backend/migrations/20260121200000_remove_node_execution_fk_from_pending_human_review/migration.sql b/autogpt_platform/backend/migrations/20260121200000_remove_node_execution_fk_from_pending_human_review/migration.sql new file mode 100644 index 0000000000..c43cb0b1e0 --- /dev/null +++ b/autogpt_platform/backend/migrations/20260121200000_remove_node_execution_fk_from_pending_human_review/migration.sql @@ -0,0 +1,7 @@ +-- Remove NodeExecution foreign key from PendingHumanReview +-- The nodeExecId column remains as the primary key, but we remove the FK constraint +-- to AgentNodeExecution since PendingHumanReview records can persist after node +-- execution records are deleted. + +-- Drop foreign key constraint that linked PendingHumanReview.nodeExecId to AgentNodeExecution.id +ALTER TABLE "PendingHumanReview" DROP CONSTRAINT IF EXISTS "PendingHumanReview_nodeExecId_fkey"; diff --git a/autogpt_platform/backend/migrations/20260127211502_add_visit_copilot_onboarding_step/migration.sql b/autogpt_platform/backend/migrations/20260127211502_add_visit_copilot_onboarding_step/migration.sql new file mode 100644 index 0000000000..6a08d9231b --- /dev/null +++ b/autogpt_platform/backend/migrations/20260127211502_add_visit_copilot_onboarding_step/migration.sql @@ -0,0 +1,2 @@ +-- AlterEnum +ALTER TYPE "OnboardingStep" ADD VALUE 'VISIT_COPILOT'; diff --git a/autogpt_platform/backend/migrations/20260127230419_add_user_workspace/migration.sql b/autogpt_platform/backend/migrations/20260127230419_add_user_workspace/migration.sql new file mode 100644 index 0000000000..bb63dccb33 --- /dev/null +++ b/autogpt_platform/backend/migrations/20260127230419_add_user_workspace/migration.sql @@ -0,0 +1,52 @@ +-- CreateEnum +CREATE TYPE "WorkspaceFileSource" AS ENUM ('UPLOAD', 'EXECUTION', 'COPILOT', 'IMPORT'); + +-- CreateTable +CREATE TABLE "UserWorkspace" ( + "id" TEXT NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + "userId" TEXT NOT NULL, + + CONSTRAINT "UserWorkspace_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "UserWorkspaceFile" ( + "id" TEXT NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + "workspaceId" TEXT NOT NULL, + "name" TEXT NOT NULL, + "path" TEXT NOT NULL, + "storagePath" TEXT NOT NULL, + "mimeType" TEXT NOT NULL, + "sizeBytes" BIGINT NOT NULL, + "checksum" TEXT, + "isDeleted" BOOLEAN NOT NULL DEFAULT false, + "deletedAt" TIMESTAMP(3), + "source" "WorkspaceFileSource" NOT NULL DEFAULT 'UPLOAD', + "sourceExecId" TEXT, + "sourceSessionId" TEXT, + "metadata" JSONB NOT NULL DEFAULT '{}', + + CONSTRAINT "UserWorkspaceFile_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "UserWorkspace_userId_key" ON "UserWorkspace"("userId"); + +-- CreateIndex +CREATE INDEX "UserWorkspace_userId_idx" ON "UserWorkspace"("userId"); + +-- CreateIndex +CREATE INDEX "UserWorkspaceFile_workspaceId_isDeleted_idx" ON "UserWorkspaceFile"("workspaceId", "isDeleted"); + +-- CreateIndex +CREATE UNIQUE INDEX "UserWorkspaceFile_workspaceId_path_key" ON "UserWorkspaceFile"("workspaceId", "path"); + +-- AddForeignKey +ALTER TABLE "UserWorkspace" ADD CONSTRAINT "UserWorkspace_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "UserWorkspaceFile" ADD CONSTRAINT "UserWorkspaceFile_workspaceId_fkey" FOREIGN KEY ("workspaceId") REFERENCES "UserWorkspace"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/autogpt_platform/backend/migrations/20260129011611_remove_workspace_file_source/migration.sql b/autogpt_platform/backend/migrations/20260129011611_remove_workspace_file_source/migration.sql new file mode 100644 index 0000000000..2709bc8484 --- /dev/null +++ b/autogpt_platform/backend/migrations/20260129011611_remove_workspace_file_source/migration.sql @@ -0,0 +1,16 @@ +/* + Warnings: + + - You are about to drop the column `source` on the `UserWorkspaceFile` table. All the data in the column will be lost. + - You are about to drop the column `sourceExecId` on the `UserWorkspaceFile` table. All the data in the column will be lost. + - You are about to drop the column `sourceSessionId` on the `UserWorkspaceFile` table. All the data in the column will be lost. + +*/ + +-- AlterTable +ALTER TABLE "UserWorkspaceFile" DROP COLUMN "source", +DROP COLUMN "sourceExecId", +DROP COLUMN "sourceSessionId"; + +-- DropEnum +DROP TYPE "WorkspaceFileSource"; diff --git a/autogpt_platform/backend/poetry.lock b/autogpt_platform/backend/poetry.lock index 45d445b609..91ac358ade 100644 --- a/autogpt_platform/backend/poetry.lock +++ b/autogpt_platform/backend/poetry.lock @@ -1169,29 +1169,6 @@ attrs = ">=21.3.0" e2b = ">=1.5.4,<2.0.0" httpx = ">=0.20.0,<1.0.0" -[[package]] -name = "elevenlabs" -version = "1.59.0" -description = "" -optional = false -python-versions = "<4.0,>=3.8" -groups = ["main"] -files = [ - {file = "elevenlabs-1.59.0-py3-none-any.whl", hash = "sha256:468145db81a0bc867708b4a8619699f75583e9481b395ec1339d0b443da771ed"}, - {file = "elevenlabs-1.59.0.tar.gz", hash = "sha256:16e735bd594e86d415dd445d249c8cc28b09996cfd627fbc10102c0a84698859"}, -] - -[package.dependencies] -httpx = ">=0.21.2" -pydantic = ">=1.9.2" -pydantic-core = ">=2.18.2,<3.0.0" -requests = ">=2.20" -typing_extensions = ">=4.0.0" -websockets = ">=11.0" - -[package.extras] -pyaudio = ["pyaudio (>=0.2.14)"] - [[package]] name = "email-validator" version = "2.2.0" @@ -4227,14 +4204,14 @@ strenum = {version = ">=0.4.9,<0.5.0", markers = "python_version < \"3.11\""} [[package]] name = "posthog" -version = "6.1.1" +version = "7.6.0" description = "Integrate PostHog into any python application." optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" groups = ["main"] files = [ - {file = "posthog-6.1.1-py3-none-any.whl", hash = "sha256:329fd3d06b4d54cec925f47235bd8e327c91403c2f9ec38f1deb849535934dba"}, - {file = "posthog-6.1.1.tar.gz", hash = "sha256:b453f54c4a2589da859fd575dd3bf86fcb40580727ec399535f268b1b9f318b8"}, + {file = "posthog-7.6.0-py3-none-any.whl", hash = "sha256:c4dd78cf77c4fecceb965f86066e5ac37886ef867d68ffe75a1db5d681d7d9ad"}, + {file = "posthog-7.6.0.tar.gz", hash = "sha256:941dfd278ee427c9b14640f09b35b5bb52a71bdf028d7dbb7307e1838fd3002e"}, ] [package.dependencies] @@ -4248,7 +4225,7 @@ typing-extensions = ">=4.2.0" [package.extras] dev = ["django-stubs", "lxml", "mypy", "mypy-baseline", "packaging", "pre-commit", "pydantic", "ruff", "setuptools", "tomli", "tomli_w", "twine", "types-mock", "types-python-dateutil", "types-requests", "types-setuptools", "types-six", "wheel"] langchain = ["langchain (>=0.2.0)"] -test = ["anthropic", "coverage", "django", "freezegun (==1.5.1)", "google-genai", "langchain-anthropic (>=0.3.15)", "langchain-community (>=0.3.25)", "langchain-core (>=0.3.65)", "langchain-openai (>=0.3.22)", "langgraph (>=0.4.8)", "mock (>=2.0.0)", "openai", "parameterized (>=0.8.1)", "pydantic", "pytest", "pytest-asyncio", "pytest-timeout"] +test = ["anthropic (>=0.72)", "coverage", "django", "freezegun (==1.5.1)", "google-genai", "langchain-anthropic (>=1.0)", "langchain-community (>=0.4)", "langchain-core (>=1.0)", "langchain-openai (>=1.0)", "langgraph (>=1.0)", "mock (>=2.0.0)", "openai (>=2.0)", "parameterized (>=0.8.1)", "pydantic", "pytest", "pytest-asyncio", "pytest-timeout"] [[package]] name = "postmarker" @@ -7384,28 +7361,6 @@ files = [ defusedxml = ">=0.7.1,<0.8.0" requests = "*" -[[package]] -name = "yt-dlp" -version = "2025.12.8" -description = "A feature-rich command-line audio/video downloader" -optional = false -python-versions = ">=3.10" -groups = ["main"] -files = [ - {file = "yt_dlp-2025.12.8-py3-none-any.whl", hash = "sha256:36e2584342e409cfbfa0b5e61448a1c5189e345cf4564294456ee509e7d3e065"}, - {file = "yt_dlp-2025.12.8.tar.gz", hash = "sha256:b773c81bb6b71cb2c111cfb859f453c7a71cf2ef44eff234ff155877184c3e4f"}, -] - -[package.extras] -build = ["build", "hatchling (>=1.27.0)", "pip", "setuptools (>=71.0.2)", "wheel"] -curl-cffi = ["curl-cffi (>=0.5.10,<0.6.dev0 || >=0.10.dev0,<0.14) ; implementation_name == \"cpython\""] -default = ["brotli ; implementation_name == \"cpython\"", "brotlicffi ; implementation_name != \"cpython\"", "certifi", "mutagen", "pycryptodomex", "requests (>=2.32.2,<3)", "urllib3 (>=2.0.2,<3)", "websockets (>=13.0)", "yt-dlp-ejs (==0.3.2)"] -dev = ["autopep8 (>=2.0,<3.0)", "pre-commit", "pytest (>=8.1,<9.0)", "pytest-rerunfailures (>=14.0,<15.0)", "ruff (>=0.14.0,<0.15.0)"] -pyinstaller = ["pyinstaller (>=6.17.0)"] -secretstorage = ["cffi", "secretstorage"] -static-analysis = ["autopep8 (>=2.0,<3.0)", "ruff (>=0.14.0,<0.15.0)"] -test = ["pytest (>=8.1,<9.0)", "pytest-rerunfailures (>=14.0,<15.0)"] - [[package]] name = "zerobouncesdk" version = "1.1.2" @@ -7557,4 +7512,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "a82dc5db159eb332ef6ae27d392dc1dfdeb2b70ef3595482829e51fdb9e3ffe2" +content-hash = "ee5742dc1a9df50dfc06d4b26a1682cbb2b25cab6b79ce5625ec272f93e4f4bf" diff --git a/autogpt_platform/backend/pyproject.toml b/autogpt_platform/backend/pyproject.toml index a35d1660ce..24aea39f33 100644 --- a/autogpt_platform/backend/pyproject.toml +++ b/autogpt_platform/backend/pyproject.toml @@ -87,6 +87,7 @@ exa-py = "^1.14.20" croniter = "^6.0.0" stagehand = "^0.5.1" gravitas-md2gdocs = "^0.1.0" +posthog = "^7.6.0" [tool.poetry.group.dev.dependencies] aiohappyeyeballs = "^2.6.1" diff --git a/autogpt_platform/backend/schema.prisma b/autogpt_platform/backend/schema.prisma index 4a2a7b583a..2da898a7ce 100644 --- a/autogpt_platform/backend/schema.prisma +++ b/autogpt_platform/backend/schema.prisma @@ -63,6 +63,7 @@ model User { IntegrationWebhooks IntegrationWebhook[] NotificationBatches UserNotificationBatch[] PendingHumanReviews PendingHumanReview[] + Workspace UserWorkspace? // OAuth Provider relations OAuthApplications OAuthApplication[] @@ -81,6 +82,7 @@ enum OnboardingStep { AGENT_INPUT CONGRATS // First Wins + VISIT_COPILOT GET_RESULTS MARKETPLACE_VISIT MARKETPLACE_ADD_AGENT @@ -136,6 +138,53 @@ model CoPilotUnderstanding { @@index([userId]) } +//////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////// +//////////////// USER WORKSPACE TABLES ///////////////// +//////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////// + +// User's persistent file storage workspace +model UserWorkspace { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + userId String @unique + User User @relation(fields: [userId], references: [id], onDelete: Cascade) + + Files UserWorkspaceFile[] + + @@index([userId]) +} + +// Individual files in a user's workspace +model UserWorkspaceFile { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + workspaceId String + Workspace UserWorkspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade) + + // File metadata + name String // User-visible filename + path String // Virtual path (e.g., "/documents/report.pdf") + storagePath String // Actual GCS or local storage path + mimeType String + sizeBytes BigInt + checksum String? // SHA256 for integrity + + // File state + isDeleted Boolean @default(false) + deletedAt DateTime? + + metadata Json @default("{}") + + @@unique([workspaceId, path]) + @@index([workspaceId, isDeleted]) +} + model BuilderSearchHistory { id String @id @default(uuid()) createdAt DateTime @default(now()) @@ -517,8 +566,6 @@ model AgentNodeExecution { stats Json? - PendingHumanReview PendingHumanReview? - @@index([agentGraphExecutionId, agentNodeId, executionStatus]) @@index([agentNodeId, executionStatus]) @@index([addedTime, queuedTime]) @@ -567,6 +614,7 @@ enum ReviewStatus { } // Pending human reviews for Human-in-the-loop blocks +// Also stores auto-approval records with special nodeExecId patterns (e.g., "auto_approve_{graph_exec_id}_{node_id}") model PendingHumanReview { nodeExecId String @id userId String @@ -585,7 +633,6 @@ model PendingHumanReview { reviewedAt DateTime? User User @relation(fields: [userId], references: [id], onDelete: Cascade) - NodeExecution AgentNodeExecution @relation(fields: [nodeExecId], references: [id], onDelete: Cascade) GraphExecution AgentGraphExecution @relation(fields: [graphExecId], references: [id], onDelete: Cascade) @@unique([nodeExecId]) // One pending review per node execution diff --git a/autogpt_platform/backend/test/agent_generator/__init__.py b/autogpt_platform/backend/test/agent_generator/__init__.py new file mode 100644 index 0000000000..8fcde1fa0f --- /dev/null +++ b/autogpt_platform/backend/test/agent_generator/__init__.py @@ -0,0 +1 @@ +"""Tests for agent generator module.""" diff --git a/autogpt_platform/backend/test/agent_generator/test_core_integration.py b/autogpt_platform/backend/test/agent_generator/test_core_integration.py new file mode 100644 index 0000000000..bdcc24ba79 --- /dev/null +++ b/autogpt_platform/backend/test/agent_generator/test_core_integration.py @@ -0,0 +1,273 @@ +""" +Tests for the Agent Generator core module. + +This test suite verifies that the core functions correctly delegate to +the external Agent Generator service. +""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.api.features.chat.tools.agent_generator import core +from backend.api.features.chat.tools.agent_generator.core import ( + AgentGeneratorNotConfiguredError, +) + + +class TestServiceNotConfigured: + """Test that functions raise AgentGeneratorNotConfiguredError when service is not configured.""" + + @pytest.mark.asyncio + async def test_decompose_goal_raises_when_not_configured(self): + """Test that decompose_goal raises error when service not configured.""" + with patch.object(core, "is_external_service_configured", return_value=False): + with pytest.raises(AgentGeneratorNotConfiguredError): + await core.decompose_goal("Build a chatbot") + + @pytest.mark.asyncio + async def test_generate_agent_raises_when_not_configured(self): + """Test that generate_agent raises error when service not configured.""" + with patch.object(core, "is_external_service_configured", return_value=False): + with pytest.raises(AgentGeneratorNotConfiguredError): + await core.generate_agent({"steps": []}) + + @pytest.mark.asyncio + async def test_generate_agent_patch_raises_when_not_configured(self): + """Test that generate_agent_patch raises error when service not configured.""" + with patch.object(core, "is_external_service_configured", return_value=False): + with pytest.raises(AgentGeneratorNotConfiguredError): + await core.generate_agent_patch("Add a node", {"nodes": []}) + + +class TestDecomposeGoal: + """Test decompose_goal function service delegation.""" + + @pytest.mark.asyncio + async def test_calls_external_service(self): + """Test that decompose_goal calls the external service.""" + expected_result = {"type": "instructions", "steps": ["Step 1"]} + + with patch.object( + core, "is_external_service_configured", return_value=True + ), patch.object( + core, "decompose_goal_external", new_callable=AsyncMock + ) as mock_external: + mock_external.return_value = expected_result + + result = await core.decompose_goal("Build a chatbot") + + mock_external.assert_called_once_with("Build a chatbot", "") + assert result == expected_result + + @pytest.mark.asyncio + async def test_passes_context_to_external_service(self): + """Test that decompose_goal passes context to external service.""" + expected_result = {"type": "instructions", "steps": ["Step 1"]} + + with patch.object( + core, "is_external_service_configured", return_value=True + ), patch.object( + core, "decompose_goal_external", new_callable=AsyncMock + ) as mock_external: + mock_external.return_value = expected_result + + await core.decompose_goal("Build a chatbot", "Use Python") + + mock_external.assert_called_once_with("Build a chatbot", "Use Python") + + @pytest.mark.asyncio + async def test_returns_none_on_service_failure(self): + """Test that decompose_goal returns None when external service fails.""" + with patch.object( + core, "is_external_service_configured", return_value=True + ), patch.object( + core, "decompose_goal_external", new_callable=AsyncMock + ) as mock_external: + mock_external.return_value = None + + result = await core.decompose_goal("Build a chatbot") + + assert result is None + + +class TestGenerateAgent: + """Test generate_agent function service delegation.""" + + @pytest.mark.asyncio + async def test_calls_external_service(self): + """Test that generate_agent calls the external service.""" + expected_result = {"name": "Test Agent", "nodes": [], "links": []} + + with patch.object( + core, "is_external_service_configured", return_value=True + ), patch.object( + core, "generate_agent_external", new_callable=AsyncMock + ) as mock_external: + mock_external.return_value = expected_result + + instructions = {"type": "instructions", "steps": ["Step 1"]} + result = await core.generate_agent(instructions) + + mock_external.assert_called_once_with(instructions) + # Result should have id, version, is_active added if not present + assert result is not None + assert result["name"] == "Test Agent" + assert "id" in result + assert result["version"] == 1 + assert result["is_active"] is True + + @pytest.mark.asyncio + async def test_preserves_existing_id_and_version(self): + """Test that external service result preserves existing id and version.""" + expected_result = { + "id": "existing-id", + "version": 3, + "is_active": False, + "name": "Test Agent", + } + + with patch.object( + core, "is_external_service_configured", return_value=True + ), patch.object( + core, "generate_agent_external", new_callable=AsyncMock + ) as mock_external: + mock_external.return_value = expected_result.copy() + + result = await core.generate_agent({"steps": []}) + + assert result is not None + assert result["id"] == "existing-id" + assert result["version"] == 3 + assert result["is_active"] is False + + @pytest.mark.asyncio + async def test_returns_none_when_external_service_fails(self): + """Test that generate_agent returns None when external service fails.""" + with patch.object( + core, "is_external_service_configured", return_value=True + ), patch.object( + core, "generate_agent_external", new_callable=AsyncMock + ) as mock_external: + mock_external.return_value = None + + result = await core.generate_agent({"steps": []}) + + assert result is None + + +class TestGenerateAgentPatch: + """Test generate_agent_patch function service delegation.""" + + @pytest.mark.asyncio + async def test_calls_external_service(self): + """Test that generate_agent_patch calls the external service.""" + expected_result = {"name": "Updated Agent", "nodes": [], "links": []} + + with patch.object( + core, "is_external_service_configured", return_value=True + ), patch.object( + core, "generate_agent_patch_external", new_callable=AsyncMock + ) as mock_external: + mock_external.return_value = expected_result + + current_agent = {"nodes": [], "links": []} + result = await core.generate_agent_patch("Add a node", current_agent) + + mock_external.assert_called_once_with("Add a node", current_agent) + assert result == expected_result + + @pytest.mark.asyncio + async def test_returns_clarifying_questions(self): + """Test that generate_agent_patch returns clarifying questions.""" + expected_result = { + "type": "clarifying_questions", + "questions": [{"question": "What type of node?"}], + } + + with patch.object( + core, "is_external_service_configured", return_value=True + ), patch.object( + core, "generate_agent_patch_external", new_callable=AsyncMock + ) as mock_external: + mock_external.return_value = expected_result + + result = await core.generate_agent_patch("Add a node", {"nodes": []}) + + assert result == expected_result + + @pytest.mark.asyncio + async def test_returns_none_when_external_service_fails(self): + """Test that generate_agent_patch returns None when service fails.""" + with patch.object( + core, "is_external_service_configured", return_value=True + ), patch.object( + core, "generate_agent_patch_external", new_callable=AsyncMock + ) as mock_external: + mock_external.return_value = None + + result = await core.generate_agent_patch("Add a node", {"nodes": []}) + + assert result is None + + +class TestJsonToGraph: + """Test json_to_graph function.""" + + def test_converts_agent_json_to_graph(self): + """Test conversion of agent JSON to Graph model.""" + agent_json = { + "id": "test-id", + "version": 2, + "is_active": True, + "name": "Test Agent", + "description": "A test agent", + "nodes": [ + { + "id": "node1", + "block_id": "block1", + "input_default": {"key": "value"}, + "metadata": {"x": 100}, + } + ], + "links": [ + { + "id": "link1", + "source_id": "node1", + "sink_id": "output", + "source_name": "result", + "sink_name": "input", + "is_static": False, + } + ], + } + + graph = core.json_to_graph(agent_json) + + assert graph.id == "test-id" + assert graph.version == 2 + assert graph.is_active is True + assert graph.name == "Test Agent" + assert graph.description == "A test agent" + assert len(graph.nodes) == 1 + assert graph.nodes[0].id == "node1" + assert graph.nodes[0].block_id == "block1" + assert len(graph.links) == 1 + assert graph.links[0].source_id == "node1" + + def test_generates_ids_if_missing(self): + """Test that missing IDs are generated.""" + agent_json = { + "name": "Test Agent", + "nodes": [{"block_id": "block1"}], + "links": [], + } + + graph = core.json_to_graph(agent_json) + + assert graph.id is not None + assert graph.nodes[0].id is not None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/autogpt_platform/backend/test/agent_generator/test_service.py b/autogpt_platform/backend/test/agent_generator/test_service.py new file mode 100644 index 0000000000..fe7a1a7fdd --- /dev/null +++ b/autogpt_platform/backend/test/agent_generator/test_service.py @@ -0,0 +1,437 @@ +""" +Tests for the Agent Generator external service client. + +This test suite verifies the external Agent Generator service integration, +including service detection, API calls, and error handling. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from backend.api.features.chat.tools.agent_generator import service + + +class TestServiceConfiguration: + """Test service configuration detection.""" + + def setup_method(self): + """Reset settings singleton before each test.""" + service._settings = None + service._client = None + + def test_external_service_not_configured_when_host_empty(self): + """Test that external service is not configured when host is empty.""" + mock_settings = MagicMock() + mock_settings.config.agentgenerator_host = "" + + with patch.object(service, "_get_settings", return_value=mock_settings): + assert service.is_external_service_configured() is False + + def test_external_service_configured_when_host_set(self): + """Test that external service is configured when host is set.""" + mock_settings = MagicMock() + mock_settings.config.agentgenerator_host = "agent-generator.local" + + with patch.object(service, "_get_settings", return_value=mock_settings): + assert service.is_external_service_configured() is True + + def test_get_base_url(self): + """Test base URL construction.""" + mock_settings = MagicMock() + mock_settings.config.agentgenerator_host = "agent-generator.local" + mock_settings.config.agentgenerator_port = 8000 + + with patch.object(service, "_get_settings", return_value=mock_settings): + url = service._get_base_url() + assert url == "http://agent-generator.local:8000" + + +class TestDecomposeGoalExternal: + """Test decompose_goal_external function.""" + + def setup_method(self): + """Reset client singleton before each test.""" + service._settings = None + service._client = None + + @pytest.mark.asyncio + async def test_decompose_goal_returns_instructions(self): + """Test successful decomposition returning instructions.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "type": "instructions", + "steps": ["Step 1", "Step 2"], + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + with patch.object(service, "_get_client", return_value=mock_client): + result = await service.decompose_goal_external("Build a chatbot") + + assert result == {"type": "instructions", "steps": ["Step 1", "Step 2"]} + mock_client.post.assert_called_once_with( + "/api/decompose-description", json={"description": "Build a chatbot"} + ) + + @pytest.mark.asyncio + async def test_decompose_goal_returns_clarifying_questions(self): + """Test decomposition returning clarifying questions.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "type": "clarifying_questions", + "questions": ["What platform?", "What language?"], + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + with patch.object(service, "_get_client", return_value=mock_client): + result = await service.decompose_goal_external("Build something") + + assert result == { + "type": "clarifying_questions", + "questions": ["What platform?", "What language?"], + } + + @pytest.mark.asyncio + async def test_decompose_goal_with_context(self): + """Test decomposition with additional context.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "type": "instructions", + "steps": ["Step 1"], + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + with patch.object(service, "_get_client", return_value=mock_client): + await service.decompose_goal_external( + "Build a chatbot", context="Use Python" + ) + + mock_client.post.assert_called_once_with( + "/api/decompose-description", + json={"description": "Build a chatbot", "user_instruction": "Use Python"}, + ) + + @pytest.mark.asyncio + async def test_decompose_goal_returns_unachievable_goal(self): + """Test decomposition returning unachievable goal response.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "type": "unachievable_goal", + "reason": "Cannot do X", + "suggested_goal": "Try Y instead", + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + with patch.object(service, "_get_client", return_value=mock_client): + result = await service.decompose_goal_external("Do something impossible") + + assert result == { + "type": "unachievable_goal", + "reason": "Cannot do X", + "suggested_goal": "Try Y instead", + } + + @pytest.mark.asyncio + async def test_decompose_goal_handles_http_error(self): + """Test decomposition handles HTTP errors gracefully.""" + mock_response = MagicMock() + mock_response.status_code = 500 + mock_client = AsyncMock() + mock_client.post.side_effect = httpx.HTTPStatusError( + "Server error", request=MagicMock(), response=mock_response + ) + + with patch.object(service, "_get_client", return_value=mock_client): + result = await service.decompose_goal_external("Build a chatbot") + + assert result is not None + assert result.get("type") == "error" + assert result.get("error_type") == "http_error" + assert "Server error" in result.get("error", "") + + @pytest.mark.asyncio + async def test_decompose_goal_handles_request_error(self): + """Test decomposition handles request errors gracefully.""" + mock_client = AsyncMock() + mock_client.post.side_effect = httpx.RequestError("Connection failed") + + with patch.object(service, "_get_client", return_value=mock_client): + result = await service.decompose_goal_external("Build a chatbot") + + assert result is not None + assert result.get("type") == "error" + assert result.get("error_type") == "connection_error" + assert "Connection failed" in result.get("error", "") + + @pytest.mark.asyncio + async def test_decompose_goal_handles_service_error(self): + """Test decomposition handles service returning error.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "success": False, + "error": "Internal error", + "error_type": "internal_error", + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + with patch.object(service, "_get_client", return_value=mock_client): + result = await service.decompose_goal_external("Build a chatbot") + + assert result is not None + assert result.get("type") == "error" + assert result.get("error") == "Internal error" + assert result.get("error_type") == "internal_error" + + +class TestGenerateAgentExternal: + """Test generate_agent_external function.""" + + def setup_method(self): + """Reset client singleton before each test.""" + service._settings = None + service._client = None + + @pytest.mark.asyncio + async def test_generate_agent_success(self): + """Test successful agent generation.""" + agent_json = { + "name": "Test Agent", + "nodes": [], + "links": [], + } + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "agent_json": agent_json, + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + instructions = {"type": "instructions", "steps": ["Step 1"]} + + with patch.object(service, "_get_client", return_value=mock_client): + result = await service.generate_agent_external(instructions) + + assert result == agent_json + mock_client.post.assert_called_once_with( + "/api/generate-agent", json={"instructions": instructions} + ) + + @pytest.mark.asyncio + async def test_generate_agent_handles_error(self): + """Test agent generation handles errors gracefully.""" + mock_client = AsyncMock() + mock_client.post.side_effect = httpx.RequestError("Connection failed") + + with patch.object(service, "_get_client", return_value=mock_client): + result = await service.generate_agent_external({"steps": []}) + + assert result is not None + assert result.get("type") == "error" + assert result.get("error_type") == "connection_error" + assert "Connection failed" in result.get("error", "") + + +class TestGenerateAgentPatchExternal: + """Test generate_agent_patch_external function.""" + + def setup_method(self): + """Reset client singleton before each test.""" + service._settings = None + service._client = None + + @pytest.mark.asyncio + async def test_generate_patch_returns_updated_agent(self): + """Test successful patch generation returning updated agent.""" + updated_agent = { + "name": "Updated Agent", + "nodes": [{"id": "1", "block_id": "test"}], + "links": [], + } + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "agent_json": updated_agent, + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + current_agent = {"name": "Old Agent", "nodes": [], "links": []} + + with patch.object(service, "_get_client", return_value=mock_client): + result = await service.generate_agent_patch_external( + "Add a new node", current_agent + ) + + assert result == updated_agent + mock_client.post.assert_called_once_with( + "/api/update-agent", + json={ + "update_request": "Add a new node", + "current_agent_json": current_agent, + }, + ) + + @pytest.mark.asyncio + async def test_generate_patch_returns_clarifying_questions(self): + """Test patch generation returning clarifying questions.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "type": "clarifying_questions", + "questions": ["What type of node?"], + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + with patch.object(service, "_get_client", return_value=mock_client): + result = await service.generate_agent_patch_external( + "Add something", {"nodes": []} + ) + + assert result == { + "type": "clarifying_questions", + "questions": ["What type of node?"], + } + + +class TestHealthCheck: + """Test health_check function.""" + + def setup_method(self): + """Reset singletons before each test.""" + service._settings = None + service._client = None + + @pytest.mark.asyncio + async def test_health_check_returns_false_when_not_configured(self): + """Test health check returns False when service not configured.""" + with patch.object( + service, "is_external_service_configured", return_value=False + ): + result = await service.health_check() + assert result is False + + @pytest.mark.asyncio + async def test_health_check_returns_true_when_healthy(self): + """Test health check returns True when service is healthy.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "status": "healthy", + "blocks_loaded": True, + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + with patch.object(service, "is_external_service_configured", return_value=True): + with patch.object(service, "_get_client", return_value=mock_client): + result = await service.health_check() + + assert result is True + mock_client.get.assert_called_once_with("/health") + + @pytest.mark.asyncio + async def test_health_check_returns_false_when_not_healthy(self): + """Test health check returns False when service is not healthy.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "status": "unhealthy", + "blocks_loaded": False, + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + with patch.object(service, "is_external_service_configured", return_value=True): + with patch.object(service, "_get_client", return_value=mock_client): + result = await service.health_check() + + assert result is False + + @pytest.mark.asyncio + async def test_health_check_returns_false_on_error(self): + """Test health check returns False on connection error.""" + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.RequestError("Connection failed") + + with patch.object(service, "is_external_service_configured", return_value=True): + with patch.object(service, "_get_client", return_value=mock_client): + result = await service.health_check() + + assert result is False + + +class TestGetBlocksExternal: + """Test get_blocks_external function.""" + + def setup_method(self): + """Reset client singleton before each test.""" + service._settings = None + service._client = None + + @pytest.mark.asyncio + async def test_get_blocks_success(self): + """Test successful blocks retrieval.""" + blocks = [ + {"id": "block1", "name": "Block 1"}, + {"id": "block2", "name": "Block 2"}, + ] + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "blocks": blocks, + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + with patch.object(service, "_get_client", return_value=mock_client): + result = await service.get_blocks_external() + + assert result == blocks + mock_client.get.assert_called_once_with("/api/blocks") + + @pytest.mark.asyncio + async def test_get_blocks_handles_error(self): + """Test blocks retrieval handles errors gracefully.""" + mock_client = AsyncMock() + mock_client.get.side_effect = httpx.RequestError("Connection failed") + + with patch.object(service, "_get_client", return_value=mock_client): + result = await service.get_blocks_external() + + assert result is None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/autogpt_platform/frontend/.env.default b/autogpt_platform/frontend/.env.default index 197a37e8bb..7a9d81e39e 100644 --- a/autogpt_platform/frontend/.env.default +++ b/autogpt_platform/frontend/.env.default @@ -30,3 +30,10 @@ NEXT_PUBLIC_TURNSTILE=disabled # PR previews NEXT_PUBLIC_PREVIEW_STEALING_DEV= + +# PostHog Analytics +NEXT_PUBLIC_POSTHOG_KEY= +NEXT_PUBLIC_POSTHOG_HOST=https://eu.i.posthog.com + +# OpenAI (for voice transcription) +OPENAI_API_KEY= diff --git a/autogpt_platform/frontend/CLAUDE.md b/autogpt_platform/frontend/CLAUDE.md new file mode 100644 index 0000000000..b58f1ad6aa --- /dev/null +++ b/autogpt_platform/frontend/CLAUDE.md @@ -0,0 +1,76 @@ +# CLAUDE.md - Frontend + +This file provides guidance to Claude Code when working with the frontend. + +## Essential Commands + +```bash +# Install dependencies +pnpm i + +# Generate API client from OpenAPI spec +pnpm generate:api + +# Start development server +pnpm dev + +# Run E2E tests +pnpm test + +# Run Storybook for component development +pnpm storybook + +# Build production +pnpm build + +# Format and lint +pnpm format + +# Type checking +pnpm types +``` + +### Code Style + +- Fully capitalize acronyms in symbols, e.g. `graphID`, `useBackendAPI` +- Use function declarations (not arrow functions) for components/handlers + +## Architecture + +- **Framework**: Next.js 15 App Router (client-first approach) +- **Data Fetching**: Type-safe generated API hooks via Orval + React Query +- **State Management**: React Query for server state, co-located UI state in components/hooks +- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks) +- **Workflow Builder**: Visual graph editor using @xyflow/react +- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling +- **Icons**: Phosphor Icons only +- **Feature Flags**: LaunchDarkly integration +- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions +- **Testing**: Playwright for E2E, Storybook for component development + +## Environment Configuration + +`.env.default` (defaults) → `.env` (user overrides) + +## Feature Development + +See @CONTRIBUTING.md for complete patterns. Quick reference: + +1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx` + - Extract component logic into custom hooks grouped by concern, not by component. Each hook should represent a cohesive domain of functionality (e.g., useSearch, useFilters, usePagination) rather than bundling all state into one useComponentState hook. + - Put each hook in its own `.ts` file + - Put sub-components in local `components/` folder + - Component props should be `type Props = { ... }` (not exported) unless it needs to be used outside the component +2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts` + - Use design system components from `src/components/` (atoms, molecules, organisms) + - Never use `src/components/__legacy__/*` +3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/` + - Regenerate with `pnpm generate:api` + - Pattern: `use{Method}{Version}{OperationName}` +4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only +5. **Testing**: Add Storybook stories for new components, Playwright for E2E +6. **Code conventions**: + - Use function declarations (not arrow functions) for components/handlers + - Do not use `useCallback` or `useMemo` unless asked to optimise a given function + - Do not type hook returns, let Typescript infer as much as possible + - Never type with `any` unless a variable/attribute can ACTUALLY be of any type diff --git a/autogpt_platform/frontend/package.json b/autogpt_platform/frontend/package.json index bc1e2d7443..f22a182d20 100644 --- a/autogpt_platform/frontend/package.json +++ b/autogpt_platform/frontend/package.json @@ -34,6 +34,7 @@ "@hookform/resolvers": "5.2.2", "@next/third-parties": "15.4.6", "@phosphor-icons/react": "2.1.10", + "@posthog/react": "1.7.0", "@radix-ui/react-accordion": "1.2.12", "@radix-ui/react-alert-dialog": "1.1.15", "@radix-ui/react-avatar": "1.1.10", @@ -91,6 +92,7 @@ "next-themes": "0.4.6", "nuqs": "2.7.2", "party-js": "2.2.0", + "posthog-js": "1.334.1", "react": "18.3.1", "react-currency-input-field": "4.0.3", "react-day-picker": "9.11.1", @@ -120,7 +122,6 @@ }, "devDependencies": { "@chromatic-com/storybook": "4.1.2", - "happy-dom": "20.3.4", "@opentelemetry/instrumentation": "0.209.0", "@playwright/test": "1.56.1", "@storybook/addon-a11y": "9.1.5", @@ -148,6 +149,7 @@ "eslint": "8.57.1", "eslint-config-next": "15.5.7", "eslint-plugin-storybook": "9.1.5", + "happy-dom": "20.3.4", "import-in-the-middle": "2.0.2", "msw": "2.11.6", "msw-storybook-addon": "2.0.6", diff --git a/autogpt_platform/frontend/pnpm-lock.yaml b/autogpt_platform/frontend/pnpm-lock.yaml index 8e83289f03..db891ccf3f 100644 --- a/autogpt_platform/frontend/pnpm-lock.yaml +++ b/autogpt_platform/frontend/pnpm-lock.yaml @@ -23,6 +23,9 @@ importers: '@phosphor-icons/react': specifier: 2.1.10 version: 2.1.10(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@posthog/react': + specifier: 1.7.0 + version: 1.7.0(@types/react@18.3.17)(posthog-js@1.334.1)(react@18.3.1) '@radix-ui/react-accordion': specifier: 1.2.12 version: 1.2.12(@types/react-dom@18.3.5(@types/react@18.3.17))(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) @@ -194,6 +197,9 @@ importers: party-js: specifier: 2.2.0 version: 2.2.0 + posthog-js: + specifier: 1.334.1 + version: 1.334.1 react: specifier: 18.3.1 version: 18.3.1 @@ -1794,6 +1800,10 @@ packages: '@open-draft/until@2.1.0': resolution: {integrity: sha512-U69T3ItWHvLwGg5eJ0n3I62nWuE6ilHlmz7zM0npLBRvPRd7e6NYmg54vvRtP5mZG7kZqZCFVdsTWo7BPtBujg==} + '@opentelemetry/api-logs@0.208.0': + resolution: {integrity: sha512-CjruKY9V6NMssL/T1kAFgzosF1v9o6oeN+aX5JB/C/xPNtmgIJqcXHG7fA82Ou1zCpWGl4lROQUKwUNE1pMCyg==} + engines: {node: '>=8.0.0'} + '@opentelemetry/api-logs@0.209.0': resolution: {integrity: sha512-xomnUNi7TiAGtOgs0tb54LyrjRZLu9shJGGwkcN7NgtiPYOpNnKLkRJtzZvTjD/w6knSZH9sFZcUSUovYOPg6A==} engines: {node: '>=8.0.0'} @@ -1814,6 +1824,12 @@ packages: peerDependencies: '@opentelemetry/api': '>=1.0.0 <1.10.0' + '@opentelemetry/exporter-logs-otlp-http@0.208.0': + resolution: {integrity: sha512-jOv40Bs9jy9bZVLo/i8FwUiuCvbjWDI+ZW13wimJm4LjnlwJxGgB+N/VWOZUTpM+ah/awXeQqKdNlpLf2EjvYg==} + engines: {node: ^18.19.0 || >=20.6.0} + peerDependencies: + '@opentelemetry/api': ^1.3.0 + '@opentelemetry/instrumentation-amqplib@0.55.0': resolution: {integrity: sha512-5ULoU8p+tWcQw5PDYZn8rySptGSLZHNX/7srqo2TioPnAAcvTy6sQFQXsNPrAnyRRtYGMetXVyZUy5OaX1+IfA==} engines: {node: ^18.19.0 || >=20.6.0} @@ -1952,6 +1968,18 @@ packages: peerDependencies: '@opentelemetry/api': ^1.3.0 + '@opentelemetry/otlp-exporter-base@0.208.0': + resolution: {integrity: sha512-gMd39gIfVb2OgxldxUtOwGJYSH8P1kVFFlJLuut32L6KgUC4gl1dMhn+YC2mGn0bDOiQYSk/uHOdSjuKp58vvA==} + engines: {node: ^18.19.0 || >=20.6.0} + peerDependencies: + '@opentelemetry/api': ^1.3.0 + + '@opentelemetry/otlp-transformer@0.208.0': + resolution: {integrity: sha512-DCFPY8C6lAQHUNkzcNT9R+qYExvsk6C5Bto2pbNxgicpcSWbe2WHShLxkOxIdNcBiYPdVHv/e7vH7K6TI+C+fQ==} + engines: {node: ^18.19.0 || >=20.6.0} + peerDependencies: + '@opentelemetry/api': ^1.3.0 + '@opentelemetry/redis-common@0.38.2': resolution: {integrity: sha512-1BCcU93iwSRZvDAgwUxC/DV4T/406SkMfxGqu5ojc3AvNI+I9GhV7v0J1HljsczuuhcnFLYqD5VmwVXfCGHzxA==} engines: {node: ^18.19.0 || >=20.6.0} @@ -1962,6 +1990,18 @@ packages: peerDependencies: '@opentelemetry/api': '>=1.3.0 <1.10.0' + '@opentelemetry/sdk-logs@0.208.0': + resolution: {integrity: sha512-QlAyL1jRpOeaqx7/leG1vJMp84g0xKP6gJmfELBpnI4O/9xPX+Hu5m1POk9Kl+veNkyth5t19hRlN6tNY1sjbA==} + engines: {node: ^18.19.0 || >=20.6.0} + peerDependencies: + '@opentelemetry/api': '>=1.4.0 <1.10.0' + + '@opentelemetry/sdk-metrics@2.2.0': + resolution: {integrity: sha512-G5KYP6+VJMZzpGipQw7Giif48h6SGQ2PFKEYCybeXJsOCB4fp8azqMAAzE5lnnHK3ZVwYQrgmFbsUJO/zOnwGw==} + engines: {node: ^18.19.0 || >=20.6.0} + peerDependencies: + '@opentelemetry/api': '>=1.9.0 <1.10.0' + '@opentelemetry/sdk-trace-base@2.2.0': resolution: {integrity: sha512-xWQgL0Bmctsalg6PaXExmzdedSp3gyKV8mQBwK/j9VGdCDu2fmXIb2gAehBKbkXCpJ4HPkgv3QfoJWRT4dHWbw==} engines: {node: ^18.19.0 || >=20.6.0} @@ -2050,11 +2090,57 @@ packages: webpack-plugin-serve: optional: true + '@posthog/core@1.13.0': + resolution: {integrity: sha512-knjncrk7qRmssFRbGzBl1Tunt21GRpe0Wv+uVelyL0Rh7PdQUsgguulzXFTps8hA6wPwTU4kq85qnbAJ3eH6Wg==} + + '@posthog/react@1.7.0': + resolution: {integrity: sha512-pM7GL7z/rKjiIwosbRiQA3buhLI6vUo+wg+T/ZrVZC7O5bVU07TfgNZTcuOj8E9dx7vDbfNrc1kjDN7PKMM8ug==} + peerDependencies: + '@types/react': '>=16.8.0' + posthog-js: '>=1.257.2' + react: '>=16.8.0' + peerDependenciesMeta: + '@types/react': + optional: true + + '@posthog/types@1.334.1': + resolution: {integrity: sha512-ypFnwTO7qbV7icylLbujbamPdQXbJq0a61GUUBnJAeTbBw/qYPIss5IRYICcbCj0uunQrwD7/CGxVb5TOYKWgA==} + '@prisma/instrumentation@6.19.0': resolution: {integrity: sha512-QcuYy25pkXM8BJ37wVFBO7Zh34nyRV1GOb2n3lPkkbRYfl4hWl3PTcImP41P0KrzVXfa/45p6eVCos27x3exIg==} peerDependencies: '@opentelemetry/api': ^1.8 + '@protobufjs/aspromise@1.1.2': + resolution: {integrity: sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ==} + + '@protobufjs/base64@1.1.2': + resolution: {integrity: sha512-AZkcAA5vnN/v4PDqKyMR5lx7hZttPDgClv83E//FMNhR2TMcLUhfRUBHCmSl0oi9zMgDDqRUJkSxO3wm85+XLg==} + + '@protobufjs/codegen@2.0.4': + resolution: {integrity: sha512-YyFaikqM5sH0ziFZCN3xDC7zeGaB/d0IUb9CATugHWbd1FRFwWwt4ld4OYMPWu5a3Xe01mGAULCdqhMlPl29Jg==} + + '@protobufjs/eventemitter@1.1.0': + resolution: {integrity: sha512-j9ednRT81vYJ9OfVuXG6ERSTdEL1xVsNgqpkxMsbIabzSo3goCjDIveeGv5d03om39ML71RdmrGNjG5SReBP/Q==} + + '@protobufjs/fetch@1.1.0': + resolution: {integrity: sha512-lljVXpqXebpsijW71PZaCYeIcE5on1w5DlQy5WH6GLbFryLUrBD4932W/E2BSpfRJWseIL4v/KPgBFxDOIdKpQ==} + + '@protobufjs/float@1.0.2': + resolution: {integrity: sha512-Ddb+kVXlXst9d+R9PfTIxh1EdNkgoRe5tOX6t01f1lYWOvJnSPDBlG241QLzcyPdoNTsblLUdujGSE4RzrTZGQ==} + + '@protobufjs/inquire@1.1.0': + resolution: {integrity: sha512-kdSefcPdruJiFMVSbn801t4vFK7KB/5gd2fYvrxhuJYg8ILrmn9SKSX2tZdV6V+ksulWqS7aXjBcRXl3wHoD9Q==} + + '@protobufjs/path@1.1.2': + resolution: {integrity: sha512-6JOcJ5Tm08dOHAbdR3GrvP+yUUfkjG5ePsHYczMFLq3ZmMkAD98cDgcT2iA1lJ9NVwFd4tH/iSSoe44YWkltEA==} + + '@protobufjs/pool@1.1.0': + resolution: {integrity: sha512-0kELaGSIDBKvcgS4zkjz1PeddatrjYcmMWOlAuAPwAeccUrPHdUqo/J6LiymHHEiJT5NrF1UVwxY14f+fy4WQw==} + + '@protobufjs/utf8@1.1.0': + resolution: {integrity: sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==} + '@radix-ui/number@1.1.1': resolution: {integrity: sha512-MkKCwxlXTgz6CFoJx3pCwn07GKp36+aZyu/u2Ln2VrA5DcdyCZkASEDBTd8x5whTQQL5CiYf4prXKLcgQdv29g==} @@ -3401,6 +3487,9 @@ packages: '@types/tedious@4.0.14': resolution: {integrity: sha512-KHPsfX/FoVbUGbyYvk1q9MMQHLPeRZhRJZdO45Q4YjvFkv4hMNghCWTvy7rdKessBsmtz4euWCWAB6/tVpI1Iw==} + '@types/trusted-types@2.0.7': + resolution: {integrity: sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==} + '@types/unist@2.0.11': resolution: {integrity: sha512-CmBKiL6NNo/OqgmMn95Fk9Whlp2mtvIv+KNpQKN2F4SjvrEesubTRWGYSg+BnWZOnlCaSTU1sMpsBOzgbYhnsA==} @@ -4278,6 +4367,9 @@ packages: core-js-pure@3.47.0: resolution: {integrity: sha512-BcxeDbzUrRnXGYIVAGFtcGQVNpFcUhVjr6W7F8XktvQW2iJP9e66GP6xdKotCRFlrxBvNIBrhwKteRXqMV86Nw==} + core-js@3.48.0: + resolution: {integrity: sha512-zpEHTy1fjTMZCKLHUZoVeylt9XrzaIN2rbPXEt0k+q7JE5CkCZdo6bNq55bn24a69CH7ErAVLKijxJja4fw+UQ==} + core-util-is@1.0.3: resolution: {integrity: sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==} @@ -4569,6 +4661,9 @@ packages: resolution: {integrity: sha512-GrwoxYN+uWlzO8uhUXRl0P+kHE4GtVPfYzVLcUxPL7KNdHKj66vvlhiweIHqYYXWlw+T8iLMp42Lm67ghw4WMQ==} engines: {node: '>= 4'} + dompurify@3.3.1: + resolution: {integrity: sha512-qkdCKzLNtrgPFP1Vo+98FRzJnBRGe4ffyCea9IwHB1fyxPOeNTHpLKYGd4Uk9xvNoH0ZoOjwZxNptyMwqrId1Q==} + domutils@2.8.0: resolution: {integrity: sha512-w96Cjofp72M5IIhpjgobBimYEfoPjx1Vx0BSX9P30WBdZW2WIKU0T1Bd0kz2eNZ9ikjKgHbEyKx8BB6H1L3h3A==} @@ -4939,6 +5034,9 @@ packages: picomatch: optional: true + fflate@0.4.8: + resolution: {integrity: sha512-FJqqoDBR00Mdj9ppamLa/Y7vxm+PRmNWA67N846RvsoYVMKB4q3y/de5PA7gUmRMYK/8CMz2GDZQmCRN1wBcWA==} + file-entry-cache@6.0.1: resolution: {integrity: sha512-7Gps/XWymbLk2QLYK4NzpMOrYjMhdIxXuIvy2QBsLE6ljuodKvdkWs/cpyJJ3CVIVpH0Oi1Hvg1ovbMzLdFBBg==} engines: {node: ^10.12.0 || >=12.0.0} @@ -5745,6 +5843,9 @@ packages: resolution: {integrity: sha512-HgMmCqIJSAKqo68l0rS2AanEWfkxaZ5wNiEFb5ggm08lDs9Xl2KxBlX3PTcaD2chBM1gXAYf491/M2Rv8Jwayg==} engines: {node: '>= 0.6.0'} + long@5.3.2: + resolution: {integrity: sha512-mNAgZ1GmyNhD7AuqnTG3/VQ26o760+ZYBPKjPvugO8+nLbYfX6TVpJPseBvopbdY+qpZ/lKUnmEc1LeZYS3QAA==} + longest-streak@3.1.0: resolution: {integrity: sha512-9Ri+o0JYgehTaVBBDoMqIl8GXtbWg711O3srftcHhZ0dqnETqLaoIK0x17fUw9rFSlK/0NlsKe0Ahhyl5pXE2g==} @@ -6534,6 +6635,12 @@ packages: resolution: {integrity: sha512-9ZhXKM/rw350N1ovuWHbGxnGh/SNJ4cnxHiM0rxE4VN41wsg8P8zWn9hv/buK00RP4WvlOyr/RBDiptyxVbkZQ==} engines: {node: '>=0.10.0'} + posthog-js@1.334.1: + resolution: {integrity: sha512-5cDzLICr2afnwX/cR9fwoLC0vN0Nb5gP5HiCigzHkgHdO+E3WsYefla3EFMQz7U4r01CBPZ+nZ9/srkzeACxtQ==} + + preact@10.28.2: + resolution: {integrity: sha512-lbteaWGzGHdlIuiJ0l2Jq454m6kcpI1zNje6d8MlGAFlYvP2GO4ibnat7P74Esfz4sPTdM6UxtTwh/d3pwM9JA==} + prelude-ls@1.2.1: resolution: {integrity: sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==} engines: {node: '>= 0.8.0'} @@ -6622,6 +6729,10 @@ packages: property-information@7.1.0: resolution: {integrity: sha512-TwEZ+X+yCJmYfL7TPUOcvBZ4QfoT5YenQiJuX//0th53DE6w0xxLEtfK3iyryQFddXuvkIk51EEgrJQ0WJkOmQ==} + protobufjs@7.5.4: + resolution: {integrity: sha512-CvexbZtbov6jW2eXAvLukXjXUW1TzFaivC46BpWc/3BpcCysb5Vffu+B3XHMm8lVEuy2Mm4XGex8hBSg1yapPg==} + engines: {node: '>=12.0.0'} + proxy-from-env@1.1.0: resolution: {integrity: sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==} @@ -6643,6 +6754,9 @@ packages: resolution: {integrity: sha512-4EK3+xJl8Ts67nLYNwqw/dsFVnCf+qR7RgXSK9jEEm9unao3njwMDdmsdvoKBKHzxd7tCYz5e5M+SnMjdtXGQQ==} engines: {node: '>=0.6'} + query-selector-shadow-dom@1.0.1: + resolution: {integrity: sha512-lT5yCqEBgfoMYpf3F2xQRK7zEr1rhIIZuceDK6+xRkJQ4NMbHTwXqk4NkwDwQMNqXgG9r9fyHnzwNVs6zV5KRw==} + querystring-es3@0.2.1: resolution: {integrity: sha512-773xhDQnZBMFobEiztv8LIl70ch5MSF/jUQVlhwFyBILqq96anmoctVIYz+ZRp0qbCKATTn6ev02M3r7Ga5vqA==} engines: {node: '>=0.4.x'} @@ -7821,6 +7935,9 @@ packages: web-namespaces@2.0.1: resolution: {integrity: sha512-bKr1DkiNa2krS7qxNtdrtHAmzuYGFQLiQ13TsorsdT6ULTkPLKuu5+GsFpDlg6JFjUTwX2DyhMPG2be8uPrqsQ==} + web-vitals@5.1.0: + resolution: {integrity: sha512-ArI3kx5jI0atlTtmV0fWU3fjpLmq/nD3Zr1iFFlJLaqa5wLBkUSzINwBPySCX/8jRyjlmy1Volw1kz1g9XE4Jg==} + webidl-conversions@3.0.1: resolution: {integrity: sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==} @@ -9420,6 +9537,10 @@ snapshots: '@open-draft/until@2.1.0': {} + '@opentelemetry/api-logs@0.208.0': + dependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/api-logs@0.209.0': dependencies: '@opentelemetry/api': 1.9.0 @@ -9435,6 +9556,15 @@ snapshots: '@opentelemetry/api': 1.9.0 '@opentelemetry/semantic-conventions': 1.38.0 + '@opentelemetry/exporter-logs-otlp-http@0.208.0(@opentelemetry/api@1.9.0)': + dependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/api-logs': 0.208.0 + '@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0) + '@opentelemetry/otlp-exporter-base': 0.208.0(@opentelemetry/api@1.9.0) + '@opentelemetry/otlp-transformer': 0.208.0(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-logs': 0.208.0(@opentelemetry/api@1.9.0) + '@opentelemetry/instrumentation-amqplib@0.55.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 @@ -9629,6 +9759,23 @@ snapshots: transitivePeerDependencies: - supports-color + '@opentelemetry/otlp-exporter-base@0.208.0(@opentelemetry/api@1.9.0)': + dependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0) + '@opentelemetry/otlp-transformer': 0.208.0(@opentelemetry/api@1.9.0) + + '@opentelemetry/otlp-transformer@0.208.0(@opentelemetry/api@1.9.0)': + dependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/api-logs': 0.208.0 + '@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0) + '@opentelemetry/resources': 2.2.0(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-logs': 0.208.0(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-metrics': 2.2.0(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-trace-base': 2.2.0(@opentelemetry/api@1.9.0) + protobufjs: 7.5.4 + '@opentelemetry/redis-common@0.38.2': {} '@opentelemetry/resources@2.2.0(@opentelemetry/api@1.9.0)': @@ -9637,6 +9784,19 @@ snapshots: '@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0) '@opentelemetry/semantic-conventions': 1.38.0 + '@opentelemetry/sdk-logs@0.208.0(@opentelemetry/api@1.9.0)': + dependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/api-logs': 0.208.0 + '@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0) + '@opentelemetry/resources': 2.2.0(@opentelemetry/api@1.9.0) + + '@opentelemetry/sdk-metrics@2.2.0(@opentelemetry/api@1.9.0)': + dependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/core': 2.2.0(@opentelemetry/api@1.9.0) + '@opentelemetry/resources': 2.2.0(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 @@ -9801,6 +9961,19 @@ snapshots: type-fest: 4.41.0 webpack-hot-middleware: 2.26.1 + '@posthog/core@1.13.0': + dependencies: + cross-spawn: 7.0.6 + + '@posthog/react@1.7.0(@types/react@18.3.17)(posthog-js@1.334.1)(react@18.3.1)': + dependencies: + posthog-js: 1.334.1 + react: 18.3.1 + optionalDependencies: + '@types/react': 18.3.17 + + '@posthog/types@1.334.1': {} + '@prisma/instrumentation@6.19.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 @@ -9808,6 +9981,29 @@ snapshots: transitivePeerDependencies: - supports-color + '@protobufjs/aspromise@1.1.2': {} + + '@protobufjs/base64@1.1.2': {} + + '@protobufjs/codegen@2.0.4': {} + + '@protobufjs/eventemitter@1.1.0': {} + + '@protobufjs/fetch@1.1.0': + dependencies: + '@protobufjs/aspromise': 1.1.2 + '@protobufjs/inquire': 1.1.0 + + '@protobufjs/float@1.0.2': {} + + '@protobufjs/inquire@1.1.0': {} + + '@protobufjs/path@1.1.2': {} + + '@protobufjs/pool@1.1.0': {} + + '@protobufjs/utf8@1.1.0': {} + '@radix-ui/number@1.1.1': {} '@radix-ui/primitive@1.1.3': {} @@ -11426,6 +11622,9 @@ snapshots: dependencies: '@types/node': 24.10.0 + '@types/trusted-types@2.0.7': + optional: true + '@types/unist@2.0.11': {} '@types/unist@3.0.3': {} @@ -12327,6 +12526,8 @@ snapshots: core-js-pure@3.47.0: {} + core-js@3.48.0: {} + core-util-is@1.0.3: {} cosmiconfig@7.1.0: @@ -12636,6 +12837,10 @@ snapshots: dependencies: domelementtype: 2.3.0 + dompurify@3.3.1: + optionalDependencies: + '@types/trusted-types': 2.0.7 + domutils@2.8.0: dependencies: dom-serializer: 1.4.1 @@ -13205,6 +13410,8 @@ snapshots: optionalDependencies: picomatch: 4.0.3 + fflate@0.4.8: {} + file-entry-cache@6.0.1: dependencies: flat-cache: 3.2.0 @@ -14092,6 +14299,8 @@ snapshots: loglevel@1.9.2: {} + long@5.3.2: {} + longest-streak@3.1.0: {} loose-envify@1.4.0: @@ -15154,6 +15363,24 @@ snapshots: dependencies: xtend: 4.0.2 + posthog-js@1.334.1: + dependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/api-logs': 0.208.0 + '@opentelemetry/exporter-logs-otlp-http': 0.208.0(@opentelemetry/api@1.9.0) + '@opentelemetry/resources': 2.2.0(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-logs': 0.208.0(@opentelemetry/api@1.9.0) + '@posthog/core': 1.13.0 + '@posthog/types': 1.334.1 + core-js: 3.48.0 + dompurify: 3.3.1 + fflate: 0.4.8 + preact: 10.28.2 + query-selector-shadow-dom: 1.0.1 + web-vitals: 5.1.0 + + preact@10.28.2: {} + prelude-ls@1.2.1: {} prettier-plugin-tailwindcss@0.7.1(prettier@3.6.2): @@ -15187,6 +15414,21 @@ snapshots: property-information@7.1.0: {} + protobufjs@7.5.4: + dependencies: + '@protobufjs/aspromise': 1.1.2 + '@protobufjs/base64': 1.1.2 + '@protobufjs/codegen': 2.0.4 + '@protobufjs/eventemitter': 1.1.0 + '@protobufjs/fetch': 1.1.0 + '@protobufjs/float': 1.0.2 + '@protobufjs/inquire': 1.1.0 + '@protobufjs/path': 1.1.2 + '@protobufjs/pool': 1.1.0 + '@protobufjs/utf8': 1.1.0 + '@types/node': 24.10.0 + long: 5.3.2 + proxy-from-env@1.1.0: {} public-encrypt@4.0.3: @@ -15208,6 +15450,8 @@ snapshots: dependencies: side-channel: 1.1.0 + query-selector-shadow-dom@1.0.1: {} + querystring-es3@0.2.1: {} queue-microtask@1.2.3: {} @@ -16619,6 +16863,8 @@ snapshots: web-namespaces@2.0.1: {} + web-vitals@5.1.0: {} + webidl-conversions@3.0.1: {} webidl-conversions@8.0.1: diff --git a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx index 1ebfe6b87b..70d9783ccd 100644 --- a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx +++ b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx @@ -2,8 +2,9 @@ import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; import { useRouter } from "next/navigation"; import { useEffect } from "react"; -import { resolveResponse, shouldShowOnboarding } from "@/app/api/helpers"; +import { resolveResponse, getOnboardingStatus } from "@/app/api/helpers"; import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding"; +import { getHomepageRoute } from "@/lib/constants"; export default function OnboardingPage() { const router = useRouter(); @@ -11,10 +12,13 @@ export default function OnboardingPage() { useEffect(() => { async function redirectToStep() { try { - // Check if onboarding is enabled - const isEnabled = await shouldShowOnboarding(); - if (!isEnabled) { - router.replace("/"); + // Check if onboarding is enabled (also gets chat flag for redirect) + const { shouldShowOnboarding, isChatEnabled } = + await getOnboardingStatus(); + const homepageRoute = getHomepageRoute(isChatEnabled); + + if (!shouldShowOnboarding) { + router.replace(homepageRoute); return; } @@ -22,7 +26,7 @@ export default function OnboardingPage() { // Handle completed onboarding if (onboarding.completedSteps.includes("GET_RESULTS")) { - router.replace("/"); + router.replace(homepageRoute); return; } diff --git a/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts b/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts index a6a07a703f..15be137f63 100644 --- a/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts +++ b/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts @@ -1,8 +1,9 @@ import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase"; +import { getHomepageRoute } from "@/lib/constants"; import BackendAPI from "@/lib/autogpt-server-api"; import { NextResponse } from "next/server"; import { revalidatePath } from "next/cache"; -import { shouldShowOnboarding } from "@/app/api/helpers"; +import { getOnboardingStatus } from "@/app/api/helpers"; // Handle the callback to complete the user session login export async function GET(request: Request) { @@ -25,11 +26,15 @@ export async function GET(request: Request) { const api = new BackendAPI(); await api.createUser(); - if (await shouldShowOnboarding()) { + // Get onboarding status from backend (includes chat flag evaluated for this user) + const { shouldShowOnboarding, isChatEnabled } = + await getOnboardingStatus(); + if (shouldShowOnboarding) { next = "/onboarding"; revalidatePath("/onboarding", "layout"); } else { - revalidatePath("/", "layout"); + next = getHomepageRoute(isChatEnabled); + revalidatePath(next, "layout"); } } catch (createUserError) { console.error("Error creating user:", createUserError); diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/AgentOutputs/AgentOutputs.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/AgentOutputs/AgentOutputs.tsx index cfea5d9452..8ec1ba8be3 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/AgentOutputs/AgentOutputs.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/AgentOutputs/AgentOutputs.tsx @@ -38,8 +38,12 @@ export const AgentOutputs = ({ flowID }: { flowID: string | null }) => { return outputNodes .map((node) => { - const executionResult = node.data.nodeExecutionResult; - const outputData = executionResult?.output_data?.output; + const executionResults = node.data.nodeExecutionResults || []; + const latestResult = + executionResults.length > 0 + ? executionResults[executionResults.length - 1] + : undefined; + const outputData = latestResult?.output_data?.output; const renderer = globalRegistry.getRenderer(outputData); diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts index 0eba6e8188..629d4662a9 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/BuilderActions/components/RunInputDialog/useRunInputDialog.ts @@ -153,6 +153,9 @@ export const useRunInputDialog = ({ Object.entries(credentialValues).filter(([_, cred]) => cred && cred.id), ); + useNodeStore.getState().clearAllNodeExecutionResults(); + useNodeStore.getState().cleanNodesStatuses(); + await executeGraph({ graphId: flowID ?? "", graphVersion: flowVersion || null, diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FloatingSafeModeToogle.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FloatingSafeModeToogle.tsx index 6c8cbb1a86..227d892fff 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FloatingSafeModeToogle.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FloatingSafeModeToogle.tsx @@ -86,7 +86,6 @@ export function FloatingSafeModeToggle({ const { currentHITLSafeMode, showHITLToggle, - isHITLStateUndetermined, handleHITLToggle, currentSensitiveActionSafeMode, showSensitiveActionToggle, @@ -99,16 +98,9 @@ export function FloatingSafeModeToggle({ return null; } - const showHITL = showHITLToggle && !isHITLStateUndetermined; - const showSensitive = showSensitiveActionToggle; - - if (!showHITL && !showSensitive) { - return null; - } - return (
- {showHITL && ( + {showHITLToggle && ( )} - {showSensitive && ( + {showSensitiveActionToggle && ( > = React.memo( (value) => value !== null && value !== undefined && value !== "", ); - const outputData = data.nodeExecutionResult?.output_data; + const latestResult = + data.nodeExecutionResults && data.nodeExecutionResults.length > 0 + ? data.nodeExecutionResults[data.nodeExecutionResults.length - 1] + : undefined; + const outputData = latestResult?.output_data; const hasOutputError = typeof outputData === "object" && outputData !== null && diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/NodeOutput.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/NodeOutput.tsx index 17134ae299..c5df24e0e6 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/NodeOutput.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/NodeOutput.tsx @@ -14,10 +14,15 @@ import { useNodeOutput } from "./useNodeOutput"; import { ViewMoreData } from "./components/ViewMoreData"; export const NodeDataRenderer = ({ nodeId }: { nodeId: string }) => { - const { outputData, copiedKey, handleCopy, executionResultId, inputData } = - useNodeOutput(nodeId); + const { + latestOutputData, + copiedKey, + handleCopy, + executionResultId, + latestInputData, + } = useNodeOutput(nodeId); - if (Object.keys(outputData).length === 0) { + if (Object.keys(latestOutputData).length === 0) { return null; } @@ -41,18 +46,19 @@ export const NodeDataRenderer = ({ nodeId }: { nodeId: string }) => {
Input - +
- {Object.entries(outputData) + {Object.entries(latestOutputData) .slice(0, 2) - .map(([key, value]) => ( -
-
- - Pin: - - - {beautifyString(key)} - -
-
- - Data: - -
- {value.map((item, index) => ( -
- -
- ))} + .map(([key, value]) => { + return ( +
+
+ + Pin: + + + {beautifyString(key)} + +
+
+ + Data: + +
+ {value.map((item, index) => ( +
+ +
+ ))} -
- - +
+ + +
-
- ))} + ); + })}
- - {Object.keys(outputData).length > 2 && ( - - )} + diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx index 0858db8f0e..680b6bc44a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/NodeDataViewer.tsx @@ -19,22 +19,51 @@ import { CopyIcon, DownloadIcon, } from "@phosphor-icons/react"; -import { FC } from "react"; +import React, { FC } from "react"; import { useNodeDataViewer } from "./useNodeDataViewer"; +import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore"; +import { useShallow } from "zustand/react/shallow"; +import { NodeDataType } from "../../helpers"; -interface NodeDataViewerProps { - data: any; +export interface NodeDataViewerProps { + data?: any; pinName: string; + nodeId?: string; execId?: string; isViewMoreData?: boolean; + dataType?: NodeDataType; } export const NodeDataViewer: FC = ({ data, pinName, + nodeId, execId = "N/A", isViewMoreData = false, + dataType = "output", }) => { + const executionResults = useNodeStore( + useShallow((state) => + nodeId ? state.getNodeExecutionResults(nodeId) : [], + ), + ); + const latestInputData = useNodeStore( + useShallow((state) => + nodeId ? state.getLatestNodeInputData(nodeId) : undefined, + ), + ); + const accumulatedOutputData = useNodeStore( + useShallow((state) => + nodeId ? state.getAccumulatedNodeOutputData(nodeId) : {}, + ), + ); + + const resolvedData = + data ?? + (dataType === "input" + ? (latestInputData ?? {}) + : (accumulatedOutputData[pinName] ?? [])); + const { outputItems, copyExecutionId, @@ -42,7 +71,20 @@ export const NodeDataViewer: FC = ({ handleDownloadItem, dataArray, copiedIndex, - } = useNodeDataViewer(data, pinName, execId); + groupedExecutions, + totalGroupedItems, + handleCopyGroupedItem, + handleDownloadGroupedItem, + copiedKey, + } = useNodeDataViewer( + resolvedData, + pinName, + execId, + executionResults, + dataType, + ); + + const shouldGroupExecutions = groupedExecutions.length > 0; return ( @@ -68,44 +110,141 @@ export const NodeDataViewer: FC = ({
- Full Output Preview + Full {dataType === "input" ? "Input" : "Output"} Preview
- {dataArray.length} item{dataArray.length !== 1 ? "s" : ""} total + {shouldGroupExecutions ? totalGroupedItems : dataArray.length}{" "} + item + {shouldGroupExecutions + ? totalGroupedItems !== 1 + ? "s" + : "" + : dataArray.length !== 1 + ? "s" + : ""}{" "} + total
-
- - Execution ID: - - - {execId} - - -
-
- Pin:{" "} - {beautifyString(pinName)} -
+ {shouldGroupExecutions ? ( +
+ Pin:{" "} + {beautifyString(pinName)} +
+ ) : ( + <> +
+ + Execution ID: + + + {execId} + + +
+
+ Pin:{" "} + + {beautifyString(pinName)} + +
+ + )}
- {dataArray.length > 0 ? ( + {shouldGroupExecutions ? ( +
+ {groupedExecutions.map((execution) => ( +
+
+ + Execution ID: + + + {execution.execId} + +
+
+ {execution.outputItems.length > 0 ? ( + execution.outputItems.map((item, index) => ( +
+
+ +
+ +
+ + +
+
+ )) + ) : ( +
+ No data available +
+ )} +
+
+ ))} +
+ ) : dataArray.length > 0 ? (
{outputItems.map((item, index) => (
diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/useNodeDataViewer.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/useNodeDataViewer.ts index d3c555970c..818d1266c1 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/useNodeDataViewer.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/NodeDataViewer/useNodeDataViewer.ts @@ -1,82 +1,70 @@ -import type { OutputMetadata } from "@/components/contextual/OutputRenderers"; -import { globalRegistry } from "@/components/contextual/OutputRenderers"; import { downloadOutputs } from "@/components/contextual/OutputRenderers/utils/download"; import { useToast } from "@/components/molecules/Toast/use-toast"; import { beautifyString } from "@/lib/utils"; -import React, { useMemo, useState } from "react"; +import { useState } from "react"; +import type { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult"; +import { + NodeDataType, + createOutputItems, + getExecutionData, + normalizeToArray, + type OutputItem, +} from "../../helpers"; + +export type GroupedExecution = { + execId: string; + outputItems: Array; +}; export const useNodeDataViewer = ( data: any, pinName: string, execId: string, + executionResults?: NodeExecutionResult[], + dataType?: NodeDataType, ) => { const { toast } = useToast(); const [copiedIndex, setCopiedIndex] = useState(null); + const [copiedKey, setCopiedKey] = useState(null); - // Normalize data to array format - const dataArray = useMemo(() => { - return Array.isArray(data) ? data : [data]; - }, [data]); + const dataArray = Array.isArray(data) ? data : [data]; - // Prepare items for the enhanced renderer system - const outputItems = useMemo(() => { - if (!dataArray) return []; - - const items: Array<{ - key: string; - label: string; - value: unknown; - metadata?: OutputMetadata; - renderer: any; - }> = []; - - dataArray.forEach((value, index) => { - const metadata: OutputMetadata = {}; - - // Extract metadata from the value if it's an object - if ( - typeof value === "object" && - value !== null && - !React.isValidElement(value) - ) { - const objValue = value as any; - if (objValue.type) metadata.type = objValue.type; - if (objValue.mimeType) metadata.mimeType = objValue.mimeType; - if (objValue.filename) metadata.filename = objValue.filename; - if (objValue.language) metadata.language = objValue.language; - } - - const renderer = globalRegistry.getRenderer(value, metadata); - if (renderer) { - items.push({ - key: `item-${index}`, + const outputItems = + !dataArray || dataArray.length === 0 + ? [] + : createOutputItems(dataArray).map((item, index) => ({ + ...item, label: index === 0 ? beautifyString(pinName) : "", - value, - metadata, - renderer, - }); - } else { - // Fallback to text renderer - const textRenderer = globalRegistry - .getAllRenderers() - .find((r) => r.name === "TextRenderer"); - if (textRenderer) { - items.push({ - key: `item-${index}`, - label: index === 0 ? beautifyString(pinName) : "", - value: - typeof value === "string" - ? value - : JSON.stringify(value, null, 2), - metadata, - renderer: textRenderer, - }); - } - } - }); + })); - return items; - }, [dataArray, pinName]); + const groupedExecutions = + !executionResults || executionResults.length === 0 + ? [] + : [...executionResults].reverse().map((result) => { + const rawData = getExecutionData( + result, + dataType || "output", + pinName, + ); + let dataArray: unknown[]; + if (dataType === "input") { + dataArray = + rawData !== undefined && rawData !== null ? [rawData] : []; + } else { + dataArray = normalizeToArray(rawData); + } + + const outputItems = createOutputItems(dataArray); + return { + execId: result.node_exec_id, + outputItems, + }; + }); + + const totalGroupedItems = groupedExecutions.reduce( + (total, execution) => total + execution.outputItems.length, + 0, + ); const copyExecutionId = () => { navigator.clipboard.writeText(execId).then(() => { @@ -122,6 +110,45 @@ export const useNodeDataViewer = ( ]); }; + const handleCopyGroupedItem = async ( + execId: string, + index: number, + item: OutputItem, + ) => { + const copyContent = item.renderer.getCopyContent(item.value, item.metadata); + + if (!copyContent) { + return; + } + + try { + let text: string; + if (typeof copyContent.data === "string") { + text = copyContent.data; + } else if (copyContent.fallbackText) { + text = copyContent.fallbackText; + } else { + return; + } + + await navigator.clipboard.writeText(text); + setCopiedKey(`${execId}-${index}`); + setTimeout(() => setCopiedKey(null), 2000); + } catch (error) { + console.error("Failed to copy:", error); + } + }; + + const handleDownloadGroupedItem = (item: OutputItem) => { + downloadOutputs([ + { + value: item.value, + metadata: item.metadata, + renderer: item.renderer, + }, + ]); + }; + return { outputItems, dataArray, @@ -129,5 +156,10 @@ export const useNodeDataViewer = ( handleCopyItem, handleDownloadItem, copiedIndex, + groupedExecutions, + totalGroupedItems, + handleCopyGroupedItem, + handleDownloadGroupedItem, + copiedKey, }; }; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/ViewMoreData.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/ViewMoreData.tsx index 7bf026fe43..74d0da06c2 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/ViewMoreData.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/components/ViewMoreData.tsx @@ -8,16 +8,28 @@ import { useState } from "react"; import { NodeDataViewer } from "./NodeDataViewer/NodeDataViewer"; import { useToast } from "@/components/molecules/Toast/use-toast"; import { CheckIcon, CopyIcon } from "@phosphor-icons/react"; +import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore"; +import { useShallow } from "zustand/react/shallow"; +import { + NodeDataType, + getExecutionEntries, + normalizeToArray, +} from "../helpers"; export const ViewMoreData = ({ - outputData, - execId, + nodeId, + dataType = "output", }: { - outputData: Record>; - execId?: string; + nodeId: string; + dataType?: NodeDataType; }) => { const [copiedKey, setCopiedKey] = useState(null); const { toast } = useToast(); + const executionResults = useNodeStore( + useShallow((state) => state.getNodeExecutionResults(nodeId)), + ); + + const reversedExecutionResults = [...executionResults].reverse(); const handleCopy = (key: string, value: any) => { const textToCopy = @@ -29,8 +41,8 @@ export const ViewMoreData = ({ setTimeout(() => setCopiedKey(null), 2000); }; - const copyExecutionId = () => { - navigator.clipboard.writeText(execId || "N/A").then(() => { + const copyExecutionId = (executionId: string) => { + navigator.clipboard.writeText(executionId || "N/A").then(() => { toast({ title: "Execution ID copied to clipboard!", duration: 2000, @@ -42,7 +54,7 @@ export const ViewMoreData = ({ -
-
- {Object.entries(outputData).map(([key, value]) => ( -
+ {reversedExecutionResults.map((result) => ( +
+ + Execution ID: + - Pin: - - - {beautifyString(key)} + {result.node_exec_id} +
-
- - Data: - -
- {value.map((item, index) => ( -
- -
- ))} -
- - -
-
+
+ {getExecutionEntries(result, dataType).map( + ([key, value]) => { + const normalizedValue = normalizeToArray(value); + return ( +
+
+ + Pin: + + + {beautifyString(key)} + +
+
+ + Data: + +
+ {normalizedValue.map((item, index) => ( +
+ +
+ ))} + +
+ + +
+
+
+
+ ); + }, + )}
))} diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/helpers.ts new file mode 100644 index 0000000000..c75cd83cac --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/helpers.ts @@ -0,0 +1,83 @@ +import type { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult"; +import type { OutputMetadata } from "@/components/contextual/OutputRenderers"; +import { globalRegistry } from "@/components/contextual/OutputRenderers"; +import React from "react"; + +export type NodeDataType = "input" | "output"; + +export type OutputItem = { + key: string; + value: unknown; + metadata?: OutputMetadata; + renderer: any; +}; + +export const normalizeToArray = (value: unknown) => { + if (value === undefined) return []; + return Array.isArray(value) ? value : [value]; +}; + +export const getExecutionData = ( + result: NodeExecutionResult, + dataType: NodeDataType, + pinName: string, +) => { + if (dataType === "input") { + return result.input_data; + } + + return result.output_data?.[pinName]; +}; + +export const createOutputItems = (dataArray: unknown[]): Array => { + const items: Array = []; + + dataArray.forEach((value, index) => { + const metadata: OutputMetadata = {}; + + if ( + typeof value === "object" && + value !== null && + !React.isValidElement(value) + ) { + const objValue = value as any; + if (objValue.type) metadata.type = objValue.type; + if (objValue.mimeType) metadata.mimeType = objValue.mimeType; + if (objValue.filename) metadata.filename = objValue.filename; + if (objValue.language) metadata.language = objValue.language; + } + + const renderer = globalRegistry.getRenderer(value, metadata); + if (renderer) { + items.push({ + key: `item-${index}`, + value, + metadata, + renderer, + }); + } else { + const textRenderer = globalRegistry + .getAllRenderers() + .find((r) => r.name === "TextRenderer"); + if (textRenderer) { + items.push({ + key: `item-${index}`, + value: + typeof value === "string" ? value : JSON.stringify(value, null, 2), + metadata, + renderer: textRenderer, + }); + } + } + }); + + return items; +}; + +export const getExecutionEntries = ( + result: NodeExecutionResult, + dataType: NodeDataType, +) => { + const data = dataType === "input" ? result.input_data : result.output_data; + return Object.entries(data || {}); +}; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/useNodeOutput.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/useNodeOutput.tsx index cfc599c6e4..8ebf1dfaf3 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/useNodeOutput.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeOutput/useNodeOutput.tsx @@ -7,15 +7,18 @@ export const useNodeOutput = (nodeId: string) => { const [copiedKey, setCopiedKey] = useState(null); const { toast } = useToast(); - const nodeExecutionResult = useNodeStore( - useShallow((state) => state.getNodeExecutionResult(nodeId)), + const latestResult = useNodeStore( + useShallow((state) => state.getLatestNodeExecutionResult(nodeId)), ); - const inputData = nodeExecutionResult?.input_data; + const latestInputData = useNodeStore( + useShallow((state) => state.getLatestNodeInputData(nodeId)), + ); + + const latestOutputData: Record> = useNodeStore( + useShallow((state) => state.getLatestNodeOutputData(nodeId) || {}), + ); - const outputData: Record> = { - ...nodeExecutionResult?.output_data, - }; const handleCopy = async (key: string, value: any) => { try { const text = JSON.stringify(value, null, 2); @@ -35,11 +38,12 @@ export const useNodeOutput = (nodeId: string) => { }); } }; + return { - outputData, - inputData, + latestOutputData, + latestInputData, copiedKey, handleCopy, - executionResultId: nodeExecutionResult?.node_exec_id, + executionResultId: latestResult?.node_exec_id, }; }; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/SubAgentUpdate/useSubAgentUpdateState.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/SubAgentUpdate/useSubAgentUpdateState.ts index d4ba538172..143cd58509 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/SubAgentUpdate/useSubAgentUpdateState.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/SubAgentUpdate/useSubAgentUpdateState.ts @@ -1,10 +1,7 @@ import { useState, useCallback, useEffect } from "react"; import { useShallow } from "zustand/react/shallow"; import { useGraphStore } from "@/app/(platform)/build/stores/graphStore"; -import { - useNodeStore, - NodeResolutionData, -} from "@/app/(platform)/build/stores/nodeStore"; +import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore"; import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore"; import { useSubAgentUpdate, @@ -13,6 +10,7 @@ import { } from "@/app/(platform)/build/hooks/useSubAgentUpdate"; import { GraphInputSchema, GraphOutputSchema } from "@/lib/autogpt-server-api"; import { CustomNodeData } from "../../CustomNode"; +import { NodeResolutionData } from "@/app/(platform)/build/stores/types"; // Stable empty set to avoid creating new references in selectors const EMPTY_SET: Set = new Set(); diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/helpers.ts index 54ddf2a61d..50326a03e6 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/helpers.ts @@ -1,5 +1,5 @@ import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus"; -import { NodeResolutionData } from "@/app/(platform)/build/stores/nodeStore"; +import { NodeResolutionData } from "@/app/(platform)/build/stores/types"; import { RJSFSchema } from "@rjsf/utils"; export const nodeStyleBasedOnStatus: Record = { diff --git a/autogpt_platform/frontend/src/app/(platform)/build/stores/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/build/stores/helpers.ts new file mode 100644 index 0000000000..bcdfd4c313 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/build/stores/helpers.ts @@ -0,0 +1,16 @@ +export const accumulateExecutionData = ( + accumulated: Record, + data: Record | undefined, +) => { + if (!data) return { ...accumulated }; + const next = { ...accumulated }; + Object.entries(data).forEach(([key, values]) => { + const nextValues = Array.isArray(values) ? values : [values]; + if (next[key]) { + next[key] = [...next[key], ...nextValues]; + } else { + next[key] = [...nextValues]; + } + }); + return next; +}; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts b/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts index 5502a8780d..f7a52636f3 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/stores/nodeStore.ts @@ -10,6 +10,8 @@ import { import { Node } from "@/app/api/__generated__/models/node"; import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus"; import { NodeExecutionResult } from "@/app/api/__generated__/models/nodeExecutionResult"; +import { NodeExecutionResultInputData } from "@/app/api/__generated__/models/nodeExecutionResultInputData"; +import { NodeExecutionResultOutputData } from "@/app/api/__generated__/models/nodeExecutionResultOutputData"; import { useHistoryStore } from "./historyStore"; import { useEdgeStore } from "./edgeStore"; import { BlockUIType } from "../components/types"; @@ -18,31 +20,10 @@ import { ensurePathExists, parseHandleIdToPath, } from "@/components/renderers/InputRenderer/helpers"; -import { IncompatibilityInfo } from "../hooks/useSubAgentUpdate/types"; +import { accumulateExecutionData } from "./helpers"; +import { NodeResolutionData } from "./types"; -// Resolution mode data stored per node -export type NodeResolutionData = { - incompatibilities: IncompatibilityInfo; - // The NEW schema from the update (what we're updating TO) - pendingUpdate: { - input_schema: Record; - output_schema: Record; - }; - // The OLD schema before the update (what we're updating FROM) - // Needed to merge and show removed inputs during resolution - currentSchema: { - input_schema: Record; - output_schema: Record; - }; - // The full updated hardcoded values to apply when resolution completes - pendingHardcodedValues: Record; -}; - -// Minimum movement (in pixels) required before logging position change to history -// Prevents spamming history with small movements when clicking on inputs inside blocks const MINIMUM_MOVE_BEFORE_LOG = 50; - -// Track initial positions when drag starts (outside store to avoid re-renders) const dragStartPositions: Record = {}; let dragStartState: { nodes: CustomNode[]; edges: CustomEdge[] } | null = null; @@ -52,6 +33,15 @@ type NodeStore = { nodeCounter: number; setNodeCounter: (nodeCounter: number) => void; nodeAdvancedStates: Record; + + latestNodeInputData: Record; + latestNodeOutputData: Record< + string, + NodeExecutionResultOutputData | undefined + >; + accumulatedNodeInputData: Record>; + accumulatedNodeOutputData: Record>; + setNodes: (nodes: CustomNode[]) => void; onNodesChange: (changes: NodeChange[]) => void; addNode: (node: CustomNode) => void; @@ -72,12 +62,26 @@ type NodeStore = { updateNodeStatus: (nodeId: string, status: AgentExecutionStatus) => void; getNodeStatus: (nodeId: string) => AgentExecutionStatus | undefined; + cleanNodesStatuses: () => void; updateNodeExecutionResult: ( nodeId: string, result: NodeExecutionResult, ) => void; - getNodeExecutionResult: (nodeId: string) => NodeExecutionResult | undefined; + getNodeExecutionResults: (nodeId: string) => NodeExecutionResult[]; + getLatestNodeInputData: ( + nodeId: string, + ) => NodeExecutionResultInputData | undefined; + getLatestNodeOutputData: ( + nodeId: string, + ) => NodeExecutionResultOutputData | undefined; + getAccumulatedNodeInputData: (nodeId: string) => Record; + getAccumulatedNodeOutputData: (nodeId: string) => Record; + getLatestNodeExecutionResult: ( + nodeId: string, + ) => NodeExecutionResult | undefined; + clearAllNodeExecutionResults: () => void; + getNodeBlockUIType: (nodeId: string) => BlockUIType; hasWebhookNodes: () => boolean; @@ -122,6 +126,10 @@ export const useNodeStore = create((set, get) => ({ nodeCounter: 0, setNodeCounter: (nodeCounter) => set({ nodeCounter }), nodeAdvancedStates: {}, + latestNodeInputData: {}, + latestNodeOutputData: {}, + accumulatedNodeInputData: {}, + accumulatedNodeOutputData: {}, incrementNodeCounter: () => set((state) => ({ nodeCounter: state.nodeCounter + 1, @@ -317,17 +325,162 @@ export const useNodeStore = create((set, get) => ({ return get().nodes.find((n) => n.id === nodeId)?.data?.status; }, - updateNodeExecutionResult: (nodeId: string, result: NodeExecutionResult) => { + cleanNodesStatuses: () => { set((state) => ({ - nodes: state.nodes.map((n) => - n.id === nodeId - ? { ...n, data: { ...n.data, nodeExecutionResult: result } } - : n, - ), + nodes: state.nodes.map((n) => ({ + ...n, + data: { ...n.data, status: undefined }, + })), })); }, - getNodeExecutionResult: (nodeId: string) => { - return get().nodes.find((n) => n.id === nodeId)?.data?.nodeExecutionResult; + + updateNodeExecutionResult: (nodeId: string, result: NodeExecutionResult) => { + set((state) => { + let latestNodeInputData = state.latestNodeInputData; + let latestNodeOutputData = state.latestNodeOutputData; + let accumulatedNodeInputData = state.accumulatedNodeInputData; + let accumulatedNodeOutputData = state.accumulatedNodeOutputData; + + const nodes = state.nodes.map((n) => { + if (n.id !== nodeId) return n; + + const existingResults = n.data.nodeExecutionResults || []; + const duplicateIndex = existingResults.findIndex( + (r) => r.node_exec_id === result.node_exec_id, + ); + + if (duplicateIndex !== -1) { + const oldResult = existingResults[duplicateIndex]; + const inputDataChanged = + JSON.stringify(oldResult.input_data) !== + JSON.stringify(result.input_data); + const outputDataChanged = + JSON.stringify(oldResult.output_data) !== + JSON.stringify(result.output_data); + + if (!inputDataChanged && !outputDataChanged) { + return n; + } + + const updatedResults = [...existingResults]; + updatedResults[duplicateIndex] = result; + + const recomputedAccumulatedInput = updatedResults.reduce( + (acc, r) => accumulateExecutionData(acc, r.input_data), + {} as Record, + ); + const recomputedAccumulatedOutput = updatedResults.reduce( + (acc, r) => accumulateExecutionData(acc, r.output_data), + {} as Record, + ); + + const mostRecentResult = updatedResults[updatedResults.length - 1]; + latestNodeInputData = { + ...latestNodeInputData, + [nodeId]: mostRecentResult.input_data, + }; + latestNodeOutputData = { + ...latestNodeOutputData, + [nodeId]: mostRecentResult.output_data, + }; + + accumulatedNodeInputData = { + ...accumulatedNodeInputData, + [nodeId]: recomputedAccumulatedInput, + }; + accumulatedNodeOutputData = { + ...accumulatedNodeOutputData, + [nodeId]: recomputedAccumulatedOutput, + }; + + return { + ...n, + data: { + ...n.data, + nodeExecutionResults: updatedResults, + }, + }; + } + + accumulatedNodeInputData = { + ...accumulatedNodeInputData, + [nodeId]: accumulateExecutionData( + accumulatedNodeInputData[nodeId] || {}, + result.input_data, + ), + }; + accumulatedNodeOutputData = { + ...accumulatedNodeOutputData, + [nodeId]: accumulateExecutionData( + accumulatedNodeOutputData[nodeId] || {}, + result.output_data, + ), + }; + + latestNodeInputData = { + ...latestNodeInputData, + [nodeId]: result.input_data, + }; + latestNodeOutputData = { + ...latestNodeOutputData, + [nodeId]: result.output_data, + }; + + return { + ...n, + data: { + ...n.data, + nodeExecutionResults: [...existingResults, result], + }, + }; + }); + + return { + nodes, + latestNodeInputData, + latestNodeOutputData, + accumulatedNodeInputData, + accumulatedNodeOutputData, + }; + }); + }, + getNodeExecutionResults: (nodeId: string) => { + return ( + get().nodes.find((n) => n.id === nodeId)?.data?.nodeExecutionResults || [] + ); + }, + getLatestNodeInputData: (nodeId: string) => { + return get().latestNodeInputData[nodeId]; + }, + getLatestNodeOutputData: (nodeId: string) => { + return get().latestNodeOutputData[nodeId]; + }, + getAccumulatedNodeInputData: (nodeId: string) => { + return get().accumulatedNodeInputData[nodeId] || {}; + }, + getAccumulatedNodeOutputData: (nodeId: string) => { + return get().accumulatedNodeOutputData[nodeId] || {}; + }, + getLatestNodeExecutionResult: (nodeId: string) => { + const results = + get().nodes.find((n) => n.id === nodeId)?.data?.nodeExecutionResults || + []; + return results.length > 0 ? results[results.length - 1] : undefined; + }, + clearAllNodeExecutionResults: () => { + set((state) => ({ + nodes: state.nodes.map((n) => ({ + ...n, + data: { + ...n.data, + nodeExecutionResults: [], + }, + })), + latestNodeInputData: {}, + latestNodeOutputData: {}, + accumulatedNodeInputData: {}, + accumulatedNodeOutputData: {}, + })); }, getNodeBlockUIType: (nodeId: string) => { return ( diff --git a/autogpt_platform/frontend/src/app/(platform)/build/stores/types.ts b/autogpt_platform/frontend/src/app/(platform)/build/stores/types.ts new file mode 100644 index 0000000000..f0ec7e6c1c --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/build/stores/types.ts @@ -0,0 +1,14 @@ +import { IncompatibilityInfo } from "../hooks/useSubAgentUpdate/types"; + +export type NodeResolutionData = { + incompatibilities: IncompatibilityInfo; + pendingUpdate: { + input_schema: Record; + output_schema: Record; + }; + currentSchema: { + input_schema: Record; + output_schema: Record; + }; + pendingHardcodedValues: Record; +}; diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/NewChatContext.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/NewChatContext.tsx deleted file mode 100644 index 0826637043..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/NewChatContext.tsx +++ /dev/null @@ -1,41 +0,0 @@ -"use client"; - -import { createContext, useContext, useRef, type ReactNode } from "react"; - -interface NewChatContextValue { - onNewChatClick: () => void; - setOnNewChatClick: (handler?: () => void) => void; - performNewChat?: () => void; - setPerformNewChat: (handler?: () => void) => void; -} - -const NewChatContext = createContext(null); - -export function NewChatProvider({ children }: { children: ReactNode }) { - const onNewChatRef = useRef<(() => void) | undefined>(); - const performNewChatRef = useRef<(() => void) | undefined>(); - const contextValueRef = useRef({ - onNewChatClick() { - onNewChatRef.current?.(); - }, - setOnNewChatClick(handler?: () => void) { - onNewChatRef.current = handler; - }, - performNewChat() { - performNewChatRef.current?.(); - }, - setPerformNewChat(handler?: () => void) { - performNewChatRef.current = handler; - }, - }); - - return ( - - {children} - - ); -} - -export function useNewChat() { - return useContext(NewChatContext); -} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/CopilotShell.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/CopilotShell.tsx index 44e32024a8..3f695da5ed 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/CopilotShell.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/CopilotShell.tsx @@ -1,12 +1,10 @@ "use client"; import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader"; +import { Text } from "@/components/atoms/Text/Text"; import { NAVBAR_HEIGHT_PX } from "@/lib/constants"; import type { ReactNode } from "react"; -import { useEffect } from "react"; -import { useNewChat } from "../../NewChatContext"; import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar"; -import { LoadingState } from "./components/LoadingState/LoadingState"; import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer"; import { MobileHeader } from "./components/MobileHeader/MobileHeader"; import { useCopilotShell } from "./useCopilotShell"; @@ -20,36 +18,21 @@ export function CopilotShell({ children }: Props) { isMobile, isDrawerOpen, isLoading, + isCreatingSession, isLoggedIn, hasActiveSession, sessions, currentSessionId, - handleSelectSession, handleOpenDrawer, handleCloseDrawer, handleDrawerOpenChange, - handleNewChat, + handleNewChatClick, + handleSessionClick, hasNextPage, isFetchingNextPage, fetchNextPage, - isReadyToShowContent, } = useCopilotShell(); - const newChatContext = useNewChat(); - const handleNewChatClickWrapper = - newChatContext?.onNewChatClick || handleNewChat; - - useEffect( - function registerNewChatHandler() { - if (!newChatContext) return; - newChatContext.setPerformNewChat(handleNewChat); - return function cleanup() { - newChatContext.setPerformNewChat(undefined); - }; - }, - [newChatContext, handleNewChat], - ); - if (!isLoggedIn) { return (
@@ -70,9 +53,9 @@ export function CopilotShell({ children }: Props) { isLoading={isLoading} hasNextPage={hasNextPage} isFetchingNextPage={isFetchingNextPage} - onSelectSession={handleSelectSession} + onSelectSession={handleSessionClick} onFetchNextPage={fetchNextPage} - onNewChat={handleNewChatClickWrapper} + onNewChat={handleNewChatClick} hasActiveSession={Boolean(hasActiveSession)} /> )} @@ -80,7 +63,18 @@ export function CopilotShell({ children }: Props) {
{isMobile && }
- {isReadyToShowContent ? children : } + {isCreatingSession ? ( +
+
+ + + Creating your chat... + +
+
+ ) : ( + children + )}
@@ -92,9 +86,9 @@ export function CopilotShell({ children }: Props) { isLoading={isLoading} hasNextPage={hasNextPage} isFetchingNextPage={isFetchingNextPage} - onSelectSession={handleSelectSession} + onSelectSession={handleSessionClick} onFetchNextPage={fetchNextPage} - onNewChat={handleNewChatClickWrapper} + onNewChat={handleNewChatClick} onClose={handleCloseDrawer} onOpenChange={handleDrawerOpenChange} hasActiveSession={Boolean(hasActiveSession)} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/components/LoadingState/LoadingState.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/components/LoadingState/LoadingState.tsx deleted file mode 100644 index 21b1663916..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/components/LoadingState/LoadingState.tsx +++ /dev/null @@ -1,15 +0,0 @@ -import { Text } from "@/components/atoms/Text/Text"; -import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader"; - -export function LoadingState() { - return ( -
-
- - - Loading your chats... - -
-
- ); -} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/components/MobileDrawer/useMobileDrawer.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/components/MobileDrawer/useMobileDrawer.ts index c9504e49a9..2ef63a4422 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/components/MobileDrawer/useMobileDrawer.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/components/MobileDrawer/useMobileDrawer.ts @@ -3,17 +3,17 @@ import { useState } from "react"; export function useMobileDrawer() { const [isDrawerOpen, setIsDrawerOpen] = useState(false); - function handleOpenDrawer() { + const handleOpenDrawer = () => { setIsDrawerOpen(true); - } + }; - function handleCloseDrawer() { + const handleCloseDrawer = () => { setIsDrawerOpen(false); - } + }; - function handleDrawerOpenChange(open: boolean) { + const handleDrawerOpenChange = (open: boolean) => { setIsDrawerOpen(open); - } + }; return { isDrawerOpen, diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/components/SessionsList/useSessionsPagination.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/components/SessionsList/useSessionsPagination.ts index 8833a419c1..61e3e6f37f 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/components/SessionsList/useSessionsPagination.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/components/SessionsList/useSessionsPagination.ts @@ -1,7 +1,7 @@ import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat"; import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse"; import { okData } from "@/app/api/helpers"; -import { useEffect, useMemo, useState } from "react"; +import { useEffect, useState } from "react"; const PAGE_SIZE = 50; @@ -11,9 +11,11 @@ export interface UseSessionsPaginationArgs { export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) { const [offset, setOffset] = useState(0); + const [accumulatedSessions, setAccumulatedSessions] = useState< SessionSummaryResponse[] >([]); + const [totalCount, setTotalCount] = useState(null); const { data, isLoading, isFetching, isError } = useGetV2ListSessions( @@ -43,17 +45,14 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) { } }, [data, offset, enabled]); - const hasNextPage = useMemo(() => { - if (totalCount === null) return false; - return accumulatedSessions.length < totalCount; - }, [accumulatedSessions.length, totalCount]); + const hasNextPage = + totalCount !== null && accumulatedSessions.length < totalCount; - const areAllSessionsLoaded = useMemo(() => { - if (totalCount === null) return false; - return ( - accumulatedSessions.length >= totalCount && !isFetching && !isLoading - ); - }, [accumulatedSessions.length, totalCount, isFetching, isLoading]); + const areAllSessionsLoaded = + totalCount !== null && + accumulatedSessions.length >= totalCount && + !isFetching && + !isLoading; useEffect(() => { if ( @@ -67,17 +66,17 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) { } }, [hasNextPage, isFetching, isLoading, isError, totalCount]); - function fetchNextPage() { + const fetchNextPage = () => { if (hasNextPage && !isFetching) { setOffset((prev) => prev + PAGE_SIZE); } - } + }; - function reset() { + const reset = () => { + // Only reset the offset - keep existing sessions visible during refetch + // The effect will replace sessions when new data arrives at offset 0 setOffset(0); - setAccumulatedSessions([]); - setTotalCount(null); - } + }; return { sessions: accumulatedSessions, diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/helpers.ts index bf4eb70ccb..ef0d414edf 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/helpers.ts @@ -2,9 +2,7 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessi import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse"; import { format, formatDistanceToNow, isToday } from "date-fns"; -export function convertSessionDetailToSummary( - session: SessionDetailResponse, -): SessionSummaryResponse { +export function convertSessionDetailToSummary(session: SessionDetailResponse) { return { id: session.id, created_at: session.created_at, @@ -13,17 +11,25 @@ export function convertSessionDetailToSummary( }; } -export function filterVisibleSessions( - sessions: SessionSummaryResponse[], -): SessionSummaryResponse[] { - return sessions.filter( - (session) => session.updated_at !== session.created_at, - ); +export function filterVisibleSessions(sessions: SessionSummaryResponse[]) { + const fiveMinutesAgo = Date.now() - 5 * 60 * 1000; + return sessions.filter((session) => { + const hasBeenUpdated = session.updated_at !== session.created_at; + + if (hasBeenUpdated) return true; + + const isRecentlyCreated = + new Date(session.created_at).getTime() > fiveMinutesAgo; + + return isRecentlyCreated; + }); } -export function getSessionTitle(session: SessionSummaryResponse): string { +export function getSessionTitle(session: SessionSummaryResponse) { if (session.title) return session.title; + const isNewSession = session.updated_at === session.created_at; + if (isNewSession) { const createdDate = new Date(session.created_at); if (isToday(createdDate)) { @@ -31,12 +37,11 @@ export function getSessionTitle(session: SessionSummaryResponse): string { } return format(createdDate, "MMM d, yyyy"); } + return "Untitled Chat"; } -export function getSessionUpdatedLabel( - session: SessionSummaryResponse, -): string { +export function getSessionUpdatedLabel(session: SessionSummaryResponse) { if (!session.updated_at) return ""; return formatDistanceToNow(new Date(session.updated_at), { addSuffix: true }); } @@ -45,8 +50,10 @@ export function mergeCurrentSessionIntoList( accumulatedSessions: SessionSummaryResponse[], currentSessionId: string | null, currentSessionData: SessionDetailResponse | null | undefined, -): SessionSummaryResponse[] { + recentlyCreatedSessions?: Map, +) { const filteredSessions: SessionSummaryResponse[] = []; + const addedIds = new Set(); if (accumulatedSessions.length > 0) { const visibleSessions = filterVisibleSessions(accumulatedSessions); @@ -61,105 +68,39 @@ export function mergeCurrentSessionIntoList( ); if (!isInVisible) { filteredSessions.push(currentInAll); + addedIds.add(currentInAll.id); } } } - filteredSessions.push(...visibleSessions); + for (const session of visibleSessions) { + if (!addedIds.has(session.id)) { + filteredSessions.push(session); + addedIds.add(session.id); + } + } } if (currentSessionId && currentSessionData) { - const isCurrentInList = filteredSessions.some( - (s) => s.id === currentSessionId, - ); - if (!isCurrentInList) { + if (!addedIds.has(currentSessionId)) { const summarySession = convertSessionDetailToSummary(currentSessionData); filteredSessions.unshift(summarySession); + addedIds.add(currentSessionId); + } + } + + if (recentlyCreatedSessions) { + for (const [sessionId, sessionData] of recentlyCreatedSessions) { + if (!addedIds.has(sessionId)) { + filteredSessions.unshift(sessionData); + addedIds.add(sessionId); + } } } return filteredSessions; } -export function getCurrentSessionId( - searchParams: URLSearchParams, -): string | null { +export function getCurrentSessionId(searchParams: URLSearchParams) { return searchParams.get("sessionId"); } - -export function shouldAutoSelectSession( - areAllSessionsLoaded: boolean, - hasAutoSelectedSession: boolean, - paramSessionId: string | null, - visibleSessions: SessionSummaryResponse[], - accumulatedSessions: SessionSummaryResponse[], - isLoading: boolean, - totalCount: number | null, -): { - shouldSelect: boolean; - sessionIdToSelect: string | null; - shouldCreate: boolean; -} { - if (!areAllSessionsLoaded || hasAutoSelectedSession) { - return { - shouldSelect: false, - sessionIdToSelect: null, - shouldCreate: false, - }; - } - - if (paramSessionId) { - return { - shouldSelect: false, - sessionIdToSelect: null, - shouldCreate: false, - }; - } - - if (visibleSessions.length > 0) { - return { - shouldSelect: true, - sessionIdToSelect: visibleSessions[0].id, - shouldCreate: false, - }; - } - - if (accumulatedSessions.length === 0 && !isLoading && totalCount === 0) { - return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: true }; - } - - if (totalCount === 0) { - return { - shouldSelect: false, - sessionIdToSelect: null, - shouldCreate: false, - }; - } - - return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: false }; -} - -export function checkReadyToShowContent( - areAllSessionsLoaded: boolean, - paramSessionId: string | null, - accumulatedSessions: SessionSummaryResponse[], - isCurrentSessionLoading: boolean, - currentSessionData: SessionDetailResponse | null | undefined, - hasAutoSelectedSession: boolean, -): boolean { - if (!areAllSessionsLoaded) return false; - - if (paramSessionId) { - const sessionFound = accumulatedSessions.some( - (s) => s.id === paramSessionId, - ); - return ( - sessionFound || - (!isCurrentSessionLoading && - currentSessionData !== undefined && - currentSessionData !== null) - ); - } - - return hasAutoSelectedSession; -} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts index cadd98da3e..74fd663ab2 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts @@ -1,26 +1,24 @@ "use client"; import { + getGetV2GetSessionQueryKey, getGetV2ListSessionsQueryKey, useGetV2GetSession, } from "@/app/api/__generated__/endpoints/chat/chat"; import { okData } from "@/app/api/helpers"; +import { useChatStore } from "@/components/contextual/Chat/chat-store"; import { useBreakpoint } from "@/lib/hooks/useBreakpoint"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { useQueryClient } from "@tanstack/react-query"; -import { usePathname, useRouter, useSearchParams } from "next/navigation"; -import { useEffect, useRef, useState } from "react"; +import { usePathname, useSearchParams } from "next/navigation"; +import { useRef } from "react"; +import { useCopilotStore } from "../../copilot-page-store"; +import { useCopilotSessionId } from "../../useCopilotSessionId"; import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer"; -import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination"; -import { - checkReadyToShowContent, - filterVisibleSessions, - getCurrentSessionId, - mergeCurrentSessionIntoList, -} from "./helpers"; +import { getCurrentSessionId } from "./helpers"; +import { useShellSessionList } from "./useShellSessionList"; export function useCopilotShell() { - const router = useRouter(); const pathname = usePathname(); const searchParams = useSearchParams(); const queryClient = useQueryClient(); @@ -29,6 +27,8 @@ export function useCopilotShell() { const isMobile = breakpoint === "base" || breakpoint === "sm" || breakpoint === "md"; + const { urlSessionId, setUrlSessionId } = useCopilotSessionId(); + const isOnHomepage = pathname === "/copilot"; const paramSessionId = searchParams.get("sessionId"); @@ -41,114 +41,113 @@ export function useCopilotShell() { const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId; - const { - sessions: accumulatedSessions, - isLoading: isSessionsLoading, - isFetching: isSessionsFetching, - hasNextPage, - areAllSessionsLoaded, - fetchNextPage, - reset: resetPagination, - } = useSessionsPagination({ - enabled: paginationEnabled, - }); - const currentSessionId = getCurrentSessionId(searchParams); - const { data: currentSessionData, isLoading: isCurrentSessionLoading } = - useGetV2GetSession(currentSessionId || "", { + const { data: currentSessionData } = useGetV2GetSession( + currentSessionId || "", + { query: { enabled: !!currentSessionId, select: okData, }, - }); - - const [hasAutoSelectedSession, setHasAutoSelectedSession] = useState(false); - const hasAutoSelectedRef = useRef(false); - - // Mark as auto-selected when sessionId is in URL - useEffect(() => { - if (paramSessionId && !hasAutoSelectedRef.current) { - hasAutoSelectedRef.current = true; - setHasAutoSelectedSession(true); - } - }, [paramSessionId]); - - // On homepage without sessionId, mark as ready immediately - useEffect(() => { - if (isOnHomepage && !paramSessionId && !hasAutoSelectedRef.current) { - hasAutoSelectedRef.current = true; - setHasAutoSelectedSession(true); - } - }, [isOnHomepage, paramSessionId]); - - // Invalidate sessions list when navigating to homepage (to show newly created sessions) - useEffect(() => { - if (isOnHomepage && !paramSessionId) { - queryClient.invalidateQueries({ - queryKey: getGetV2ListSessionsQueryKey(), - }); - } - }, [isOnHomepage, paramSessionId, queryClient]); - - // Reset pagination when query becomes disabled - const prevPaginationEnabledRef = useRef(paginationEnabled); - useEffect(() => { - if (prevPaginationEnabledRef.current && !paginationEnabled) { - resetPagination(); - resetAutoSelect(); - } - prevPaginationEnabledRef.current = paginationEnabled; - }, [paginationEnabled, resetPagination]); - - const sessions = mergeCurrentSessionIntoList( - accumulatedSessions, - currentSessionId, - currentSessionData, + }, ); - const visibleSessions = filterVisibleSessions(sessions); + const { + sessions, + isLoading, + isSessionsFetching, + hasNextPage, + fetchNextPage, + resetPagination, + recentlyCreatedSessionsRef, + } = useShellSessionList({ + paginationEnabled, + currentSessionId, + currentSessionData, + isOnHomepage, + paramSessionId, + }); - const sidebarSelectedSessionId = - isOnHomepage && !paramSessionId ? null : currentSessionId; + const stopStream = useChatStore((s) => s.stopStream); + const onStreamComplete = useChatStore((s) => s.onStreamComplete); + const isStreaming = useCopilotStore((s) => s.isStreaming); + const isCreatingSession = useCopilotStore((s) => s.isCreatingSession); + const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession); + const openInterruptModal = useCopilotStore((s) => s.openInterruptModal); - const isReadyToShowContent = isOnHomepage - ? true - : checkReadyToShowContent( - areAllSessionsLoaded, - paramSessionId, - accumulatedSessions, - isCurrentSessionLoading, - currentSessionData, - hasAutoSelectedSession, - ); + const pendingActionRef = useRef<(() => void) | null>(null); - function handleSelectSession(sessionId: string) { - // Navigate using replaceState to avoid full page reload - window.history.replaceState(null, "", `/copilot?sessionId=${sessionId}`); - // Force a re-render by updating the URL through router - router.replace(`/copilot?sessionId=${sessionId}`); + async function stopCurrentStream() { + if (!currentSessionId) return; + + setIsSwitchingSession(true); + await new Promise((resolve) => { + const unsubscribe = onStreamComplete((completedId) => { + if (completedId === currentSessionId) { + clearTimeout(timeout); + unsubscribe(); + resolve(); + } + }); + const timeout = setTimeout(() => { + unsubscribe(); + resolve(); + }, 3000); + stopStream(currentSessionId); + }); + + queryClient.invalidateQueries({ + queryKey: getGetV2GetSessionQueryKey(currentSessionId), + }); + setIsSwitchingSession(false); + } + + function selectSession(sessionId: string) { + if (sessionId === currentSessionId) return; + if (recentlyCreatedSessionsRef.current.has(sessionId)) { + queryClient.invalidateQueries({ + queryKey: getGetV2GetSessionQueryKey(sessionId), + }); + } + setUrlSessionId(sessionId, { shallow: false }); if (isMobile) handleCloseDrawer(); } - function handleNewChat() { - resetAutoSelect(); + function startNewChat() { resetPagination(); - // Invalidate and refetch sessions list to ensure newly created sessions appear queryClient.invalidateQueries({ queryKey: getGetV2ListSessionsQueryKey(), }); - window.history.replaceState(null, "", "/copilot"); - router.replace("/copilot"); + setUrlSessionId(null, { shallow: false }); if (isMobile) handleCloseDrawer(); } - function resetAutoSelect() { - hasAutoSelectedRef.current = false; - setHasAutoSelectedSession(false); + function handleSessionClick(sessionId: string) { + if (sessionId === currentSessionId) return; + + if (isStreaming) { + pendingActionRef.current = async () => { + await stopCurrentStream(); + selectSession(sessionId); + }; + openInterruptModal(pendingActionRef.current); + } else { + selectSession(sessionId); + } } - const isLoading = isSessionsLoading && accumulatedSessions.length === 0; + function handleNewChatClick() { + if (isStreaming) { + pendingActionRef.current = async () => { + await stopCurrentStream(); + startNewChat(); + }; + openInterruptModal(pendingActionRef.current); + } else { + startNewChat(); + } + } return { isMobile, @@ -156,17 +155,17 @@ export function useCopilotShell() { isLoggedIn, hasActiveSession: Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)), - isLoading, - sessions: visibleSessions, - currentSessionId: sidebarSelectedSessionId, - handleSelectSession, + isLoading: isLoading || isCreatingSession, + isCreatingSession, + sessions, + currentSessionId: urlSessionId, handleOpenDrawer, handleCloseDrawer, handleDrawerOpenChange, - handleNewChat, + handleNewChatClick, + handleSessionClick, hasNextPage, isFetchingNextPage: isSessionsFetching, fetchNextPage, - isReadyToShowContent, }; } diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useShellSessionList.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useShellSessionList.ts new file mode 100644 index 0000000000..fb39a11096 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useShellSessionList.ts @@ -0,0 +1,113 @@ +import { getGetV2ListSessionsQueryKey } from "@/app/api/__generated__/endpoints/chat/chat"; +import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; +import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse"; +import { useChatStore } from "@/components/contextual/Chat/chat-store"; +import { useQueryClient } from "@tanstack/react-query"; +import { useEffect, useMemo, useRef } from "react"; +import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination"; +import { + convertSessionDetailToSummary, + filterVisibleSessions, + mergeCurrentSessionIntoList, +} from "./helpers"; + +interface UseShellSessionListArgs { + paginationEnabled: boolean; + currentSessionId: string | null; + currentSessionData: SessionDetailResponse | null | undefined; + isOnHomepage: boolean; + paramSessionId: string | null; +} + +export function useShellSessionList({ + paginationEnabled, + currentSessionId, + currentSessionData, + isOnHomepage, + paramSessionId, +}: UseShellSessionListArgs) { + const queryClient = useQueryClient(); + const onStreamComplete = useChatStore((s) => s.onStreamComplete); + + const { + sessions: accumulatedSessions, + isLoading: isSessionsLoading, + isFetching: isSessionsFetching, + hasNextPage, + fetchNextPage, + reset: resetPagination, + } = useSessionsPagination({ + enabled: paginationEnabled, + }); + + const recentlyCreatedSessionsRef = useRef< + Map + >(new Map()); + + useEffect(() => { + if (isOnHomepage && !paramSessionId) { + queryClient.invalidateQueries({ + queryKey: getGetV2ListSessionsQueryKey(), + }); + } + }, [isOnHomepage, paramSessionId, queryClient]); + + useEffect(() => { + if (currentSessionId && currentSessionData) { + const isNewSession = + currentSessionData.updated_at === currentSessionData.created_at; + const isNotInAccumulated = !accumulatedSessions.some( + (s) => s.id === currentSessionId, + ); + if (isNewSession || isNotInAccumulated) { + const summary = convertSessionDetailToSummary(currentSessionData); + recentlyCreatedSessionsRef.current.set(currentSessionId, summary); + } + } + }, [currentSessionId, currentSessionData, accumulatedSessions]); + + useEffect(() => { + for (const sessionId of recentlyCreatedSessionsRef.current.keys()) { + if (accumulatedSessions.some((s) => s.id === sessionId)) { + recentlyCreatedSessionsRef.current.delete(sessionId); + } + } + }, [accumulatedSessions]); + + useEffect(() => { + const unsubscribe = onStreamComplete(() => { + queryClient.invalidateQueries({ + queryKey: getGetV2ListSessionsQueryKey(), + }); + }); + return unsubscribe; + }, [onStreamComplete, queryClient]); + + const sessions = useMemo( + () => + mergeCurrentSessionIntoList( + accumulatedSessions, + currentSessionId, + currentSessionData, + recentlyCreatedSessionsRef.current, + ), + [accumulatedSessions, currentSessionId, currentSessionData], + ); + + const visibleSessions = useMemo( + () => filterVisibleSessions(sessions), + [sessions], + ); + + const isLoading = isSessionsLoading && accumulatedSessions.length === 0; + + return { + sessions: visibleSessions, + isLoading, + isSessionsFetching, + hasNextPage, + fetchNextPage, + resetPagination, + recentlyCreatedSessionsRef, + }; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/copilot-page-store.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/copilot-page-store.ts new file mode 100644 index 0000000000..9fc97a14e3 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/copilot-page-store.ts @@ -0,0 +1,56 @@ +"use client"; + +import { create } from "zustand"; + +interface CopilotStoreState { + isStreaming: boolean; + isSwitchingSession: boolean; + isCreatingSession: boolean; + isInterruptModalOpen: boolean; + pendingAction: (() => void) | null; +} + +interface CopilotStoreActions { + setIsStreaming: (isStreaming: boolean) => void; + setIsSwitchingSession: (isSwitchingSession: boolean) => void; + setIsCreatingSession: (isCreating: boolean) => void; + openInterruptModal: (onConfirm: () => void) => void; + confirmInterrupt: () => void; + cancelInterrupt: () => void; +} + +type CopilotStore = CopilotStoreState & CopilotStoreActions; + +export const useCopilotStore = create((set, get) => ({ + isStreaming: false, + isSwitchingSession: false, + isCreatingSession: false, + isInterruptModalOpen: false, + pendingAction: null, + + setIsStreaming(isStreaming) { + set({ isStreaming }); + }, + + setIsSwitchingSession(isSwitchingSession) { + set({ isSwitchingSession }); + }, + + setIsCreatingSession(isCreatingSession) { + set({ isCreatingSession }); + }, + + openInterruptModal(onConfirm) { + set({ isInterruptModalOpen: true, pendingAction: onConfirm }); + }, + + confirmInterrupt() { + const { pendingAction } = get(); + set({ isInterruptModalOpen: false, pendingAction: null }); + if (pendingAction) pendingAction(); + }, + + cancelInterrupt() { + set({ isInterruptModalOpen: false, pendingAction: null }); + }, +})); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts index a5818f0a9f..692a5741f4 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts @@ -1,28 +1,5 @@ import type { User } from "@supabase/supabase-js"; -export type PageState = - | { type: "welcome" } - | { type: "newChat" } - | { type: "creating"; prompt: string } - | { type: "chat"; sessionId: string; initialPrompt?: string }; - -export function getInitialPromptFromState( - pageState: PageState, - storedInitialPrompt: string | undefined, -) { - if (storedInitialPrompt) return storedInitialPrompt; - if (pageState.type === "creating") return pageState.prompt; - if (pageState.type === "chat") return pageState.initialPrompt; -} - -export function shouldResetToWelcome(pageState: PageState) { - return ( - pageState.type !== "newChat" && - pageState.type !== "creating" && - pageState.type !== "welcome" - ); -} - export function getGreetingName(user?: User | null): string { if (!user) return "there"; const metadata = user.user_metadata as Record | undefined; diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/layout.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/layout.tsx index 0f40de8f25..89cf72e2ba 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/layout.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/layout.tsx @@ -1,11 +1,6 @@ import type { ReactNode } from "react"; -import { NewChatProvider } from "./NewChatContext"; import { CopilotShell } from "./components/CopilotShell/CopilotShell"; export default function CopilotLayout({ children }: { children: ReactNode }) { - return ( - - {children} - - ); + return {children}; } diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx index 3bbafd087b..104b238895 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx @@ -1,22 +1,25 @@ "use client"; -import { Skeleton } from "@/components/__legacy__/ui/skeleton"; import { Button } from "@/components/atoms/Button/Button"; +import { Skeleton } from "@/components/atoms/Skeleton/Skeleton"; import { Text } from "@/components/atoms/Text/Text"; import { Chat } from "@/components/contextual/Chat/Chat"; import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput"; -import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader"; import { Dialog } from "@/components/molecules/Dialog/Dialog"; +import { useCopilotStore } from "./copilot-page-store"; import { useCopilotPage } from "./useCopilotPage"; export default function CopilotPage() { const { state, handlers } = useCopilotPage(); + const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen); + const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt); + const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt); const { greetingName, quickActions, isLoading, - pageState, - isNewChatModalOpen, + hasSession, + initialPrompt, isReady, } = state; const { @@ -24,24 +27,16 @@ export default function CopilotPage() { startChatWithPrompt, handleSessionNotFound, handleStreamingChange, - handleCancelNewChat, - proceedWithNewChat, - handleNewChatModalOpen, } = handlers; - if (!isReady) { - return null; - } + if (!isReady) return null; - // Show Chat when we have an active session - if (pageState.type === "chat") { + if (hasSession) { return (
@@ -49,31 +44,33 @@ export default function CopilotPage() { title="Interrupt current chat?" styling={{ maxWidth: 300, width: "100%" }} controlled={{ - isOpen: isNewChatModalOpen, - set: handleNewChatModalOpen, + isOpen: isInterruptModalOpen, + set: (open) => { + if (!open) cancelInterrupt(); + }, }} - onClose={handleCancelNewChat} + onClose={cancelInterrupt} >
The current chat response will be interrupted. Are you sure you - want to start a new chat? + want to continue?
@@ -83,34 +80,6 @@ export default function CopilotPage() { ); } - if (pageState.type === "newChat") { - return ( -
-
- - - Loading your chats... - -
-
- ); - } - - // Show loading state while creating session and sending first message - if (pageState.type === "creating") { - return ( -
-
- - - Loading your chats... - -
-
- ); - } - - // Show Welcome screen return (
diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts index cb13137432..e4713cd24a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts @@ -1,86 +1,44 @@ -import { postV2CreateSession } from "@/app/api/__generated__/endpoints/chat/chat"; +import { + getGetV2ListSessionsQueryKey, + postV2CreateSession, +} from "@/app/api/__generated__/endpoints/chat/chat"; import { useToast } from "@/components/molecules/Toast/use-toast"; import { getHomepageRoute } from "@/lib/constants"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; +import { useOnboarding } from "@/providers/onboarding/onboarding-provider"; import { Flag, type FlagValues, useGetFlag, } from "@/services/feature-flags/use-get-flag"; +import { SessionKey, sessionStorage } from "@/services/storage/session-storage"; import * as Sentry from "@sentry/nextjs"; +import { useQueryClient } from "@tanstack/react-query"; import { useFlags } from "launchdarkly-react-client-sdk"; import { useRouter } from "next/navigation"; -import { useEffect, useReducer } from "react"; -import { useNewChat } from "./NewChatContext"; -import { getGreetingName, getQuickActions, type PageState } from "./helpers"; -import { useCopilotURLState } from "./useCopilotURLState"; - -type CopilotState = { - pageState: PageState; - isStreaming: boolean; - isNewChatModalOpen: boolean; - initialPrompts: Record; - previousSessionId: string | null; -}; - -type CopilotAction = - | { type: "setPageState"; pageState: PageState } - | { type: "setStreaming"; isStreaming: boolean } - | { type: "setNewChatModalOpen"; isOpen: boolean } - | { type: "setInitialPrompt"; sessionId: string; prompt: string } - | { type: "setPreviousSessionId"; sessionId: string | null }; - -function isSamePageState(next: PageState, current: PageState) { - if (next.type !== current.type) return false; - if (next.type === "creating" && current.type === "creating") { - return next.prompt === current.prompt; - } - if (next.type === "chat" && current.type === "chat") { - return ( - next.sessionId === current.sessionId && - next.initialPrompt === current.initialPrompt - ); - } - return true; -} - -function copilotReducer( - state: CopilotState, - action: CopilotAction, -): CopilotState { - if (action.type === "setPageState") { - if (isSamePageState(action.pageState, state.pageState)) return state; - return { ...state, pageState: action.pageState }; - } - if (action.type === "setStreaming") { - if (action.isStreaming === state.isStreaming) return state; - return { ...state, isStreaming: action.isStreaming }; - } - if (action.type === "setNewChatModalOpen") { - if (action.isOpen === state.isNewChatModalOpen) return state; - return { ...state, isNewChatModalOpen: action.isOpen }; - } - if (action.type === "setInitialPrompt") { - if (state.initialPrompts[action.sessionId] === action.prompt) return state; - return { - ...state, - initialPrompts: { - ...state.initialPrompts, - [action.sessionId]: action.prompt, - }, - }; - } - if (action.type === "setPreviousSessionId") { - if (state.previousSessionId === action.sessionId) return state; - return { ...state, previousSessionId: action.sessionId }; - } - return state; -} +import { useEffect } from "react"; +import { useCopilotStore } from "./copilot-page-store"; +import { getGreetingName, getQuickActions } from "./helpers"; +import { useCopilotSessionId } from "./useCopilotSessionId"; export function useCopilotPage() { const router = useRouter(); + const queryClient = useQueryClient(); const { user, isLoggedIn, isUserLoading } = useSupabase(); const { toast } = useToast(); + const { completeStep } = useOnboarding(); + + const { urlSessionId, setUrlSessionId } = useCopilotSessionId(); + const setIsStreaming = useCopilotStore((s) => s.setIsStreaming); + const isCreating = useCopilotStore((s) => s.isCreatingSession); + const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession); + + // Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus + useEffect(() => { + if (isLoggedIn) { + completeStep("VISIT_COPILOT"); + } + }, [completeStep, isLoggedIn]); const isChatEnabled = useGetFlag(Flag.CHAT); const flags = useFlags(); @@ -91,86 +49,27 @@ export function useCopilotPage() { const isFlagReady = !isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined; - const [state, dispatch] = useReducer(copilotReducer, { - pageState: { type: "welcome" }, - isStreaming: false, - isNewChatModalOpen: false, - initialPrompts: {}, - previousSessionId: null, - }); - - const newChatContext = useNewChat(); const greetingName = getGreetingName(user); const quickActions = getQuickActions(); - function setPageState(pageState: PageState) { - dispatch({ type: "setPageState", pageState }); - } + const hasSession = Boolean(urlSessionId); + const initialPrompt = urlSessionId + ? getInitialPrompt(urlSessionId) + : undefined; - function setInitialPrompt(sessionId: string, prompt: string) { - dispatch({ type: "setInitialPrompt", sessionId, prompt }); - } - - function setPreviousSessionId(sessionId: string | null) { - dispatch({ type: "setPreviousSessionId", sessionId }); - } - - const { setUrlSessionId } = useCopilotURLState({ - pageState: state.pageState, - initialPrompts: state.initialPrompts, - previousSessionId: state.previousSessionId, - setPageState, - setInitialPrompt, - setPreviousSessionId, - }); - - useEffect( - function registerNewChatHandler() { - if (!newChatContext) return; - newChatContext.setOnNewChatClick(handleNewChatClick); - return function cleanup() { - newChatContext.setOnNewChatClick(undefined); - }; - }, - [newChatContext, handleNewChatClick], - ); - - useEffect( - function transitionNewChatToWelcome() { - if (state.pageState.type === "newChat") { - function setWelcomeState() { - dispatch({ type: "setPageState", pageState: { type: "welcome" } }); - } - - const timer = setTimeout(setWelcomeState, 300); - - return function cleanup() { - clearTimeout(timer); - }; - } - }, - [state.pageState.type], - ); - - useEffect( - function ensureAccess() { - if (!isFlagReady) return; - if (isChatEnabled === false) { - router.replace(homepageRoute); - } - }, - [homepageRoute, isChatEnabled, isFlagReady, router], - ); + useEffect(() => { + if (!isFlagReady) return; + if (isChatEnabled === false) { + router.replace(homepageRoute); + } + }, [homepageRoute, isChatEnabled, isFlagReady, router]); async function startChatWithPrompt(prompt: string) { if (!prompt?.trim()) return; - if (state.pageState.type === "creating") return; + if (isCreating) return; const trimmedPrompt = prompt.trim(); - dispatch({ - type: "setPageState", - pageState: { type: "creating", prompt: trimmedPrompt }, - }); + setIsCreating(true); try { const sessionResponse = await postV2CreateSession({ @@ -182,23 +81,19 @@ export function useCopilotPage() { } const sessionId = sessionResponse.data.id; + setInitialPrompt(sessionId, trimmedPrompt); - dispatch({ - type: "setInitialPrompt", - sessionId, - prompt: trimmedPrompt, + await queryClient.invalidateQueries({ + queryKey: getGetV2ListSessionsQueryKey(), }); - await setUrlSessionId(sessionId, { shallow: false }); - dispatch({ - type: "setPageState", - pageState: { type: "chat", sessionId, initialPrompt: trimmedPrompt }, - }); + await setUrlSessionId(sessionId, { shallow: true }); } catch (error) { console.error("[CopilotPage] Failed to start chat:", error); toast({ title: "Failed to start chat", variant: "destructive" }); Sentry.captureException(error); - dispatch({ type: "setPageState", pageState: { type: "welcome" } }); + } finally { + setIsCreating(false); } } @@ -211,37 +106,7 @@ export function useCopilotPage() { } function handleStreamingChange(isStreamingValue: boolean) { - dispatch({ type: "setStreaming", isStreaming: isStreamingValue }); - } - - async function proceedWithNewChat() { - dispatch({ type: "setNewChatModalOpen", isOpen: false }); - if (newChatContext?.performNewChat) { - newChatContext.performNewChat(); - return; - } - try { - await setUrlSessionId(null, { shallow: false }); - } catch (error) { - console.error("[CopilotPage] Failed to clear session:", error); - } - router.replace("/copilot"); - } - - function handleCancelNewChat() { - dispatch({ type: "setNewChatModalOpen", isOpen: false }); - } - - function handleNewChatModalOpen(isOpen: boolean) { - dispatch({ type: "setNewChatModalOpen", isOpen }); - } - - function handleNewChatClick() { - if (state.isStreaming) { - dispatch({ type: "setNewChatModalOpen", isOpen: true }); - } else { - proceedWithNewChat(); - } + setIsStreaming(isStreamingValue); } return { @@ -249,8 +114,8 @@ export function useCopilotPage() { greetingName, quickActions, isLoading: isUserLoading, - pageState: state.pageState, - isNewChatModalOpen: state.isNewChatModalOpen, + hasSession, + initialPrompt, isReady: isFlagReady && isChatEnabled !== false && isLoggedIn, }, handlers: { @@ -258,9 +123,32 @@ export function useCopilotPage() { startChatWithPrompt, handleSessionNotFound, handleStreamingChange, - handleCancelNewChat, - proceedWithNewChat, - handleNewChatModalOpen, }, }; } + +function getInitialPrompt(sessionId: string): string | undefined { + try { + const prompts = JSON.parse( + sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}", + ); + return prompts[sessionId]; + } catch { + return undefined; + } +} + +function setInitialPrompt(sessionId: string, prompt: string): void { + try { + const prompts = JSON.parse( + sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}", + ); + prompts[sessionId] = prompt; + sessionStorage.set( + SessionKey.CHAT_INITIAL_PROMPTS, + JSON.stringify(prompts), + ); + } catch { + // Ignore storage errors + } +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotSessionId.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotSessionId.ts new file mode 100644 index 0000000000..87f9b7d3ae --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotSessionId.ts @@ -0,0 +1,10 @@ +import { parseAsString, useQueryState } from "nuqs"; + +export function useCopilotSessionId() { + const [urlSessionId, setUrlSessionId] = useQueryState( + "sessionId", + parseAsString, + ); + + return { urlSessionId, setUrlSessionId }; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotURLState.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotURLState.ts deleted file mode 100644 index 5e37e29a15..0000000000 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotURLState.ts +++ /dev/null @@ -1,80 +0,0 @@ -import { parseAsString, useQueryState } from "nuqs"; -import { useLayoutEffect } from "react"; -import { - getInitialPromptFromState, - type PageState, - shouldResetToWelcome, -} from "./helpers"; - -interface UseCopilotUrlStateArgs { - pageState: PageState; - initialPrompts: Record; - previousSessionId: string | null; - setPageState: (pageState: PageState) => void; - setInitialPrompt: (sessionId: string, prompt: string) => void; - setPreviousSessionId: (sessionId: string | null) => void; -} - -export function useCopilotURLState({ - pageState, - initialPrompts, - previousSessionId, - setPageState, - setInitialPrompt, - setPreviousSessionId, -}: UseCopilotUrlStateArgs) { - const [urlSessionId, setUrlSessionId] = useQueryState( - "sessionId", - parseAsString, - ); - - function syncSessionFromUrl() { - if (urlSessionId) { - if (pageState.type === "chat" && pageState.sessionId === urlSessionId) { - setPreviousSessionId(urlSessionId); - return; - } - - const storedInitialPrompt = initialPrompts[urlSessionId]; - const currentInitialPrompt = getInitialPromptFromState( - pageState, - storedInitialPrompt, - ); - - if (currentInitialPrompt) { - setInitialPrompt(urlSessionId, currentInitialPrompt); - } - - setPageState({ - type: "chat", - sessionId: urlSessionId, - initialPrompt: currentInitialPrompt, - }); - setPreviousSessionId(urlSessionId); - return; - } - - const wasInChat = previousSessionId !== null && pageState.type === "chat"; - setPreviousSessionId(null); - if (wasInChat) { - setPageState({ type: "newChat" }); - return; - } - - if (shouldResetToWelcome(pageState)) { - setPageState({ type: "welcome" }); - } - } - - useLayoutEffect(syncSessionFromUrl, [ - urlSessionId, - pageState.type, - previousSessionId, - initialPrompts, - ]); - - return { - urlSessionId, - setUrlSessionId, - }; -} diff --git a/autogpt_platform/frontend/src/app/(platform)/layout.tsx b/autogpt_platform/frontend/src/app/(platform)/layout.tsx index f5e3f3b99b..048110f8b2 100644 --- a/autogpt_platform/frontend/src/app/(platform)/layout.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/layout.tsx @@ -1,10 +1,12 @@ import { Navbar } from "@/components/layout/Navbar/Navbar"; +import { NetworkStatusMonitor } from "@/services/network-status/NetworkStatusMonitor"; import { ReactNode } from "react"; import { AdminImpersonationBanner } from "./admin/components/AdminImpersonationBanner"; export default function PlatformLayout({ children }: { children: ReactNode }) { return (
+
{children}
diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/RunAgentModal.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/RunAgentModal.tsx index d5ba9142ee..aff06d79c5 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/RunAgentModal.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/RunAgentModal.tsx @@ -14,6 +14,10 @@ import { import { Dialog } from "@/components/molecules/Dialog/Dialog"; import { useEffect, useRef, useState } from "react"; import { ScheduleAgentModal } from "../ScheduleAgentModal/ScheduleAgentModal"; +import { + AIAgentSafetyPopup, + useAIAgentSafetyPopup, +} from "./components/AIAgentSafetyPopup/AIAgentSafetyPopup"; import { ModalHeader } from "./components/ModalHeader/ModalHeader"; import { ModalRunSection } from "./components/ModalRunSection/ModalRunSection"; import { RunActions } from "./components/RunActions/RunActions"; @@ -83,8 +87,18 @@ export function RunAgentModal({ const [isScheduleModalOpen, setIsScheduleModalOpen] = useState(false); const [hasOverflow, setHasOverflow] = useState(false); + const [isSafetyPopupOpen, setIsSafetyPopupOpen] = useState(false); + const [pendingRunAction, setPendingRunAction] = useState<(() => void) | null>( + null, + ); const contentRef = useRef(null); + const { shouldShowPopup, dismissPopup } = useAIAgentSafetyPopup( + agent.id, + agent.has_sensitive_action, + agent.has_human_in_the_loop, + ); + const hasAnySetupFields = Object.keys(agentInputFields || {}).length > 0 || Object.keys(agentCredentialsInputFields || {}).length > 0; @@ -165,6 +179,24 @@ export function RunAgentModal({ onScheduleCreated?.(schedule); } + function handleRunWithSafetyCheck() { + if (shouldShowPopup) { + setPendingRunAction(() => handleRun); + setIsSafetyPopupOpen(true); + } else { + handleRun(); + } + } + + function handleSafetyPopupAcknowledge() { + setIsSafetyPopupOpen(false); + dismissPopup(); + if (pendingRunAction) { + pendingRunAction(); + setPendingRunAction(null); + } + } + return ( <> + + ); } diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/AIAgentSafetyPopup/AIAgentSafetyPopup.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/AIAgentSafetyPopup/AIAgentSafetyPopup.tsx new file mode 100644 index 0000000000..f2d178b33d --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/components/AIAgentSafetyPopup/AIAgentSafetyPopup.tsx @@ -0,0 +1,108 @@ +"use client"; + +import { Button } from "@/components/atoms/Button/Button"; +import { Text } from "@/components/atoms/Text/Text"; +import { Dialog } from "@/components/molecules/Dialog/Dialog"; +import { Key, storage } from "@/services/storage/local-storage"; +import { ShieldCheckIcon } from "@phosphor-icons/react"; +import { useCallback, useEffect, useState } from "react"; + +interface Props { + agentId: string; + onAcknowledge: () => void; + isOpen: boolean; +} + +export function AIAgentSafetyPopup({ agentId, onAcknowledge, isOpen }: Props) { + function handleAcknowledge() { + // Add this agent to the list of agents for which popup has been shown + const seenAgentsJson = storage.get(Key.AI_AGENT_SAFETY_POPUP_SHOWN); + const seenAgents: string[] = seenAgentsJson + ? JSON.parse(seenAgentsJson) + : []; + + if (!seenAgents.includes(agentId)) { + seenAgents.push(agentId); + storage.set(Key.AI_AGENT_SAFETY_POPUP_SHOWN, JSON.stringify(seenAgents)); + } + + onAcknowledge(); + } + + if (!isOpen) return null; + + return ( + {} }} + styling={{ maxWidth: "480px" }} + > + +
+
+ +
+ + + Safety Checks Enabled + + + + AI-generated agents may take actions that affect your data or + external systems. + + + + AutoGPT includes safety checks so you'll always have the + opportunity to review and approve sensitive actions before they + happen. + + + +
+
+
+ ); +} + +export function useAIAgentSafetyPopup( + agentId: string, + hasSensitiveAction: boolean, + hasHumanInTheLoop: boolean, +) { + const [shouldShowPopup, setShouldShowPopup] = useState(false); + const [hasChecked, setHasChecked] = useState(false); + + useEffect(() => { + if (hasChecked) return; + + const seenAgentsJson = storage.get(Key.AI_AGENT_SAFETY_POPUP_SHOWN); + const seenAgents: string[] = seenAgentsJson + ? JSON.parse(seenAgentsJson) + : []; + const hasSeenPopupForThisAgent = seenAgents.includes(agentId); + const isRelevantAgent = hasSensitiveAction || hasHumanInTheLoop; + + setShouldShowPopup(!hasSeenPopupForThisAgent && isRelevantAgent); + setHasChecked(true); + }, [agentId, hasSensitiveAction, hasHumanInTheLoop, hasChecked]); + + const dismissPopup = useCallback(() => { + setShouldShowPopup(false); + }, []); + + return { + shouldShowPopup, + dismissPopup, + }; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedRunView/components/SafeModeToggle.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedRunView/components/SafeModeToggle.tsx index dc0258c768..0fafa67414 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedRunView/components/SafeModeToggle.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/selected-views/SelectedRunView/components/SafeModeToggle.tsx @@ -69,7 +69,6 @@ export function SafeModeToggle({ graph, className }: Props) { const { currentHITLSafeMode, showHITLToggle, - isHITLStateUndetermined, handleHITLToggle, currentSensitiveActionSafeMode, showSensitiveActionToggle, @@ -78,20 +77,13 @@ export function SafeModeToggle({ graph, className }: Props) { shouldShowToggle, } = useAgentSafeMode(graph); - if (!shouldShowToggle || isHITLStateUndetermined) { - return null; - } - - const showHITL = showHITLToggle && !isHITLStateUndetermined; - const showSensitive = showSensitiveActionToggle; - - if (!showHITL && !showSensitive) { + if (!shouldShowToggle) { return null; } return (
- {showHITL && ( + {showHITLToggle && ( )} - {showSensitive && ( + {showSensitiveActionToggle && ( { + const token = await getServerAuthToken(); + + const headers: Record = {}; + if (token && token !== "no-token-found") { + headers["Authorization"] = `Bearer ${token}`; + } + + const response = await fetch(backendUrl, { + method: "GET", + headers, + redirect: "follow", // Follow redirects to signed URLs + }); + + if (!response.ok) { + return NextResponse.json( + { error: `Failed to download file: ${response.statusText}` }, + { status: response.status }, + ); + } + + // Get the content type from the backend response + const contentType = + response.headers.get("Content-Type") || "application/octet-stream"; + const contentDisposition = response.headers.get("Content-Disposition"); + + // Stream the response body + const responseHeaders: Record = { + "Content-Type": contentType, + }; + + if (contentDisposition) { + responseHeaders["Content-Disposition"] = contentDisposition; + } + + // Return the binary content + const arrayBuffer = await response.arrayBuffer(); + return new NextResponse(arrayBuffer, { + status: 200, + headers: responseHeaders, + }); +} + async function handleJsonRequest( req: NextRequest, method: string, @@ -180,6 +244,11 @@ async function handler( }; try { + // Handle workspace file downloads separately (binary response) + if (method === "GET" && isWorkspaceDownloadRequest(path)) { + return await handleWorkspaceDownload(req, backendUrl); + } + if (method === "GET" || method === "DELETE") { responseBody = await handleGetDeleteRequest(method, backendUrl, req); } else if (contentType?.includes("application/json")) { diff --git a/autogpt_platform/frontend/src/app/api/transcribe/route.ts b/autogpt_platform/frontend/src/app/api/transcribe/route.ts new file mode 100644 index 0000000000..10c182cdfa --- /dev/null +++ b/autogpt_platform/frontend/src/app/api/transcribe/route.ts @@ -0,0 +1,77 @@ +import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers"; +import { NextRequest, NextResponse } from "next/server"; + +const WHISPER_API_URL = "https://api.openai.com/v1/audio/transcriptions"; +const MAX_FILE_SIZE = 25 * 1024 * 1024; // 25MB - Whisper's limit + +function getExtensionFromMimeType(mimeType: string): string { + const subtype = mimeType.split("/")[1]?.split(";")[0]; + return subtype || "webm"; +} + +export async function POST(request: NextRequest) { + const token = await getServerAuthToken(); + + if (!token || token === "no-token-found") { + return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); + } + + const apiKey = process.env.OPENAI_API_KEY; + + if (!apiKey) { + return NextResponse.json( + { error: "OpenAI API key not configured" }, + { status: 401 }, + ); + } + + try { + const formData = await request.formData(); + const audioFile = formData.get("audio"); + + if (!audioFile || !(audioFile instanceof Blob)) { + return NextResponse.json( + { error: "No audio file provided" }, + { status: 400 }, + ); + } + + if (audioFile.size > MAX_FILE_SIZE) { + return NextResponse.json( + { error: "File too large. Maximum size is 25MB." }, + { status: 413 }, + ); + } + + const ext = getExtensionFromMimeType(audioFile.type); + const whisperFormData = new FormData(); + whisperFormData.append("file", audioFile, `recording.${ext}`); + whisperFormData.append("model", "whisper-1"); + + const response = await fetch(WHISPER_API_URL, { + method: "POST", + headers: { + Authorization: `Bearer ${apiKey}`, + }, + body: whisperFormData, + }); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + console.error("Whisper API error:", errorData); + return NextResponse.json( + { error: errorData.error?.message || "Transcription failed" }, + { status: response.status }, + ); + } + + const result = await response.json(); + return NextResponse.json({ text: result.text }); + } catch (error) { + console.error("Transcription error:", error); + return NextResponse.json( + { error: "Failed to process audio" }, + { status: 500 }, + ); + } +} diff --git a/autogpt_platform/frontend/src/app/providers.tsx b/autogpt_platform/frontend/src/app/providers.tsx index 8ea199abc8..267814e7c2 100644 --- a/autogpt_platform/frontend/src/app/providers.tsx +++ b/autogpt_platform/frontend/src/app/providers.tsx @@ -6,28 +6,40 @@ import { BackendAPIProvider } from "@/lib/autogpt-server-api/context"; import { getQueryClient } from "@/lib/react-query/queryClient"; import CredentialsProvider from "@/providers/agent-credentials/credentials-provider"; import OnboardingProvider from "@/providers/onboarding/onboarding-provider"; +import { + PostHogPageViewTracker, + PostHogProvider, + PostHogUserTracker, +} from "@/providers/posthog/posthog-provider"; import { LaunchDarklyProvider } from "@/services/feature-flags/feature-flag-provider"; import { QueryClientProvider } from "@tanstack/react-query"; import { ThemeProvider, ThemeProviderProps } from "next-themes"; import { NuqsAdapter } from "nuqs/adapters/next/app"; +import { Suspense } from "react"; export function Providers({ children, ...props }: ThemeProviderProps) { const queryClient = getQueryClient(); return ( - - - - - - - {children} - - - - - + + + + + + + + + + + + {children} + + + + + + ); diff --git a/autogpt_platform/frontend/src/components/atoms/Skeleton/Skeleton.tsx b/autogpt_platform/frontend/src/components/atoms/Skeleton/Skeleton.tsx new file mode 100644 index 0000000000..4789e281ce --- /dev/null +++ b/autogpt_platform/frontend/src/components/atoms/Skeleton/Skeleton.tsx @@ -0,0 +1,14 @@ +import { cn } from "@/lib/utils"; + +interface Props extends React.HTMLAttributes { + className?: string; +} + +export function Skeleton({ className, ...props }: Props) { + return ( +
+ ); +} diff --git a/autogpt_platform/frontend/src/components/atoms/Skeleton/skeleton.stories.tsx b/autogpt_platform/frontend/src/components/atoms/Skeleton/skeleton.stories.tsx index 04d87a6e0e..69bb7c3440 100644 --- a/autogpt_platform/frontend/src/components/atoms/Skeleton/skeleton.stories.tsx +++ b/autogpt_platform/frontend/src/components/atoms/Skeleton/skeleton.stories.tsx @@ -1,4 +1,4 @@ -import { Skeleton } from "@/components/__legacy__/ui/skeleton"; +import { Skeleton } from "./Skeleton"; import type { Meta, StoryObj } from "@storybook/nextjs"; const meta: Meta = { diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx index ba7584765d..ada8c26231 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx @@ -1,16 +1,17 @@ "use client"; +import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId"; +import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store"; +import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; import { Text } from "@/components/atoms/Text/Text"; import { cn } from "@/lib/utils"; import { useEffect, useRef } from "react"; import { ChatContainer } from "./components/ChatContainer/ChatContainer"; import { ChatErrorState } from "./components/ChatErrorState/ChatErrorState"; -import { ChatLoader } from "./components/ChatLoader/ChatLoader"; import { useChat } from "./useChat"; export interface ChatProps { className?: string; - urlSessionId?: string | null; initialPrompt?: string; onSessionNotFound?: () => void; onStreamingChange?: (isStreaming: boolean) => void; @@ -18,12 +19,13 @@ export interface ChatProps { export function Chat({ className, - urlSessionId, initialPrompt, onSessionNotFound, onStreamingChange, }: ChatProps) { + const { urlSessionId } = useCopilotSessionId(); const hasHandledNotFoundRef = useRef(false); + const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession); const { messages, isLoading, @@ -33,49 +35,59 @@ export function Chat({ sessionId, createSession, showLoader, + startPollingForOperation, } = useChat({ urlSessionId }); - useEffect( - function handleMissingSession() { - if (!onSessionNotFound) return; - if (!urlSessionId) return; - if (!isSessionNotFound || isLoading || isCreating) return; - if (hasHandledNotFoundRef.current) return; - hasHandledNotFoundRef.current = true; - onSessionNotFound(); - }, - [onSessionNotFound, urlSessionId, isSessionNotFound, isLoading, isCreating], - ); + useEffect(() => { + if (!onSessionNotFound) return; + if (!urlSessionId) return; + if (!isSessionNotFound || isLoading || isCreating) return; + if (hasHandledNotFoundRef.current) return; + hasHandledNotFoundRef.current = true; + onSessionNotFound(); + }, [ + onSessionNotFound, + urlSessionId, + isSessionNotFound, + isLoading, + isCreating, + ]); + + const shouldShowLoader = + (showLoader && (isLoading || isCreating)) || isSwitchingSession; return (
{/* Main Content */}
{/* Loading State */} - {showLoader && (isLoading || isCreating) && ( + {shouldShowLoader && (
-
- +
+ - Loading your chats... + {isSwitchingSession + ? "Switching chat..." + : "Loading your chat..."}
)} {/* Error State */} - {error && !isLoading && ( + {error && !isLoading && !isSwitchingSession && ( )} {/* Session Content */} - {sessionId && !isLoading && !error && ( + {sessionId && !isLoading && !error && !isSwitchingSession && ( )}
diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts new file mode 100644 index 0000000000..8229630e5d --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts @@ -0,0 +1,289 @@ +"use client"; + +import { create } from "zustand"; +import type { + ActiveStream, + StreamChunk, + StreamCompleteCallback, + StreamResult, + StreamStatus, +} from "./chat-types"; +import { executeStream } from "./stream-executor"; + +const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes + +interface ChatStoreState { + activeStreams: Map; + completedStreams: Map; + activeSessions: Set; + streamCompleteCallbacks: Set; +} + +interface ChatStoreActions { + startStream: ( + sessionId: string, + message: string, + isUserMessage: boolean, + context?: { url: string; content: string }, + onChunk?: (chunk: StreamChunk) => void, + ) => Promise; + stopStream: (sessionId: string) => void; + subscribeToStream: ( + sessionId: string, + onChunk: (chunk: StreamChunk) => void, + skipReplay?: boolean, + ) => () => void; + getStreamStatus: (sessionId: string) => StreamStatus; + getCompletedStream: (sessionId: string) => StreamResult | undefined; + clearCompletedStream: (sessionId: string) => void; + isStreaming: (sessionId: string) => boolean; + registerActiveSession: (sessionId: string) => void; + unregisterActiveSession: (sessionId: string) => void; + isSessionActive: (sessionId: string) => boolean; + onStreamComplete: (callback: StreamCompleteCallback) => () => void; +} + +type ChatStore = ChatStoreState & ChatStoreActions; + +function notifyStreamComplete( + callbacks: Set, + sessionId: string, +) { + for (const callback of callbacks) { + try { + callback(sessionId); + } catch (err) { + console.warn("[ChatStore] Stream complete callback error:", err); + } + } +} + +function cleanupExpiredStreams( + completedStreams: Map, +): Map { + const now = Date.now(); + const cleaned = new Map(completedStreams); + for (const [sessionId, result] of cleaned) { + if (now - result.completedAt > COMPLETED_STREAM_TTL) { + cleaned.delete(sessionId); + } + } + return cleaned; +} + +export const useChatStore = create((set, get) => ({ + activeStreams: new Map(), + completedStreams: new Map(), + activeSessions: new Set(), + streamCompleteCallbacks: new Set(), + + startStream: async function startStream( + sessionId, + message, + isUserMessage, + context, + onChunk, + ) { + const state = get(); + const newActiveStreams = new Map(state.activeStreams); + let newCompletedStreams = new Map(state.completedStreams); + const callbacks = state.streamCompleteCallbacks; + + const existingStream = newActiveStreams.get(sessionId); + if (existingStream) { + existingStream.abortController.abort(); + const normalizedStatus = + existingStream.status === "streaming" + ? "completed" + : existingStream.status; + const result: StreamResult = { + sessionId, + status: normalizedStatus, + chunks: existingStream.chunks, + completedAt: Date.now(), + error: existingStream.error, + }; + newCompletedStreams.set(sessionId, result); + newActiveStreams.delete(sessionId); + newCompletedStreams = cleanupExpiredStreams(newCompletedStreams); + if (normalizedStatus === "completed" || normalizedStatus === "error") { + notifyStreamComplete(callbacks, sessionId); + } + } + + const abortController = new AbortController(); + const initialCallbacks = new Set<(chunk: StreamChunk) => void>(); + if (onChunk) initialCallbacks.add(onChunk); + + const stream: ActiveStream = { + sessionId, + abortController, + status: "streaming", + startedAt: Date.now(), + chunks: [], + onChunkCallbacks: initialCallbacks, + }; + + newActiveStreams.set(sessionId, stream); + set({ + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + }); + + try { + await executeStream(stream, message, isUserMessage, context); + } finally { + if (onChunk) stream.onChunkCallbacks.delete(onChunk); + if (stream.status !== "streaming") { + const currentState = get(); + const finalActiveStreams = new Map(currentState.activeStreams); + let finalCompletedStreams = new Map(currentState.completedStreams); + + const storedStream = finalActiveStreams.get(sessionId); + if (storedStream === stream) { + const result: StreamResult = { + sessionId, + status: stream.status, + chunks: stream.chunks, + completedAt: Date.now(), + error: stream.error, + }; + finalCompletedStreams.set(sessionId, result); + finalActiveStreams.delete(sessionId); + finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams); + set({ + activeStreams: finalActiveStreams, + completedStreams: finalCompletedStreams, + }); + if (stream.status === "completed" || stream.status === "error") { + notifyStreamComplete( + currentState.streamCompleteCallbacks, + sessionId, + ); + } + } + } + } + }, + + stopStream: function stopStream(sessionId) { + const state = get(); + const stream = state.activeStreams.get(sessionId); + if (!stream) return; + + stream.abortController.abort(); + stream.status = "completed"; + + const newActiveStreams = new Map(state.activeStreams); + let newCompletedStreams = new Map(state.completedStreams); + + const result: StreamResult = { + sessionId, + status: stream.status, + chunks: stream.chunks, + completedAt: Date.now(), + error: stream.error, + }; + newCompletedStreams.set(sessionId, result); + newActiveStreams.delete(sessionId); + newCompletedStreams = cleanupExpiredStreams(newCompletedStreams); + + set({ + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + }); + + notifyStreamComplete(state.streamCompleteCallbacks, sessionId); + }, + + subscribeToStream: function subscribeToStream( + sessionId, + onChunk, + skipReplay = false, + ) { + const state = get(); + const stream = state.activeStreams.get(sessionId); + + if (stream) { + if (!skipReplay) { + for (const chunk of stream.chunks) { + onChunk(chunk); + } + } + + stream.onChunkCallbacks.add(onChunk); + + return function unsubscribe() { + stream.onChunkCallbacks.delete(onChunk); + }; + } + + return function noop() {}; + }, + + getStreamStatus: function getStreamStatus(sessionId) { + const { activeStreams, completedStreams } = get(); + + const active = activeStreams.get(sessionId); + if (active) return active.status; + + const completed = completedStreams.get(sessionId); + if (completed) return completed.status; + + return "idle"; + }, + + getCompletedStream: function getCompletedStream(sessionId) { + return get().completedStreams.get(sessionId); + }, + + clearCompletedStream: function clearCompletedStream(sessionId) { + const state = get(); + if (!state.completedStreams.has(sessionId)) return; + + const newCompletedStreams = new Map(state.completedStreams); + newCompletedStreams.delete(sessionId); + set({ completedStreams: newCompletedStreams }); + }, + + isStreaming: function isStreaming(sessionId) { + const stream = get().activeStreams.get(sessionId); + return stream?.status === "streaming"; + }, + + registerActiveSession: function registerActiveSession(sessionId) { + const state = get(); + if (state.activeSessions.has(sessionId)) return; + + const newActiveSessions = new Set(state.activeSessions); + newActiveSessions.add(sessionId); + set({ activeSessions: newActiveSessions }); + }, + + unregisterActiveSession: function unregisterActiveSession(sessionId) { + const state = get(); + if (!state.activeSessions.has(sessionId)) return; + + const newActiveSessions = new Set(state.activeSessions); + newActiveSessions.delete(sessionId); + set({ activeSessions: newActiveSessions }); + }, + + isSessionActive: function isSessionActive(sessionId) { + return get().activeSessions.has(sessionId); + }, + + onStreamComplete: function onStreamComplete(callback) { + const state = get(); + const newCallbacks = new Set(state.streamCompleteCallbacks); + newCallbacks.add(callback); + set({ streamCompleteCallbacks: newCallbacks }); + + return function unsubscribe() { + const currentState = get(); + const cleanedCallbacks = new Set(currentState.streamCompleteCallbacks); + cleanedCallbacks.delete(callback); + set({ streamCompleteCallbacks: cleanedCallbacks }); + }; + }, +})); diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts new file mode 100644 index 0000000000..8c8aa7b704 --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts @@ -0,0 +1,94 @@ +import type { ToolArguments, ToolResult } from "@/types/chat"; + +export type StreamStatus = "idle" | "streaming" | "completed" | "error"; + +export interface StreamChunk { + type: + | "text_chunk" + | "text_ended" + | "tool_call" + | "tool_call_start" + | "tool_response" + | "login_needed" + | "need_login" + | "credentials_needed" + | "error" + | "usage" + | "stream_end"; + timestamp?: string; + content?: string; + message?: string; + code?: string; + details?: Record; + tool_id?: string; + tool_name?: string; + arguments?: ToolArguments; + result?: ToolResult; + success?: boolean; + idx?: number; + session_id?: string; + agent_info?: { + graph_id: string; + name: string; + trigger_type: string; + }; + provider?: string; + provider_name?: string; + credential_type?: string; + scopes?: string[]; + title?: string; + [key: string]: unknown; +} + +export type VercelStreamChunk = + | { type: "start"; messageId: string } + | { type: "finish" } + | { type: "text-start"; id: string } + | { type: "text-delta"; id: string; delta: string } + | { type: "text-end"; id: string } + | { type: "tool-input-start"; toolCallId: string; toolName: string } + | { + type: "tool-input-available"; + toolCallId: string; + toolName: string; + input: Record; + } + | { + type: "tool-output-available"; + toolCallId: string; + toolName?: string; + output: unknown; + success?: boolean; + } + | { + type: "usage"; + promptTokens: number; + completionTokens: number; + totalTokens: number; + } + | { + type: "error"; + errorText: string; + code?: string; + details?: Record; + }; + +export interface ActiveStream { + sessionId: string; + abortController: AbortController; + status: StreamStatus; + startedAt: number; + chunks: StreamChunk[]; + error?: Error; + onChunkCallbacks: Set<(chunk: StreamChunk) => void>; +} + +export interface StreamResult { + sessionId: string; + status: StreamStatus; + chunks: StreamChunk[]; + completedAt: number; + error?: Error; +} + +export type StreamCompleteCallback = (sessionId: string) => void; diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx index 17748f8dbc..dec221338a 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx @@ -4,6 +4,7 @@ import { Text } from "@/components/atoms/Text/Text"; import { Dialog } from "@/components/molecules/Dialog/Dialog"; import { useBreakpoint } from "@/lib/hooks/useBreakpoint"; import { cn } from "@/lib/utils"; +import { GlobeHemisphereEastIcon } from "@phosphor-icons/react"; import { useEffect } from "react"; import { ChatInput } from "../ChatInput/ChatInput"; import { MessageList } from "../MessageList/MessageList"; @@ -15,6 +16,7 @@ export interface ChatContainerProps { initialPrompt?: string; className?: string; onStreamingChange?: (isStreaming: boolean) => void; + onOperationStarted?: () => void; } export function ChatContainer({ @@ -23,6 +25,7 @@ export function ChatContainer({ initialPrompt, className, onStreamingChange, + onOperationStarted, }: ChatContainerProps) { const { messages, @@ -37,6 +40,7 @@ export function ChatContainer({ sessionId, initialMessages, initialPrompt, + onOperationStarted, }); useEffect(() => { @@ -55,24 +59,37 @@ export function ChatContainer({ )} > + + + Service unavailable + +
+ } controlled={{ isOpen: isRegionBlockedModalOpen, set: handleRegionModalOpenChange, }} onClose={handleRegionModalClose} + styling={{ maxWidth: 550, width: "100%", minWidth: "auto" }} > -
+
- This model is not available in your region. Please connect via VPN - and try again. + The Autogpt AI model is not available in your region or your + connection is blocking it. Please try again with a different + connection. -
+
diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts index 791cf046d5..82e9b05e88 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts @@ -1,5 +1,5 @@ import { toast } from "sonner"; -import { StreamChunk } from "../../useChatStream"; +import type { StreamChunk } from "../../chat-types"; import type { HandlerDependencies } from "./handlers"; import { handleError, diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts index 96198a0386..f3cac01f96 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts @@ -22,6 +22,7 @@ export interface HandlerDependencies { setIsStreamingInitiated: Dispatch>; setIsRegionBlockedModalOpen: Dispatch>; sessionId: string; + onOperationStarted?: () => void; } export function isRegionBlockedError(chunk: StreamChunk): boolean { @@ -48,6 +49,15 @@ export function handleTextEnded( const completedText = deps.streamingChunksRef.current.join(""); if (completedText.trim()) { deps.setMessages((prev) => { + // Check if this exact message already exists to prevent duplicates + const exists = prev.some( + (msg) => + msg.type === "message" && + msg.role === "assistant" && + msg.content === completedText, + ); + if (exists) return prev; + const assistantMessage: ChatMessageData = { type: "message", role: "assistant", @@ -154,6 +164,11 @@ export function handleToolResponse( } return; } + // Trigger polling when operation_started is received + if (responseMessage.type === "operation_started") { + deps.onOperationStarted?.(); + } + deps.setMessages((prev) => { const toolCallIndex = prev.findIndex( (msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id, @@ -203,13 +218,24 @@ export function handleStreamEnd( ]); } if (completedContent.trim()) { - const assistantMessage: ChatMessageData = { - type: "message", - role: "assistant", - content: completedContent, - timestamp: new Date(), - }; - deps.setMessages((prev) => [...prev, assistantMessage]); + deps.setMessages((prev) => { + // Check if this exact message already exists to prevent duplicates + const exists = prev.some( + (msg) => + msg.type === "message" && + msg.role === "assistant" && + msg.content === completedContent, + ); + if (exists) return prev; + + const assistantMessage: ChatMessageData = { + type: "message", + role: "assistant", + content: completedContent, + timestamp: new Date(), + }; + return [...prev, assistantMessage]; + }); } deps.setStreamingChunks([]); deps.streamingChunksRef.current = []; diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts index 9d51003a93..e744c9bc34 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts @@ -1,7 +1,118 @@ +import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; import { SessionKey, sessionStorage } from "@/services/storage/session-storage"; import type { ToolResult } from "@/types/chat"; import type { ChatMessageData } from "../ChatMessage/useChatMessage"; +export function processInitialMessages( + initialMessages: SessionDetailResponse["messages"], +): ChatMessageData[] { + const processedMessages: ChatMessageData[] = []; + const toolCallMap = new Map(); + + for (const msg of initialMessages) { + if (!isValidMessage(msg)) { + console.warn("Invalid message structure from backend:", msg); + continue; + } + + let content = String(msg.content || ""); + const role = String(msg.role || "assistant").toLowerCase(); + const toolCalls = msg.tool_calls; + const timestamp = msg.timestamp + ? new Date(msg.timestamp as string) + : undefined; + + if (role === "user") { + content = removePageContext(content); + if (!content.trim()) continue; + processedMessages.push({ + type: "message", + role: "user", + content, + timestamp, + }); + continue; + } + + if (role === "assistant") { + content = content + .replace(/[\s\S]*?<\/thinking>/gi, "") + .replace(/[\s\S]*?<\/internal_reasoning>/gi, "") + .trim(); + + if (toolCalls && isToolCallArray(toolCalls) && toolCalls.length > 0) { + for (const toolCall of toolCalls) { + const toolName = toolCall.function.name; + const toolId = toolCall.id; + toolCallMap.set(toolId, toolName); + + try { + const args = JSON.parse(toolCall.function.arguments || "{}"); + processedMessages.push({ + type: "tool_call", + toolId, + toolName, + arguments: args, + timestamp, + }); + } catch (err) { + console.warn("Failed to parse tool call arguments:", err); + processedMessages.push({ + type: "tool_call", + toolId, + toolName, + arguments: {}, + timestamp, + }); + } + } + if (content.trim()) { + processedMessages.push({ + type: "message", + role: "assistant", + content, + timestamp, + }); + } + } else if (content.trim()) { + processedMessages.push({ + type: "message", + role: "assistant", + content, + timestamp, + }); + } + continue; + } + + if (role === "tool") { + const toolCallId = (msg.tool_call_id as string) || ""; + const toolName = toolCallMap.get(toolCallId) || "unknown"; + const toolResponse = parseToolResponse( + content, + toolCallId, + toolName, + timestamp, + ); + if (toolResponse) { + processedMessages.push(toolResponse); + } + continue; + } + + if (content.trim()) { + processedMessages.push({ + type: "message", + role: role as "user" | "assistant" | "system", + content, + timestamp, + }); + } + } + + return processedMessages; +} + export function hasSentInitialPrompt(sessionId: string): boolean { try { const sent = JSON.parse( @@ -193,6 +304,7 @@ export function parseToolResponse( if (isAgentArray(agentsData)) { return { type: "agent_carousel", + toolId, toolName: "agent_carousel", agents: agentsData, totalCount: parsedResult.total_count as number | undefined, @@ -205,6 +317,7 @@ export function parseToolResponse( if (responseType === "execution_started") { return { type: "execution_started", + toolId, toolName: "execution_started", executionId: (parsedResult.execution_id as string) || "", agentName: (parsedResult.graph_name as string) || undefined, @@ -213,6 +326,58 @@ export function parseToolResponse( timestamp: timestamp || new Date(), }; } + if (responseType === "clarification_needed") { + return { + type: "clarification_needed", + toolName, + questions: + (parsedResult.questions as Array<{ + question: string; + keyword: string; + example?: string; + }>) || [], + message: + (parsedResult.message as string) || + "I need more information to proceed.", + sessionId: (parsedResult.session_id as string) || "", + timestamp: timestamp || new Date(), + }; + } + if (responseType === "operation_started") { + return { + type: "operation_started", + toolName: (parsedResult.tool_name as string) || toolName, + toolId, + operationId: (parsedResult.operation_id as string) || "", + message: + (parsedResult.message as string) || + "Operation started. You can close this tab.", + timestamp: timestamp || new Date(), + }; + } + if (responseType === "operation_pending") { + return { + type: "operation_pending", + toolName: (parsedResult.tool_name as string) || toolName, + toolId, + operationId: (parsedResult.operation_id as string) || "", + message: + (parsedResult.message as string) || + "Operation in progress. Please wait...", + timestamp: timestamp || new Date(), + }; + } + if (responseType === "operation_in_progress") { + return { + type: "operation_in_progress", + toolName: (parsedResult.tool_name as string) || toolName, + toolCallId: (parsedResult.tool_call_id as string) || toolId, + message: + (parsedResult.message as string) || + "Operation already in progress. Please wait...", + timestamp: timestamp || new Date(), + }; + } if (responseType === "need_login") { return { type: "login_needed", diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts index 42dd04670d..46f384d055 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts @@ -1,5 +1,6 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; -import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { useEffect, useMemo, useRef, useState } from "react"; +import { useChatStore } from "../../chat-store"; import { toast } from "sonner"; import { useChatStream } from "../../useChatStream"; import { usePageContext } from "../../usePageContext"; @@ -9,23 +10,44 @@ import { createUserMessage, filterAuthMessages, hasSentInitialPrompt, - isToolCallArray, - isValidMessage, markInitialPromptSent, - parseToolResponse, - removePageContext, + processInitialMessages, } from "./helpers"; +// Helper to generate deduplication key for a message +function getMessageKey(msg: ChatMessageData): string { + if (msg.type === "message") { + // Don't include timestamp - dedupe by role + content only + // This handles the case where local and server timestamps differ + // Server messages are authoritative, so duplicates from local state are filtered + return `msg:${msg.role}:${msg.content}`; + } else if (msg.type === "tool_call") { + return `toolcall:${msg.toolId}`; + } else if (msg.type === "tool_response") { + return `toolresponse:${(msg as any).toolId}`; + } else if ( + msg.type === "operation_started" || + msg.type === "operation_pending" || + msg.type === "operation_in_progress" + ) { + return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`; + } else { + return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`; + } +} + interface Args { sessionId: string | null; initialMessages: SessionDetailResponse["messages"]; initialPrompt?: string; + onOperationStarted?: () => void; } export function useChatContainer({ sessionId, initialMessages, initialPrompt, + onOperationStarted, }: Args) { const [messages, setMessages] = useState([]); const [streamingChunks, setStreamingChunks] = useState([]); @@ -41,11 +63,18 @@ export function useChatContainer({ sendMessage: sendStreamMessage, stopStreaming, } = useChatStream(); + const activeStreams = useChatStore((s) => s.activeStreams); + const subscribeToStream = useChatStore((s) => s.subscribeToStream); const isStreaming = isStreamingInitiated || hasTextChunks; - useEffect(() => { - if (sessionId !== previousSessionIdRef.current) { - stopStreaming(previousSessionIdRef.current ?? undefined, true); + useEffect( + function handleSessionChange() { + if (sessionId === previousSessionIdRef.current) return; + + const prevSession = previousSessionIdRef.current; + if (prevSession) { + stopStreaming(prevSession); + } previousSessionIdRef.current = sessionId; setMessages([]); setStreamingChunks([]); @@ -53,138 +82,11 @@ export function useChatContainer({ setHasTextChunks(false); setIsStreamingInitiated(false); hasResponseRef.current = false; - } - }, [sessionId, stopStreaming]); - const allMessages = useMemo(() => { - const processedInitialMessages: ChatMessageData[] = []; - const toolCallMap = new Map(); + if (!sessionId) return; - for (const msg of initialMessages) { - if (!isValidMessage(msg)) { - console.warn("Invalid message structure from backend:", msg); - continue; - } - - let content = String(msg.content || ""); - const role = String(msg.role || "assistant").toLowerCase(); - const toolCalls = msg.tool_calls; - const timestamp = msg.timestamp - ? new Date(msg.timestamp as string) - : undefined; - - if (role === "user") { - content = removePageContext(content); - if (!content.trim()) continue; - processedInitialMessages.push({ - type: "message", - role: "user", - content, - timestamp, - }); - continue; - } - - if (role === "assistant") { - content = content - .replace(/[\s\S]*?<\/thinking>/gi, "") - .trim(); - - if (toolCalls && isToolCallArray(toolCalls) && toolCalls.length > 0) { - for (const toolCall of toolCalls) { - const toolName = toolCall.function.name; - const toolId = toolCall.id; - toolCallMap.set(toolId, toolName); - - try { - const args = JSON.parse(toolCall.function.arguments || "{}"); - processedInitialMessages.push({ - type: "tool_call", - toolId, - toolName, - arguments: args, - timestamp, - }); - } catch (err) { - console.warn("Failed to parse tool call arguments:", err); - processedInitialMessages.push({ - type: "tool_call", - toolId, - toolName, - arguments: {}, - timestamp, - }); - } - } - if (content.trim()) { - processedInitialMessages.push({ - type: "message", - role: "assistant", - content, - timestamp, - }); - } - } else if (content.trim()) { - processedInitialMessages.push({ - type: "message", - role: "assistant", - content, - timestamp, - }); - } - continue; - } - - if (role === "tool") { - const toolCallId = (msg.tool_call_id as string) || ""; - const toolName = toolCallMap.get(toolCallId) || "unknown"; - const toolResponse = parseToolResponse( - content, - toolCallId, - toolName, - timestamp, - ); - if (toolResponse) { - processedInitialMessages.push(toolResponse); - } - continue; - } - - if (content.trim()) { - processedInitialMessages.push({ - type: "message", - role: role as "user" | "assistant" | "system", - content, - timestamp, - }); - } - } - - return [...processedInitialMessages, ...messages]; - }, [initialMessages, messages]); - - const sendMessage = useCallback( - async function sendMessage( - content: string, - isUserMessage: boolean = true, - context?: { url: string; content: string }, - ) { - if (!sessionId) { - console.error("[useChatContainer] Cannot send message: no session ID"); - return; - } - setIsRegionBlockedModalOpen(false); - if (isUserMessage) { - const userMessage = createUserMessage(content); - setMessages((prev) => [...filterAuthMessages(prev), userMessage]); - } else { - setMessages((prev) => filterAuthMessages(prev)); - } - setStreamingChunks([]); - streamingChunksRef.current = []; - setHasTextChunks(false); - setIsStreamingInitiated(true); - hasResponseRef.current = false; + const activeStream = activeStreams.get(sessionId); + if (!activeStream || activeStream.status !== "streaming") return; const dispatcher = createStreamEventDispatcher({ setHasTextChunks, @@ -195,44 +97,170 @@ export function useChatContainer({ setIsRegionBlockedModalOpen, sessionId, setIsStreamingInitiated, + onOperationStarted, }); - try { - await sendStreamMessage( - sessionId, - content, - dispatcher, - isUserMessage, - context, - ); - } catch (err) { - console.error("[useChatContainer] Failed to send message:", err); - setIsStreamingInitiated(false); - - // Don't show error toast for AbortError (expected during cleanup) - if (err instanceof Error && err.name === "AbortError") return; - - const errorMessage = - err instanceof Error ? err.message : "Failed to send message"; - toast.error("Failed to send message", { - description: errorMessage, - }); - } + setIsStreamingInitiated(true); + const skipReplay = initialMessages.length > 0; + return subscribeToStream(sessionId, dispatcher, skipReplay); }, - [sessionId, sendStreamMessage], + [ + sessionId, + stopStreaming, + activeStreams, + subscribeToStream, + onOperationStarted, + ], ); - const handleStopStreaming = useCallback(() => { + // Collect toolIds from completed tool results in initialMessages + // Used to filter out operation messages when their results arrive + const completedToolIds = useMemo(() => { + const processedInitial = processInitialMessages(initialMessages); + const ids = new Set(); + for (const msg of processedInitial) { + if ( + msg.type === "tool_response" || + msg.type === "agent_carousel" || + msg.type === "execution_started" + ) { + const toolId = (msg as any).toolId; + if (toolId) { + ids.add(toolId); + } + } + } + return ids; + }, [initialMessages]); + + // Clean up local operation messages when their completed results arrive from polling + // This effect runs when completedToolIds changes (i.e., when polling brings new results) + useEffect( + function cleanupCompletedOperations() { + if (completedToolIds.size === 0) return; + + setMessages((prev) => { + const filtered = prev.filter((msg) => { + if ( + msg.type === "operation_started" || + msg.type === "operation_pending" || + msg.type === "operation_in_progress" + ) { + const toolId = (msg as any).toolId || (msg as any).toolCallId; + if (toolId && completedToolIds.has(toolId)) { + return false; // Remove - operation completed + } + } + return true; + }); + // Only update state if something was actually filtered + return filtered.length === prev.length ? prev : filtered; + }); + }, + [completedToolIds], + ); + + // Combine initial messages from backend with local streaming messages, + // Server messages maintain correct order; only append truly new local messages + const allMessages = useMemo(() => { + const processedInitial = processInitialMessages(initialMessages); + + // Build a set of keys from server messages for deduplication + const serverKeys = new Set(); + for (const msg of processedInitial) { + serverKeys.add(getMessageKey(msg)); + } + + // Filter local messages: remove duplicates and completed operation messages + const newLocalMessages = messages.filter((msg) => { + // Remove operation messages for completed tools + if ( + msg.type === "operation_started" || + msg.type === "operation_pending" || + msg.type === "operation_in_progress" + ) { + const toolId = (msg as any).toolId || (msg as any).toolCallId; + if (toolId && completedToolIds.has(toolId)) { + return false; + } + } + // Remove messages that already exist in server data + const key = getMessageKey(msg); + return !serverKeys.has(key); + }); + + // Server messages first (correct order), then new local messages + return [...processedInitial, ...newLocalMessages]; + }, [initialMessages, messages, completedToolIds]); + + async function sendMessage( + content: string, + isUserMessage: boolean = true, + context?: { url: string; content: string }, + ) { + if (!sessionId) { + console.error("[useChatContainer] Cannot send message: no session ID"); + return; + } + setIsRegionBlockedModalOpen(false); + if (isUserMessage) { + const userMessage = createUserMessage(content); + setMessages((prev) => [...filterAuthMessages(prev), userMessage]); + } else { + setMessages((prev) => filterAuthMessages(prev)); + } + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + setIsStreamingInitiated(true); + hasResponseRef.current = false; + + const dispatcher = createStreamEventDispatcher({ + setHasTextChunks, + setStreamingChunks, + streamingChunksRef, + hasResponseRef, + setMessages, + setIsRegionBlockedModalOpen, + sessionId, + setIsStreamingInitiated, + onOperationStarted, + }); + + try { + await sendStreamMessage( + sessionId, + content, + dispatcher, + isUserMessage, + context, + ); + } catch (err) { + console.error("[useChatContainer] Failed to send message:", err); + setIsStreamingInitiated(false); + + if (err instanceof Error && err.name === "AbortError") return; + + const errorMessage = + err instanceof Error ? err.message : "Failed to send message"; + toast.error("Failed to send message", { + description: errorMessage, + }); + } + } + + function handleStopStreaming() { stopStreaming(); setStreamingChunks([]); streamingChunksRef.current = []; setHasTextChunks(false); setIsStreamingInitiated(false); - }, [stopStreaming]); + } const { capturePageContext } = usePageContext(); + const sendMessageRef = useRef(sendMessage); + sendMessageRef.current = sendMessage; - // Send initial prompt if provided (for new sessions from homepage) useEffect( function handleInitialPrompt() { if (!initialPrompt || !sessionId) return; @@ -241,15 +269,9 @@ export function useChatContainer({ markInitialPromptSent(sessionId); const context = capturePageContext(); - sendMessage(initialPrompt, true, context); + sendMessageRef.current(initialPrompt, true, context); }, - [ - initialPrompt, - sessionId, - initialMessages.length, - sendMessage, - capturePageContext, - ], + [initialPrompt, sessionId, initialMessages.length, capturePageContext], ); async function sendMessageWithContext( diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx index 8cdecf0bf4..521f6f6320 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx @@ -1,7 +1,14 @@ import { Button } from "@/components/atoms/Button/Button"; import { cn } from "@/lib/utils"; -import { ArrowUpIcon, StopIcon } from "@phosphor-icons/react"; +import { + ArrowUpIcon, + CircleNotchIcon, + MicrophoneIcon, + StopIcon, +} from "@phosphor-icons/react"; +import { RecordingIndicator } from "./components/RecordingIndicator"; import { useChatInput } from "./useChatInput"; +import { useVoiceRecording } from "./useVoiceRecording"; export interface Props { onSend: (message: string) => void; @@ -21,22 +28,36 @@ export function ChatInput({ className, }: Props) { const inputId = "chat-input"; - const { value, setValue, handleKeyDown, handleSend, hasMultipleLines } = - useChatInput({ - onSend, - disabled: disabled || isStreaming, - maxRows: 4, - inputId, - }); + const { + value, + setValue, + handleKeyDown: baseHandleKeyDown, + handleSubmit, + handleChange, + hasMultipleLines, + } = useChatInput({ + onSend, + disabled: disabled || isStreaming, + maxRows: 4, + inputId, + }); - function handleSubmit(e: React.FormEvent) { - e.preventDefault(); - handleSend(); - } - - function handleChange(e: React.ChangeEvent) { - setValue(e.target.value); - } + const { + isRecording, + isTranscribing, + elapsedTime, + toggleRecording, + handleKeyDown, + showMicButton, + isInputDisabled, + audioStream, + } = useVoiceRecording({ + setValue, + disabled: disabled || isStreaming, + isStreaming, + value, + baseHandleKeyDown, + }); return (
@@ -44,8 +65,11 @@ export function ChatInput({
@@ -55,48 +79,94 @@ export function ChatInput({ value={value} onChange={handleChange} onKeyDown={handleKeyDown} - placeholder={placeholder} - disabled={disabled || isStreaming} + placeholder={ + isTranscribing + ? "Transcribing..." + : isRecording + ? "" + : placeholder + } + disabled={isInputDisabled} rows={1} className={cn( "w-full resize-none overflow-y-auto border-0 bg-transparent text-[1rem] leading-6 text-black", "placeholder:text-zinc-400", "focus:outline-none focus:ring-0", "disabled:text-zinc-500", - hasMultipleLines ? "pb-6 pl-4 pr-4 pt-2" : "pb-4 pl-4 pr-14 pt-4", + hasMultipleLines + ? "pb-6 pl-4 pr-4 pt-2" + : showMicButton + ? "pb-4 pl-14 pr-14 pt-4" + : "pb-4 pl-4 pr-14 pt-4", )} /> + {isRecording && !value && ( +
+ +
+ )}
- Press Enter to send, Shift+Enter for new line + Press Enter to send, Shift+Enter for new line, Space to record voice - {isStreaming ? ( - - ) : ( - + {showMicButton && ( +
+ +
)} + +
+ {isStreaming ? ( + + ) : ( + + )} +
); diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/components/AudioWaveform.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/components/AudioWaveform.tsx new file mode 100644 index 0000000000..10cbb3fc9f --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/components/AudioWaveform.tsx @@ -0,0 +1,142 @@ +"use client"; + +import { useEffect, useRef, useState } from "react"; + +interface Props { + stream: MediaStream | null; + barCount?: number; + barWidth?: number; + barGap?: number; + barColor?: string; + minBarHeight?: number; + maxBarHeight?: number; +} + +export function AudioWaveform({ + stream, + barCount = 24, + barWidth = 3, + barGap = 2, + barColor = "#ef4444", // red-500 + minBarHeight = 4, + maxBarHeight = 32, +}: Props) { + const [bars, setBars] = useState(() => + Array(barCount).fill(minBarHeight), + ); + const analyserRef = useRef(null); + const audioContextRef = useRef(null); + const sourceRef = useRef(null); + const animationRef = useRef(null); + + useEffect(() => { + if (!stream) { + setBars(Array(barCount).fill(minBarHeight)); + return; + } + + // Create audio context and analyser + const audioContext = new AudioContext(); + const analyser = audioContext.createAnalyser(); + analyser.fftSize = 512; + analyser.smoothingTimeConstant = 0.8; + + // Connect the stream to the analyser + const source = audioContext.createMediaStreamSource(stream); + source.connect(analyser); + + audioContextRef.current = audioContext; + analyserRef.current = analyser; + sourceRef.current = source; + + const timeData = new Uint8Array(analyser.frequencyBinCount); + + const updateBars = () => { + if (!analyserRef.current) return; + + analyserRef.current.getByteTimeDomainData(timeData); + + // Distribute time-domain data across bars + // This shows waveform amplitude, making all bars respond to audio + const newBars: number[] = []; + const samplesPerBar = timeData.length / barCount; + + for (let i = 0; i < barCount; i++) { + // Sample waveform data for this bar + let maxAmplitude = 0; + const startIdx = Math.floor(i * samplesPerBar); + const endIdx = Math.floor((i + 1) * samplesPerBar); + + for (let j = startIdx; j < endIdx && j < timeData.length; j++) { + // Convert to amplitude (distance from center 128) + const amplitude = Math.abs(timeData[j] - 128); + maxAmplitude = Math.max(maxAmplitude, amplitude); + } + + // Map amplitude (0-128) to bar height + const normalized = (maxAmplitude / 128) * 255; + const height = + minBarHeight + (normalized / 255) * (maxBarHeight - minBarHeight); + newBars.push(height); + } + + setBars(newBars); + animationRef.current = requestAnimationFrame(updateBars); + }; + + updateBars(); + + return () => { + if (animationRef.current) { + cancelAnimationFrame(animationRef.current); + } + if (sourceRef.current) { + sourceRef.current.disconnect(); + } + if (audioContextRef.current) { + audioContextRef.current.close(); + } + analyserRef.current = null; + audioContextRef.current = null; + sourceRef.current = null; + }; + }, [stream, barCount, minBarHeight, maxBarHeight]); + + const totalWidth = barCount * barWidth + (barCount - 1) * barGap; + + return ( +
+ {bars.map((height, i) => { + const barHeight = Math.max(minBarHeight, height); + return ( +
+
+
+ ); + })} +
+ ); +} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/components/RecordingIndicator.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/components/RecordingIndicator.tsx new file mode 100644 index 0000000000..0be0d069bb --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/components/RecordingIndicator.tsx @@ -0,0 +1,26 @@ +import { formatElapsedTime } from "../helpers"; +import { AudioWaveform } from "./AudioWaveform"; + +type Props = { + elapsedTime: number; + audioStream: MediaStream | null; +}; + +export function RecordingIndicator({ elapsedTime, audioStream }: Props) { + return ( +
+ + + {formatElapsedTime(elapsedTime)} + +
+ ); +} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/helpers.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/helpers.ts new file mode 100644 index 0000000000..26bae8c9d9 --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/helpers.ts @@ -0,0 +1,6 @@ +export function formatElapsedTime(ms: number): string { + const seconds = Math.floor(ms / 1000); + const minutes = Math.floor(seconds / 60); + const remainingSeconds = seconds % 60; + return `${minutes}:${remainingSeconds.toString().padStart(2, "0")}`; +} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/useChatInput.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/useChatInput.ts index 93d764b026..a053e6080f 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/useChatInput.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/useChatInput.ts @@ -1,6 +1,12 @@ -import { KeyboardEvent, useCallback, useEffect, useState } from "react"; +import { + ChangeEvent, + FormEvent, + KeyboardEvent, + useEffect, + useState, +} from "react"; -interface UseChatInputArgs { +interface Args { onSend: (message: string) => void; disabled?: boolean; maxRows?: number; @@ -12,10 +18,27 @@ export function useChatInput({ disabled = false, maxRows = 5, inputId = "chat-input", -}: UseChatInputArgs) { +}: Args) { const [value, setValue] = useState(""); const [hasMultipleLines, setHasMultipleLines] = useState(false); + useEffect( + function focusOnMount() { + const textarea = document.getElementById(inputId) as HTMLTextAreaElement; + if (textarea) textarea.focus(); + }, + [inputId], + ); + + useEffect( + function focusWhenEnabled() { + if (disabled) return; + const textarea = document.getElementById(inputId) as HTMLTextAreaElement; + if (textarea) textarea.focus(); + }, + [disabled, inputId], + ); + useEffect(() => { const textarea = document.getElementById(inputId) as HTMLTextAreaElement; const wrapper = document.getElementById( @@ -77,7 +100,7 @@ export function useChatInput({ } }, [value, maxRows, inputId]); - const handleSend = useCallback(() => { + const handleSend = () => { if (disabled || !value.trim()) return; onSend(value.trim()); setValue(""); @@ -93,23 +116,31 @@ export function useChatInput({ wrapper.style.height = ""; wrapper.style.maxHeight = ""; } - }, [value, onSend, disabled, inputId]); + }; - const handleKeyDown = useCallback( - (event: KeyboardEvent) => { - if (event.key === "Enter" && !event.shiftKey) { - event.preventDefault(); - handleSend(); - } - }, - [handleSend], - ); + function handleKeyDown(event: KeyboardEvent) { + if (event.key === "Enter" && !event.shiftKey) { + event.preventDefault(); + handleSend(); + } + } + + function handleSubmit(e: FormEvent) { + e.preventDefault(); + handleSend(); + } + + function handleChange(e: ChangeEvent) { + setValue(e.target.value); + } return { value, setValue, handleKeyDown, handleSend, + handleSubmit, + handleChange, hasMultipleLines, }; } diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/useVoiceRecording.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/useVoiceRecording.ts new file mode 100644 index 0000000000..13b625e69c --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/useVoiceRecording.ts @@ -0,0 +1,240 @@ +import { useToast } from "@/components/molecules/Toast/use-toast"; +import React, { + KeyboardEvent, + useCallback, + useEffect, + useRef, + useState, +} from "react"; + +const MAX_RECORDING_DURATION = 2 * 60 * 1000; // 2 minutes in ms + +interface Args { + setValue: React.Dispatch>; + disabled?: boolean; + isStreaming?: boolean; + value: string; + baseHandleKeyDown: (event: KeyboardEvent) => void; +} + +export function useVoiceRecording({ + setValue, + disabled = false, + isStreaming = false, + value, + baseHandleKeyDown, +}: Args) { + const [isRecording, setIsRecording] = useState(false); + const [isTranscribing, setIsTranscribing] = useState(false); + const [error, setError] = useState(null); + const [elapsedTime, setElapsedTime] = useState(0); + + const mediaRecorderRef = useRef(null); + const chunksRef = useRef([]); + const timerRef = useRef(null); + const startTimeRef = useRef(0); + const streamRef = useRef(null); + const isRecordingRef = useRef(false); + + const isSupported = + typeof window !== "undefined" && + !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia); + + const clearTimer = useCallback(() => { + if (timerRef.current) { + clearInterval(timerRef.current); + timerRef.current = null; + } + }, []); + + const cleanup = useCallback(() => { + clearTimer(); + if (streamRef.current) { + streamRef.current.getTracks().forEach((track) => track.stop()); + streamRef.current = null; + } + mediaRecorderRef.current = null; + chunksRef.current = []; + setElapsedTime(0); + }, [clearTimer]); + + const handleTranscription = useCallback( + (text: string) => { + setValue((prev) => { + const trimmedPrev = prev.trim(); + if (trimmedPrev) { + return `${trimmedPrev} ${text}`; + } + return text; + }); + }, + [setValue], + ); + + const transcribeAudio = useCallback( + async (audioBlob: Blob) => { + setIsTranscribing(true); + setError(null); + + try { + const formData = new FormData(); + formData.append("audio", audioBlob); + + const response = await fetch("/api/transcribe", { + method: "POST", + body: formData, + }); + + if (!response.ok) { + const data = await response.json().catch(() => ({})); + throw new Error(data.error || "Transcription failed"); + } + + const data = await response.json(); + if (data.text) { + handleTranscription(data.text); + } + } catch (err) { + const message = + err instanceof Error ? err.message : "Transcription failed"; + setError(message); + console.error("Transcription error:", err); + } finally { + setIsTranscribing(false); + } + }, + [handleTranscription], + ); + + const stopRecording = useCallback(() => { + if (mediaRecorderRef.current && isRecordingRef.current) { + mediaRecorderRef.current.stop(); + isRecordingRef.current = false; + setIsRecording(false); + clearTimer(); + } + }, [clearTimer]); + + const startRecording = useCallback(async () => { + if (disabled || isRecordingRef.current || isTranscribing) return; + + setError(null); + chunksRef.current = []; + + try { + const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); + streamRef.current = stream; + + const mediaRecorder = new MediaRecorder(stream, { + mimeType: MediaRecorder.isTypeSupported("audio/webm") + ? "audio/webm" + : "audio/mp4", + }); + + mediaRecorderRef.current = mediaRecorder; + + mediaRecorder.ondataavailable = (event) => { + if (event.data.size > 0) { + chunksRef.current.push(event.data); + } + }; + + mediaRecorder.onstop = async () => { + const audioBlob = new Blob(chunksRef.current, { + type: mediaRecorder.mimeType, + }); + + // Cleanup stream + if (streamRef.current) { + streamRef.current.getTracks().forEach((track) => track.stop()); + streamRef.current = null; + } + + if (audioBlob.size > 0) { + await transcribeAudio(audioBlob); + } + }; + + mediaRecorder.start(1000); // Collect data every second + isRecordingRef.current = true; + setIsRecording(true); + startTimeRef.current = Date.now(); + + // Start elapsed time timer + timerRef.current = setInterval(() => { + const elapsed = Date.now() - startTimeRef.current; + setElapsedTime(elapsed); + + // Auto-stop at max duration + if (elapsed >= MAX_RECORDING_DURATION) { + stopRecording(); + } + }, 100); + } catch (err) { + console.error("Failed to start recording:", err); + if (err instanceof DOMException && err.name === "NotAllowedError") { + setError("Microphone permission denied"); + } else { + setError("Failed to access microphone"); + } + cleanup(); + } + }, [disabled, isTranscribing, stopRecording, transcribeAudio, cleanup]); + + const toggleRecording = useCallback(() => { + if (isRecording) { + stopRecording(); + } else { + startRecording(); + } + }, [isRecording, startRecording, stopRecording]); + + const { toast } = useToast(); + + useEffect(() => { + if (error) { + toast({ + title: "Voice recording failed", + description: error, + variant: "destructive", + }); + } + }, [error, toast]); + + const handleKeyDown = useCallback( + (event: KeyboardEvent) => { + if (event.key === " " && !value.trim() && !isTranscribing) { + event.preventDefault(); + toggleRecording(); + return; + } + baseHandleKeyDown(event); + }, + [value, isTranscribing, toggleRecording, baseHandleKeyDown], + ); + + const showMicButton = isSupported && !isStreaming; + const isInputDisabled = disabled || isStreaming || isTranscribing; + + // Cleanup on unmount + useEffect(() => { + return () => { + cleanup(); + }; + }, [cleanup]); + + return { + isRecording, + isTranscribing, + error, + elapsedTime, + startRecording, + stopRecording, + toggleRecording, + isSupported, + handleKeyDown, + showMicButton, + isInputDisabled, + audioStream: streamRef.current, + }; +} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatMessage/ChatMessage.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatMessage/ChatMessage.tsx index a2827ce611..c922d0da76 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatMessage/ChatMessage.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatMessage/ChatMessage.tsx @@ -14,7 +14,9 @@ import { AgentCarouselMessage } from "../AgentCarouselMessage/AgentCarouselMessa import { AIChatBubble } from "../AIChatBubble/AIChatBubble"; import { AuthPromptWidget } from "../AuthPromptWidget/AuthPromptWidget"; import { ChatCredentialsSetup } from "../ChatCredentialsSetup/ChatCredentialsSetup"; +import { ClarificationQuestionsWidget } from "../ClarificationQuestionsWidget/ClarificationQuestionsWidget"; import { ExecutionStartedMessage } from "../ExecutionStartedMessage/ExecutionStartedMessage"; +import { PendingOperationWidget } from "../PendingOperationWidget/PendingOperationWidget"; import { MarkdownContent } from "../MarkdownContent/MarkdownContent"; import { NoResultsMessage } from "../NoResultsMessage/NoResultsMessage"; import { ToolCallMessage } from "../ToolCallMessage/ToolCallMessage"; @@ -69,6 +71,10 @@ export function ChatMessage({ isToolResponse, isLoginNeeded, isCredentialsNeeded, + isClarificationNeeded, + isOperationStarted, + isOperationPending, + isOperationInProgress, } = useChatMessage(message); const displayContent = getDisplayContent(message, isUser); @@ -96,6 +102,18 @@ export function ChatMessage({ } } + function handleClarificationAnswers(answers: Record) { + if (onSendMessage) { + const contextMessage = Object.entries(answers) + .map(([keyword, answer]) => `${keyword}: ${answer}`) + .join("\n"); + + onSendMessage( + `I have the answers to your questions:\n\n${contextMessage}\n\nPlease proceed with creating the agent.`, + ); + } + } + const handleCopy = useCallback( async function handleCopy() { if (message.type !== "message") return; @@ -112,10 +130,6 @@ export function ChatMessage({ [displayContent, message], ); - function isLongResponse(content: string): boolean { - return content.split("\n").length > 5; - } - const handleTryAgain = useCallback(() => { if (message.type !== "message" || !onSendMessage) return; onSendMessage(message.content, message.role === "user"); @@ -141,6 +155,17 @@ export function ChatMessage({ ); } + if (isClarificationNeeded && message.type === "clarification_needed") { + return ( + + ); + } + // Render login needed messages if (isLoginNeeded && message.type === "login_needed") { // If user is already logged in, show success message instead of auth prompt @@ -269,6 +294,42 @@ export function ChatMessage({ ); } + // Render operation_started messages (long-running background operations) + if (isOperationStarted && message.type === "operation_started") { + return ( + + ); + } + + // Render operation_pending messages (operations in progress when refreshing) + if (isOperationPending && message.type === "operation_pending") { + return ( + + ); + } + + // Render operation_in_progress messages (duplicate request while operation running) + if (isOperationInProgress && message.type === "operation_in_progress") { + return ( + + ); + } + // Render tool response messages (but skip agent_output if it's being rendered inside assistant message) if (isToolResponse && message.type === "tool_response") { return ( @@ -333,7 +394,7 @@ export function ChatMessage({ )} - {!isUser && isFinalMessage && isLongResponse(displayContent) && ( + {!isUser && isFinalMessage && !isStreaming && ( + {onCancel && ( + + )} +
+ +
+
+
+ ); +} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/MarkdownContent/MarkdownContent.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/MarkdownContent/MarkdownContent.tsx index 51a0794090..3dd5eca692 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/MarkdownContent/MarkdownContent.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/MarkdownContent/MarkdownContent.tsx @@ -1,6 +1,8 @@ "use client"; +import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace"; import { cn } from "@/lib/utils"; +import { EyeSlash } from "@phosphor-icons/react"; import React from "react"; import ReactMarkdown from "react-markdown"; import remarkGfm from "remark-gfm"; @@ -29,12 +31,88 @@ interface InputProps extends React.InputHTMLAttributes { type?: string; } +/** + * Converts a workspace:// URL to a proxy URL that routes through Next.js to the backend. + * workspace://abc123 -> /api/proxy/api/workspace/files/abc123/download + * + * Uses the generated API URL helper and routes through the Next.js proxy + * which handles authentication and proper backend routing. + */ +/** + * URL transformer for ReactMarkdown. + * Converts workspace:// URLs to proxy URLs that route through Next.js to the backend. + * workspace://abc123 -> /api/proxy/api/workspace/files/abc123/download + * + * This is needed because ReactMarkdown sanitizes URLs and only allows + * http, https, mailto, and tel protocols by default. + */ +function resolveWorkspaceUrl(src: string): string { + if (src.startsWith("workspace://")) { + const fileId = src.replace("workspace://", ""); + // Use the generated API URL helper to get the correct path + const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId); + // Route through the Next.js proxy (same pattern as customMutator for client-side) + return `/api/proxy${apiPath}`; + } + return src; +} + +/** + * Check if the image URL is a workspace file (AI cannot see these yet). + * After URL transformation, workspace files have URLs like /api/proxy/api/workspace/files/... + */ +function isWorkspaceImage(src: string | undefined): boolean { + return src?.includes("/workspace/files/") ?? false; +} + +/** + * Custom image component that shows an indicator when the AI cannot see the image. + * Note: src is already transformed by urlTransform, so workspace:// is now /api/workspace/... + */ +function MarkdownImage(props: Record) { + const src = props.src as string | undefined; + const alt = props.alt as string | undefined; + + const aiCannotSee = isWorkspaceImage(src); + + // If no src, show a placeholder + if (!src) { + return ( + + [Image: {alt || "missing src"}] + + ); + } + + return ( + + {/* eslint-disable-next-line @next/next/no-img-element */} + {alt + {aiCannotSee && ( + + + AI cannot see this image + + )} + + ); +} + export function MarkdownContent({ content, className }: MarkdownContentProps) { return (
{ const isInline = !className?.includes("language-"); @@ -206,6 +284,9 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) { {children} ), + img: ({ src, alt, ...props }) => ( + + ), }} > {content} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/MessageList/components/LastToolResponse/LastToolResponse.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/MessageList/components/LastToolResponse/LastToolResponse.tsx index 3e6bf91ad2..15b10e5715 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/MessageList/components/LastToolResponse/LastToolResponse.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/MessageList/components/LastToolResponse/LastToolResponse.tsx @@ -1,7 +1,5 @@ -import { AIChatBubble } from "../../../AIChatBubble/AIChatBubble"; import type { ChatMessageData } from "../../../ChatMessage/useChatMessage"; -import { MarkdownContent } from "../../../MarkdownContent/MarkdownContent"; -import { formatToolResponse } from "../../../ToolResponseMessage/helpers"; +import { ToolResponseMessage } from "../../../ToolResponseMessage/ToolResponseMessage"; import { shouldSkipAgentOutput } from "../../helpers"; export interface LastToolResponseProps { @@ -15,16 +13,15 @@ export function LastToolResponse({ }: LastToolResponseProps) { if (message.type !== "tool_response") return null; - // Skip if this is an agent_output that should be rendered inside assistant message if (shouldSkipAgentOutput(message, prevMessage)) return null; - const formattedText = formatToolResponse(message.result, message.toolName); - return (
- - - +
); } diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/PendingOperationWidget/PendingOperationWidget.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/PendingOperationWidget/PendingOperationWidget.tsx new file mode 100644 index 0000000000..6cfea7f327 --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/PendingOperationWidget/PendingOperationWidget.tsx @@ -0,0 +1,109 @@ +"use client"; + +import { Card } from "@/components/atoms/Card/Card"; +import { Text } from "@/components/atoms/Text/Text"; +import { cn } from "@/lib/utils"; +import { CircleNotch, CheckCircle, XCircle } from "@phosphor-icons/react"; + +type OperationStatus = + | "pending" + | "started" + | "in_progress" + | "completed" + | "error"; + +interface Props { + status: OperationStatus; + message: string; + toolName?: string; + className?: string; +} + +function getOperationTitle(toolName?: string): string { + if (!toolName) return "Operation"; + // Convert tool name to human-readable format + // e.g., "create_agent" -> "Creating Agent", "edit_agent" -> "Editing Agent" + if (toolName === "create_agent") return "Creating Agent"; + if (toolName === "edit_agent") return "Editing Agent"; + // Default: capitalize and format tool name + return toolName + .split("_") + .map((word) => word.charAt(0).toUpperCase() + word.slice(1)) + .join(" "); +} + +export function PendingOperationWidget({ + status, + message, + toolName, + className, +}: Props) { + const isPending = + status === "pending" || status === "started" || status === "in_progress"; + const isCompleted = status === "completed"; + const isError = status === "error"; + + const operationTitle = getOperationTitle(toolName); + + return ( +
+
+
+
+ {isPending && ( + + )} + {isCompleted && ( + + )} + {isError && ( + + )} +
+
+ +
+ +
+ + {isPending && operationTitle} + {isCompleted && `${operationTitle} Complete`} + {isError && `${operationTitle} Failed`} + + + {message} + +
+ + {isPending && ( + + Check your library in a few minutes. + + )} + + {toolName && ( + + Tool: {toolName} + + )} +
+
+
+
+ ); +} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ToolResponseMessage/ToolResponseMessage.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/ToolResponseMessage/ToolResponseMessage.tsx index 1ba10dd248..27da02beb8 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ToolResponseMessage/ToolResponseMessage.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ToolResponseMessage/ToolResponseMessage.tsx @@ -1,7 +1,14 @@ +import { Text } from "@/components/atoms/Text/Text"; +import { cn } from "@/lib/utils"; import type { ToolResult } from "@/types/chat"; +import { WarningCircleIcon } from "@phosphor-icons/react"; import { AIChatBubble } from "../AIChatBubble/AIChatBubble"; import { MarkdownContent } from "../MarkdownContent/MarkdownContent"; -import { formatToolResponse } from "./helpers"; +import { + formatToolResponse, + getErrorMessage, + isErrorResponse, +} from "./helpers"; export interface ToolResponseMessageProps { toolId?: string; @@ -18,6 +25,24 @@ export function ToolResponseMessage({ success: _success, className, }: ToolResponseMessageProps) { + if (isErrorResponse(result)) { + const errorMessage = getErrorMessage(result); + return ( + +
+ + + {errorMessage} + +
+
+ ); + } + const formattedText = formatToolResponse(result, toolName); return ( diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ToolResponseMessage/helpers.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ToolResponseMessage/helpers.ts index cf2bca95f7..e886e1a28c 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ToolResponseMessage/helpers.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ToolResponseMessage/helpers.ts @@ -1,3 +1,123 @@ +function stripInternalReasoning(content: string): string { + return content + .replace(/[\s\S]*?<\/internal_reasoning>/gi, "") + .replace(/[\s\S]*?<\/thinking>/gi, "") + .replace(/\n{3,}/g, "\n\n") + .trim(); +} + +export function isErrorResponse(result: unknown): boolean { + if (typeof result === "string") { + const lower = result.toLowerCase(); + return ( + lower.startsWith("error:") || + lower.includes("not found") || + lower.includes("does not exist") || + lower.includes("failed to") || + lower.includes("unable to") + ); + } + if (typeof result === "object" && result !== null) { + const response = result as Record; + return response.type === "error" || response.error !== undefined; + } + return false; +} + +export function getErrorMessage(result: unknown): string { + if (typeof result === "string") { + return stripInternalReasoning(result.replace(/^error:\s*/i, "")); + } + if (typeof result === "object" && result !== null) { + const response = result as Record; + if (response.error) return stripInternalReasoning(String(response.error)); + if (response.message) + return stripInternalReasoning(String(response.message)); + } + return "An error occurred"; +} + +/** + * Check if a value is a workspace file reference. + */ +function isWorkspaceRef(value: unknown): value is string { + return typeof value === "string" && value.startsWith("workspace://"); +} + +/** + * Check if a workspace reference appears to be an image based on common patterns. + * Since workspace refs don't have extensions, we check the context or assume image + * for certain block types. + * + * TODO: Replace keyword matching with MIME type encoded in workspace ref. + * e.g., workspace://abc123#image/png or workspace://abc123#video/mp4 + * This would let frontend render correctly without fragile keyword matching. + */ +function isLikelyImageRef(value: string, outputKey?: string): boolean { + if (!isWorkspaceRef(value)) return false; + + // Check output key name for video-related hints (these are NOT images) + const videoKeywords = ["video", "mp4", "mov", "avi", "webm", "movie", "clip"]; + if (outputKey) { + const lowerKey = outputKey.toLowerCase(); + if (videoKeywords.some((kw) => lowerKey.includes(kw))) { + return false; + } + } + + // Check output key name for image-related hints + const imageKeywords = [ + "image", + "img", + "photo", + "picture", + "thumbnail", + "avatar", + "icon", + "screenshot", + ]; + if (outputKey) { + const lowerKey = outputKey.toLowerCase(); + if (imageKeywords.some((kw) => lowerKey.includes(kw))) { + return true; + } + } + + // Default to treating workspace refs as potential images + // since that's the most common case for generated content + return true; +} + +/** + * Format a single output value, converting workspace refs to markdown images. + */ +function formatOutputValue(value: unknown, outputKey?: string): string { + if (isWorkspaceRef(value) && isLikelyImageRef(value, outputKey)) { + // Format as markdown image + return `![${outputKey || "Generated image"}](${value})`; + } + + if (typeof value === "string") { + // Check for data URIs (images) + if (value.startsWith("data:image/")) { + return `![${outputKey || "Generated image"}](${value})`; + } + return value; + } + + if (Array.isArray(value)) { + return value + .map((item, idx) => formatOutputValue(item, `${outputKey}_${idx}`)) + .join("\n\n"); + } + + if (typeof value === "object" && value !== null) { + return JSON.stringify(value, null, 2); + } + + return String(value); +} + function getToolCompletionPhrase(toolName: string): string { const toolCompletionPhrases: Record = { add_understanding: "Updated your business information", @@ -28,10 +148,10 @@ export function formatToolResponse(result: unknown, toolName: string): string { const parsed = JSON.parse(trimmed); return formatToolResponse(parsed, toolName); } catch { - return trimmed; + return stripInternalReasoning(trimmed); } } - return result; + return stripInternalReasoning(result); } if (typeof result !== "object" || result === null) { @@ -88,10 +208,26 @@ export function formatToolResponse(result: unknown, toolName: string): string { case "block_output": const blockName = (response.block_name as string) || "Block"; - const outputs = response.outputs as Record | undefined; + const outputs = response.outputs as Record | undefined; if (outputs && Object.keys(outputs).length > 0) { - const outputKeys = Object.keys(outputs); - return `${blockName} executed successfully. Outputs: ${outputKeys.join(", ")}`; + const formattedOutputs: string[] = []; + + for (const [key, values] of Object.entries(outputs)) { + if (!Array.isArray(values) || values.length === 0) continue; + + // Format each value in the output array + for (const value of values) { + const formatted = formatOutputValue(value, key); + if (formatted) { + formattedOutputs.push(formatted); + } + } + } + + if (formattedOutputs.length > 0) { + return `${blockName} executed successfully.\n\n${formattedOutputs.join("\n\n")}`; + } + return `${blockName} executed successfully.`; } return `${blockName} executed successfully.`; diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/UserChatBubble/UserChatBubble.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/UserChatBubble/UserChatBubble.tsx index 46459ff894..39a6cb36ad 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/UserChatBubble/UserChatBubble.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/UserChatBubble/UserChatBubble.tsx @@ -10,7 +10,7 @@ export function UserChatBubble({ children, className }: UserChatBubbleProps) { return (
{ + const { sessionId, abortController } = stream; + + try { + const url = `/api/chat/sessions/${sessionId}/stream`; + const body = JSON.stringify({ + message, + is_user_message: isUserMessage, + context: context || null, + }); + + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + }, + body, + signal: abortController.signal, + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(errorText || `HTTP ${response.status}`); + } + + if (!response.body) { + throw new Error("Response body is null"); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + while (true) { + const { done, value } = await reader.read(); + + if (done) { + notifySubscribers(stream, { type: "stream_end" }); + stream.status = "completed"; + return; + } + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() || ""; + + for (const line of lines) { + const data = parseSSELine(line); + if (data !== null) { + if (data === "[DONE]") { + notifySubscribers(stream, { type: "stream_end" }); + stream.status = "completed"; + return; + } + + try { + const rawChunk = JSON.parse(data) as + | StreamChunk + | VercelStreamChunk; + const chunk = normalizeStreamChunk(rawChunk); + if (!chunk) continue; + + notifySubscribers(stream, chunk); + + if (chunk.type === "stream_end") { + stream.status = "completed"; + return; + } + + if (chunk.type === "error") { + stream.status = "error"; + stream.error = new Error( + chunk.message || chunk.content || "Stream error", + ); + return; + } + } catch (err) { + console.warn("[StreamExecutor] Failed to parse SSE chunk:", err); + } + } + } + } + } catch (err) { + if (err instanceof Error && err.name === "AbortError") { + notifySubscribers(stream, { type: "stream_end" }); + stream.status = "completed"; + return; + } + + if (retryCount < MAX_RETRIES) { + const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount); + console.log( + `[StreamExecutor] Retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`, + ); + await new Promise((resolve) => setTimeout(resolve, retryDelay)); + return executeStream( + stream, + message, + isUserMessage, + context, + retryCount + 1, + ); + } + + stream.status = "error"; + stream.error = err instanceof Error ? err : new Error("Stream failed"); + notifySubscribers(stream, { + type: "error", + message: stream.error.message, + }); + } +} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/stream-utils.ts b/autogpt_platform/frontend/src/components/contextual/Chat/stream-utils.ts new file mode 100644 index 0000000000..4100926e79 --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/stream-utils.ts @@ -0,0 +1,84 @@ +import type { ToolArguments, ToolResult } from "@/types/chat"; +import type { StreamChunk, VercelStreamChunk } from "./chat-types"; + +const LEGACY_STREAM_TYPES = new Set([ + "text_chunk", + "text_ended", + "tool_call", + "tool_call_start", + "tool_response", + "login_needed", + "need_login", + "credentials_needed", + "error", + "usage", + "stream_end", +]); + +export function isLegacyStreamChunk( + chunk: StreamChunk | VercelStreamChunk, +): chunk is StreamChunk { + return LEGACY_STREAM_TYPES.has(chunk.type as StreamChunk["type"]); +} + +export function normalizeStreamChunk( + chunk: StreamChunk | VercelStreamChunk, +): StreamChunk | null { + if (isLegacyStreamChunk(chunk)) return chunk; + + switch (chunk.type) { + case "text-delta": + return { type: "text_chunk", content: chunk.delta }; + case "text-end": + return { type: "text_ended" }; + case "tool-input-available": + return { + type: "tool_call_start", + tool_id: chunk.toolCallId, + tool_name: chunk.toolName, + arguments: chunk.input as ToolArguments, + }; + case "tool-output-available": + return { + type: "tool_response", + tool_id: chunk.toolCallId, + tool_name: chunk.toolName, + result: chunk.output as ToolResult, + success: chunk.success ?? true, + }; + case "usage": + return { + type: "usage", + promptTokens: chunk.promptTokens, + completionTokens: chunk.completionTokens, + totalTokens: chunk.totalTokens, + }; + case "error": + return { + type: "error", + message: chunk.errorText, + code: chunk.code, + details: chunk.details, + }; + case "finish": + return { type: "stream_end" }; + case "start": + case "text-start": + return null; + case "tool-input-start": + return { + type: "tool_call_start", + tool_id: chunk.toolCallId, + tool_name: chunk.toolName, + arguments: {}, + }; + } +} + +export const MAX_RETRIES = 3; +export const INITIAL_RETRY_DELAY = 1000; + +export function parseSSELine(line: string): string | null { + if (line.startsWith("data: ")) return line.slice(6); + return null; +} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/useChat.ts b/autogpt_platform/frontend/src/components/contextual/Chat/useChat.ts index cf629a287c..124301abc4 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/useChat.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/useChat.ts @@ -2,7 +2,6 @@ import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { useEffect, useRef, useState } from "react"; -import { toast } from "sonner"; import { useChatSession } from "./useChatSession"; import { useChatStream } from "./useChatStream"; @@ -27,6 +26,7 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) { claimSession, clearSession: clearSessionBase, loadSession, + startPollingForOperation, } = useChatSession({ urlSessionId, autoCreate: false, @@ -67,38 +67,16 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) { ], ); - useEffect(() => { - if (isLoading || isCreating) { - const timer = setTimeout(() => { - setShowLoader(true); - }, 300); - return () => clearTimeout(timer); - } else { + useEffect( + function showLoaderWithDelay() { + if (isLoading || isCreating) { + const timer = setTimeout(() => setShowLoader(true), 300); + return () => clearTimeout(timer); + } setShowLoader(false); - } - }, [isLoading, isCreating]); - - useEffect(function monitorNetworkStatus() { - function handleOnline() { - toast.success("Connection restored", { - description: "You're back online", - }); - } - - function handleOffline() { - toast.error("You're offline", { - description: "Check your internet connection", - }); - } - - window.addEventListener("online", handleOnline); - window.addEventListener("offline", handleOffline); - - return () => { - window.removeEventListener("online", handleOnline); - window.removeEventListener("offline", handleOffline); - }; - }, []); + }, + [isLoading, isCreating], + ); function clearSession() { clearSessionBase(); @@ -117,5 +95,6 @@ export function useChat({ urlSessionId }: UseChatArgs = {}) { loadSession, sessionId: sessionIdFromHook, showLoader, + startPollingForOperation, }; } diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/useChatDrawer.ts b/autogpt_platform/frontend/src/components/contextual/Chat/useChatDrawer.ts deleted file mode 100644 index 62e1a5a569..0000000000 --- a/autogpt_platform/frontend/src/components/contextual/Chat/useChatDrawer.ts +++ /dev/null @@ -1,17 +0,0 @@ -"use client"; - -import { create } from "zustand"; - -interface ChatDrawerState { - isOpen: boolean; - open: () => void; - close: () => void; - toggle: () => void; -} - -export const useChatDrawer = create((set) => ({ - isOpen: false, - open: () => set({ isOpen: true }), - close: () => set({ isOpen: false }), - toggle: () => set((state) => ({ isOpen: !state.isOpen })), -})); diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/useChatSession.ts b/autogpt_platform/frontend/src/components/contextual/Chat/useChatSession.ts index 553e348f79..936a49936c 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/useChatSession.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/useChatSession.ts @@ -1,6 +1,7 @@ import { getGetV2GetSessionQueryKey, getGetV2GetSessionQueryOptions, + getGetV2ListSessionsQueryKey, postV2CreateSession, useGetV2GetSession, usePatchV2SessionAssignUser, @@ -58,6 +59,7 @@ export function useChatSession({ query: { enabled: !!sessionId, select: okData, + staleTime: 0, retry: shouldRetrySessionLoad, retryDelay: getSessionRetryDelay, }, @@ -101,6 +103,125 @@ export function useChatSession({ } }, [createError, loadError]); + // Track if we should be polling (set by external callers when they receive operation_started via SSE) + const [forcePolling, setForcePolling] = useState(false); + // Track if we've seen server acknowledge the pending operation (to avoid clearing forcePolling prematurely) + const hasSeenServerPendingRef = useRef(false); + + // Check if there are any pending operations in the messages + // Must check all operation types: operation_pending, operation_started, operation_in_progress + const hasPendingOperationsFromServer = useMemo(() => { + if (!messages || messages.length === 0) return false; + const pendingTypes = new Set([ + "operation_pending", + "operation_in_progress", + "operation_started", + ]); + return messages.some((msg) => { + if (msg.role !== "tool" || !msg.content) return false; + try { + const content = + typeof msg.content === "string" + ? JSON.parse(msg.content) + : msg.content; + return pendingTypes.has(content?.type); + } catch { + return false; + } + }); + }, [messages]); + + // Track when server has acknowledged the pending operation + useEffect(() => { + if (hasPendingOperationsFromServer) { + hasSeenServerPendingRef.current = true; + } + }, [hasPendingOperationsFromServer]); + + // Combined: poll if server has pending ops OR if we received operation_started via SSE + const hasPendingOperations = hasPendingOperationsFromServer || forcePolling; + + // Clear forcePolling only after server has acknowledged AND completed the operation + useEffect(() => { + if ( + forcePolling && + !hasPendingOperationsFromServer && + hasSeenServerPendingRef.current + ) { + // Server acknowledged the operation and it's now complete + setForcePolling(false); + hasSeenServerPendingRef.current = false; + } + }, [forcePolling, hasPendingOperationsFromServer]); + + // Function to trigger polling (called when operation_started is received via SSE) + function startPollingForOperation() { + setForcePolling(true); + hasSeenServerPendingRef.current = false; // Reset for new operation + } + + // Refresh sessions list when a pending operation completes + // (hasPendingOperations transitions from true to false) + const prevHasPendingOperationsRef = useRef(hasPendingOperations); + useEffect( + function refreshSessionsListOnOperationComplete() { + const wasHasPending = prevHasPendingOperationsRef.current; + prevHasPendingOperationsRef.current = hasPendingOperations; + + // Only invalidate when transitioning from pending to not pending + if (wasHasPending && !hasPendingOperations && sessionId) { + queryClient.invalidateQueries({ + queryKey: getGetV2ListSessionsQueryKey(), + }); + } + }, + [hasPendingOperations, sessionId, queryClient], + ); + + // Poll for updates when there are pending operations + // Backoff: 2s, 4s, 6s, 8s, 10s, ... up to 30s max + const pollAttemptRef = useRef(0); + const hasPendingOperationsRef = useRef(hasPendingOperations); + hasPendingOperationsRef.current = hasPendingOperations; + + useEffect( + function pollForPendingOperations() { + if (!sessionId || !hasPendingOperations) { + pollAttemptRef.current = 0; + return; + } + + let cancelled = false; + let timeoutId: ReturnType | null = null; + + function schedule() { + // 2s, 4s, 6s, 8s, 10s, ... 30s (max) + const delay = Math.min((pollAttemptRef.current + 1) * 2000, 30000); + timeoutId = setTimeout(async () => { + if (cancelled) return; + pollAttemptRef.current += 1; + try { + await refetch(); + } catch (err) { + console.error("[useChatSession] Poll failed:", err); + } finally { + if (!cancelled && hasPendingOperationsRef.current) { + schedule(); + } + } + }, delay); + } + + schedule(); + + return () => { + cancelled = true; + if (timeoutId) clearTimeout(timeoutId); + }; + }, + [sessionId, hasPendingOperations, refetch], + ); + async function createSession() { try { setError(null); @@ -227,11 +348,13 @@ export function useChatSession({ isCreating, error, isSessionNotFound: isNotFoundError(loadError), + hasPendingOperations, createSession, loadSession, refreshSession, claimSession, clearSession, + startPollingForOperation, }; } diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/useChatStream.ts b/autogpt_platform/frontend/src/components/contextual/Chat/useChatStream.ts index 903c19cd30..5a9f637457 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/useChatStream.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/useChatStream.ts @@ -1,543 +1,110 @@ -import type { ToolArguments, ToolResult } from "@/types/chat"; -import { useCallback, useEffect, useRef, useState } from "react"; +"use client"; + +import { useEffect, useRef, useState } from "react"; import { toast } from "sonner"; +import { useChatStore } from "./chat-store"; +import type { StreamChunk } from "./chat-types"; -const MAX_RETRIES = 3; -const INITIAL_RETRY_DELAY = 1000; - -export interface StreamChunk { - type: - | "text_chunk" - | "text_ended" - | "tool_call" - | "tool_call_start" - | "tool_response" - | "login_needed" - | "need_login" - | "credentials_needed" - | "error" - | "usage" - | "stream_end"; - timestamp?: string; - content?: string; - message?: string; - code?: string; - details?: Record; - tool_id?: string; - tool_name?: string; - arguments?: ToolArguments; - result?: ToolResult; - success?: boolean; - idx?: number; - session_id?: string; - agent_info?: { - graph_id: string; - name: string; - trigger_type: string; - }; - provider?: string; - provider_name?: string; - credential_type?: string; - scopes?: string[]; - title?: string; - [key: string]: unknown; -} - -type VercelStreamChunk = - | { type: "start"; messageId: string } - | { type: "finish" } - | { type: "text-start"; id: string } - | { type: "text-delta"; id: string; delta: string } - | { type: "text-end"; id: string } - | { type: "tool-input-start"; toolCallId: string; toolName: string } - | { - type: "tool-input-available"; - toolCallId: string; - toolName: string; - input: ToolArguments; - } - | { - type: "tool-output-available"; - toolCallId: string; - toolName?: string; - output: ToolResult; - success?: boolean; - } - | { - type: "usage"; - promptTokens: number; - completionTokens: number; - totalTokens: number; - } - | { - type: "error"; - errorText: string; - code?: string; - details?: Record; - }; - -const LEGACY_STREAM_TYPES = new Set([ - "text_chunk", - "text_ended", - "tool_call", - "tool_call_start", - "tool_response", - "login_needed", - "need_login", - "credentials_needed", - "error", - "usage", - "stream_end", -]); - -function isLegacyStreamChunk( - chunk: StreamChunk | VercelStreamChunk, -): chunk is StreamChunk { - return LEGACY_STREAM_TYPES.has(chunk.type as StreamChunk["type"]); -} - -function normalizeStreamChunk( - chunk: StreamChunk | VercelStreamChunk, -): StreamChunk | null { - if (isLegacyStreamChunk(chunk)) { - return chunk; - } - switch (chunk.type) { - case "text-delta": - return { type: "text_chunk", content: chunk.delta }; - case "text-end": - return { type: "text_ended" }; - case "tool-input-available": - return { - type: "tool_call_start", - tool_id: chunk.toolCallId, - tool_name: chunk.toolName, - arguments: chunk.input, - }; - case "tool-output-available": - return { - type: "tool_response", - tool_id: chunk.toolCallId, - tool_name: chunk.toolName, - result: chunk.output, - success: chunk.success ?? true, - }; - case "usage": - return { - type: "usage", - promptTokens: chunk.promptTokens, - completionTokens: chunk.completionTokens, - totalTokens: chunk.totalTokens, - }; - case "error": - return { - type: "error", - message: chunk.errorText, - code: chunk.code, - details: chunk.details, - }; - case "finish": - return { type: "stream_end" }; - case "start": - case "text-start": - return null; - case "tool-input-start": - const toolInputStart = chunk as Extract< - VercelStreamChunk, - { type: "tool-input-start" } - >; - return { - type: "tool_call_start", - tool_id: toolInputStart.toolCallId, - tool_name: toolInputStart.toolName, - arguments: {}, - }; - } -} +export type { StreamChunk } from "./chat-types"; export function useChatStream() { const [isStreaming, setIsStreaming] = useState(false); const [error, setError] = useState(null); - const retryCountRef = useRef(0); - const retryTimeoutRef = useRef(null); - const abortControllerRef = useRef(null); const currentSessionIdRef = useRef(null); - const requestStartTimeRef = useRef(null); - - const stopStreaming = useCallback( - (sessionId?: string, force: boolean = false) => { - console.log("[useChatStream] stopStreaming called", { - hasAbortController: !!abortControllerRef.current, - isAborted: abortControllerRef.current?.signal.aborted, - currentSessionId: currentSessionIdRef.current, - requestedSessionId: sessionId, - requestStartTime: requestStartTimeRef.current, - timeSinceStart: requestStartTimeRef.current - ? Date.now() - requestStartTimeRef.current - : null, - force, - stack: new Error().stack, - }); - - if ( - sessionId && - currentSessionIdRef.current && - currentSessionIdRef.current !== sessionId - ) { - console.log( - "[useChatStream] Session changed, aborting previous stream", - { - oldSessionId: currentSessionIdRef.current, - newSessionId: sessionId, - }, - ); - } - - const controller = abortControllerRef.current; - if (controller) { - const timeSinceStart = requestStartTimeRef.current - ? Date.now() - requestStartTimeRef.current - : null; - - if (!force && timeSinceStart !== null && timeSinceStart < 100) { - console.log( - "[useChatStream] Request just started (<100ms), skipping abort to prevent race condition", - { - timeSinceStart, - }, - ); - return; - } - - try { - const signal = controller.signal; - - if ( - signal && - typeof signal.aborted === "boolean" && - !signal.aborted - ) { - console.log("[useChatStream] Aborting stream"); - controller.abort(); - } else { - console.log( - "[useChatStream] Stream already aborted or signal invalid", - ); - } - } catch (error) { - if (error instanceof Error && error.name === "AbortError") { - console.log( - "[useChatStream] AbortError caught (expected during cleanup)", - ); - } else { - console.warn("[useChatStream] Error aborting stream:", error); - } - } finally { - abortControllerRef.current = null; - requestStartTimeRef.current = null; - } - } - if (retryTimeoutRef.current) { - clearTimeout(retryTimeoutRef.current); - retryTimeoutRef.current = null; - } - setIsStreaming(false); - }, - [], + const onChunkCallbackRef = useRef<((chunk: StreamChunk) => void) | null>( + null, ); + const stopStream = useChatStore((s) => s.stopStream); + const unregisterActiveSession = useChatStore( + (s) => s.unregisterActiveSession, + ); + const isSessionActive = useChatStore((s) => s.isSessionActive); + const onStreamComplete = useChatStore((s) => s.onStreamComplete); + const getCompletedStream = useChatStore((s) => s.getCompletedStream); + const registerActiveSession = useChatStore((s) => s.registerActiveSession); + const startStream = useChatStore((s) => s.startStream); + const getStreamStatus = useChatStore((s) => s.getStreamStatus); + + function stopStreaming(sessionId?: string) { + const targetSession = sessionId || currentSessionIdRef.current; + if (targetSession) { + stopStream(targetSession); + unregisterActiveSession(targetSession); + } + setIsStreaming(false); + } + useEffect(() => { - console.log("[useChatStream] Component mounted"); - return () => { - const sessionIdAtUnmount = currentSessionIdRef.current; - console.log( - "[useChatStream] Component unmounting, calling stopStreaming", - { - sessionIdAtUnmount, - }, - ); - stopStreaming(undefined, false); + return function cleanup() { + const sessionId = currentSessionIdRef.current; + if (sessionId && !isSessionActive(sessionId)) { + stopStream(sessionId); + } currentSessionIdRef.current = null; + onChunkCallbackRef.current = null; }; - }, [stopStreaming]); + }, []); - const sendMessage = useCallback( - async ( - sessionId: string, - message: string, - onChunk: (chunk: StreamChunk) => void, - isUserMessage: boolean = true, - context?: { url: string; content: string }, - isRetry: boolean = false, - ) => { - console.log("[useChatStream] sendMessage called", { - sessionId, - message: message.substring(0, 50), - isUserMessage, - isRetry, - stack: new Error().stack, - }); + useEffect(() => { + const unsubscribe = onStreamComplete( + function handleStreamComplete(completedSessionId) { + if (completedSessionId !== currentSessionIdRef.current) return; - const previousSessionId = currentSessionIdRef.current; - stopStreaming(sessionId, true); - currentSessionIdRef.current = sessionId; - - const abortController = new AbortController(); - abortControllerRef.current = abortController; - requestStartTimeRef.current = Date.now(); - console.log("[useChatStream] Created new AbortController", { - sessionId, - previousSessionId, - requestStartTime: requestStartTimeRef.current, - }); - - if (abortController.signal.aborted) { - console.warn( - "[useChatStream] AbortController was aborted before request started", - ); - requestStartTimeRef.current = null; - return Promise.reject(new Error("Request aborted")); - } - - if (!isRetry) { - retryCountRef.current = 0; - } - setIsStreaming(true); - setError(null); - - try { - const url = `/api/chat/sessions/${sessionId}/stream`; - const body = JSON.stringify({ - message, - is_user_message: isUserMessage, - context: context || null, - }); - - const response = await fetch(url, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "text/event-stream", - }, - body, - signal: abortController.signal, - }); - - console.info("[useChatStream] Stream response", { - sessionId, - status: response.status, - ok: response.ok, - contentType: response.headers.get("content-type"), - }); - - if (!response.ok) { - const errorText = await response.text(); - console.warn("[useChatStream] Stream response error", { - sessionId, - status: response.status, - errorText, - }); - throw new Error(errorText || `HTTP ${response.status}`); - } - - if (!response.body) { - console.warn("[useChatStream] Response body is null", { sessionId }); - throw new Error("Response body is null"); - } - - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let buffer = ""; - let receivedChunkCount = 0; - let firstChunkAt: number | null = null; - let loggedLineCount = 0; - - return new Promise((resolve, reject) => { - let didDispatchStreamEnd = false; - - function dispatchStreamEnd() { - if (didDispatchStreamEnd) return; - didDispatchStreamEnd = true; - onChunk({ type: "stream_end" }); - } - - const cleanup = () => { - reader.cancel().catch(() => { - // Ignore cancel errors - }); - }; - - async function readStream() { - try { - while (true) { - const { done, value } = await reader.read(); - - if (done) { - cleanup(); - console.info("[useChatStream] Stream closed", { - sessionId, - receivedChunkCount, - timeSinceStart: requestStartTimeRef.current - ? Date.now() - requestStartTimeRef.current - : null, - }); - dispatchStreamEnd(); - retryCountRef.current = 0; - stopStreaming(); - resolve(); - return; - } - - buffer += decoder.decode(value, { stream: true }); - const lines = buffer.split("\n"); - buffer = lines.pop() || ""; - - for (const line of lines) { - if (line.startsWith("data: ")) { - const data = line.slice(6); - if (loggedLineCount < 3) { - console.info("[useChatStream] Raw stream line", { - sessionId, - data: - data.length > 300 ? `${data.slice(0, 300)}...` : data, - }); - loggedLineCount += 1; - } - if (data === "[DONE]") { - cleanup(); - console.info("[useChatStream] Stream done marker", { - sessionId, - receivedChunkCount, - timeSinceStart: requestStartTimeRef.current - ? Date.now() - requestStartTimeRef.current - : null, - }); - dispatchStreamEnd(); - retryCountRef.current = 0; - stopStreaming(); - resolve(); - return; - } - - try { - const rawChunk = JSON.parse(data) as - | StreamChunk - | VercelStreamChunk; - const chunk = normalizeStreamChunk(rawChunk); - if (!chunk) { - continue; - } - - if (!firstChunkAt) { - firstChunkAt = Date.now(); - console.info("[useChatStream] First stream chunk", { - sessionId, - chunkType: chunk.type, - timeSinceStart: requestStartTimeRef.current - ? firstChunkAt - requestStartTimeRef.current - : null, - }); - } - receivedChunkCount += 1; - - // Call the chunk handler - onChunk(chunk); - - // Handle stream lifecycle - if (chunk.type === "stream_end") { - didDispatchStreamEnd = true; - cleanup(); - console.info("[useChatStream] Stream end chunk", { - sessionId, - receivedChunkCount, - timeSinceStart: requestStartTimeRef.current - ? Date.now() - requestStartTimeRef.current - : null, - }); - retryCountRef.current = 0; - stopStreaming(); - resolve(); - return; - } else if (chunk.type === "error") { - cleanup(); - reject( - new Error( - chunk.message || chunk.content || "Stream error", - ), - ); - return; - } - } catch (err) { - // Skip invalid JSON lines - console.warn("Failed to parse SSE chunk:", err, data); - } - } - } - } - } catch (err) { - if (err instanceof Error && err.name === "AbortError") { - cleanup(); - dispatchStreamEnd(); - stopStreaming(); - resolve(); - return; - } - - const streamError = - err instanceof Error ? err : new Error("Failed to read stream"); - - if (retryCountRef.current < MAX_RETRIES) { - retryCountRef.current += 1; - const retryDelay = - INITIAL_RETRY_DELAY * Math.pow(2, retryCountRef.current - 1); - - toast.info("Connection interrupted", { - description: `Retrying in ${retryDelay / 1000} seconds...`, - }); - - retryTimeoutRef.current = setTimeout(() => { - sendMessage( - sessionId, - message, - onChunk, - isUserMessage, - context, - true, - ).catch((_err) => { - // Retry failed - }); - }, retryDelay); - } else { - setError(streamError); - toast.error("Connection Failed", { - description: - "Unable to connect to chat service. Please try again.", - }); - cleanup(); - dispatchStreamEnd(); - retryCountRef.current = 0; - stopStreaming(); - reject(streamError); - } - } - } - - readStream(); - }); - } catch (err) { - if (err instanceof Error && err.name === "AbortError") { - setIsStreaming(false); - return Promise.resolve(); - } - const streamError = - err instanceof Error ? err : new Error("Failed to start stream"); - setError(streamError); setIsStreaming(false); - throw streamError; + const completed = getCompletedStream(completedSessionId); + if (completed?.error) { + setError(completed.error); + } + unregisterActiveSession(completedSessionId); + }, + ); + + return unsubscribe; + }, []); + + async function sendMessage( + sessionId: string, + message: string, + onChunk: (chunk: StreamChunk) => void, + isUserMessage: boolean = true, + context?: { url: string; content: string }, + ) { + const previousSessionId = currentSessionIdRef.current; + if (previousSessionId && previousSessionId !== sessionId) { + stopStreaming(previousSessionId); + } + + currentSessionIdRef.current = sessionId; + onChunkCallbackRef.current = onChunk; + setIsStreaming(true); + setError(null); + + registerActiveSession(sessionId); + + try { + await startStream(sessionId, message, isUserMessage, context, onChunk); + + const status = getStreamStatus(sessionId); + if (status === "error") { + const completed = getCompletedStream(sessionId); + if (completed?.error) { + setError(completed.error); + toast.error("Connection Failed", { + description: "Unable to connect to chat service. Please try again.", + }); + throw completed.error; + } } - }, - [stopStreaming], - ); + } catch (err) { + const streamError = + err instanceof Error ? err : new Error("Failed to start stream"); + setError(streamError); + throw streamError; + } finally { + setIsStreaming(false); + } + } return { isStreaming, diff --git a/autogpt_platform/frontend/src/components/layout/Navbar/components/Wallet/Wallet.tsx b/autogpt_platform/frontend/src/components/layout/Navbar/components/Wallet/Wallet.tsx index 0a3c7de6c8..4a25c84f92 100644 --- a/autogpt_platform/frontend/src/components/layout/Navbar/components/Wallet/Wallet.tsx +++ b/autogpt_platform/frontend/src/components/layout/Navbar/components/Wallet/Wallet.tsx @@ -255,13 +255,18 @@ export function Wallet() { (notification: WebSocketNotification) => { if ( notification.type !== "onboarding" || - notification.event !== "step_completed" || - !walletRef.current + notification.event !== "step_completed" ) { return; } - // Only trigger confetti for tasks that are in groups + // Always refresh credits when any onboarding step completes + fetchCredits(); + + // Only trigger confetti for tasks that are in displayed groups + if (!walletRef.current) { + return; + } const taskIds = groups .flatMap((group) => group.tasks) .map((task) => task.id); @@ -274,7 +279,6 @@ export function Wallet() { return; } - fetchCredits(); party.confetti(walletRef.current, { count: 30, spread: 120, @@ -284,7 +288,7 @@ export function Wallet() { modules: [fadeOut], }); }, - [fetchCredits, fadeOut], + [fetchCredits, fadeOut, groups], ); // WebSocket setup for onboarding notifications diff --git a/autogpt_platform/frontend/src/components/organisms/FloatingReviewsPanel/FloatingReviewsPanel.tsx b/autogpt_platform/frontend/src/components/organisms/FloatingReviewsPanel/FloatingReviewsPanel.tsx index 2b04c0ed9a..4805508054 100644 --- a/autogpt_platform/frontend/src/components/organisms/FloatingReviewsPanel/FloatingReviewsPanel.tsx +++ b/autogpt_platform/frontend/src/components/organisms/FloatingReviewsPanel/FloatingReviewsPanel.tsx @@ -31,6 +31,29 @@ export function FloatingReviewsPanel({ query: { enabled: !!(graphId && executionId), select: okData, + // Poll while execution is in progress to detect status changes + refetchInterval: (q) => { + // Note: refetchInterval callback receives raw data before select transform + const rawData = q.state.data as + | { status: number; data?: { status?: string } } + | undefined; + if (rawData?.status !== 200) return false; + + const status = rawData?.data?.status; + if (!status) return false; + + // Poll every 2 seconds while running or in review + if ( + status === AgentExecutionStatus.RUNNING || + status === AgentExecutionStatus.QUEUED || + status === AgentExecutionStatus.INCOMPLETE || + status === AgentExecutionStatus.REVIEW + ) { + return 2000; + } + return false; + }, + refetchIntervalInBackground: true, }, }, ); @@ -40,28 +63,47 @@ export function FloatingReviewsPanel({ useShallow((state) => state.graphExecutionStatus), ); + // Determine if we should poll for pending reviews + const isInReviewStatus = + executionDetails?.status === AgentExecutionStatus.REVIEW || + graphExecutionStatus === AgentExecutionStatus.REVIEW; + const { pendingReviews, isLoading, refetch } = usePendingReviewsForExecution( executionId || "", + { + enabled: !!executionId, + // Poll every 2 seconds when in REVIEW status to catch new reviews + refetchInterval: isInReviewStatus ? 2000 : false, + }, ); + // Refetch pending reviews when execution status changes useEffect(() => { - if (executionId) { + if (executionId && executionDetails?.status) { refetch(); } }, [executionDetails?.status, executionId, refetch]); - // Refetch when graph execution status changes to REVIEW - useEffect(() => { - if (graphExecutionStatus === AgentExecutionStatus.REVIEW && executionId) { - refetch(); - } - }, [graphExecutionStatus, executionId, refetch]); + // Hide panel if: + // 1. No execution ID + // 2. No pending reviews and not in REVIEW status + // 3. Execution is RUNNING or QUEUED (hasn't paused for review yet) + if (!executionId) { + return null; + } if ( - !executionId || - (!isLoading && - pendingReviews.length === 0 && - executionDetails?.status !== AgentExecutionStatus.REVIEW) + !isLoading && + pendingReviews.length === 0 && + executionDetails?.status !== AgentExecutionStatus.REVIEW + ) { + return null; + } + + // Don't show panel while execution is still running/queued (not paused for review) + if ( + executionDetails?.status === AgentExecutionStatus.RUNNING || + executionDetails?.status === AgentExecutionStatus.QUEUED ) { return null; } diff --git a/autogpt_platform/frontend/src/components/organisms/PendingReviewCard/PendingReviewCard.tsx b/autogpt_platform/frontend/src/components/organisms/PendingReviewCard/PendingReviewCard.tsx index 3ac636060c..bd456ce771 100644 --- a/autogpt_platform/frontend/src/components/organisms/PendingReviewCard/PendingReviewCard.tsx +++ b/autogpt_platform/frontend/src/components/organisms/PendingReviewCard/PendingReviewCard.tsx @@ -1,10 +1,8 @@ import { PendingHumanReviewModel } from "@/app/api/__generated__/models/pendingHumanReviewModel"; import { Text } from "@/components/atoms/Text/Text"; -import { Button } from "@/components/atoms/Button/Button"; import { Input } from "@/components/atoms/Input/Input"; import { Switch } from "@/components/atoms/Switch/Switch"; -import { TrashIcon, EyeSlashIcon } from "@phosphor-icons/react"; -import { useState } from "react"; +import { useEffect, useState } from "react"; interface StructuredReviewPayload { data: unknown; @@ -40,37 +38,49 @@ function extractReviewData(payload: unknown): { interface PendingReviewCardProps { review: PendingHumanReviewModel; onReviewDataChange: (nodeExecId: string, data: string) => void; - reviewMessage?: string; - onReviewMessageChange?: (nodeExecId: string, message: string) => void; - isDisabled?: boolean; - onToggleDisabled?: (nodeExecId: string) => void; + autoApproveFuture?: boolean; + onAutoApproveFutureChange?: (nodeExecId: string, enabled: boolean) => void; + externalDataValue?: string; + showAutoApprove?: boolean; + nodeId?: string; } export function PendingReviewCard({ review, onReviewDataChange, - reviewMessage = "", - onReviewMessageChange, - isDisabled = false, - onToggleDisabled, + autoApproveFuture = false, + onAutoApproveFutureChange, + externalDataValue, + showAutoApprove = true, + nodeId, }: PendingReviewCardProps) { const extractedData = extractReviewData(review.payload); const isDataEditable = review.editable; - const instructions = extractedData.instructions || review.instructions; + + let instructions = review.instructions; + + const isHITLBlock = instructions && !instructions.includes("Block"); + + if (instructions && !isHITLBlock) { + instructions = undefined; + } + const [currentData, setCurrentData] = useState(extractedData.data); + useEffect(() => { + if (externalDataValue !== undefined) { + try { + const parsedData = JSON.parse(externalDataValue); + setCurrentData(parsedData); + } catch {} + } + }, [externalDataValue]); + const handleDataChange = (newValue: unknown) => { setCurrentData(newValue); onReviewDataChange(review.node_exec_id, JSON.stringify(newValue, null, 2)); }; - const handleMessageChange = (newMessage: string) => { - onReviewMessageChange?.(review.node_exec_id, newMessage); - }; - - // Show simplified view when no toggle functionality is provided (Screenshot 1 mode) - const showSimplified = !onToggleDisabled; - const renderDataInput = () => { const data = currentData; @@ -137,97 +147,59 @@ export function PendingReviewCard({ } }; - // Helper function to get proper field label - const getFieldLabel = (instructions?: string) => { - if (instructions) - return instructions.charAt(0).toUpperCase() + instructions.slice(1); - return "Data to Review"; + const getShortenedNodeId = (id: string) => { + if (id.length <= 8) return id; + return `${id.slice(0, 4)}...${id.slice(-4)}`; }; - // Use the existing HITL review interface return (
- {!showSimplified && ( -
-
- {isDisabled && ( - - This item will be rejected - - )} + {nodeId && ( + + Node #{getShortenedNodeId(nodeId)} + + )} + +
+ {instructions && ( + + {instructions} + + )} + + {isDataEditable && !autoApproveFuture ? ( + renderDataInput() + ) : ( +
+ + {JSON.stringify(currentData, null, 2)} +
- -
- )} + )} +
- {/* Show instructions as field label */} - {instructions && ( -
- - {getFieldLabel(instructions)} - - {isDataEditable && !isDisabled ? ( - renderDataInput() - ) : ( -
- - {JSON.stringify(currentData, null, 2)} - -
+ {/* Auto-approve toggle for this review */} + {showAutoApprove && onAutoApproveFutureChange && ( +
+
+ + onAutoApproveFutureChange(review.node_exec_id, enabled) + } + /> + + Auto-approve future executions of this block + +
+ {autoApproveFuture && ( + + Original data will be used for this and all future reviews from + this block. + )}
)} - - {/* If no instructions, show data directly */} - {!instructions && ( -
- - Data to Review - {!isDataEditable && ( - - (Read-only) - - )} - - {isDataEditable && !isDisabled ? ( - renderDataInput() - ) : ( -
- - {JSON.stringify(currentData, null, 2)} - -
- )} -
- )} - - {!showSimplified && isDisabled && ( -
- - Rejection Reason (Optional): - - handleMessageChange(e.target.value)} - placeholder="Add any notes about why you're rejecting this..." - /> -
- )}
); } diff --git a/autogpt_platform/frontend/src/components/organisms/PendingReviewsList/PendingReviewsList.tsx b/autogpt_platform/frontend/src/components/organisms/PendingReviewsList/PendingReviewsList.tsx index 3253b0ee6d..5adb3919b6 100644 --- a/autogpt_platform/frontend/src/components/organisms/PendingReviewsList/PendingReviewsList.tsx +++ b/autogpt_platform/frontend/src/components/organisms/PendingReviewsList/PendingReviewsList.tsx @@ -1,10 +1,16 @@ -import { useState } from "react"; +import { useMemo, useState } from "react"; import { PendingHumanReviewModel } from "@/app/api/__generated__/models/pendingHumanReviewModel"; import { PendingReviewCard } from "@/components/organisms/PendingReviewCard/PendingReviewCard"; import { Text } from "@/components/atoms/Text/Text"; import { Button } from "@/components/atoms/Button/Button"; +import { Switch } from "@/components/atoms/Switch/Switch"; import { useToast } from "@/components/molecules/Toast/use-toast"; -import { ClockIcon, WarningIcon } from "@phosphor-icons/react"; +import { + ClockIcon, + WarningIcon, + CaretDownIcon, + CaretRightIcon, +} from "@phosphor-icons/react"; import { usePostV2ProcessReviewAction } from "@/app/api/__generated__/endpoints/executions/executions"; interface PendingReviewsListProps { @@ -32,16 +38,34 @@ export function PendingReviewsList({ }, ); - const [reviewMessageMap, setReviewMessageMap] = useState< - Record - >({}); - const [pendingAction, setPendingAction] = useState< "approve" | "reject" | null >(null); + const [autoApproveFutureMap, setAutoApproveFutureMap] = useState< + Record + >({}); + + const [collapsedGroups, setCollapsedGroups] = useState< + Record + >({}); + const { toast } = useToast(); + const groupedReviews = useMemo(() => { + return reviews.reduce( + (acc, review) => { + const nodeId = review.node_id || "unknown"; + if (!acc[nodeId]) { + acc[nodeId] = []; + } + acc[nodeId].push(review); + return acc; + }, + {} as Record, + ); + }, [reviews]); + const reviewActionMutation = usePostV2ProcessReviewAction({ mutation: { onSuccess: (res) => { @@ -88,8 +112,33 @@ export function PendingReviewsList({ setReviewDataMap((prev) => ({ ...prev, [nodeExecId]: data })); } - function handleReviewMessageChange(nodeExecId: string, message: string) { - setReviewMessageMap((prev) => ({ ...prev, [nodeExecId]: message })); + function handleAutoApproveFutureToggle(nodeId: string, enabled: boolean) { + setAutoApproveFutureMap((prev) => ({ + ...prev, + [nodeId]: enabled, + })); + + if (enabled) { + const nodeReviews = groupedReviews[nodeId] || []; + setReviewDataMap((prev) => { + const updated = { ...prev }; + nodeReviews.forEach((review) => { + updated[review.node_exec_id] = JSON.stringify( + review.payload, + null, + 2, + ); + }); + return updated; + }); + } + } + + function toggleGroupCollapse(nodeId: string) { + setCollapsedGroups((prev) => ({ + ...prev, + [nodeId]: !prev[nodeId], + })); } function processReviews(approved: boolean) { @@ -107,22 +156,25 @@ export function PendingReviewsList({ for (const review of reviews) { const reviewData = reviewDataMap[review.node_exec_id]; - const reviewMessage = reviewMessageMap[review.node_exec_id]; + const autoApproveThisNode = autoApproveFutureMap[review.node_id || ""]; - let parsedData: any = review.payload; // Default to original payload + let parsedData: any = undefined; - // Parse edited data if available and editable - if (review.editable && reviewData) { - try { - parsedData = JSON.parse(reviewData); - } catch (error) { - toast({ - title: "Invalid JSON", - description: `Please fix the JSON format in review for node ${review.node_exec_id}: ${error instanceof Error ? error.message : "Invalid syntax"}`, - variant: "destructive", - }); - setPendingAction(null); - return; + if (!autoApproveThisNode) { + if (review.editable && reviewData) { + try { + parsedData = JSON.parse(reviewData); + } catch (error) { + toast({ + title: "Invalid JSON", + description: `Please fix the JSON format in review for node ${review.node_exec_id}: ${error instanceof Error ? error.message : "Invalid syntax"}`, + variant: "destructive", + }); + setPendingAction(null); + return; + } + } else { + parsedData = review.payload; } } @@ -130,7 +182,7 @@ export function PendingReviewsList({ node_exec_id: review.node_exec_id, approved, reviewed_data: parsedData, - message: reviewMessage || undefined, + auto_approve_future: autoApproveThisNode && approved, }); } @@ -158,7 +210,6 @@ export function PendingReviewsList({ return (
- {/* Warning Box Header */}
- {reviews.map((review) => ( - - ))} + {Object.entries(groupedReviews).map(([nodeId, nodeReviews]) => { + const isCollapsed = collapsedGroups[nodeId] ?? nodeReviews.length > 1; + const reviewCount = nodeReviews.length; + + const firstReview = nodeReviews[0]; + const blockName = firstReview?.instructions; + const reviewTitle = `Review required for ${blockName}`; + + const getShortenedNodeId = (id: string) => { + if (id.length <= 8) return id; + return `${id.slice(0, 4)}...${id.slice(-4)}`; + }; + + return ( +
+ + + {!isCollapsed && ( +
+ {nodeReviews.map((review) => ( + + ))} + +
+ + handleAutoApproveFutureToggle(nodeId, enabled) + } + /> + + Auto-approve future executions of this node + +
+
+ )} +
+ ); + })}
-
- - Note: Changes you make here apply only to this task - - -
+
+
+ + + You can turn auto-approval on or off using the toggle above for each + node. +
); diff --git a/autogpt_platform/frontend/src/hooks/usePendingReviews.ts b/autogpt_platform/frontend/src/hooks/usePendingReviews.ts index 8257814fcf..b9d7d711a1 100644 --- a/autogpt_platform/frontend/src/hooks/usePendingReviews.ts +++ b/autogpt_platform/frontend/src/hooks/usePendingReviews.ts @@ -15,8 +15,22 @@ export function usePendingReviews() { }; } -export function usePendingReviewsForExecution(graphExecId: string) { - const query = useGetV2GetPendingReviewsForExecution(graphExecId); +interface UsePendingReviewsForExecutionOptions { + enabled?: boolean; + refetchInterval?: number | false; +} + +export function usePendingReviewsForExecution( + graphExecId: string, + options?: UsePendingReviewsForExecutionOptions, +) { + const query = useGetV2GetPendingReviewsForExecution(graphExecId, { + query: { + enabled: options?.enabled ?? !!graphExecId, + refetchInterval: options?.refetchInterval, + refetchIntervalInBackground: !!options?.refetchInterval, + }, + }); return { pendingReviews: okData(query.data) || [], diff --git a/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts b/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts index 82c03bc9f1..74855f5e28 100644 --- a/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts +++ b/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts @@ -516,7 +516,7 @@ export type GraphValidationErrorResponse = { /* *** LIBRARY *** */ -/* Mirror of backend/server/v2/library/model.py:LibraryAgent */ +/* Mirror of backend/api/features/library/model.py:LibraryAgent */ export type LibraryAgent = { id: LibraryAgentID; graph_id: GraphID; @@ -616,7 +616,7 @@ export enum LibraryAgentSortEnum { /* *** CREDENTIALS *** */ -/* Mirror of backend/server/integrations/router.py:CredentialsMetaResponse */ +/* Mirror of backend/api/features/integrations/router.py:CredentialsMetaResponse */ export type CredentialsMetaResponse = { id: string; provider: CredentialsProviderName; @@ -628,13 +628,13 @@ export type CredentialsMetaResponse = { is_system?: boolean; }; -/* Mirror of backend/server/integrations/router.py:CredentialsDeletionResponse */ +/* Mirror of backend/api/features/integrations/router.py:CredentialsDeletionResponse */ export type CredentialsDeleteResponse = { deleted: true; revoked: boolean | null; }; -/* Mirror of backend/server/integrations/router.py:CredentialsDeletionNeedsConfirmationResponse */ +/* Mirror of backend/api/features/integrations/router.py:CredentialsDeletionNeedsConfirmationResponse */ export type CredentialsDeleteNeedConfirmationResponse = { deleted: false; need_confirmation: true; @@ -888,7 +888,7 @@ export type Schedule = { export type ScheduleID = Brand; -/* Mirror of backend/server/routers/v1.py:ScheduleCreationRequest */ +/* Mirror of backend/api/features/v1.py:ScheduleCreationRequest */ export type ScheduleCreatable = { graph_id: GraphID; graph_version: number; @@ -1003,6 +1003,7 @@ export type OnboardingStep = | "AGENT_INPUT" | "CONGRATS" // First Wins + | "VISIT_COPILOT" | "GET_RESULTS" | "MARKETPLACE_VISIT" | "MARKETPLACE_ADD_AGENT" diff --git a/autogpt_platform/frontend/src/providers/posthog/posthog-provider.tsx b/autogpt_platform/frontend/src/providers/posthog/posthog-provider.tsx new file mode 100644 index 0000000000..674f6c55eb --- /dev/null +++ b/autogpt_platform/frontend/src/providers/posthog/posthog-provider.tsx @@ -0,0 +1,72 @@ +"use client"; + +import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; +import { environment } from "@/services/environment"; +import { PostHogProvider as PHProvider } from "@posthog/react"; +import { usePathname, useSearchParams } from "next/navigation"; +import posthog from "posthog-js"; +import { ReactNode, useEffect, useRef } from "react"; + +export function PostHogProvider({ children }: { children: ReactNode }) { + const isPostHogEnabled = environment.isPostHogEnabled(); + const postHogCredentials = environment.getPostHogCredentials(); + + useEffect(() => { + if (postHogCredentials.key) { + posthog.init(postHogCredentials.key, { + api_host: postHogCredentials.host, + defaults: "2025-11-30", + capture_pageview: false, + capture_pageleave: true, + autocapture: true, + }); + } + }, []); + + if (!isPostHogEnabled) return <>{children}; + + return {children}; +} + +export function PostHogUserTracker() { + const { user, isUserLoading } = useSupabase(); + const previousUserIdRef = useRef(null); + const isPostHogEnabled = environment.isPostHogEnabled(); + + useEffect(() => { + if (isUserLoading || !isPostHogEnabled) return; + + if (user) { + if (previousUserIdRef.current !== user.id) { + posthog.identify(user.id, { + email: user.email, + ...(user.user_metadata?.name && { name: user.user_metadata.name }), + }); + previousUserIdRef.current = user.id; + } + } else if (previousUserIdRef.current !== null) { + posthog.reset(); + previousUserIdRef.current = null; + } + }, [user, isUserLoading, isPostHogEnabled]); + + return null; +} + +export function PostHogPageViewTracker() { + const pathname = usePathname(); + const searchParams = useSearchParams(); + const isPostHogEnabled = environment.isPostHogEnabled(); + + useEffect(() => { + if (pathname && isPostHogEnabled) { + let url = window.origin + pathname; + if (searchParams && searchParams.toString()) { + url = url + `?${searchParams.toString()}`; + } + posthog.capture("$pageview", { $current_url: url }); + } + }, [pathname, searchParams, isPostHogEnabled]); + + return null; +} diff --git a/autogpt_platform/frontend/src/services/environment/index.ts b/autogpt_platform/frontend/src/services/environment/index.ts index cdd5b421b5..f19bc417e3 100644 --- a/autogpt_platform/frontend/src/services/environment/index.ts +++ b/autogpt_platform/frontend/src/services/environment/index.ts @@ -76,6 +76,13 @@ function getPreviewStealingDev() { return branch; } +function getPostHogCredentials() { + return { + key: process.env.NEXT_PUBLIC_POSTHOG_KEY, + host: process.env.NEXT_PUBLIC_POSTHOG_HOST, + }; +} + function isProductionBuild() { return process.env.NODE_ENV === "production"; } @@ -116,6 +123,13 @@ function areFeatureFlagsEnabled() { return process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "enabled"; } +function isPostHogEnabled() { + const inCloud = isCloud(); + const key = process.env.NEXT_PUBLIC_POSTHOG_KEY; + const host = process.env.NEXT_PUBLIC_POSTHOG_HOST; + return inCloud && key && host; +} + export const environment = { // Generic getEnvironmentStr, @@ -128,6 +142,7 @@ export const environment = { getSupabaseUrl, getSupabaseAnonKey, getPreviewStealingDev, + getPostHogCredentials, // Assertions isServerSide, isClientSide, @@ -138,5 +153,6 @@ export const environment = { isCloud, isLocal, isVercelPreview, + isPostHogEnabled, areFeatureFlagsEnabled, }; diff --git a/autogpt_platform/frontend/src/services/network-status/NetworkStatusMonitor.tsx b/autogpt_platform/frontend/src/services/network-status/NetworkStatusMonitor.tsx new file mode 100644 index 0000000000..7552bbf78c --- /dev/null +++ b/autogpt_platform/frontend/src/services/network-status/NetworkStatusMonitor.tsx @@ -0,0 +1,8 @@ +"use client"; + +import { useNetworkStatus } from "./useNetworkStatus"; + +export function NetworkStatusMonitor() { + useNetworkStatus(); + return null; +} diff --git a/autogpt_platform/frontend/src/services/network-status/useNetworkStatus.ts b/autogpt_platform/frontend/src/services/network-status/useNetworkStatus.ts new file mode 100644 index 0000000000..472a6e0e90 --- /dev/null +++ b/autogpt_platform/frontend/src/services/network-status/useNetworkStatus.ts @@ -0,0 +1,28 @@ +"use client"; + +import { useEffect } from "react"; +import { toast } from "sonner"; + +export function useNetworkStatus() { + useEffect(function monitorNetworkStatus() { + function handleOnline() { + toast.success("Connection restored", { + description: "You're back online", + }); + } + + function handleOffline() { + toast.error("You're offline", { + description: "Check your internet connection", + }); + } + + window.addEventListener("online", handleOnline); + window.addEventListener("offline", handleOffline); + + return function cleanup() { + window.removeEventListener("online", handleOnline); + window.removeEventListener("offline", handleOffline); + }; + }, []); +} diff --git a/autogpt_platform/frontend/src/services/storage/local-storage.ts b/autogpt_platform/frontend/src/services/storage/local-storage.ts index 494ddc3ccc..a1aa63741a 100644 --- a/autogpt_platform/frontend/src/services/storage/local-storage.ts +++ b/autogpt_platform/frontend/src/services/storage/local-storage.ts @@ -10,6 +10,7 @@ export enum Key { LIBRARY_AGENTS_CACHE = "library-agents-cache", CHAT_SESSION_ID = "chat_session_id", COOKIE_CONSENT = "autogpt_cookie_consent", + AI_AGENT_SAFETY_POPUP_SHOWN = "ai-agent-safety-popup-shown", } function get(key: Key) { diff --git a/autogpt_platform/frontend/src/services/storage/session-storage.ts b/autogpt_platform/frontend/src/services/storage/session-storage.ts index 8404da571c..1be82c98fb 100644 --- a/autogpt_platform/frontend/src/services/storage/session-storage.ts +++ b/autogpt_platform/frontend/src/services/storage/session-storage.ts @@ -3,6 +3,7 @@ import { environment } from "../environment"; export enum SessionKey { CHAT_SENT_INITIAL_PROMPTS = "chat_sent_initial_prompts", + CHAT_INITIAL_PROMPTS = "chat_initial_prompts", } function get(key: SessionKey) { diff --git a/autogpt_platform/frontend/src/tests/pages/login.page.ts b/autogpt_platform/frontend/src/tests/pages/login.page.ts index 9082cc6219..adcb8d908b 100644 --- a/autogpt_platform/frontend/src/tests/pages/login.page.ts +++ b/autogpt_platform/frontend/src/tests/pages/login.page.ts @@ -37,9 +37,13 @@ export class LoginPage { this.page.on("load", (page) => console.log(`ℹ️ Now at URL: ${page.url()}`)); // Start waiting for navigation before clicking + // Wait for redirect to marketplace, onboarding, library, or copilot (new landing pages) const leaveLoginPage = this.page .waitForURL( - (url) => /^\/(marketplace|onboarding(\/.*)?)?$/.test(url.pathname), + (url: URL) => + /^\/(marketplace|onboarding(\/.*)?|library|copilot)?$/.test( + url.pathname, + ), { timeout: 10_000 }, ) .catch((reason) => { diff --git a/autogpt_platform/frontend/src/tests/utils/signup.ts b/autogpt_platform/frontend/src/tests/utils/signup.ts index 7c8fdbe01b..192a9129b9 100644 --- a/autogpt_platform/frontend/src/tests/utils/signup.ts +++ b/autogpt_platform/frontend/src/tests/utils/signup.ts @@ -36,14 +36,16 @@ export async function signupTestUser( const signupButton = getButton("Sign up"); await signupButton.click(); - // Wait for successful signup - could redirect to onboarding or marketplace + // Wait for successful signup - could redirect to various pages depending on onboarding state try { - // Wait for either onboarding or marketplace redirect - await Promise.race([ - page.waitForURL(/\/onboarding/, { timeout: 15000 }), - page.waitForURL(/\/marketplace/, { timeout: 15000 }), - ]); + // Wait for redirect to onboarding, marketplace, copilot, or library + // Use a single waitForURL with a callback to avoid Promise.race race conditions + await page.waitForURL( + (url: URL) => + /\/(onboarding|marketplace|copilot|library)/.test(url.pathname), + { timeout: 15000 }, + ); } catch (error) { console.error( "❌ Timeout waiting for redirect, current URL:", @@ -54,14 +56,19 @@ export async function signupTestUser( const currentUrl = page.url(); - // Handle onboarding or marketplace redirect + // Handle onboarding redirect if needed if (currentUrl.includes("/onboarding") && ignoreOnboarding) { await page.goto("http://localhost:3000/marketplace"); await page.waitForLoadState("domcontentloaded", { timeout: 10000 }); } - // Verify we're on the expected final page - if (ignoreOnboarding || currentUrl.includes("/marketplace")) { + // Verify we're on an expected final page and user is authenticated + if (currentUrl.includes("/copilot") || currentUrl.includes("/library")) { + // For copilot/library landing pages, just verify user is authenticated + await page + .getByTestId("profile-popout-menu-trigger") + .waitFor({ state: "visible", timeout: 10000 }); + } else if (ignoreOnboarding || currentUrl.includes("/marketplace")) { // Verify we're on marketplace await page .getByText( diff --git a/docs/integrations/README.md b/docs/integrations/README.md index 023e4cbb45..7c0d0f474a 100644 --- a/docs/integrations/README.md +++ b/docs/integrations/README.md @@ -53,7 +53,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim | [Block Installation](block-integrations/basic.md#block-installation) | Given a code string, this block allows the verification and installation of a block code into the system | | [Concatenate Lists](block-integrations/basic.md#concatenate-lists) | Concatenates multiple lists into a single list | | [Dictionary Is Empty](block-integrations/basic.md#dictionary-is-empty) | Checks if a dictionary is empty | -| [File Store](block-integrations/basic.md#file-store) | Stores the input file in the temporary directory | +| [File Store](block-integrations/basic.md#file-store) | Downloads and stores a file from a URL, data URI, or local path | | [Find In Dictionary](block-integrations/basic.md#find-in-dictionary) | A block that looks up a value in a dictionary, list, or object by key or index and returns the corresponding value | | [Find In List](block-integrations/basic.md#find-in-list) | Finds the index of the value in the list | | [Get All Memories](block-integrations/basic.md#get-all-memories) | Retrieve all memories from Mem0 with optional conversation filtering | diff --git a/docs/integrations/block-integrations/basic.md b/docs/integrations/block-integrations/basic.md index f92d19002f..5a73fd5a03 100644 --- a/docs/integrations/block-integrations/basic.md +++ b/docs/integrations/block-integrations/basic.md @@ -709,7 +709,7 @@ This is useful for conditional logic where you need to verify if data was return ## File Store ### What it is -Stores the input file in the temporary directory. +Downloads and stores a file from a URL, data URI, or local path. Use this to fetch images, documents, or other files for processing. In CoPilot: saves to workspace (use list_workspace_files to see it). In graphs: outputs a data URI to pass to other blocks. ### How it works @@ -722,15 +722,15 @@ The block outputs a file path that other blocks can use to access the stored fil | Input | Description | Type | Required | |-------|-------------|------|----------| -| file_in | The file to store in the temporary directory, it can be a URL, data URI, or local path. | str (file) | Yes | -| base_64 | Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks). | bool | No | +| file_in | The file to download and store. Can be a URL (https://...), data URI, or local path. | str (file) | Yes | +| base_64 | Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks). | bool | No | ### Outputs | Output | Description | Type | |--------|-------------|------| | error | Error message if the operation failed | str | -| file_out | The relative path to the stored file in the temporary directory. | str (file) | +| file_out | Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks. | str (file) | ### Possible use case diff --git a/docs/integrations/block-integrations/multimedia.md b/docs/integrations/block-integrations/multimedia.md new file mode 100644 index 0000000000..6b8f261346 --- /dev/null +++ b/docs/integrations/block-integrations/multimedia.md @@ -0,0 +1,117 @@ +# Multimedia + +Blocks for processing and manipulating video and audio files. + + +## Add Audio To Video + +### What it is +Block to attach an audio file to a video file using moviepy. + +### How it works + +This block combines a video file with an audio file using the moviepy library. The audio track is attached to the video, optionally with volume adjustment via the volume parameter (1.0 = original volume). + +Input files can be URLs, data URIs, or local paths. The output format is automatically determined: `workspace://` URLs in CoPilot, data URIs in graph executions. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| video_in | Video input (URL, data URI, or local path). | str (file) | Yes | +| audio_in | Audio input (URL, data URI, or local path). | str (file) | Yes | +| volume | Volume scale for the newly attached audio track (1.0 = original). | float | No | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if the operation failed | str | +| video_out | Final video (with attached audio), as a path or data URI. | str (file) | + +### Possible use case + +**Add Voiceover**: Combine generated voiceover audio with video content for narrated videos. + +**Background Music**: Add music tracks to silent videos or replace existing audio. + +**Audio Replacement**: Swap the audio track of a video for localization or accessibility. + + +--- + +## Loop Video + +### What it is +Block to loop a video to a given duration or number of repeats. + +### How it works + +This block extends a video by repeating it to reach a target duration or number of loops. Set duration to specify the total length in seconds, or use n_loops to repeat the video a specific number of times. + +The looped video is seamlessly concatenated. The output format is automatically determined: `workspace://` URLs in CoPilot, data URIs in graph executions. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| video_in | The input video (can be a URL, data URI, or local path). | str (file) | Yes | +| duration | Target duration (in seconds) to loop the video to. If omitted, defaults to no looping. | float | No | +| n_loops | Number of times to repeat the video. If omitted, defaults to 1 (no repeat). | int | No | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if the operation failed | str | +| video_out | Looped video returned either as a relative path or a data URI. | str | + +### Possible use case + +**Background Videos**: Loop short clips to match the duration of longer audio or content. + +**GIF-Like Content**: Create seamlessly looping video content for social media. + +**Filler Content**: Extend short video clips to meet minimum duration requirements. + + +--- + +## Media Duration + +### What it is +Block to get the duration of a media file. + +### How it works + +This block analyzes a media file and returns its duration in seconds. Set is_video to true for video files or false for audio files to ensure proper parsing. + +The input can be a URL, data URI, or local file path. The duration is returned as a float for precise timing calculations. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| media_in | Media input (URL, data URI, or local path). | str (file) | Yes | +| is_video | Whether the media is a video (True) or audio (False). | bool | No | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if the operation failed | str | +| duration | Duration of the media file (in seconds). | float | + +### Possible use case + +**Video Processing Prep**: Get video duration before deciding how to loop, trim, or synchronize it. + +**Audio Matching**: Determine audio length to generate matching-length video content. + +**Content Validation**: Verify that uploaded media meets duration requirements. + + +--- diff --git a/docs/platform/block-sdk-guide.md b/docs/platform/block-sdk-guide.md index 5b3eda5184..42fd883251 100644 --- a/docs/platform/block-sdk-guide.md +++ b/docs/platform/block-sdk-guide.md @@ -277,6 +277,50 @@ async def run( token = credentials.api_key.get_secret_value() ``` +### Handling Files + +When your block works with files (images, videos, documents), use `store_media_file()`: + +```python +from backend.data.execution import ExecutionContext +from backend.util.file import store_media_file +from backend.util.type import MediaFileType + +async def run( + self, + input_data: Input, + *, + execution_context: ExecutionContext, + **kwargs, +): + # PROCESSING: Need local file path for tools like ffmpeg, MoviePy, PIL + local_path = await store_media_file( + file=input_data.video, + execution_context=execution_context, + return_format="for_local_processing", + ) + + # EXTERNAL API: Need base64 content for APIs like Replicate, OpenAI + image_b64 = await store_media_file( + file=input_data.image, + execution_context=execution_context, + return_format="for_external_api", + ) + + # OUTPUT: Return to user/next block (auto-adapts to context) + result = await store_media_file( + file=generated_url, + execution_context=execution_context, + return_format="for_block_output", # workspace:// in CoPilot, data URI in graphs + ) + yield "image_url", result +``` + +**Return format options:** +- `"for_local_processing"` - Local file path for processing tools +- `"for_external_api"` - Data URI for external APIs needing base64 +- `"for_block_output"` - **Always use for outputs** - automatically picks best format + ## Testing Your Block ```bash diff --git a/docs/platform/contributing/oauth-integration-flow.md b/docs/platform/contributing/oauth-integration-flow.md index dbc7a54be5..f6c3f7fd17 100644 --- a/docs/platform/contributing/oauth-integration-flow.md +++ b/docs/platform/contributing/oauth-integration-flow.md @@ -25,7 +25,7 @@ This document focuses on the **API Integration OAuth flow** used for connecting ### 2. Backend API Trust Boundary - **Location**: Server-side FastAPI application - **Components**: - - Integration router (`/backend/backend/server/integrations/router.py`) + - Integration router (`/backend/backend/api/features/integrations/router.py`) - OAuth handlers (`/backend/backend/integrations/oauth/`) - Credentials store (`/backend/backend/integrations/credentials_store.py`) - **Trust Level**: Trusted - server-controlled environment diff --git a/docs/platform/new_blocks.md b/docs/platform/new_blocks.md index d9d329ff51..114ff8d9a4 100644 --- a/docs/platform/new_blocks.md +++ b/docs/platform/new_blocks.md @@ -111,6 +111,71 @@ Follow these steps to create and test a new block: - `graph_exec_id`: The ID of the execution of the agent. This changes every time the agent has a new "run" - `node_exec_id`: The ID of the execution of the node. This changes every time the node is executed - `node_id`: The ID of the node that is being executed. It changes every version of the graph, but not every time the node is executed. + - `execution_context`: An `ExecutionContext` object containing user_id, graph_exec_id, workspace_id, and session_id. Required for file handling. + +### Handling Files in Blocks + +When your block needs to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. This function handles downloading, validation, virus scanning, and storage. + +**Import:** +```python +from backend.data.execution import ExecutionContext +from backend.util.file import store_media_file +from backend.util.type import MediaFileType +``` + +**The `return_format` parameter determines what you get back:** + +| Format | Use When | Returns | +|--------|----------|---------| +| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) | +| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) | +| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs | + +**Examples:** + +```python +async def run( + self, + input_data: Input, + *, + execution_context: ExecutionContext, + **kwargs, +) -> BlockOutput: + # PROCESSING: Need to work with file locally (ffmpeg, MoviePy, PIL) + local_path = await store_media_file( + file=input_data.video, + execution_context=execution_context, + return_format="for_local_processing", + ) + # local_path = "video.mp4" - use with Path, ffmpeg, subprocess, etc. + full_path = get_exec_file_path(execution_context.graph_exec_id, local_path) + + # EXTERNAL API: Need to send content to an API like Replicate + image_b64 = await store_media_file( + file=input_data.image, + execution_context=execution_context, + return_format="for_external_api", + ) + # image_b64 = "data:image/png;base64,iVBORw0..." - send to external API + + # OUTPUT: Returning result from block to user/next block + result_url = await store_media_file( + file=generated_image_url, + execution_context=execution_context, + return_format="for_block_output", + ) + yield "image_url", result_url + # In CoPilot: result_url = "workspace://abc123" (persistent, context-efficient) + # In graphs: result_url = "data:image/png;base64,..." (for next block/display) +``` + +**Key points:** + +- `for_block_output` is the **only** format that auto-adapts to execution context +- Always use `for_block_output` for block outputs unless you have a specific reason not to +- Never manually check for `workspace_id` - let `for_block_output` handle the logic +- The function handles URLs, data URIs, `workspace://` references, and local paths as input ### Field Types diff --git a/docs/platform/ollama.md b/docs/platform/ollama.md index 392bfabfe8..ecab9b8ae1 100644 --- a/docs/platform/ollama.md +++ b/docs/platform/ollama.md @@ -246,7 +246,7 @@ If you encounter any issues, verify that: ```bash ollama pull llama3.2 ``` -- If using a custom model, ensure it's added to the model list in `backend/server/model.py` +- If using a custom model, ensure it's added to the model list in `backend/api/model.py` #### Docker Issues - Ensure Docker daemon is running: