diff --git a/.branchlet.json b/.branchlet.json index cc13ff9f74..d02cd60e20 100644 --- a/.branchlet.json +++ b/.branchlet.json @@ -29,8 +29,7 @@ "postCreateCmd": [ "cd autogpt_platform/autogpt_libs && poetry install", "cd autogpt_platform/backend && poetry install && poetry run prisma generate", - "cd autogpt_platform/frontend && pnpm install", - "cd docs && pip install -r requirements.txt" + "cd autogpt_platform/frontend && pnpm install" ], "terminalCommand": "code .", "deleteBranchWithWorktree": false diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 870e6b4b0a..3c72eaae18 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -160,7 +160,7 @@ pnpm storybook # Start component development server **Backend Entry Points:** -- `backend/backend/server/server.py` - FastAPI application setup +- `backend/backend/api/rest_api.py` - FastAPI application setup - `backend/backend/data/` - Database models and user management - `backend/blocks/` - Agent execution blocks and logic @@ -219,7 +219,7 @@ Agents are built using a visual block-based system where each block performs a s ### API Development -1. Update routes in `/backend/backend/server/routers/` +1. Update routes in `/backend/backend/api/features/` 2. Add/update Pydantic models in same directory 3. Write tests alongside route files 4. For `data/*.py` changes, validate user ID checks @@ -285,7 +285,7 @@ Agents are built using a visual block-based system where each block performs a s ### Security Guidelines -**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`): +**Cache Protection Middleware** (`/backend/backend/api/middleware/security.py`): - Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private` - Uses allow list approach for cacheable paths (static assets, health checks, public pages) diff --git a/.gitignore b/.gitignore index dfce8ba810..012a0b5227 100644 --- a/.gitignore +++ b/.gitignore @@ -178,4 +178,6 @@ autogpt_platform/backend/settings.py *.ign.* .test-contents .claude/settings.local.json +CLAUDE.local.md /autogpt_platform/backend/logs +.next \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md index cd176f8a2d..202c4c6e02 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -16,7 +16,6 @@ See `docs/content/platform/getting-started.md` for setup instructions. - Format Python code with `poetry run format`. - Format frontend code using `pnpm format`. - ## Frontend guidelines: See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference: @@ -33,14 +32,17 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference: 4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only 5. **Testing**: Add Storybook stories for new components, Playwright for E2E 6. **Code conventions**: Function declarations (not arrow functions) for components/handlers + - Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component - Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts) - Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible - Avoid large hooks, abstract logic into `helpers.ts` files when sensible - Use function declarations for components, arrow functions only for callbacks - No barrel files or `index.ts` re-exports -- Do not use `useCallback` or `useMemo` unless strictly needed - Avoid comments at all times unless the code is very complex +- Do not use `useCallback` or `useMemo` unless asked to optimise a given function +- Do not type hook returns, let Typescript infer as much as possible +- Never type with `any`, if not types available use `unknown` ## Testing @@ -49,22 +51,8 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference: Always run the relevant linters and tests before committing. Use conventional commit messages for all commits (e.g. `feat(backend): add API`). - Types: - - feat - - fix - - refactor - - ci - - dx (developer experience) - Scopes: - - platform - - platform/library - - platform/marketplace - - backend - - backend/executor - - frontend - - frontend/library - - frontend/marketplace - - blocks +Types: - feat - fix - refactor - ci - dx (developer experience) +Scopes: - platform - platform/library - platform/marketplace - backend - backend/executor - frontend - frontend/library - frontend/marketplace - blocks ## Pull requests diff --git a/README.md b/README.md index 3572fe318b..349d8818ef 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Before proceeding with the installation, ensure your system meets the following ### Updated Setup Instructions: We've moved to a fully maintained and regularly updated documentation site. -👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/) +👉 [Follow the official self-hosting guide here](https://agpt.co/docs/platform/getting-started/getting-started) This tutorial assumes you have Docker, VSCode, git and npm installed. diff --git a/autogpt_platform/CLAUDE.md b/autogpt_platform/CLAUDE.md index 2c76e7db80..62adbdaefa 100644 --- a/autogpt_platform/CLAUDE.md +++ b/autogpt_platform/CLAUDE.md @@ -6,152 +6,30 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co AutoGPT Platform is a monorepo containing: -- **Backend** (`/backend`): Python FastAPI server with async support -- **Frontend** (`/frontend`): Next.js React application -- **Shared Libraries** (`/autogpt_libs`): Common Python utilities +- **Backend** (`backend`): Python FastAPI server with async support +- **Frontend** (`frontend`): Next.js React application +- **Shared Libraries** (`autogpt_libs`): Common Python utilities -## Essential Commands +## Component Documentation -### Backend Development +- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks +- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns -```bash -# Install dependencies -cd backend && poetry install - -# Run database migrations -poetry run prisma migrate dev - -# Start all services (database, redis, rabbitmq, clamav) -docker compose up -d - -# Run the backend server -poetry run serve - -# Run tests -poetry run test - -# Run specific test -poetry run pytest path/to/test_file.py::test_function_name - -# Run block tests (tests that validate all blocks work correctly) -poetry run pytest backend/blocks/test/test_block.py -xvs - -# Run tests for a specific block (e.g., GetCurrentTimeBlock) -poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs - -# Lint and format -# prefer format if you want to just "fix" it and only get the errors that can't be autofixed -poetry run format # Black + isort -poetry run lint # ruff -``` - -More details can be found in TESTING.md - -#### Creating/Updating Snapshots - -When you first write a test or when the expected output changes: - -```bash -poetry run pytest path/to/test.py --snapshot-update -``` - -⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected. - -### Frontend Development - -```bash -# Install dependencies -cd frontend && pnpm i - -# Generate API client from OpenAPI spec -pnpm generate:api - -# Start development server -pnpm dev - -# Run E2E tests -pnpm test - -# Run Storybook for component development -pnpm storybook - -# Build production -pnpm build - -# Format and lint -pnpm format - -# Type checking -pnpm types -``` - -**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns. - -**Key Frontend Conventions:** - -- Separate render logic from data/behavior in components -- Use generated API hooks from `@/app/api/__generated__/endpoints/` -- Use function declarations (not arrow functions) for components/handlers -- Use design system components from `src/components/` (atoms, molecules, organisms) -- Only use Phosphor Icons -- Never use `src/components/__legacy__/*` or deprecated `BackendAPI` - -## Architecture Overview - -### Backend Architecture - -- **API Layer**: FastAPI with REST and WebSocket endpoints -- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings -- **Queue System**: RabbitMQ for async task processing -- **Execution Engine**: Separate executor service processes agent workflows -- **Authentication**: JWT-based with Supabase integration -- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies - -### Frontend Architecture - -- **Framework**: Next.js 15 App Router (client-first approach) -- **Data Fetching**: Type-safe generated API hooks via Orval + React Query -- **State Management**: React Query for server state, co-located UI state in components/hooks -- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks) -- **Workflow Builder**: Visual graph editor using @xyflow/react -- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling -- **Icons**: Phosphor Icons only -- **Feature Flags**: LaunchDarkly integration -- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions -- **Testing**: Playwright for E2E, Storybook for component development - -### Key Concepts +## Key Concepts 1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend -2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks +2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks 3. **Integrations**: OAuth and API connections stored per user 4. **Store**: Marketplace for sharing agent templates 5. **Virus Scanning**: ClamAV integration for file upload security -### Testing Approach - -- Backend uses pytest with snapshot testing for API responses -- Test files are colocated with source files (`*_test.py`) -- Frontend uses Playwright for E2E tests -- Component testing via Storybook - -### Database Schema - -Key models (defined in `/backend/schema.prisma`): - -- `User`: Authentication and profile data -- `AgentGraph`: Workflow definitions with version control -- `AgentGraphExecution`: Execution history and results -- `AgentNode`: Individual nodes in a workflow -- `StoreListing`: Marketplace listings for sharing agents - ### Environment Configuration #### Configuration Files -- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides) -- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides) -- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides) +- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides) +- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides) +- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides) #### Docker Environment Loading Order @@ -167,83 +45,12 @@ Key models (defined in `/backend/schema.prisma`): - Backend/Frontend services use YAML anchors for consistent configuration - Supabase services (`db/docker/docker-compose.yml`) follow the same pattern -### Common Development Tasks - -**Adding a new block:** - -Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers: - -- Provider configuration with `ProviderBuilder` -- Block schema definition -- Authentication (API keys, OAuth, webhooks) -- Testing and validation -- File organization - -Quick steps: - -1. Create new file in `/backend/backend/blocks/` -2. Configure provider using `ProviderBuilder` in `_config.py` -3. Inherit from `Block` base class -4. Define input/output schemas using `BlockSchema` -5. Implement async `run` method -6. Generate unique block ID using `uuid.uuid4()` -7. Test with `poetry run pytest backend/blocks/test/test_block.py` - -Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively? -ex: do the inputs and outputs tie well together? - -If you get any pushback or hit complex block conditions check the new_blocks guide in the docs. - -**Modifying the API:** - -1. Update route in `/backend/backend/server/routers/` -2. Add/update Pydantic models in same directory -3. Write tests alongside the route file -4. Run `poetry run test` to verify - -### Frontend guidelines: - -See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference: - -1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx` - - Add `usePageName.ts` hook for logic - - Put sub-components in local `components/` folder -2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts` - - Use design system components from `src/components/` (atoms, molecules, organisms) - - Never use `src/components/__legacy__/*` -3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/` - - Regenerate with `pnpm generate:api` - - Pattern: `use{Method}{Version}{OperationName}` -4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only -5. **Testing**: Add Storybook stories for new components, Playwright for E2E -6. **Code conventions**: Function declarations (not arrow functions) for components/handlers -- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component -- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts) -- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible -- Avoid large hooks, abstract logic into `helpers.ts` files when sensible -- Use function declarations for components, arrow functions only for callbacks -- No barrel files or `index.ts` re-exports -- Do not use `useCallback` or `useMemo` unless strictly needed -- Avoid comments at all times unless the code is very complex - -### Security Implementation - -**Cache Protection Middleware:** - -- Located in `/backend/backend/server/middleware/security.py` -- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private` -- Uses an allow list approach - only explicitly permitted paths can be cached -- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation -- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies -- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware -- Applied to both main API server and external API applications - ### Creating Pull Requests -- Create the PR aginst the `dev` branch of the repository. -- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/ -- Use conventional commit messages (see below)/ -- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/ +- Create the PR against the `dev` branch of the repository. +- Ensure the branch name is descriptive (e.g., `feature/add-new-block`) +- Use conventional commit messages (see below) +- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description - Run the github pre-commit hooks to ensure code quality. ### Reviewing/Revising Pull Requests diff --git a/autogpt_platform/backend/CLAUDE.md b/autogpt_platform/backend/CLAUDE.md new file mode 100644 index 0000000000..53d52bb4d3 --- /dev/null +++ b/autogpt_platform/backend/CLAUDE.md @@ -0,0 +1,170 @@ +# CLAUDE.md - Backend + +This file provides guidance to Claude Code when working with the backend. + +## Essential Commands + +To run something with Python package dependencies you MUST use `poetry run ...`. + +```bash +# Install dependencies +poetry install + +# Run database migrations +poetry run prisma migrate dev + +# Start all services (database, redis, rabbitmq, clamav) +docker compose up -d + +# Run the backend as a whole +poetry run app + +# Run tests +poetry run test + +# Run specific test +poetry run pytest path/to/test_file.py::test_function_name + +# Run block tests (tests that validate all blocks work correctly) +poetry run pytest backend/blocks/test/test_block.py -xvs + +# Run tests for a specific block (e.g., GetCurrentTimeBlock) +poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs + +# Lint and format +# prefer format if you want to just "fix" it and only get the errors that can't be autofixed +poetry run format # Black + isort +poetry run lint # ruff +``` + +More details can be found in @TESTING.md + +### Creating/Updating Snapshots + +When you first write a test or when the expected output changes: + +```bash +poetry run pytest path/to/test.py --snapshot-update +``` + +⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected. + +## Architecture + +- **API Layer**: FastAPI with REST and WebSocket endpoints +- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings +- **Queue System**: RabbitMQ for async task processing +- **Execution Engine**: Separate executor service processes agent workflows +- **Authentication**: JWT-based with Supabase integration +- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies + +## Testing Approach + +- Uses pytest with snapshot testing for API responses +- Test files are colocated with source files (`*_test.py`) + +## Database Schema + +Key models (defined in `schema.prisma`): + +- `User`: Authentication and profile data +- `AgentGraph`: Workflow definitions with version control +- `AgentGraphExecution`: Execution history and results +- `AgentNode`: Individual nodes in a workflow +- `StoreListing`: Marketplace listings for sharing agents + +## Environment Configuration + +- **Backend**: `.env.default` (defaults) → `.env` (user overrides) + +## Common Development Tasks + +### Adding a new block + +Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers: + +- Provider configuration with `ProviderBuilder` +- Block schema definition +- Authentication (API keys, OAuth, webhooks) +- Testing and validation +- File organization + +Quick steps: + +1. Create new file in `backend/blocks/` +2. Configure provider using `ProviderBuilder` in `_config.py` +3. Inherit from `Block` base class +4. Define input/output schemas using `BlockSchema` +5. Implement async `run` method +6. Generate unique block ID using `uuid.uuid4()` +7. Test with `poetry run pytest backend/blocks/test/test_block.py` + +Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively? +ex: do the inputs and outputs tie well together? + +If you get any pushback or hit complex block conditions check the new_blocks guide in the docs. + +#### Handling files in blocks with `store_media_file()` + +When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back: + +| Format | Use When | Returns | +|--------|----------|---------| +| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) | +| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) | +| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs | + +**Examples:** + +```python +# INPUT: Need to process file locally with ffmpeg +local_path = await store_media_file( + file=input_data.video, + execution_context=execution_context, + return_format="for_local_processing", +) +# local_path = "video.mp4" - use with Path/ffmpeg/etc + +# INPUT: Need to send to external API like Replicate +image_b64 = await store_media_file( + file=input_data.image, + execution_context=execution_context, + return_format="for_external_api", +) +# image_b64 = "..." - send to API + +# OUTPUT: Returning result from block +result_url = await store_media_file( + file=generated_image_url, + execution_context=execution_context, + return_format="for_block_output", +) +yield "image_url", result_url +# In CoPilot: result_url = "workspace://abc123" +# In graphs: result_url = "data:image/png;base64,..." +``` + +**Key points:** + +- `for_block_output` is the ONLY format that auto-adapts to execution context +- Always use `for_block_output` for block outputs unless you have a specific reason not to +- Never hardcode workspace checks - let `for_block_output` handle it + +### Modifying the API + +1. Update route in `backend/api/features/` +2. Add/update Pydantic models in same directory +3. Write tests alongside the route file +4. Run `poetry run test` to verify + +## Security Implementation + +### Cache Protection Middleware + +- Located in `backend/api/middleware/security.py` +- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private` +- Uses an allow list approach - only explicitly permitted paths can be cached +- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation +- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies +- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware +- Applied to both main API server and external API applications diff --git a/autogpt_platform/backend/TESTING.md b/autogpt_platform/backend/TESTING.md index a3a5db68ef..2e09144485 100644 --- a/autogpt_platform/backend/TESTING.md +++ b/autogpt_platform/backend/TESTING.md @@ -138,7 +138,7 @@ If the test doesn't need the `user_id` specifically, mocking is not necessary as #### Using Global Auth Fixtures -Two global auth fixtures are provided by `backend/server/conftest.py`: +Two global auth fixtures are provided by `backend/api/conftest.py`: - `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id") - `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id") diff --git a/autogpt_platform/backend/backend/api/features/builder/routes.py b/autogpt_platform/backend/backend/api/features/builder/routes.py index 7fe9cab189..15b922178d 100644 --- a/autogpt_platform/backend/backend/api/features/builder/routes.py +++ b/autogpt_platform/backend/backend/api/features/builder/routes.py @@ -17,7 +17,7 @@ router = fastapi.APIRouter( ) -# Taken from backend/server/v2/store/db.py +# Taken from backend/api/features/store/db.py def sanitize_query(query: str | None) -> str | None: if query is None: return query diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py new file mode 100644 index 0000000000..f447d46bd7 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py @@ -0,0 +1,368 @@ +"""Redis Streams consumer for operation completion messages. + +This module provides a consumer (ChatCompletionConsumer) that listens for +completion notifications (OperationCompleteMessage) from external services +(like Agent Generator) and triggers the appropriate stream registry and +chat service updates via process_operation_success/process_operation_failure. + +Why Redis Streams instead of RabbitMQ? +-------------------------------------- +While the project typically uses RabbitMQ for async task queues (e.g., execution +queue), Redis Streams was chosen for chat completion notifications because: + +1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis + Streams (via stream_registry) for message persistence and replay. Using Redis + Streams for completion notifications keeps all chat streaming infrastructure + in one system, simplifying operations and reducing cross-system coordination. + +2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs, + allowing consumers to replay missed messages after reconnection. This aligns + with the SSE reconnection pattern where clients can resume from last_message_id. + +3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic + load balancing across pods with explicit message claiming (XAUTOCLAIM) for + recovering from dead consumers - ideal for the completion callback pattern. + +4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for + stream_registry) provides lower latency than an additional RabbitMQ hop. + +5. **Atomicity with Task State**: Completion processing often needs to update + task metadata stored in Redis. Keeping both in Redis enables simpler + transactional semantics without distributed coordination. + +The consumer uses Redis Streams with consumer groups for reliable message +processing across multiple platform pods, with XAUTOCLAIM for reclaiming +stale pending messages from dead consumers. +""" + +import asyncio +import logging +import os +import uuid +from typing import Any + +import orjson +from prisma import Prisma +from pydantic import BaseModel +from redis.exceptions import ResponseError + +from backend.data.redis_client import get_redis_async + +from . import stream_registry +from .completion_handler import process_operation_failure, process_operation_success +from .config import ChatConfig + +logger = logging.getLogger(__name__) +config = ChatConfig() + + +class OperationCompleteMessage(BaseModel): + """Message format for operation completion notifications.""" + + operation_id: str + task_id: str + success: bool + result: dict | str | None = None + error: str | None = None + + +class ChatCompletionConsumer: + """Consumer for chat operation completion messages from Redis Streams. + + This consumer initializes its own Prisma client in start() to ensure + database operations work correctly within this async context. + + Uses Redis consumer groups to allow multiple platform pods to consume + messages reliably with automatic redelivery on failure. + """ + + def __init__(self): + self._consumer_task: asyncio.Task | None = None + self._running = False + self._prisma: Prisma | None = None + self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}" + + async def start(self) -> None: + """Start the completion consumer.""" + if self._running: + logger.warning("Completion consumer already running") + return + + # Create consumer group if it doesn't exist + try: + redis = await get_redis_async() + await redis.xgroup_create( + config.stream_completion_name, + config.stream_consumer_group, + id="0", + mkstream=True, + ) + logger.info( + f"Created consumer group '{config.stream_consumer_group}' " + f"on stream '{config.stream_completion_name}'" + ) + except ResponseError as e: + if "BUSYGROUP" in str(e): + logger.debug( + f"Consumer group '{config.stream_consumer_group}' already exists" + ) + else: + raise + + self._running = True + self._consumer_task = asyncio.create_task(self._consume_messages()) + logger.info( + f"Chat completion consumer started (consumer: {self._consumer_name})" + ) + + async def _ensure_prisma(self) -> Prisma: + """Lazily initialize Prisma client on first use.""" + if self._prisma is None: + database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432") + self._prisma = Prisma(datasource={"url": database_url}) + await self._prisma.connect() + logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)") + return self._prisma + + async def stop(self) -> None: + """Stop the completion consumer.""" + self._running = False + + if self._consumer_task: + self._consumer_task.cancel() + try: + await self._consumer_task + except asyncio.CancelledError: + pass + self._consumer_task = None + + if self._prisma: + await self._prisma.disconnect() + self._prisma = None + logger.info("[COMPLETION] Consumer Prisma client disconnected") + + logger.info("Chat completion consumer stopped") + + async def _consume_messages(self) -> None: + """Main message consumption loop with retry logic.""" + max_retries = 10 + retry_delay = 5 # seconds + retry_count = 0 + block_timeout = 5000 # milliseconds + + while self._running and retry_count < max_retries: + try: + redis = await get_redis_async() + + # Reset retry count on successful connection + retry_count = 0 + + while self._running: + # First, claim any stale pending messages from dead consumers + # Redis does NOT auto-redeliver pending messages; we must explicitly + # claim them using XAUTOCLAIM + try: + claimed_result = await redis.xautoclaim( + name=config.stream_completion_name, + groupname=config.stream_consumer_group, + consumername=self._consumer_name, + min_idle_time=config.stream_claim_min_idle_ms, + start_id="0-0", + count=10, + ) + # xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids]) + if claimed_result and len(claimed_result) >= 2: + claimed_entries = claimed_result[1] + if claimed_entries: + logger.info( + f"Claimed {len(claimed_entries)} stale pending messages" + ) + for entry_id, data in claimed_entries: + if not self._running: + return + await self._process_entry(redis, entry_id, data) + except Exception as e: + logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}") + + # Read new messages from the stream + messages = await redis.xreadgroup( + groupname=config.stream_consumer_group, + consumername=self._consumer_name, + streams={config.stream_completion_name: ">"}, + block=block_timeout, + count=10, + ) + + if not messages: + continue + + for stream_name, entries in messages: + for entry_id, data in entries: + if not self._running: + return + await self._process_entry(redis, entry_id, data) + + except asyncio.CancelledError: + logger.info("Consumer cancelled") + return + except Exception as e: + retry_count += 1 + logger.error( + f"Consumer error (retry {retry_count}/{max_retries}): {e}", + exc_info=True, + ) + if self._running and retry_count < max_retries: + await asyncio.sleep(retry_delay) + else: + logger.error("Max retries reached, stopping consumer") + return + + async def _process_entry( + self, redis: Any, entry_id: str, data: dict[str, Any] + ) -> None: + """Process a single stream entry and acknowledge it on success. + + Args: + redis: Redis client connection + entry_id: The stream entry ID + data: The entry data dict + """ + try: + # Handle the message + message_data = data.get("data") + if message_data: + await self._handle_message( + message_data.encode() + if isinstance(message_data, str) + else message_data + ) + + # Acknowledge the message after successful processing + await redis.xack( + config.stream_completion_name, + config.stream_consumer_group, + entry_id, + ) + except Exception as e: + logger.error( + f"Error processing completion message {entry_id}: {e}", + exc_info=True, + ) + # Message remains in pending state and will be claimed by + # XAUTOCLAIM after min_idle_time expires + + async def _handle_message(self, body: bytes) -> None: + """Handle a completion message using our own Prisma client.""" + try: + data = orjson.loads(body) + message = OperationCompleteMessage(**data) + except Exception as e: + logger.error(f"Failed to parse completion message: {e}") + return + + logger.info( + f"[COMPLETION] Received completion for operation {message.operation_id} " + f"(task_id={message.task_id}, success={message.success})" + ) + + # Find task in registry + task = await stream_registry.find_task_by_operation_id(message.operation_id) + if task is None: + task = await stream_registry.get_task(message.task_id) + + if task is None: + logger.warning( + f"[COMPLETION] Task not found for operation {message.operation_id} " + f"(task_id={message.task_id})" + ) + return + + logger.info( + f"[COMPLETION] Found task: task_id={task.task_id}, " + f"session_id={task.session_id}, tool_call_id={task.tool_call_id}" + ) + + # Guard against empty task fields + if not task.task_id or not task.session_id or not task.tool_call_id: + logger.error( + f"[COMPLETION] Task has empty critical fields! " + f"task_id={task.task_id!r}, session_id={task.session_id!r}, " + f"tool_call_id={task.tool_call_id!r}" + ) + return + + if message.success: + await self._handle_success(task, message) + else: + await self._handle_failure(task, message) + + async def _handle_success( + self, + task: stream_registry.ActiveTask, + message: OperationCompleteMessage, + ) -> None: + """Handle successful operation completion.""" + prisma = await self._ensure_prisma() + await process_operation_success(task, message.result, prisma) + + async def _handle_failure( + self, + task: stream_registry.ActiveTask, + message: OperationCompleteMessage, + ) -> None: + """Handle failed operation completion.""" + prisma = await self._ensure_prisma() + await process_operation_failure(task, message.error, prisma) + + +# Module-level consumer instance +_consumer: ChatCompletionConsumer | None = None + + +async def start_completion_consumer() -> None: + """Start the global completion consumer.""" + global _consumer + if _consumer is None: + _consumer = ChatCompletionConsumer() + await _consumer.start() + + +async def stop_completion_consumer() -> None: + """Stop the global completion consumer.""" + global _consumer + if _consumer: + await _consumer.stop() + _consumer = None + + +async def publish_operation_complete( + operation_id: str, + task_id: str, + success: bool, + result: dict | str | None = None, + error: str | None = None, +) -> None: + """Publish an operation completion message to Redis Streams. + + Args: + operation_id: The operation ID that completed. + task_id: The task ID associated with the operation. + success: Whether the operation succeeded. + result: The result data (for success). + error: The error message (for failure). + """ + message = OperationCompleteMessage( + operation_id=operation_id, + task_id=task_id, + success=success, + result=result, + error=error, + ) + + redis = await get_redis_async() + await redis.xadd( + config.stream_completion_name, + {"data": message.model_dump_json()}, + maxlen=config.stream_max_length, + ) + logger.info(f"Published completion for operation {operation_id}") diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_handler.py b/autogpt_platform/backend/backend/api/features/chat/completion_handler.py new file mode 100644 index 0000000000..905fa2ddba --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/completion_handler.py @@ -0,0 +1,344 @@ +"""Shared completion handling for operation success and failure. + +This module provides common logic for handling operation completion from both: +- The Redis Streams consumer (completion_consumer.py) +- The HTTP webhook endpoint (routes.py) +""" + +import logging +from typing import Any + +import orjson +from prisma import Prisma + +from . import service as chat_service +from . import stream_registry +from .response_model import StreamError, StreamToolOutputAvailable +from .tools.models import ErrorResponse + +logger = logging.getLogger(__name__) + +# Tools that produce agent_json that needs to be saved to library +AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"} + +# Keys that should be stripped from agent_json when returning in error responses +SENSITIVE_KEYS = frozenset( + { + "api_key", + "apikey", + "api_secret", + "password", + "secret", + "credentials", + "credential", + "token", + "access_token", + "refresh_token", + "private_key", + "privatekey", + "auth", + "authorization", + } +) + + +def _sanitize_agent_json(obj: Any) -> Any: + """Recursively sanitize agent_json by removing sensitive keys. + + Args: + obj: The object to sanitize (dict, list, or primitive) + + Returns: + Sanitized copy with sensitive keys removed/redacted + """ + if isinstance(obj, dict): + return { + k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v) + for k, v in obj.items() + } + elif isinstance(obj, list): + return [_sanitize_agent_json(item) for item in obj] + else: + return obj + + +class ToolMessageUpdateError(Exception): + """Raised when updating a tool message in the database fails.""" + + pass + + +async def _update_tool_message( + session_id: str, + tool_call_id: str, + content: str, + prisma_client: Prisma | None, +) -> None: + """Update tool message in database. + + Args: + session_id: The session ID + tool_call_id: The tool call ID to update + content: The new content for the message + prisma_client: Optional Prisma client. If None, uses chat_service. + + Raises: + ToolMessageUpdateError: If the database update fails. The caller should + handle this to avoid marking the task as completed with inconsistent state. + """ + try: + if prisma_client: + # Use provided Prisma client (for consumer with its own connection) + updated_count = await prisma_client.chatmessage.update_many( + where={ + "sessionId": session_id, + "toolCallId": tool_call_id, + }, + data={"content": content}, + ) + # Check if any rows were updated - 0 means message not found + if updated_count == 0: + raise ToolMessageUpdateError( + f"No message found with tool_call_id={tool_call_id} in session {session_id}" + ) + else: + # Use service function (for webhook endpoint) + await chat_service._update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=content, + ) + except ToolMessageUpdateError: + raise + except Exception as e: + logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True) + raise ToolMessageUpdateError( + f"Failed to update tool message for tool_call_id={tool_call_id}: {e}" + ) from e + + +def serialize_result(result: dict | list | str | int | float | bool | None) -> str: + """Serialize result to JSON string with sensible defaults. + + Args: + result: The result to serialize. Can be a dict, list, string, + number, boolean, or None. + + Returns: + JSON string representation of the result. Returns '{"status": "completed"}' + only when result is explicitly None. + """ + if isinstance(result, str): + return result + if result is None: + return '{"status": "completed"}' + return orjson.dumps(result).decode("utf-8") + + +async def _save_agent_from_result( + result: dict[str, Any], + user_id: str | None, + tool_name: str, +) -> dict[str, Any]: + """Save agent to library if result contains agent_json. + + Args: + result: The result dict that may contain agent_json + user_id: The user ID to save the agent for + tool_name: The tool name (create_agent or edit_agent) + + Returns: + Updated result dict with saved agent details, or original result if no agent_json + """ + if not user_id: + logger.warning("[COMPLETION] Cannot save agent: no user_id in task") + return result + + agent_json = result.get("agent_json") + if not agent_json: + logger.warning( + f"[COMPLETION] {tool_name} completed but no agent_json in result" + ) + return result + + try: + from .tools.agent_generator import save_agent_to_library + + is_update = tool_name == "edit_agent" + created_graph, library_agent = await save_agent_to_library( + agent_json, user_id, is_update=is_update + ) + + logger.info( + f"[COMPLETION] Saved agent '{created_graph.name}' to library " + f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})" + ) + + # Return a response similar to AgentSavedResponse + return { + "type": "agent_saved", + "message": f"Agent '{created_graph.name}' has been saved to your library!", + "agent_id": created_graph.id, + "agent_name": created_graph.name, + "library_agent_id": library_agent.id, + "library_agent_link": f"/library/agents/{library_agent.id}", + "agent_page_link": f"/build?flowID={created_graph.id}", + } + except Exception as e: + logger.error( + f"[COMPLETION] Failed to save agent to library: {e}", + exc_info=True, + ) + # Return error but don't fail the whole operation + # Sanitize agent_json to remove sensitive keys before returning + return { + "type": "error", + "message": f"Agent was generated but failed to save: {str(e)}", + "error": str(e), + "agent_json": _sanitize_agent_json(agent_json), + } + + +async def process_operation_success( + task: stream_registry.ActiveTask, + result: dict | str | None, + prisma_client: Prisma | None = None, +) -> None: + """Handle successful operation completion. + + Publishes the result to the stream registry, updates the database, + generates LLM continuation, and marks the task as completed. + + Args: + task: The active task that completed + result: The result data from the operation + prisma_client: Optional Prisma client for database operations. + If None, uses chat_service._update_pending_operation instead. + + Raises: + ToolMessageUpdateError: If the database update fails. The task will be + marked as failed instead of completed to avoid inconsistent state. + """ + # For agent generation tools, save the agent to library + if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict): + result = await _save_agent_from_result(result, task.user_id, task.tool_name) + + # Serialize result for output (only substitute default when result is exactly None) + result_output = result if result is not None else {"status": "completed"} + output_str = ( + result_output + if isinstance(result_output, str) + else orjson.dumps(result_output).decode("utf-8") + ) + + # Publish result to stream registry + await stream_registry.publish_chunk( + task.task_id, + StreamToolOutputAvailable( + toolCallId=task.tool_call_id, + toolName=task.tool_name, + output=output_str, + success=True, + ), + ) + + # Update pending operation in database + # If this fails, we must not continue to mark the task as completed + result_str = serialize_result(result) + try: + await _update_tool_message( + session_id=task.session_id, + tool_call_id=task.tool_call_id, + content=result_str, + prisma_client=prisma_client, + ) + except ToolMessageUpdateError: + # DB update failed - mark task as failed to avoid inconsistent state + logger.error( + f"[COMPLETION] DB update failed for task {task.task_id}, " + "marking as failed instead of completed" + ) + await stream_registry.publish_chunk( + task.task_id, + StreamError(errorText="Failed to save operation result to database"), + ) + await stream_registry.mark_task_completed(task.task_id, status="failed") + raise + + # Generate LLM continuation with streaming + try: + await chat_service._generate_llm_continuation_with_streaming( + session_id=task.session_id, + user_id=task.user_id, + task_id=task.task_id, + ) + except Exception as e: + logger.error( + f"[COMPLETION] Failed to generate LLM continuation: {e}", + exc_info=True, + ) + + # Mark task as completed and release Redis lock + await stream_registry.mark_task_completed(task.task_id, status="completed") + try: + await chat_service._mark_operation_completed(task.tool_call_id) + except Exception as e: + logger.error(f"[COMPLETION] Failed to mark operation completed: {e}") + + logger.info( + f"[COMPLETION] Successfully processed completion for task {task.task_id}" + ) + + +async def process_operation_failure( + task: stream_registry.ActiveTask, + error: str | None, + prisma_client: Prisma | None = None, +) -> None: + """Handle failed operation completion. + + Publishes the error to the stream registry, updates the database with + the error response, and marks the task as failed. + + Args: + task: The active task that failed + error: The error message from the operation + prisma_client: Optional Prisma client for database operations. + If None, uses chat_service._update_pending_operation instead. + """ + error_msg = error or "Operation failed" + + # Publish error to stream registry + await stream_registry.publish_chunk( + task.task_id, + StreamError(errorText=error_msg), + ) + + # Update pending operation with error + # If this fails, we still continue to mark the task as failed + error_response = ErrorResponse( + message=error_msg, + error=error, + ) + try: + await _update_tool_message( + session_id=task.session_id, + tool_call_id=task.tool_call_id, + content=error_response.model_dump_json(), + prisma_client=prisma_client, + ) + except ToolMessageUpdateError: + # DB update failed - log but continue with cleanup + logger.error( + f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, " + "continuing with cleanup" + ) + + # Mark task as failed and release Redis lock + await stream_registry.mark_task_completed(task.task_id, status="failed") + try: + await chat_service._mark_operation_completed(task.tool_call_id) + except Exception as e: + logger.error(f"[COMPLETION] Failed to mark operation completed: {e}") + + logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}") diff --git a/autogpt_platform/backend/backend/api/features/chat/config.py b/autogpt_platform/backend/backend/api/features/chat/config.py index dba7934877..2e8dbf5413 100644 --- a/autogpt_platform/backend/backend/api/features/chat/config.py +++ b/autogpt_platform/backend/backend/api/features/chat/config.py @@ -44,6 +44,48 @@ class ChatConfig(BaseSettings): description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)", ) + # Stream registry configuration for SSE reconnection + stream_ttl: int = Field( + default=3600, + description="TTL in seconds for stream data in Redis (1 hour)", + ) + stream_max_length: int = Field( + default=10000, + description="Maximum number of messages to store per stream", + ) + + # Redis Streams configuration for completion consumer + stream_completion_name: str = Field( + default="chat:completions", + description="Redis Stream name for operation completions", + ) + stream_consumer_group: str = Field( + default="chat_consumers", + description="Consumer group name for completion stream", + ) + stream_claim_min_idle_ms: int = Field( + default=60000, + description="Minimum idle time in milliseconds before claiming pending messages from dead consumers", + ) + + # Redis key prefixes for stream registry + task_meta_prefix: str = Field( + default="chat:task:meta:", + description="Prefix for task metadata hash keys", + ) + task_stream_prefix: str = Field( + default="chat:stream:", + description="Prefix for task message stream keys", + ) + task_op_prefix: str = Field( + default="chat:task:op:", + description="Prefix for operation ID to task ID mapping keys", + ) + internal_api_key: str | None = Field( + default=None, + description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)", + ) + # Langfuse Prompt Management Configuration # Note: Langfuse credentials are in Settings().secrets (settings.py) langfuse_prompt_name: str = Field( @@ -82,6 +124,14 @@ class ChatConfig(BaseSettings): v = "https://openrouter.ai/api/v1" return v + @field_validator("internal_api_key", mode="before") + @classmethod + def get_internal_api_key(cls, v): + """Get internal API key from environment if not provided.""" + if v is None: + v = os.getenv("CHAT_INTERNAL_API_KEY") + return v + # Prompt paths for different contexts PROMPT_PATHS: dict[str, str] = { "default": "prompts/chat_system.md", diff --git a/autogpt_platform/backend/backend/api/features/chat/response_model.py b/autogpt_platform/backend/backend/api/features/chat/response_model.py index 53a8cf3a1f..f627a42fcc 100644 --- a/autogpt_platform/backend/backend/api/features/chat/response_model.py +++ b/autogpt_platform/backend/backend/api/features/chat/response_model.py @@ -52,6 +52,10 @@ class StreamStart(StreamBaseResponse): type: ResponseType = ResponseType.START messageId: str = Field(..., description="Unique message ID") + taskId: str | None = Field( + default=None, + description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream", + ) class StreamFinish(StreamBaseResponse): diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index cab51543b1..3e731d86ac 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -1,19 +1,23 @@ """Chat API routes for chat session management and streaming via SSE.""" import logging +import uuid as uuid_module from collections.abc import AsyncGenerator from typing import Annotated from autogpt_libs import auth -from fastapi import APIRouter, Depends, Query, Security +from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security from fastapi.responses import StreamingResponse from pydantic import BaseModel from backend.util.exceptions import NotFoundError from . import service as chat_service +from . import stream_registry +from .completion_handler import process_operation_failure, process_operation_success from .config import ChatConfig from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions +from .response_model import StreamFinish, StreamHeartbeat, StreamStart config = ChatConfig() @@ -55,6 +59,15 @@ class CreateSessionResponse(BaseModel): user_id: str | None +class ActiveStreamInfo(BaseModel): + """Information about an active stream for reconnection.""" + + task_id: str + last_message_id: str # Redis Stream message ID for resumption + operation_id: str # Operation ID for completion tracking + tool_name: str # Name of the tool being executed + + class SessionDetailResponse(BaseModel): """Response model providing complete details for a chat session, including messages.""" @@ -63,6 +76,7 @@ class SessionDetailResponse(BaseModel): updated_at: str user_id: str | None messages: list[dict] + active_stream: ActiveStreamInfo | None = None # Present if stream is still active class SessionSummaryResponse(BaseModel): @@ -81,6 +95,14 @@ class ListSessionsResponse(BaseModel): total: int +class OperationCompleteRequest(BaseModel): + """Request model for external completion webhook.""" + + success: bool + result: dict | str | None = None + error: str | None = None + + # ========== Routes ========== @@ -166,13 +188,14 @@ async def get_session( Retrieve the details of a specific chat session. Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages. + If there's an active stream for this session, returns the task_id for reconnection. Args: session_id: The unique identifier for the desired chat session. user_id: The optional authenticated user ID, or None for anonymous access. Returns: - SessionDetailResponse: Details for the requested session, or None if not found. + SessionDetailResponse: Details for the requested session, including active_stream info if applicable. """ session = await get_chat_session(session_id, user_id) @@ -180,11 +203,28 @@ async def get_session( raise NotFoundError(f"Session {session_id} not found.") messages = [message.model_dump() for message in session.messages] - logger.info( - f"Returning session {session_id}: " - f"message_count={len(messages)}, " - f"roles={[m.get('role') for m in messages]}" + + # Check if there's an active stream for this session + active_stream_info = None + active_task, last_message_id = await stream_registry.get_active_task_for_session( + session_id, user_id ) + if active_task: + # Filter out the in-progress assistant message from the session response. + # The client will receive the complete assistant response through the SSE + # stream replay instead, preventing duplicate content. + if messages and messages[-1].get("role") == "assistant": + messages = messages[:-1] + + # Use "0-0" as last_message_id to replay the stream from the beginning. + # Since we filtered out the cached assistant message, the client needs + # the full stream to reconstruct the response. + active_stream_info = ActiveStreamInfo( + task_id=active_task.task_id, + last_message_id="0-0", + operation_id=active_task.operation_id, + tool_name=active_task.tool_name, + ) return SessionDetailResponse( id=session.session_id, @@ -192,6 +232,7 @@ async def get_session( updated_at=session.updated_at.isoformat(), user_id=session.user_id or None, messages=messages, + active_stream=active_stream_info, ) @@ -211,49 +252,112 @@ async def stream_chat_post( - Tool call UI elements (if invoked) - Tool execution results + The AI generation runs in a background task that continues even if the client disconnects. + All chunks are written to Redis for reconnection support. If the client disconnects, + they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off. + Args: session_id: The chat session identifier to associate with the streamed messages. request: Request body containing message, is_user_message, and optional context. user_id: Optional authenticated user ID. Returns: - StreamingResponse: SSE-formatted response chunks. + StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event + containing the task_id for reconnection. """ + import asyncio + session = await _validate_and_get_session(session_id, user_id) + # Create a task in the stream registry for reconnection support + task_id = str(uuid_module.uuid4()) + operation_id = str(uuid_module.uuid4()) + await stream_registry.create_task( + task_id=task_id, + session_id=session_id, + user_id=user_id, + tool_call_id="chat_stream", # Not a tool call, but needed for the model + tool_name="chat", + operation_id=operation_id, + ) + + # Background task that runs the AI generation independently of SSE connection + async def run_ai_generation(): + try: + # Emit a start event with task_id for reconnection + start_chunk = StreamStart(messageId=task_id, taskId=task_id) + await stream_registry.publish_chunk(task_id, start_chunk) + + async for chunk in chat_service.stream_chat_completion( + session_id, + request.message, + is_user_message=request.is_user_message, + user_id=user_id, + session=session, # Pass pre-fetched session to avoid double-fetch + context=request.context, + ): + # Write to Redis (subscribers will receive via XREAD) + await stream_registry.publish_chunk(task_id, chunk) + + # Mark task as completed + await stream_registry.mark_task_completed(task_id, "completed") + except Exception as e: + logger.error( + f"Error in background AI generation for session {session_id}: {e}" + ) + await stream_registry.mark_task_completed(task_id, "failed") + + # Start the AI generation in a background task + bg_task = asyncio.create_task(run_ai_generation()) + await stream_registry.set_task_asyncio_task(task_id, bg_task) + + # SSE endpoint that subscribes to the task's stream async def event_generator() -> AsyncGenerator[str, None]: - chunk_count = 0 - first_chunk_type: str | None = None - async for chunk in chat_service.stream_chat_completion( - session_id, - request.message, - is_user_message=request.is_user_message, - user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch - context=request.context, - ): - if chunk_count < 3: - logger.info( - "Chat stream chunk", - extra={ - "session_id": session_id, - "chunk_type": str(chunk.type), - }, - ) - if not first_chunk_type: - first_chunk_type = str(chunk.type) - chunk_count += 1 - yield chunk.to_sse() - logger.info( - "Chat stream completed", - extra={ - "session_id": session_id, - "chunk_count": chunk_count, - "first_chunk_type": first_chunk_type, - }, - ) - # AI SDK protocol termination - yield "data: [DONE]\n\n" + subscriber_queue = None + try: + # Subscribe to the task stream (this replays existing messages + live updates) + subscriber_queue = await stream_registry.subscribe_to_task( + task_id=task_id, + user_id=user_id, + last_message_id="0-0", # Get all messages from the beginning + ) + + if subscriber_queue is None: + yield StreamFinish().to_sse() + yield "data: [DONE]\n\n" + return + + # Read from the subscriber queue and yield to SSE + while True: + try: + chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0) + yield chunk.to_sse() + + # Check for finish signal + if isinstance(chunk, StreamFinish): + break + except asyncio.TimeoutError: + # Send heartbeat to keep connection alive + yield StreamHeartbeat().to_sse() + + except GeneratorExit: + pass # Client disconnected - background task continues + except Exception as e: + logger.error(f"Error in SSE stream for task {task_id}: {e}") + finally: + # Unsubscribe when client disconnects or stream ends to prevent resource leak + if subscriber_queue is not None: + try: + await stream_registry.unsubscribe_from_task( + task_id, subscriber_queue + ) + except Exception as unsub_err: + logger.error( + f"Error unsubscribing from task {task_id}: {unsub_err}", + exc_info=True, + ) + # AI SDK protocol termination - always yield even if unsubscribe fails + yield "data: [DONE]\n\n" return StreamingResponse( event_generator(), @@ -366,6 +470,251 @@ async def session_assign_user( return {"status": "ok"} +# ========== Task Streaming (SSE Reconnection) ========== + + +@router.get( + "/tasks/{task_id}/stream", +) +async def stream_task( + task_id: str, + user_id: str | None = Depends(auth.get_user_id), + last_message_id: str = Query( + default="0-0", + description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.", + ), +): + """ + Reconnect to a long-running task's SSE stream. + + When a long-running operation (like agent generation) starts, the client + receives a task_id. If the connection drops, the client can reconnect + using this endpoint to resume receiving updates. + + Args: + task_id: The task ID from the operation_started response. + user_id: Authenticated user ID for ownership validation. + last_message_id: Last Redis Stream message ID received ("0-0" for full replay). + + Returns: + StreamingResponse: SSE-formatted response chunks starting after last_message_id. + + Raises: + HTTPException: 404 if task not found, 410 if task expired, 403 if access denied. + """ + # Check task existence and expiry before subscribing + task, error_code = await stream_registry.get_task_with_expiry_info(task_id) + + if error_code == "TASK_EXPIRED": + raise HTTPException( + status_code=410, + detail={ + "code": "TASK_EXPIRED", + "message": "This operation has expired. Please try again.", + }, + ) + + if error_code == "TASK_NOT_FOUND": + raise HTTPException( + status_code=404, + detail={ + "code": "TASK_NOT_FOUND", + "message": f"Task {task_id} not found.", + }, + ) + + # Validate ownership if task has an owner + if task and task.user_id and user_id != task.user_id: + raise HTTPException( + status_code=403, + detail={ + "code": "ACCESS_DENIED", + "message": "You do not have access to this task.", + }, + ) + + # Get subscriber queue from stream registry + subscriber_queue = await stream_registry.subscribe_to_task( + task_id=task_id, + user_id=user_id, + last_message_id=last_message_id, + ) + + if subscriber_queue is None: + raise HTTPException( + status_code=404, + detail={ + "code": "TASK_NOT_FOUND", + "message": f"Task {task_id} not found or access denied.", + }, + ) + + async def event_generator() -> AsyncGenerator[str, None]: + import asyncio + + heartbeat_interval = 15.0 # Send heartbeat every 15 seconds + try: + while True: + try: + # Wait for next chunk with timeout for heartbeats + chunk = await asyncio.wait_for( + subscriber_queue.get(), timeout=heartbeat_interval + ) + yield chunk.to_sse() + + # Check for finish signal + if isinstance(chunk, StreamFinish): + break + except asyncio.TimeoutError: + # Send heartbeat to keep connection alive + yield StreamHeartbeat().to_sse() + except Exception as e: + logger.error(f"Error in task stream {task_id}: {e}", exc_info=True) + finally: + # Unsubscribe when client disconnects or stream ends + try: + await stream_registry.unsubscribe_from_task(task_id, subscriber_queue) + except Exception as unsub_err: + logger.error( + f"Error unsubscribing from task {task_id}: {unsub_err}", + exc_info=True, + ) + # AI SDK protocol termination - always yield even if unsubscribe fails + yield "data: [DONE]\n\n" + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + "x-vercel-ai-ui-message-stream": "v1", + }, + ) + + +@router.get( + "/tasks/{task_id}", +) +async def get_task_status( + task_id: str, + user_id: str | None = Depends(auth.get_user_id), +) -> dict: + """ + Get the status of a long-running task. + + Args: + task_id: The task ID to check. + user_id: Authenticated user ID for ownership validation. + + Returns: + dict: Task status including task_id, status, tool_name, and operation_id. + + Raises: + NotFoundError: If task_id is not found or user doesn't have access. + """ + task = await stream_registry.get_task(task_id) + + if task is None: + raise NotFoundError(f"Task {task_id} not found.") + + # Validate ownership - if task has an owner, requester must match + if task.user_id and user_id != task.user_id: + raise NotFoundError(f"Task {task_id} not found.") + + return { + "task_id": task.task_id, + "session_id": task.session_id, + "status": task.status, + "tool_name": task.tool_name, + "operation_id": task.operation_id, + "created_at": task.created_at.isoformat(), + } + + +# ========== External Completion Webhook ========== + + +@router.post( + "/operations/{operation_id}/complete", + status_code=200, +) +async def complete_operation( + operation_id: str, + request: OperationCompleteRequest, + x_api_key: str | None = Header(default=None), +) -> dict: + """ + External completion webhook for long-running operations. + + Called by Agent Generator (or other services) when an operation completes. + This triggers the stream registry to publish completion and continue LLM generation. + + Args: + operation_id: The operation ID to complete. + request: Completion payload with success status and result/error. + x_api_key: Internal API key for authentication. + + Returns: + dict: Status of the completion. + + Raises: + HTTPException: If API key is invalid or operation not found. + """ + # Validate internal API key - reject if not configured or invalid + if not config.internal_api_key: + logger.error( + "Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured" + ) + raise HTTPException( + status_code=503, + detail="Webhook not available: internal API key not configured", + ) + if x_api_key != config.internal_api_key: + raise HTTPException(status_code=401, detail="Invalid API key") + + # Find task by operation_id + task = await stream_registry.find_task_by_operation_id(operation_id) + if task is None: + raise HTTPException( + status_code=404, + detail=f"Operation {operation_id} not found", + ) + + logger.info( + f"Received completion webhook for operation {operation_id} " + f"(task_id={task.task_id}, success={request.success})" + ) + + if request.success: + await process_operation_success(task, request.result) + else: + await process_operation_failure(task, request.error) + + return {"status": "ok", "task_id": task.task_id} + + +# ========== Configuration ========== + + +@router.get("/config/ttl", status_code=200) +async def get_ttl_config() -> dict: + """ + Get the stream TTL configuration. + + Returns the Time-To-Live settings for chat streams, which determines + how long clients can reconnect to an active stream. + + Returns: + dict: TTL configuration with seconds and milliseconds values. + """ + return { + "stream_ttl_seconds": config.stream_ttl, + "stream_ttl_ms": config.stream_ttl * 1000, + } + + # ========== Health Check ========== diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index 20216162b5..218575085b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -3,9 +3,13 @@ import logging import time from asyncio import CancelledError from collections.abc import AsyncGenerator -from typing import Any +from typing import TYPE_CHECKING, Any, cast import openai + +if TYPE_CHECKING: + from backend.util.prompt import CompressResult + import orjson from langfuse import get_client from openai import ( @@ -15,7 +19,13 @@ from openai import ( PermissionDeniedError, RateLimitError, ) -from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam +from openai.types.chat import ( + ChatCompletionChunk, + ChatCompletionMessageParam, + ChatCompletionStreamOptionsParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolParam, +) from backend.data.redis_client import get_redis_async from backend.data.understanding import ( @@ -26,6 +36,7 @@ from backend.util.exceptions import NotFoundError from backend.util.settings import Settings from . import db as chat_db +from . import stream_registry from .config import ChatConfig from .model import ( ChatMessage, @@ -794,207 +805,58 @@ def _is_region_blocked_error(error: Exception) -> bool: return "not available in your region" in str(error).lower() -async def _summarize_messages( +async def _manage_context_window( messages: list, model: str, api_key: str | None = None, base_url: str | None = None, - timeout: float = 30.0, -) -> str: - """Summarize a list of messages into concise context. +) -> "CompressResult": + """ + Manage context window using the unified compress_context function. - Uses the same model as the chat for higher quality summaries. + This is a thin wrapper that creates an OpenAI client for summarization + and delegates to the shared compression logic in prompt.py. Args: - messages: List of message dicts to summarize - model: Model to use for summarization (same as chat model) - api_key: API key for OpenAI client - base_url: Base URL for OpenAI client - timeout: Request timeout in seconds (default: 30.0) + messages: List of messages in OpenAI format + model: Model name for token counting and summarization + api_key: API key for summarization calls + base_url: Base URL for summarization calls Returns: - Summarized text + CompressResult with compacted messages and metadata """ - # Format messages for summarization - conversation = [] - for msg in messages: - role = msg.get("role", "") - content = msg.get("content", "") - # Include user, assistant, and tool messages (tool outputs are important context) - if content and role in ("user", "assistant", "tool"): - conversation.append(f"{role.upper()}: {content}") - - conversation_text = "\n\n".join(conversation) - - # Handle empty conversation - if not conversation_text: - return "No conversation history available." - - # Truncate conversation to fit within summarization model's context - # gpt-4o-mini has 128k context, but we limit to ~25k tokens (~100k chars) for safety - MAX_CHARS = 100_000 - if len(conversation_text) > MAX_CHARS: - conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]" - - # Call LLM to summarize import openai - summarization_client = openai.AsyncOpenAI( - api_key=api_key, base_url=base_url, timeout=timeout - ) + from backend.util.prompt import compress_context - response = await summarization_client.chat.completions.create( - model=model, - messages=[ - { - "role": "system", - "content": ( - "Create a detailed summary of the conversation so far. " - "This summary will be used as context when continuing the conversation.\n\n" - "Before writing the summary, analyze each message chronologically to identify:\n" - "- User requests and their explicit goals\n" - "- Your approach and key decisions made\n" - "- Technical specifics (file names, tool outputs, function signatures)\n" - "- Errors encountered and resolutions applied\n\n" - "You MUST include ALL of the following sections:\n\n" - "## 1. Primary Request and Intent\n" - "The user's explicit goals and what they are trying to accomplish.\n\n" - "## 2. Key Technical Concepts\n" - "Technologies, frameworks, tools, and patterns being used or discussed.\n\n" - "## 3. Files and Resources Involved\n" - "Specific files examined or modified, with relevant snippets and identifiers.\n\n" - "## 4. Errors and Fixes\n" - "Problems encountered, error messages, and their resolutions. " - "Include any user feedback on fixes.\n\n" - "## 5. Problem Solving\n" - "Issues that have been resolved and how they were addressed.\n\n" - "## 6. All User Messages\n" - "A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n" - "## 7. Pending Tasks\n" - "Work items the user explicitly requested that have not yet been completed.\n\n" - "## 8. Current Work\n" - "Precise description of what was being worked on most recently, including relevant context.\n\n" - "## 9. Next Steps\n" - "What should happen next, aligned with the user's most recent requests. " - "Include verbatim quotes of recent instructions if relevant." - ), - }, - {"role": "user", "content": f"Summarize:\n\n{conversation_text}"}, - ], - max_tokens=1500, - temperature=0.3, - ) + # Convert messages to dict format + messages_dict = [] + for msg in messages: + if isinstance(msg, dict): + msg_dict = {k: v for k, v in msg.items() if v is not None} + else: + msg_dict = dict(msg) + messages_dict.append(msg_dict) - summary = response.choices[0].message.content - return summary or "No summary available." - - -def _ensure_tool_pairs_intact( - recent_messages: list[dict], - all_messages: list[dict], - start_index: int, -) -> list[dict]: - """ - Ensure tool_call/tool_response pairs stay together after slicing. - - When slicing messages for context compaction, a naive slice can separate - an assistant message containing tool_calls from its corresponding tool - response messages. This causes API validation errors (e.g., Anthropic's - "unexpected tool_use_id found in tool_result blocks"). - - This function checks for orphan tool responses in the slice and extends - backwards to include their corresponding assistant messages. - - Args: - recent_messages: The sliced messages to validate - all_messages: The complete message list (for looking up missing assistants) - start_index: The index in all_messages where recent_messages begins - - Returns: - A potentially extended list of messages with tool pairs intact - """ - if not recent_messages: - return recent_messages - - # Collect all tool_call_ids from assistant messages in the slice - available_tool_call_ids: set[str] = set() - for msg in recent_messages: - if msg.get("role") == "assistant" and msg.get("tool_calls"): - for tc in msg["tool_calls"]: - tc_id = tc.get("id") - if tc_id: - available_tool_call_ids.add(tc_id) - - # Find orphan tool responses (tool messages whose tool_call_id is missing) - orphan_tool_call_ids: set[str] = set() - for msg in recent_messages: - if msg.get("role") == "tool": - tc_id = msg.get("tool_call_id") - if tc_id and tc_id not in available_tool_call_ids: - orphan_tool_call_ids.add(tc_id) - - if not orphan_tool_call_ids: - # No orphans, slice is valid - return recent_messages - - # Find the assistant messages that contain the orphan tool_call_ids - # Search backwards from start_index in all_messages - messages_to_prepend: list[dict] = [] - for i in range(start_index - 1, -1, -1): - msg = all_messages[i] - if msg.get("role") == "assistant" and msg.get("tool_calls"): - msg_tool_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")} - if msg_tool_ids & orphan_tool_call_ids: - # This assistant message has tool_calls we need - # Also collect its contiguous tool responses that follow it - assistant_and_responses: list[dict] = [msg] - - # Scan forward from this assistant to collect tool responses - for j in range(i + 1, start_index): - following_msg = all_messages[j] - if following_msg.get("role") == "tool": - tool_id = following_msg.get("tool_call_id") - if tool_id and tool_id in msg_tool_ids: - assistant_and_responses.append(following_msg) - else: - # Stop at first non-tool message - break - - # Prepend the assistant and its tool responses (maintain order) - messages_to_prepend = assistant_and_responses + messages_to_prepend - # Mark these as found - orphan_tool_call_ids -= msg_tool_ids - # Also add this assistant's tool_call_ids to available set - available_tool_call_ids |= msg_tool_ids - - if not orphan_tool_call_ids: - # Found all missing assistants - break - - if orphan_tool_call_ids: - # Some tool_call_ids couldn't be resolved - remove those tool responses - # This shouldn't happen in normal operation but handles edge cases - logger.warning( - f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. " - "Removing orphan tool responses." - ) - recent_messages = [ - msg - for msg in recent_messages - if not ( - msg.get("role") == "tool" - and msg.get("tool_call_id") in orphan_tool_call_ids + # Only create client if api_key is provided (enables summarization) + # Use context manager to avoid socket leaks + if api_key: + async with openai.AsyncOpenAI( + api_key=api_key, base_url=base_url, timeout=30.0 + ) as client: + return await compress_context( + messages=messages_dict, + model=model, + client=client, ) - ] - - if messages_to_prepend: - logger.info( - f"Extended recent messages by {len(messages_to_prepend)} to preserve " - f"tool_call/tool_response pairs" + else: + # No API key - use truncation-only mode + return await compress_context( + messages=messages_dict, + model=model, + client=None, ) - return messages_to_prepend + recent_messages - - return recent_messages async def _stream_chat_chunks( @@ -1022,11 +884,8 @@ async def _stream_chat_chunks( logger.info("Starting pure chat stream") - # Build messages with system prompt prepended messages = session.to_openai_messages() if system_prompt: - from openai.types.chat import ChatCompletionSystemMessageParam - system_message = ChatCompletionSystemMessageParam( role="system", content=system_prompt, @@ -1034,314 +893,38 @@ async def _stream_chat_chunks( messages = [system_message] + messages # Apply context window management - token_count = 0 # Initialize for exception handler - try: - from backend.util.prompt import estimate_token_count + context_result = await _manage_context_window( + messages=messages, + model=model, + api_key=config.api_key, + base_url=config.base_url, + ) - # Convert to dict for token counting - # OpenAI message types are TypedDicts, so they're already dict-like - messages_dict = [] - for msg in messages: - # TypedDict objects are already dicts, just filter None values - if isinstance(msg, dict): - msg_dict = {k: v for k, v in msg.items() if v is not None} - else: - # Fallback for unexpected types - msg_dict = dict(msg) - messages_dict.append(msg_dict) - - # Estimate tokens using appropriate tokenizer - # Normalize model name for token counting (tiktoken only supports OpenAI models) - token_count_model = model - if "/" in model: - # Strip provider prefix (e.g., "anthropic/claude-opus-4.5" -> "claude-opus-4.5") - token_count_model = model.split("/")[-1] - - # For Claude and other non-OpenAI models, approximate with gpt-4o tokenizer - # Most modern LLMs have similar tokenization (~1 token per 4 chars) - if "claude" in token_count_model.lower() or not any( - known in token_count_model.lower() - for known in ["gpt", "o1", "chatgpt", "text-"] - ): - token_count_model = "gpt-4o" - - # Attempt token counting with error handling - try: - token_count = estimate_token_count(messages_dict, model=token_count_model) - except Exception as token_error: - # If token counting fails, use gpt-4o as fallback approximation - logger.warning( - f"Token counting failed for model {token_count_model}: {token_error}. " - "Using gpt-4o approximation." - ) - token_count = estimate_token_count(messages_dict, model="gpt-4o") - - # If over threshold, summarize old messages - if token_count > 120_000: - KEEP_RECENT = 15 - - # Check if we have a system prompt at the start - has_system_prompt = ( - len(messages) > 0 and messages[0].get("role") == "system" - ) - - # Always attempt mitigation when over limit, even with few messages - if messages: - # Split messages based on whether system prompt exists - # Calculate start index for the slice - slice_start = max(0, len(messages_dict) - KEEP_RECENT) - recent_messages = messages_dict[-KEEP_RECENT:] - - # Ensure tool_call/tool_response pairs stay together - # This prevents API errors from orphan tool responses - recent_messages = _ensure_tool_pairs_intact( - recent_messages, messages_dict, slice_start - ) - - if has_system_prompt: - # Keep system prompt separate, summarize everything between system and recent - system_msg = messages[0] - old_messages_dict = messages_dict[1:-KEEP_RECENT] - else: - # No system prompt, summarize everything except recent - system_msg = None - old_messages_dict = messages_dict[:-KEEP_RECENT] - - # Summarize any non-empty old messages (no minimum threshold) - # If we're over the token limit, we need to compress whatever we can - if old_messages_dict: - # Summarize old messages using the same model as chat - summary_text = await _summarize_messages( - old_messages_dict, - model=model, - api_key=config.api_key, - base_url=config.base_url, - ) - - # Build new message list - # Use assistant role (not system) to prevent privilege escalation - # of user-influenced content to instruction-level authority - from openai.types.chat import ChatCompletionAssistantMessageParam - - summary_msg = ChatCompletionAssistantMessageParam( - role="assistant", - content=( - "[Previous conversation summary — for context only]: " - f"{summary_text}" - ), - ) - - # Rebuild messages based on whether we have a system prompt - if has_system_prompt: - # system_prompt + summary + recent_messages - messages = [system_msg, summary_msg] + recent_messages - else: - # summary + recent_messages (no original system prompt) - messages = [summary_msg] + recent_messages - - logger.info( - f"Context summarized: {token_count} tokens, " - f"summarized {len(old_messages_dict)} old messages, " - f"kept last {KEEP_RECENT} messages" - ) - - # Fallback: If still over limit after summarization, progressively drop recent messages - # This handles edge cases where recent messages are extremely large - new_messages_dict = [] - for msg in messages: - if isinstance(msg, dict): - msg_dict = {k: v for k, v in msg.items() if v is not None} - else: - msg_dict = dict(msg) - new_messages_dict.append(msg_dict) - - new_token_count = estimate_token_count( - new_messages_dict, model=token_count_model - ) - - if new_token_count > 120_000: - # Still over limit - progressively reduce KEEP_RECENT - logger.warning( - f"Still over limit after summarization: {new_token_count} tokens. " - "Reducing number of recent messages kept." - ) - - for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]: - if keep_count == 0: - # Try with just system prompt + summary (no recent messages) - if has_system_prompt: - messages = [system_msg, summary_msg] - else: - messages = [summary_msg] - logger.info( - "Trying with 0 recent messages (system + summary only)" - ) - else: - # Slice from ORIGINAL recent_messages to avoid duplicating summary - reduced_recent = ( - recent_messages[-keep_count:] - if len(recent_messages) >= keep_count - else recent_messages - ) - # Ensure tool pairs stay intact in the reduced slice - reduced_slice_start = max( - 0, len(recent_messages) - keep_count - ) - reduced_recent = _ensure_tool_pairs_intact( - reduced_recent, recent_messages, reduced_slice_start - ) - if has_system_prompt: - messages = [ - system_msg, - summary_msg, - ] + reduced_recent - else: - messages = [summary_msg] + reduced_recent - - new_messages_dict = [] - for msg in messages: - if isinstance(msg, dict): - msg_dict = { - k: v for k, v in msg.items() if v is not None - } - else: - msg_dict = dict(msg) - new_messages_dict.append(msg_dict) - - new_token_count = estimate_token_count( - new_messages_dict, model=token_count_model - ) - - if new_token_count <= 120_000: - logger.info( - f"Reduced to {keep_count} recent messages, " - f"now {new_token_count} tokens" - ) - break - else: - logger.error( - f"Unable to reduce token count below threshold even with 0 messages. " - f"Final count: {new_token_count} tokens" - ) - # ABSOLUTE LAST RESORT: Drop system prompt - # This should only happen if summary itself is massive - if has_system_prompt and len(messages) > 1: - messages = messages[1:] # Drop system prompt - logger.critical( - "CRITICAL: Dropped system prompt as absolute last resort. " - "Behavioral consistency may be affected." - ) - # Yield error to user - yield StreamError( - errorText=( - "Warning: System prompt dropped due to size constraints. " - "Assistant behavior may be affected." - ) - ) - else: - # No old messages to summarize - all messages are "recent" - # Apply progressive truncation to reduce token count - logger.warning( - f"Token count {token_count} exceeds threshold but no old messages to summarize. " - f"Applying progressive truncation to recent messages." - ) - - # Create a base list excluding system prompt to avoid duplication - # This is the pool of messages we'll slice from in the loop - # Use messages_dict for type consistency with _ensure_tool_pairs_intact - base_msgs = ( - messages_dict[1:] if has_system_prompt else messages_dict - ) - - # Try progressively smaller keep counts - new_token_count = token_count # Initialize with current count - for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]: - if keep_count == 0: - # Try with just system prompt (no recent messages) - if has_system_prompt: - messages = [system_msg] - logger.info( - "Trying with 0 recent messages (system prompt only)" - ) - else: - # No system prompt and no recent messages = empty messages list - # This is invalid, skip this iteration - continue - else: - if len(base_msgs) < keep_count: - continue # Skip if we don't have enough messages - - # Slice from base_msgs to get recent messages (without system prompt) - recent_messages = base_msgs[-keep_count:] - - # Ensure tool pairs stay intact in the reduced slice - reduced_slice_start = max(0, len(base_msgs) - keep_count) - recent_messages = _ensure_tool_pairs_intact( - recent_messages, base_msgs, reduced_slice_start - ) - - if has_system_prompt: - messages = [system_msg] + recent_messages - else: - messages = recent_messages - - new_messages_dict = [] - for msg in messages: - if msg is None: - continue # Skip None messages (type safety) - if isinstance(msg, dict): - msg_dict = { - k: v for k, v in msg.items() if v is not None - } - else: - msg_dict = dict(msg) - new_messages_dict.append(msg_dict) - - new_token_count = estimate_token_count( - new_messages_dict, model=token_count_model - ) - - if new_token_count <= 120_000: - logger.info( - f"Reduced to {keep_count} recent messages, " - f"now {new_token_count} tokens" - ) - break - else: - # Even with 0 messages still over limit - logger.error( - f"Unable to reduce token count below threshold even with 0 messages. " - f"Final count: {new_token_count} tokens. Messages may be extremely large." - ) - # ABSOLUTE LAST RESORT: Drop system prompt - if has_system_prompt and len(messages) > 1: - messages = messages[1:] # Drop system prompt - logger.critical( - "CRITICAL: Dropped system prompt as absolute last resort. " - "Behavioral consistency may be affected." - ) - # Yield error to user - yield StreamError( - errorText=( - "Warning: System prompt dropped due to size constraints. " - "Assistant behavior may be affected." - ) - ) - - except Exception as e: - logger.error(f"Context summarization failed: {e}", exc_info=True) - # If we were over the token limit, yield error to user - # Don't silently continue with oversized messages that will fail - if token_count > 120_000: + if context_result.error: + if "System prompt dropped" in context_result.error: + # Warning only - continue with reduced context yield StreamError( errorText=( - f"Unable to manage context window (token limit exceeded: {token_count} tokens). " - "Context summarization failed. Please start a new conversation." + "Warning: System prompt dropped due to size constraints. " + "Assistant behavior may be affected." + ) + ) + else: + # Any other error - abort to prevent failed LLM calls + yield StreamError( + errorText=( + f"Context window management failed: {context_result.error}. " + "Please start a new conversation." ) ) yield StreamFinish() return - # Otherwise, continue with original messages (under limit) + + messages = context_result.messages + if context_result.was_compacted: + logger.info( + f"Context compacted for streaming: {context_result.token_count} tokens" + ) # Loop to handle tool calls and continue conversation while True: @@ -1369,14 +952,6 @@ async def _stream_chat_chunks( :128 ] # OpenRouter limit - # Create the stream with proper types - from typing import cast - - from openai.types.chat import ( - ChatCompletionMessageParam, - ChatCompletionStreamOptionsParam, - ) - stream = await client.chat.completions.create( model=model, messages=cast(list[ChatCompletionMessageParam], messages), @@ -1610,8 +1185,9 @@ async def _yield_tool_call( ) return - # Generate operation ID + # Generate operation ID and task ID operation_id = str(uuid_module.uuid4()) + task_id = str(uuid_module.uuid4()) # Build a user-friendly message based on tool and arguments if tool_name == "create_agent": @@ -1654,6 +1230,16 @@ async def _yield_tool_call( # Wrap session save and task creation in try-except to release lock on failure try: + # Create task in stream registry for SSE reconnection support + await stream_registry.create_task( + task_id=task_id, + session_id=session.session_id, + user_id=session.user_id, + tool_call_id=tool_call_id, + tool_name=tool_name, + operation_id=operation_id, + ) + # Save assistant message with tool_call FIRST (required by LLM) assistant_message = ChatMessage( role="assistant", @@ -1675,23 +1261,27 @@ async def _yield_tool_call( session.messages.append(pending_message) await upsert_chat_session(session) logger.info( - f"Saved pending operation {operation_id} for tool {tool_name} " - f"in session {session.session_id}" + f"Saved pending operation {operation_id} (task_id={task_id}) " + f"for tool {tool_name} in session {session.session_id}" ) # Store task reference in module-level set to prevent GC before completion - task = asyncio.create_task( - _execute_long_running_tool( + bg_task = asyncio.create_task( + _execute_long_running_tool_with_streaming( tool_name=tool_name, parameters=arguments, tool_call_id=tool_call_id, operation_id=operation_id, + task_id=task_id, session_id=session.session_id, user_id=session.user_id, ) ) - _background_tasks.add(task) - task.add_done_callback(_background_tasks.discard) + _background_tasks.add(bg_task) + bg_task.add_done_callback(_background_tasks.discard) + + # Associate the asyncio task with the stream registry task + await stream_registry.set_task_asyncio_task(task_id, bg_task) except Exception as e: # Roll back appended messages to prevent data corruption on subsequent saves if ( @@ -1709,6 +1299,11 @@ async def _yield_tool_call( # Release the Redis lock since the background task won't be spawned await _mark_operation_completed(tool_call_id) + # Mark stream registry task as failed if it was created + try: + await stream_registry.mark_task_completed(task_id, status="failed") + except Exception: + pass logger.error( f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True ) @@ -1722,6 +1317,7 @@ async def _yield_tool_call( message=started_msg, operation_id=operation_id, tool_name=tool_name, + task_id=task_id, # Include task_id for SSE reconnection ).model_dump_json(), success=True, ) @@ -1791,6 +1387,9 @@ async def _execute_long_running_tool( This function runs independently of the SSE connection, so the operation survives if the user closes their browser tab. + + NOTE: This is the legacy function without stream registry support. + Use _execute_long_running_tool_with_streaming for new implementations. """ try: # Load fresh session (not stale reference) @@ -1834,10 +1433,142 @@ async def _execute_long_running_tool( tool_call_id=tool_call_id, result=error_response.model_dump_json(), ) + # Generate LLM continuation so user sees explanation even for errors + try: + await _generate_llm_continuation(session_id=session_id, user_id=user_id) + except Exception as llm_err: + logger.warning(f"Failed to generate LLM continuation for error: {llm_err}") finally: await _mark_operation_completed(tool_call_id) +async def _execute_long_running_tool_with_streaming( + tool_name: str, + parameters: dict[str, Any], + tool_call_id: str, + operation_id: str, + task_id: str, + session_id: str, + user_id: str | None, +) -> None: + """Execute a long-running tool with stream registry support for SSE reconnection. + + This function runs independently of the SSE connection, publishes progress + to the stream registry, and survives if the user closes their browser tab. + Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming. + + If the external service returns a 202 Accepted (async), this function exits + early and lets the Redis Streams completion consumer handle the rest. + """ + # Track whether we delegated to async processing - if so, the Redis Streams + # completion consumer (stream_registry / completion_consumer) will handle cleanup, not us + delegated_to_async = False + + try: + # Load fresh session (not stale reference) + session = await get_chat_session(session_id, user_id) + if not session: + logger.error(f"Session {session_id} not found for background tool") + await stream_registry.mark_task_completed(task_id, status="failed") + return + + # Pass operation_id and task_id to the tool for async processing + enriched_parameters = { + **parameters, + "_operation_id": operation_id, + "_task_id": task_id, + } + + # Execute the actual tool + result = await execute_tool( + tool_name=tool_name, + parameters=enriched_parameters, + tool_call_id=tool_call_id, + user_id=user_id, + session=session, + ) + + # Check if the tool result indicates async processing + # (e.g., Agent Generator returned 202 Accepted) + try: + if isinstance(result.output, dict): + result_data = result.output + elif result.output: + result_data = orjson.loads(result.output) + else: + result_data = {} + if result_data.get("status") == "accepted": + logger.info( + f"Tool {tool_name} delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id}). " + f"Redis Streams completion consumer will handle the rest." + ) + # Don't publish result, don't continue with LLM, and don't cleanup + # The Redis Streams consumer (completion_consumer) will handle + # everything when the external service completes via webhook + delegated_to_async = True + return + except (orjson.JSONDecodeError, TypeError): + pass # Not JSON or not async - continue normally + + # Publish tool result to stream registry + await stream_registry.publish_chunk(task_id, result) + + # Update the pending message with result + result_str = ( + result.output + if isinstance(result.output, str) + else orjson.dumps(result.output).decode("utf-8") + ) + await _update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=result_str, + ) + + logger.info( + f"Background tool {tool_name} completed for session {session_id} " + f"(task_id={task_id})" + ) + + # Generate LLM continuation and stream chunks to registry + await _generate_llm_continuation_with_streaming( + session_id=session_id, + user_id=user_id, + task_id=task_id, + ) + + # Mark task as completed in stream registry + await stream_registry.mark_task_completed(task_id, status="completed") + + except Exception as e: + logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True) + error_response = ErrorResponse( + message=f"Tool {tool_name} failed: {str(e)}", + ) + + # Publish error to stream registry followed by finish event + await stream_registry.publish_chunk( + task_id, + StreamError(errorText=str(e)), + ) + await stream_registry.publish_chunk(task_id, StreamFinish()) + + await _update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=error_response.model_dump_json(), + ) + + # Mark task as failed in stream registry + await stream_registry.mark_task_completed(task_id, status="failed") + finally: + # Only cleanup if we didn't delegate to async processing + # For async path, the Redis Streams completion consumer handles cleanup + if not delegated_to_async: + await _mark_operation_completed(tool_call_id) + + async def _update_pending_operation( session_id: str, tool_call_id: str, @@ -1895,17 +1626,36 @@ async def _generate_llm_continuation( # Build system prompt system_prompt, _ = await _build_system_prompt(user_id) - # Build messages in OpenAI format messages = session.to_openai_messages() if system_prompt: - from openai.types.chat import ChatCompletionSystemMessageParam - system_message = ChatCompletionSystemMessageParam( role="system", content=system_prompt, ) messages = [system_message] + messages + # Apply context window management to prevent oversized requests + context_result = await _manage_context_window( + messages=messages, + model=config.model, + api_key=config.api_key, + base_url=config.base_url, + ) + + if context_result.error and "System prompt dropped" not in context_result.error: + logger.error( + f"Context window management failed for session {session_id}: " + f"{context_result.error} (tokens={context_result.token_count})" + ) + return + + messages = context_result.messages + if context_result.was_compacted: + logger.info( + f"Context compacted for LLM continuation: " + f"{context_result.token_count} tokens" + ) + # Build extra_body for tracing extra_body: dict[str, Any] = { "posthogProperties": { @@ -1918,19 +1668,54 @@ async def _generate_llm_continuation( if session_id: extra_body["session_id"] = session_id[:128] - # Make non-streaming LLM call (no tools - just text response) - from typing import cast + retry_count = 0 + last_error: Exception | None = None + response = None - from openai.types.chat import ChatCompletionMessageParam + while retry_count <= MAX_RETRIES: + try: + logger.info( + f"Generating LLM continuation for session {session_id}" + f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}" + ) - # No tools parameter = text-only response (no tool calls) - response = await client.chat.completions.create( - model=config.model, - messages=cast(list[ChatCompletionMessageParam], messages), - extra_body=extra_body, - ) + response = await client.chat.completions.create( + model=config.model, + messages=cast(list[ChatCompletionMessageParam], messages), + extra_body=extra_body, + ) + last_error = None # Clear any previous error on success + break # Success, exit retry loop + except Exception as e: + last_error = e + if _is_retryable_error(e) and retry_count < MAX_RETRIES: + retry_count += 1 + delay = min( + BASE_DELAY_SECONDS * (2 ** (retry_count - 1)), + MAX_DELAY_SECONDS, + ) + logger.warning( + f"Retryable error in LLM continuation: {e!s}. " + f"Retrying in {delay:.1f}s (attempt {retry_count}/{MAX_RETRIES})" + ) + await asyncio.sleep(delay) + continue + else: + # Non-retryable error - log and exit gracefully + logger.error( + f"Non-retryable error in LLM continuation: {e!s}", + exc_info=True, + ) + return - if response.choices and response.choices[0].message.content: + if last_error: + logger.error( + f"Max retries ({MAX_RETRIES}) exceeded for LLM continuation. " + f"Last error: {last_error!s}" + ) + return + + if response and response.choices and response.choices[0].message.content: assistant_content = response.choices[0].message.content # Reload session from DB to avoid race condition with user messages @@ -1964,3 +1749,128 @@ async def _generate_llm_continuation( except Exception as e: logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True) + + +async def _generate_llm_continuation_with_streaming( + session_id: str, + user_id: str | None, + task_id: str, +) -> None: + """Generate an LLM response with streaming to the stream registry. + + This is called by background tasks to continue the conversation + after a tool result is saved. Chunks are published to the stream registry + so reconnecting clients can receive them. + """ + import uuid as uuid_module + + try: + # Load fresh session from DB (bypass cache to get the updated tool result) + await invalidate_session_cache(session_id) + session = await get_chat_session(session_id, user_id) + if not session: + logger.error(f"Session {session_id} not found for LLM continuation") + return + + # Build system prompt + system_prompt, _ = await _build_system_prompt(user_id) + + # Build messages in OpenAI format + messages = session.to_openai_messages() + if system_prompt: + from openai.types.chat import ChatCompletionSystemMessageParam + + system_message = ChatCompletionSystemMessageParam( + role="system", + content=system_prompt, + ) + messages = [system_message] + messages + + # Build extra_body for tracing + extra_body: dict[str, Any] = { + "posthogProperties": { + "environment": settings.config.app_env.value, + }, + } + if user_id: + extra_body["user"] = user_id[:128] + extra_body["posthogDistinctId"] = user_id + if session_id: + extra_body["session_id"] = session_id[:128] + + # Make streaming LLM call (no tools - just text response) + from typing import cast + + from openai.types.chat import ChatCompletionMessageParam + + # Generate unique IDs for AI SDK protocol + message_id = str(uuid_module.uuid4()) + text_block_id = str(uuid_module.uuid4()) + + # Publish start event + await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id)) + await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id)) + + # Stream the response + stream = await client.chat.completions.create( + model=config.model, + messages=cast(list[ChatCompletionMessageParam], messages), + extra_body=extra_body, + stream=True, + ) + + assistant_content = "" + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + delta = chunk.choices[0].delta.content + assistant_content += delta + # Publish delta to stream registry + await stream_registry.publish_chunk( + task_id, + StreamTextDelta(id=text_block_id, delta=delta), + ) + + # Publish end events + await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id)) + + if assistant_content: + # Reload session from DB to avoid race condition with user messages + fresh_session = await get_chat_session(session_id, user_id) + if not fresh_session: + logger.error( + f"Session {session_id} disappeared during LLM continuation" + ) + return + + # Save assistant message to database + assistant_message = ChatMessage( + role="assistant", + content=assistant_content, + ) + fresh_session.messages.append(assistant_message) + + # Save to database (not cache) to persist the response + await upsert_chat_session(fresh_session) + + # Invalidate cache so next poll/refresh gets fresh data + await invalidate_session_cache(session_id) + + logger.info( + f"Generated streaming LLM continuation for session {session_id} " + f"(task_id={task_id}), response length: {len(assistant_content)}" + ) + else: + logger.warning( + f"Streaming LLM continuation returned empty response for {session_id}" + ) + + except Exception as e: + logger.error( + f"Failed to generate streaming LLM continuation: {e}", exc_info=True + ) + # Publish error to stream registry followed by finish event + await stream_registry.publish_chunk( + task_id, + StreamError(errorText=f"Failed to generate response: {e}"), + ) + await stream_registry.publish_chunk(task_id, StreamFinish()) diff --git a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py new file mode 100644 index 0000000000..88a5023e2b --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py @@ -0,0 +1,704 @@ +"""Stream registry for managing reconnectable SSE streams. + +This module provides a registry for tracking active streaming tasks and their +messages. It uses Redis for all state management (no in-memory state), making +pods stateless and horizontally scalable. + +Architecture: +- Redis Stream: Persists all messages for replay and real-time delivery +- Redis Hash: Task metadata (status, session_id, etc.) + +Subscribers: +1. Replay missed messages from Redis Stream (XREAD) +2. Listen for live updates via blocking XREAD +3. No in-memory state required on the subscribing pod +""" + +import asyncio +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Literal + +import orjson + +from backend.data.redis_client import get_redis_async + +from .config import ChatConfig +from .response_model import StreamBaseResponse, StreamError, StreamFinish + +logger = logging.getLogger(__name__) +config = ChatConfig() + +# Track background tasks for this pod (just the asyncio.Task reference, not subscribers) +_local_tasks: dict[str, asyncio.Task] = {} + +# Track listener tasks per subscriber queue for cleanup +# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe +_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {} + +# Timeout for putting chunks into subscriber queues (seconds) +# If the queue is full and doesn't drain within this time, send an overflow error +QUEUE_PUT_TIMEOUT = 5.0 + +# Lua script for atomic compare-and-swap status update (idempotent completion) +# Returns 1 if status was updated, 0 if already completed/failed +COMPLETE_TASK_SCRIPT = """ +local current = redis.call("HGET", KEYS[1], "status") +if current == "running" then + redis.call("HSET", KEYS[1], "status", ARGV[1]) + return 1 +end +return 0 +""" + + +@dataclass +class ActiveTask: + """Represents an active streaming task (metadata only, no in-memory queues).""" + + task_id: str + session_id: str + user_id: str | None + tool_call_id: str + tool_name: str + operation_id: str + status: Literal["running", "completed", "failed"] = "running" + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + asyncio_task: asyncio.Task | None = None + + +def _get_task_meta_key(task_id: str) -> str: + """Get Redis key for task metadata.""" + return f"{config.task_meta_prefix}{task_id}" + + +def _get_task_stream_key(task_id: str) -> str: + """Get Redis key for task message stream.""" + return f"{config.task_stream_prefix}{task_id}" + + +def _get_operation_mapping_key(operation_id: str) -> str: + """Get Redis key for operation_id to task_id mapping.""" + return f"{config.task_op_prefix}{operation_id}" + + +async def create_task( + task_id: str, + session_id: str, + user_id: str | None, + tool_call_id: str, + tool_name: str, + operation_id: str, +) -> ActiveTask: + """Create a new streaming task in Redis. + + Args: + task_id: Unique identifier for the task + session_id: Chat session ID + user_id: User ID (may be None for anonymous) + tool_call_id: Tool call ID from the LLM + tool_name: Name of the tool being executed + operation_id: Operation ID for webhook callbacks + + Returns: + The created ActiveTask instance (metadata only) + """ + task = ActiveTask( + task_id=task_id, + session_id=session_id, + user_id=user_id, + tool_call_id=tool_call_id, + tool_name=tool_name, + operation_id=operation_id, + ) + + # Store metadata in Redis + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + op_key = _get_operation_mapping_key(operation_id) + + await redis.hset( # type: ignore[misc] + meta_key, + mapping={ + "task_id": task_id, + "session_id": session_id, + "user_id": user_id or "", + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "operation_id": operation_id, + "status": task.status, + "created_at": task.created_at.isoformat(), + }, + ) + await redis.expire(meta_key, config.stream_ttl) + + # Create operation_id -> task_id mapping for webhook lookups + await redis.set(op_key, task_id, ex=config.stream_ttl) + + logger.debug(f"Created task {task_id} for session {session_id}") + + return task + + +async def publish_chunk( + task_id: str, + chunk: StreamBaseResponse, +) -> str: + """Publish a chunk to Redis Stream. + + All delivery is via Redis Streams - no in-memory state. + + Args: + task_id: Task ID to publish to + chunk: The stream response chunk to publish + + Returns: + The Redis Stream message ID + """ + chunk_json = chunk.model_dump_json() + message_id = "0-0" + + try: + redis = await get_redis_async() + stream_key = _get_task_stream_key(task_id) + + # Write to Redis Stream for persistence and real-time delivery + raw_id = await redis.xadd( + stream_key, + {"data": chunk_json}, + maxlen=config.stream_max_length, + ) + message_id = raw_id if isinstance(raw_id, str) else raw_id.decode() + + # Set TTL on stream to match task metadata TTL + await redis.expire(stream_key, config.stream_ttl) + except Exception as e: + logger.error( + f"Failed to publish chunk for task {task_id}: {e}", + exc_info=True, + ) + + return message_id + + +async def subscribe_to_task( + task_id: str, + user_id: str | None, + last_message_id: str = "0-0", +) -> asyncio.Queue[StreamBaseResponse] | None: + """Subscribe to a task's stream with replay of missed messages. + + This is fully stateless - uses Redis Stream for replay and pub/sub for live updates. + + Args: + task_id: Task ID to subscribe to + user_id: User ID for ownership validation + last_message_id: Last Redis Stream message ID received ("0-0" for full replay) + + Returns: + An asyncio Queue that will receive stream chunks, or None if task not found + or user doesn't have access + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + + if not meta: + logger.debug(f"Task {task_id} not found in Redis") + return None + + # Note: Redis client uses decode_responses=True, so keys are strings + task_status = meta.get("status", "") + task_user_id = meta.get("user_id", "") or None + + # Validate ownership - if task has an owner, requester must match + if task_user_id: + if user_id != task_user_id: + logger.warning( + f"User {user_id} denied access to task {task_id} " + f"owned by {task_user_id}" + ) + return None + + subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue() + stream_key = _get_task_stream_key(task_id) + + # Step 1: Replay messages from Redis Stream + messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000) + + replayed_count = 0 + replay_last_id = last_message_id + if messages: + for _stream_name, stream_messages in messages: + for msg_id, msg_data in stream_messages: + replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + # Note: Redis client uses decode_responses=True, so keys are strings + if "data" in msg_data: + try: + chunk_data = orjson.loads(msg_data["data"]) + chunk = _reconstruct_chunk(chunk_data) + if chunk: + await subscriber_queue.put(chunk) + replayed_count += 1 + except Exception as e: + logger.warning(f"Failed to replay message: {e}") + + logger.debug(f"Task {task_id}: replayed {replayed_count} messages") + + # Step 2: If task is still running, start stream listener for live updates + if task_status == "running": + listener_task = asyncio.create_task( + _stream_listener(task_id, subscriber_queue, replay_last_id) + ) + # Track listener task for cleanup on unsubscribe + _listener_tasks[id(subscriber_queue)] = (task_id, listener_task) + else: + # Task is completed/failed - add finish marker + await subscriber_queue.put(StreamFinish()) + + return subscriber_queue + + +async def _stream_listener( + task_id: str, + subscriber_queue: asyncio.Queue[StreamBaseResponse], + last_replayed_id: str, +) -> None: + """Listen to Redis Stream for new messages using blocking XREAD. + + This approach avoids the duplicate message issue that can occur with pub/sub + when messages are published during the gap between replay and subscription. + + Args: + task_id: Task ID to listen for + subscriber_queue: Queue to deliver messages to + last_replayed_id: Last message ID from replay (continue from here) + """ + queue_id = id(subscriber_queue) + # Track the last successfully delivered message ID for recovery hints + last_delivered_id = last_replayed_id + + try: + redis = await get_redis_async() + stream_key = _get_task_stream_key(task_id) + current_id = last_replayed_id + + while True: + # Block for up to 30 seconds waiting for new messages + # This allows periodic checking if task is still running + messages = await redis.xread( + {stream_key: current_id}, block=30000, count=100 + ) + + if not messages: + # Timeout - check if task is still running + meta_key = _get_task_meta_key(task_id) + status = await redis.hget(meta_key, "status") # type: ignore[misc] + if status and status != "running": + try: + await asyncio.wait_for( + subscriber_queue.put(StreamFinish()), + timeout=QUEUE_PUT_TIMEOUT, + ) + except asyncio.TimeoutError: + logger.warning( + f"Timeout delivering finish event for task {task_id}" + ) + break + continue + + for _stream_name, stream_messages in messages: + for msg_id, msg_data in stream_messages: + current_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + + if "data" not in msg_data: + continue + + try: + chunk_data = orjson.loads(msg_data["data"]) + chunk = _reconstruct_chunk(chunk_data) + if chunk: + try: + await asyncio.wait_for( + subscriber_queue.put(chunk), + timeout=QUEUE_PUT_TIMEOUT, + ) + # Update last delivered ID on successful delivery + last_delivered_id = current_id + except asyncio.TimeoutError: + logger.warning( + f"Subscriber queue full for task {task_id}, " + f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s" + ) + # Send overflow error with recovery info + try: + overflow_error = StreamError( + errorText="Message delivery timeout - some messages may have been missed", + code="QUEUE_OVERFLOW", + details={ + "last_delivered_id": last_delivered_id, + "recovery_hint": f"Reconnect with last_message_id={last_delivered_id}", + }, + ) + subscriber_queue.put_nowait(overflow_error) + except asyncio.QueueFull: + # Queue is completely stuck, nothing more we can do + logger.error( + f"Cannot deliver overflow error for task {task_id}, " + "queue completely blocked" + ) + + # Stop listening on finish + if isinstance(chunk, StreamFinish): + return + except Exception as e: + logger.warning(f"Error processing stream message: {e}") + + except asyncio.CancelledError: + logger.debug(f"Stream listener cancelled for task {task_id}") + raise # Re-raise to propagate cancellation + except Exception as e: + logger.error(f"Stream listener error for task {task_id}: {e}") + # On error, send finish to unblock subscriber + try: + await asyncio.wait_for( + subscriber_queue.put(StreamFinish()), + timeout=QUEUE_PUT_TIMEOUT, + ) + except (asyncio.TimeoutError, asyncio.QueueFull): + logger.warning( + f"Could not deliver finish event for task {task_id} after error" + ) + finally: + # Clean up listener task mapping on exit + _listener_tasks.pop(queue_id, None) + + +async def mark_task_completed( + task_id: str, + status: Literal["completed", "failed"] = "completed", +) -> bool: + """Mark a task as completed and publish finish event. + + This is idempotent - calling multiple times with the same task_id is safe. + Uses atomic compare-and-swap via Lua script to prevent race conditions. + Status is updated first (source of truth), then finish event is published (best-effort). + + Args: + task_id: Task ID to mark as completed + status: Final status ("completed" or "failed") + + Returns: + True if task was newly marked completed, False if already completed/failed + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + + # Atomic compare-and-swap: only update if status is "running" + # This prevents race conditions when multiple callers try to complete simultaneously + result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc] + + if result == 0: + logger.debug(f"Task {task_id} already completed/failed, skipping") + return False + + # THEN publish finish event (best-effort - listeners can detect via status polling) + try: + await publish_chunk(task_id, StreamFinish()) + except Exception as e: + logger.error( + f"Failed to publish finish event for task {task_id}: {e}. " + "Listeners will detect completion via status polling." + ) + + # Clean up local task reference if exists + _local_tasks.pop(task_id, None) + return True + + +async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None: + """Find a task by its operation ID. + + Used by webhook callbacks to locate the task to update. + + Args: + operation_id: Operation ID to search for + + Returns: + ActiveTask if found, None otherwise + """ + redis = await get_redis_async() + op_key = _get_operation_mapping_key(operation_id) + task_id = await redis.get(op_key) + + if not task_id: + return None + + task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id + return await get_task(task_id_str) + + +async def get_task(task_id: str) -> ActiveTask | None: + """Get a task by its ID from Redis. + + Args: + task_id: Task ID to look up + + Returns: + ActiveTask if found, None otherwise + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + + if not meta: + return None + + # Note: Redis client uses decode_responses=True, so keys/values are strings + return ActiveTask( + task_id=meta.get("task_id", ""), + session_id=meta.get("session_id", ""), + user_id=meta.get("user_id", "") or None, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status=meta.get("status", "running"), # type: ignore[arg-type] + ) + + +async def get_task_with_expiry_info( + task_id: str, +) -> tuple[ActiveTask | None, str | None]: + """Get a task by its ID with expiration detection. + + Returns (task, error_code) where error_code is: + - None if task found + - "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired) + - "TASK_NOT_FOUND" if neither exists + + Args: + task_id: Task ID to look up + + Returns: + Tuple of (ActiveTask or None, error_code or None) + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + stream_key = _get_task_stream_key(task_id) + + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + + if not meta: + # Check if stream still has data (metadata expired but stream hasn't) + stream_len = await redis.xlen(stream_key) + if stream_len > 0: + return None, "TASK_EXPIRED" + return None, "TASK_NOT_FOUND" + + # Note: Redis client uses decode_responses=True, so keys/values are strings + return ( + ActiveTask( + task_id=meta.get("task_id", ""), + session_id=meta.get("session_id", ""), + user_id=meta.get("user_id", "") or None, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status=meta.get("status", "running"), # type: ignore[arg-type] + ), + None, + ) + + +async def get_active_task_for_session( + session_id: str, + user_id: str | None = None, +) -> tuple[ActiveTask | None, str]: + """Get the active (running) task for a session, if any. + + Scans Redis for tasks matching the session_id with status="running". + + Args: + session_id: Session ID to look up + user_id: User ID for ownership validation (optional) + + Returns: + Tuple of (ActiveTask if found and running, last_message_id from Redis Stream) + """ + + redis = await get_redis_async() + + # Scan Redis for task metadata keys + cursor = 0 + tasks_checked = 0 + + while True: + cursor, keys = await redis.scan( + cursor, match=f"{config.task_meta_prefix}*", count=100 + ) + + for key in keys: + tasks_checked += 1 + meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc] + if not meta: + continue + + # Note: Redis client uses decode_responses=True, so keys/values are strings + task_session_id = meta.get("session_id", "") + task_status = meta.get("status", "") + task_user_id = meta.get("user_id", "") or None + task_id = meta.get("task_id", "") + + if task_session_id == session_id and task_status == "running": + # Validate ownership - if task has an owner, requester must match + if task_user_id and user_id != task_user_id: + continue + + # Get the last message ID from Redis Stream + stream_key = _get_task_stream_key(task_id) + last_id = "0-0" + try: + messages = await redis.xrevrange(stream_key, count=1) + if messages: + msg_id = messages[0][0] + last_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + except Exception as e: + logger.warning(f"Failed to get last message ID: {e}") + + return ( + ActiveTask( + task_id=task_id, + session_id=task_session_id, + user_id=task_user_id, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status="running", + ), + last_id, + ) + + if cursor == 0: + break + + return None, "0-0" + + +def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None: + """Reconstruct a StreamBaseResponse from JSON data. + + Args: + chunk_data: Parsed JSON data from Redis + + Returns: + Reconstructed response object, or None if unknown type + """ + from .response_model import ( + ResponseType, + StreamError, + StreamFinish, + StreamHeartbeat, + StreamStart, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, + StreamUsage, + ) + + # Map response types to their corresponding classes + type_to_class: dict[str, type[StreamBaseResponse]] = { + ResponseType.START.value: StreamStart, + ResponseType.FINISH.value: StreamFinish, + ResponseType.TEXT_START.value: StreamTextStart, + ResponseType.TEXT_DELTA.value: StreamTextDelta, + ResponseType.TEXT_END.value: StreamTextEnd, + ResponseType.TOOL_INPUT_START.value: StreamToolInputStart, + ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable, + ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable, + ResponseType.ERROR.value: StreamError, + ResponseType.USAGE.value: StreamUsage, + ResponseType.HEARTBEAT.value: StreamHeartbeat, + } + + chunk_type = chunk_data.get("type") + chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type] + + if chunk_class is None: + logger.warning(f"Unknown chunk type: {chunk_type}") + return None + + try: + return chunk_class(**chunk_data) + except Exception as e: + logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}") + return None + + +async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None: + """Track the asyncio.Task for a task (local reference only). + + This is just for cleanup purposes - the task state is in Redis. + + Args: + task_id: Task ID + asyncio_task: The asyncio Task to track + """ + _local_tasks[task_id] = asyncio_task + + +async def unsubscribe_from_task( + task_id: str, + subscriber_queue: asyncio.Queue[StreamBaseResponse], +) -> None: + """Clean up when a subscriber disconnects. + + Cancels the XREAD-based listener task associated with this subscriber queue + to prevent resource leaks. + + Args: + task_id: Task ID + subscriber_queue: The subscriber's queue used to look up the listener task + """ + queue_id = id(subscriber_queue) + listener_entry = _listener_tasks.pop(queue_id, None) + + if listener_entry is None: + logger.debug( + f"No listener task found for task {task_id} queue {queue_id} " + "(may have already completed)" + ) + return + + stored_task_id, listener_task = listener_entry + + if stored_task_id != task_id: + logger.warning( + f"Task ID mismatch in unsubscribe: expected {task_id}, " + f"found {stored_task_id}" + ) + + if listener_task.done(): + logger.debug(f"Listener task for task {task_id} already completed") + return + + # Cancel the listener task + listener_task.cancel() + + try: + # Wait for the task to be cancelled with a timeout + await asyncio.wait_for(listener_task, timeout=5.0) + except asyncio.CancelledError: + # Expected - the task was successfully cancelled + pass + except asyncio.TimeoutError: + logger.warning( + f"Timeout waiting for listener task cancellation for task {task_id}" + ) + except Exception as e: + logger.error(f"Error during listener task cancellation for task {task_id}: {e}") + + logger.debug(f"Successfully unsubscribed from task {task_id}") diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/IDEAS.md b/autogpt_platform/backend/backend/api/features/chat/tools/IDEAS.md new file mode 100644 index 0000000000..656aac61c4 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/IDEAS.md @@ -0,0 +1,79 @@ +# CoPilot Tools - Future Ideas + +## Multimodal Image Support for CoPilot + +**Problem:** CoPilot uses a vision-capable model but can't "see" workspace images. When a block generates an image and returns `workspace://abc123`, CoPilot can't evaluate it (e.g., checking blog thumbnail quality). + +**Backend Solution:** +When preparing messages for the LLM, detect `workspace://` image references and convert them to proper image content blocks: + +```python +# Before sending to LLM, scan for workspace image references +# and inject them as image content parts + +# Example message transformation: +# FROM: {"role": "assistant", "content": "Generated image: workspace://abc123"} +# TO: {"role": "assistant", "content": [ +# {"type": "text", "text": "Generated image: workspace://abc123"}, +# {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}} +# ]} +``` + +**Where to implement:** +- In the chat stream handler before calling the LLM +- Or in a message preprocessing step +- Need to fetch image from workspace, convert to base64, add as image content + +**Considerations:** +- Only do this for image MIME types (image/png, image/jpeg, etc.) +- May want a size limit (don't pass 10MB images) +- Track which images were "shown" to the AI for frontend indicator +- Cost implications - vision API calls are more expensive + +**Frontend Solution:** +Show visual indicator on workspace files in chat: +- If AI saw the image: normal display +- If AI didn't see it: overlay icon saying "AI can't see this image" + +Requires response metadata indicating which `workspace://` refs were passed to the model. + +--- + +## Output Post-Processing Layer for run_block + +**Problem:** Many blocks produce large outputs that: +- Consume massive context (100KB base64 image = ~133KB tokens) +- Can't fit in conversation +- Break things and cause high LLM costs + +**Proposed Solution:** Instead of modifying individual blocks or `store_media_file()`, implement a centralized output processor in `run_block.py` that handles outputs before they're returned to CoPilot. + +**Benefits:** +1. **Centralized** - one place to handle all output processing +2. **Future-proof** - new blocks automatically get output processing +3. **Keeps blocks pure** - they don't need to know about context constraints +4. **Handles all large outputs** - not just images + +**Processing Rules:** +- Detect base64 data URIs → save to workspace, return `workspace://` reference +- Truncate very long strings (>N chars) with truncation note +- Summarize large arrays/lists (e.g., "Array with 1000 items, first 5: [...]") +- Handle nested large outputs in dicts recursively +- Cap total output size + +**Implementation Location:** `run_block.py` after block execution, before returning `BlockOutputResponse` + +**Example:** +```python +def _process_outputs_for_context( + outputs: dict[str, list[Any]], + workspace_manager: WorkspaceManager, + max_string_length: int = 10000, + max_array_preview: int = 5, +) -> dict[str, list[Any]]: + """Process block outputs to prevent context bloat.""" + processed = {} + for name, values in outputs.items(): + processed[name] = [_process_value(v, workspace_manager) for v in values] + return processed +``` diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py index beeb128ae9..dcbc35ef37 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py @@ -10,6 +10,7 @@ from .add_understanding import AddUnderstandingTool from .agent_output import AgentOutputTool from .base import BaseTool from .create_agent import CreateAgentTool +from .customize_agent import CustomizeAgentTool from .edit_agent import EditAgentTool from .find_agent import FindAgentTool from .find_block import FindBlockTool @@ -18,6 +19,12 @@ from .get_doc_page import GetDocPageTool from .run_agent import RunAgentTool from .run_block import RunBlockTool from .search_docs import SearchDocsTool +from .workspace_files import ( + DeleteWorkspaceFileTool, + ListWorkspaceFilesTool, + ReadWorkspaceFileTool, + WriteWorkspaceFileTool, +) if TYPE_CHECKING: from backend.api.features.chat.response_model import StreamToolOutputAvailable @@ -28,6 +35,7 @@ logger = logging.getLogger(__name__) TOOL_REGISTRY: dict[str, BaseTool] = { "add_understanding": AddUnderstandingTool(), "create_agent": CreateAgentTool(), + "customize_agent": CustomizeAgentTool(), "edit_agent": EditAgentTool(), "find_agent": FindAgentTool(), "find_block": FindBlockTool(), @@ -37,6 +45,11 @@ TOOL_REGISTRY: dict[str, BaseTool] = { "view_agent_output": AgentOutputTool(), "search_docs": SearchDocsTool(), "get_doc_page": GetDocPageTool(), + # Workspace tools for CoPilot file operations + "list_workspace_files": ListWorkspaceFilesTool(), + "read_workspace_file": ReadWorkspaceFileTool(), + "write_workspace_file": WriteWorkspaceFileTool(), + "delete_workspace_file": DeleteWorkspaceFileTool(), } # Export individual tool instances for backwards compatibility diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py index 392f642c41..4266834220 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py @@ -2,27 +2,58 @@ from .core import ( AgentGeneratorNotConfiguredError, + AgentJsonValidationError, + AgentSummary, + DecompositionResult, + DecompositionStep, + LibraryAgentSummary, + MarketplaceAgentSummary, + customize_template, decompose_goal, + enrich_library_agents_from_steps, + extract_search_terms_from_steps, + extract_uuids_from_text, generate_agent, generate_agent_patch, get_agent_as_json, + get_all_relevant_agents_for_generation, + get_library_agent_by_graph_id, + get_library_agent_by_id, + get_library_agents_for_generation, + graph_to_json, json_to_graph, save_agent_to_library, + search_marketplace_agents_for_generation, ) +from .errors import get_user_message_for_error from .service import health_check as check_external_service_health from .service import is_external_service_configured __all__ = [ - # Core functions + "AgentGeneratorNotConfiguredError", + "AgentJsonValidationError", + "AgentSummary", + "DecompositionResult", + "DecompositionStep", + "LibraryAgentSummary", + "MarketplaceAgentSummary", + "check_external_service_health", + "customize_template", "decompose_goal", + "enrich_library_agents_from_steps", + "extract_search_terms_from_steps", + "extract_uuids_from_text", "generate_agent", "generate_agent_patch", - "save_agent_to_library", "get_agent_as_json", - "json_to_graph", - # Exceptions - "AgentGeneratorNotConfiguredError", - # Service + "get_all_relevant_agents_for_generation", + "get_library_agent_by_graph_id", + "get_library_agent_by_id", + "get_library_agents_for_generation", + "get_user_message_for_error", + "graph_to_json", "is_external_service_configured", - "check_external_service_health", + "json_to_graph", + "save_agent_to_library", + "search_marketplace_agents_for_generation", ] diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py index fc15587110..b88b9b2924 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py @@ -1,13 +1,25 @@ """Core agent generation functions.""" import logging +import re import uuid -from typing import Any +from typing import Any, NotRequired, TypedDict from backend.api.features.library import db as library_db -from backend.data.graph import Graph, Link, Node, create_graph +from backend.api.features.store import db as store_db +from backend.data.graph import ( + Graph, + Link, + Node, + create_graph, + get_graph, + get_graph_all_versions, + get_store_listed_graphs, +) +from backend.util.exceptions import DatabaseError, NotFoundError from .service import ( + customize_template_external, decompose_goal_external, generate_agent_external, generate_agent_patch_external, @@ -16,6 +28,74 @@ from .service import ( logger = logging.getLogger(__name__) +AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565" + + +class ExecutionSummary(TypedDict): + """Summary of a single execution for quality assessment.""" + + status: str + correctness_score: NotRequired[float] + activity_summary: NotRequired[str] + + +class LibraryAgentSummary(TypedDict): + """Summary of a library agent for sub-agent composition. + + Includes recent executions to help the LLM decide whether to use this agent. + Each execution shows status, correctness_score (0-1), and activity_summary. + """ + + graph_id: str + graph_version: int + name: str + description: str + input_schema: dict[str, Any] + output_schema: dict[str, Any] + recent_executions: NotRequired[list[ExecutionSummary]] + + +class MarketplaceAgentSummary(TypedDict): + """Summary of a marketplace agent for sub-agent composition.""" + + name: str + description: str + sub_heading: str + creator: str + is_marketplace_agent: bool + + +class DecompositionStep(TypedDict, total=False): + """A single step in decomposed instructions.""" + + description: str + action: str + block_name: str + tool: str + name: str + + +class DecompositionResult(TypedDict, total=False): + """Result from decompose_goal - can be instructions, questions, or error.""" + + type: str + steps: list[DecompositionStep] + questions: list[dict[str, Any]] + error: str + error_type: str + + +AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any] + + +def _to_dict_list( + agents: list[AgentSummary] | list[dict[str, Any]] | None, +) -> list[dict[str, Any]] | None: + """Convert typed agent summaries to plain dicts for external service calls.""" + if agents is None: + return None + return [dict(a) for a in agents] + class AgentGeneratorNotConfiguredError(Exception): """Raised when the external Agent Generator service is not configured.""" @@ -36,15 +116,422 @@ def _check_service_configured() -> None: ) -async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None: +_UUID_PATTERN = re.compile( + r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}", + re.IGNORECASE, +) + + +def extract_uuids_from_text(text: str) -> list[str]: + """Extract all UUID v4 strings from text. + + Args: + text: Text that may contain UUIDs (e.g., user's goal description) + + Returns: + List of unique UUIDs found in the text (lowercase) + """ + matches = _UUID_PATTERN.findall(text) + return list({m.lower() for m in matches}) + + +async def get_library_agent_by_id( + user_id: str, agent_id: str +) -> LibraryAgentSummary | None: + """Fetch a specific library agent by its ID (library agent ID or graph_id). + + This function tries multiple lookup strategies: + 1. First tries to find by graph_id (AgentGraph primary key) + 2. If not found, tries to find by library agent ID (LibraryAgent primary key) + + This handles both cases: + - User provides graph_id (e.g., from AgentExecutorBlock) + - User provides library agent ID (e.g., from library URL) + + Args: + user_id: The user ID + agent_id: The ID to look up (can be graph_id or library agent ID) + + Returns: + LibraryAgentSummary if found, None otherwise + """ + try: + agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id) + if agent: + logger.debug(f"Found library agent by graph_id: {agent.name}") + return LibraryAgentSummary( + graph_id=agent.graph_id, + graph_version=agent.graph_version, + name=agent.name, + description=agent.description, + input_schema=agent.input_schema, + output_schema=agent.output_schema, + ) + except DatabaseError: + raise + except Exception as e: + logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}") + + try: + agent = await library_db.get_library_agent(agent_id, user_id) + if agent: + logger.debug(f"Found library agent by library_id: {agent.name}") + return LibraryAgentSummary( + graph_id=agent.graph_id, + graph_version=agent.graph_version, + name=agent.name, + description=agent.description, + input_schema=agent.input_schema, + output_schema=agent.output_schema, + ) + except NotFoundError: + logger.debug(f"Library agent not found by library_id: {agent_id}") + except DatabaseError: + raise + except Exception as e: + logger.warning( + f"Could not fetch library agent by library_id {agent_id}: {e}", + exc_info=True, + ) + + return None + + +get_library_agent_by_graph_id = get_library_agent_by_id + + +async def get_library_agents_for_generation( + user_id: str, + search_query: str | None = None, + exclude_graph_id: str | None = None, + max_results: int = 15, +) -> list[LibraryAgentSummary]: + """Fetch user's library agents formatted for Agent Generator. + + Uses search-based fetching to return relevant agents instead of all agents. + This is more scalable for users with large libraries. + + Includes recent_executions list to help the LLM assess agent quality: + - Each execution has status, correctness_score (0-1), and activity_summary + - This gives the LLM concrete examples of recent performance + + Args: + user_id: The user ID + search_query: Optional search term to find relevant agents (user's goal/description) + exclude_graph_id: Optional graph ID to exclude (prevents circular references) + max_results: Maximum number of agents to return (default 15) + + Returns: + List of LibraryAgentSummary with schemas and recent executions for sub-agent composition + """ + try: + response = await library_db.list_library_agents( + user_id=user_id, + search_term=search_query, + page=1, + page_size=max_results, + include_executions=True, + ) + + results: list[LibraryAgentSummary] = [] + for agent in response.agents: + if exclude_graph_id is not None and agent.graph_id == exclude_graph_id: + continue + + summary = LibraryAgentSummary( + graph_id=agent.graph_id, + graph_version=agent.graph_version, + name=agent.name, + description=agent.description, + input_schema=agent.input_schema, + output_schema=agent.output_schema, + ) + if agent.recent_executions: + exec_summaries: list[ExecutionSummary] = [] + for ex in agent.recent_executions: + exec_sum = ExecutionSummary(status=ex.status) + if ex.correctness_score is not None: + exec_sum["correctness_score"] = ex.correctness_score + if ex.activity_summary: + exec_sum["activity_summary"] = ex.activity_summary + exec_summaries.append(exec_sum) + summary["recent_executions"] = exec_summaries + results.append(summary) + return results + except DatabaseError: + raise + except Exception as e: + logger.warning(f"Failed to fetch library agents: {e}") + return [] + + +async def search_marketplace_agents_for_generation( + search_query: str, + max_results: int = 10, +) -> list[LibraryAgentSummary]: + """Search marketplace agents formatted for Agent Generator. + + Fetches marketplace agents and their full schemas so they can be used + as sub-agents in generated workflows. + + Args: + search_query: Search term to find relevant public agents + max_results: Maximum number of agents to return (default 10) + + Returns: + List of LibraryAgentSummary with full input/output schemas + """ + try: + response = await store_db.get_store_agents( + search_query=search_query, + page=1, + page_size=max_results, + ) + + agents_with_graphs = [ + agent for agent in response.agents if agent.agent_graph_id + ] + + if not agents_with_graphs: + return [] + + graph_ids = [agent.agent_graph_id for agent in agents_with_graphs] + graphs = await get_store_listed_graphs(*graph_ids) + + results: list[LibraryAgentSummary] = [] + for agent in agents_with_graphs: + graph_id = agent.agent_graph_id + if graph_id and graph_id in graphs: + graph = graphs[graph_id] + results.append( + LibraryAgentSummary( + graph_id=graph.id, + graph_version=graph.version, + name=agent.agent_name, + description=agent.description, + input_schema=graph.input_schema, + output_schema=graph.output_schema, + ) + ) + return results + except Exception as e: + logger.warning(f"Failed to search marketplace agents: {e}") + return [] + + +async def get_all_relevant_agents_for_generation( + user_id: str, + search_query: str | None = None, + exclude_graph_id: str | None = None, + include_library: bool = True, + include_marketplace: bool = True, + max_library_results: int = 15, + max_marketplace_results: int = 10, +) -> list[AgentSummary]: + """Fetch relevant agents from library and/or marketplace. + + Searches both user's library and marketplace by default. + Explicitly mentioned UUIDs in the search query are always looked up. + + Args: + user_id: The user ID + search_query: Search term to find relevant agents (user's goal/description) + exclude_graph_id: Optional graph ID to exclude (prevents circular references) + include_library: Whether to search user's library (default True) + include_marketplace: Whether to also search marketplace (default True) + max_library_results: Max library agents to return (default 15) + max_marketplace_results: Max marketplace agents to return (default 10) + + Returns: + List of AgentSummary with full schemas (both library and marketplace agents) + """ + agents: list[AgentSummary] = [] + seen_graph_ids: set[str] = set() + + if search_query: + mentioned_uuids = extract_uuids_from_text(search_query) + for graph_id in mentioned_uuids: + if graph_id == exclude_graph_id: + continue + agent = await get_library_agent_by_graph_id(user_id, graph_id) + agent_graph_id = agent.get("graph_id") if agent else None + if agent and agent_graph_id and agent_graph_id not in seen_graph_ids: + agents.append(agent) + seen_graph_ids.add(agent_graph_id) + logger.debug( + f"Found explicitly mentioned agent: {agent.get('name') or 'Unknown'}" + ) + + if include_library: + library_agents = await get_library_agents_for_generation( + user_id=user_id, + search_query=search_query, + exclude_graph_id=exclude_graph_id, + max_results=max_library_results, + ) + for agent in library_agents: + graph_id = agent.get("graph_id") + if graph_id and graph_id not in seen_graph_ids: + agents.append(agent) + seen_graph_ids.add(graph_id) + + if include_marketplace and search_query: + marketplace_agents = await search_marketplace_agents_for_generation( + search_query=search_query, + max_results=max_marketplace_results, + ) + for agent in marketplace_agents: + graph_id = agent.get("graph_id") + if graph_id and graph_id not in seen_graph_ids: + agents.append(agent) + seen_graph_ids.add(graph_id) + + return agents + + +def extract_search_terms_from_steps( + decomposition_result: DecompositionResult | dict[str, Any], +) -> list[str]: + """Extract search terms from decomposed instruction steps. + + Analyzes the decomposition result to extract relevant keywords + for additional library agent searches. + + Args: + decomposition_result: Result from decompose_goal containing steps + + Returns: + List of unique search terms extracted from steps + """ + search_terms: list[str] = [] + + if decomposition_result.get("type") != "instructions": + return search_terms + + steps = decomposition_result.get("steps", []) + if not steps: + return search_terms + + step_keys: list[str] = ["description", "action", "block_name", "tool", "name"] + + for step in steps: + for key in step_keys: + value = step.get(key) # type: ignore[union-attr] + if isinstance(value, str) and len(value) > 3: + search_terms.append(value) + + seen: set[str] = set() + unique_terms: list[str] = [] + for term in search_terms: + term_lower = term.lower() + if term_lower not in seen: + seen.add(term_lower) + unique_terms.append(term) + + return unique_terms + + +async def enrich_library_agents_from_steps( + user_id: str, + decomposition_result: DecompositionResult | dict[str, Any], + existing_agents: list[AgentSummary] | list[dict[str, Any]], + exclude_graph_id: str | None = None, + include_marketplace: bool = True, + max_additional_results: int = 10, +) -> list[AgentSummary] | list[dict[str, Any]]: + """Enrich library agents list with additional searches based on decomposed steps. + + This implements two-phase search: after decomposition, we search for additional + relevant agents based on the specific steps identified. + + Args: + user_id: The user ID + decomposition_result: Result from decompose_goal containing steps + existing_agents: Already fetched library agents from initial search + exclude_graph_id: Optional graph ID to exclude + include_marketplace: Whether to also search marketplace + max_additional_results: Max additional agents per search term (default 10) + + Returns: + Combined list of library agents (existing + newly discovered) + """ + search_terms = extract_search_terms_from_steps(decomposition_result) + + if not search_terms: + return existing_agents + + existing_ids: set[str] = set() + existing_names: set[str] = set() + + for agent in existing_agents: + agent_name = agent.get("name") + if agent_name and isinstance(agent_name, str): + existing_names.add(agent_name.lower()) + graph_id = agent.get("graph_id") # type: ignore[call-overload] + if graph_id and isinstance(graph_id, str): + existing_ids.add(graph_id) + + all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents) + + for term in search_terms[:3]: + try: + additional_agents = await get_all_relevant_agents_for_generation( + user_id=user_id, + search_query=term, + exclude_graph_id=exclude_graph_id, + include_marketplace=include_marketplace, + max_library_results=max_additional_results, + max_marketplace_results=5, + ) + + for agent in additional_agents: + agent_name = agent.get("name") + if not agent_name or not isinstance(agent_name, str): + continue + agent_name_lower = agent_name.lower() + + if agent_name_lower in existing_names: + continue + + graph_id = agent.get("graph_id") # type: ignore[call-overload] + if graph_id and graph_id in existing_ids: + continue + + all_agents.append(agent) + existing_names.add(agent_name_lower) + if graph_id and isinstance(graph_id, str): + existing_ids.add(graph_id) + + except DatabaseError: + logger.error(f"Database error searching for agents with term '{term}'") + raise + except Exception as e: + logger.warning( + f"Failed to search for additional agents with term '{term}': {e}" + ) + + logger.debug( + f"Enriched library agents: {len(existing_agents)} initial + " + f"{len(all_agents) - len(existing_agents)} additional = {len(all_agents)} total" + ) + + return all_agents + + +async def decompose_goal( + description: str, + context: str = "", + library_agents: list[AgentSummary] | None = None, +) -> DecompositionResult | None: """Break down a goal into steps or return clarifying questions. Args: description: Natural language goal description context: Additional context (e.g., answers to previous questions) + library_agents: User's library agents available for sub-agent composition Returns: - Dict with either: + DecompositionResult with either: - {"type": "clarifying_questions", "questions": [...]} - {"type": "instructions", "steps": [...]} Or None on error @@ -54,26 +541,47 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any] """ _check_service_configured() logger.info("Calling external Agent Generator service for decompose_goal") - return await decompose_goal_external(description, context) + result = await decompose_goal_external( + description, context, _to_dict_list(library_agents) + ) + return result # type: ignore[return-value] -async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None: +async def generate_agent( + instructions: DecompositionResult | dict[str, Any], + library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None, + operation_id: str | None = None, + task_id: str | None = None, +) -> dict[str, Any] | None: """Generate agent JSON from instructions. Args: instructions: Structured instructions from decompose_goal + library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams + completion notification) + task_id: Task ID for async processing (enables Redis Streams persistence + and SSE delivery) Returns: - Agent JSON dict or None on error + Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error Raises: AgentGeneratorNotConfiguredError: If the external service is not configured. """ _check_service_configured() logger.info("Calling external Agent Generator service for generate_agent") - result = await generate_agent_external(instructions) + result = await generate_agent_external( + dict(instructions), _to_dict_list(library_agents), operation_id, task_id + ) + + # Don't modify async response + if result and result.get("status") == "accepted": + return result + if result: - # Ensure required fields + if isinstance(result, dict) and result.get("type") == "error": + return result if "id" not in result: result["id"] = str(uuid.uuid4()) if "version" not in result: @@ -83,6 +591,12 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None: return result +class AgentJsonValidationError(Exception): + """Raised when agent JSON is invalid or missing required fields.""" + + pass + + def json_to_graph(agent_json: dict[str, Any]) -> Graph: """Convert agent JSON dict to Graph model. @@ -91,25 +605,55 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph: Returns: Graph ready for saving + + Raises: + AgentJsonValidationError: If required fields are missing from nodes or links """ nodes = [] - for n in agent_json.get("nodes", []): + for idx, n in enumerate(agent_json.get("nodes", [])): + block_id = n.get("block_id") + if not block_id: + node_id = n.get("id", f"index_{idx}") + raise AgentJsonValidationError( + f"Node '{node_id}' is missing required field 'block_id'" + ) node = Node( id=n.get("id", str(uuid.uuid4())), - block_id=n["block_id"], + block_id=block_id, input_default=n.get("input_default", {}), metadata=n.get("metadata", {}), ) nodes.append(node) links = [] - for link_data in agent_json.get("links", []): + for idx, link_data in enumerate(agent_json.get("links", [])): + source_id = link_data.get("source_id") + sink_id = link_data.get("sink_id") + source_name = link_data.get("source_name") + sink_name = link_data.get("sink_name") + + missing_fields = [] + if not source_id: + missing_fields.append("source_id") + if not sink_id: + missing_fields.append("sink_id") + if not source_name: + missing_fields.append("source_name") + if not sink_name: + missing_fields.append("sink_name") + + if missing_fields: + link_id = link_data.get("id", f"index_{idx}") + raise AgentJsonValidationError( + f"Link '{link_id}' is missing required fields: {', '.join(missing_fields)}" + ) + link = Link( id=link_data.get("id", str(uuid.uuid4())), - source_id=link_data["source_id"], - sink_id=link_data["sink_id"], - source_name=link_data["source_name"], - sink_name=link_data["sink_name"], + source_id=source_id, + sink_id=sink_id, + source_name=source_name, + sink_name=sink_name, is_static=link_data.get("is_static", False), ) links.append(link) @@ -130,22 +674,40 @@ def _reassign_node_ids(graph: Graph) -> None: This is needed when creating a new version to avoid unique constraint violations. """ - # Create mapping from old node IDs to new UUIDs id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes} - # Reassign node IDs for node in graph.nodes: node.id = id_map[node.id] - # Update link references to use new node IDs for link in graph.links: - link.id = str(uuid.uuid4()) # Also give links new IDs + link.id = str(uuid.uuid4()) if link.source_id in id_map: link.source_id = id_map[link.source_id] if link.sink_id in id_map: link.sink_id = id_map[link.sink_id] +def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None: + """Populate user_id in AgentExecutorBlock nodes. + + The external agent generator creates AgentExecutorBlock nodes with empty user_id. + This function fills in the actual user_id so sub-agents run with correct permissions. + + Args: + agent_json: Agent JSON dict (modified in place) + user_id: User ID to set + """ + for node in agent_json.get("nodes", []): + if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID: + input_default = node.get("input_default") or {} + if not input_default.get("user_id"): + input_default["user_id"] = user_id + node["input_default"] = input_default + logger.debug( + f"Set user_id for AgentExecutorBlock node {node.get('id')}" + ) + + async def save_agent_to_library( agent_json: dict[str, Any], user_id: str, is_update: bool = False ) -> tuple[Graph, Any]: @@ -159,33 +721,27 @@ async def save_agent_to_library( Returns: Tuple of (created Graph, LibraryAgent) """ - from backend.data.graph import get_graph_all_versions + # Populate user_id in AgentExecutorBlock nodes before conversion + _populate_agent_executor_user_ids(agent_json, user_id) graph = json_to_graph(agent_json) if is_update: - # For updates, keep the same graph ID but increment version - # and reassign node/link IDs to avoid conflicts if graph.id: existing_versions = await get_graph_all_versions(graph.id, user_id) if existing_versions: latest_version = max(v.version for v in existing_versions) graph.version = latest_version + 1 - # Reassign node IDs (but keep graph ID the same) _reassign_node_ids(graph) logger.info(f"Updating agent {graph.id} to version {graph.version}") else: - # For new agents, always generate a fresh UUID to avoid collisions graph.id = str(uuid.uuid4()) graph.version = 1 - # Reassign all node IDs as well _reassign_node_ids(graph) logger.info(f"Creating new agent with ID {graph.id}") - # Save to database created_graph = await create_graph(graph, user_id) - # Add to user's library (or update existing library agent) library_agents = await library_db.create_library_agent( graph=created_graph, user_id=user_id, @@ -196,26 +752,15 @@ async def save_agent_to_library( return created_graph, library_agents[0] -async def get_agent_as_json( - graph_id: str, user_id: str | None -) -> dict[str, Any] | None: - """Fetch an agent and convert to JSON format for editing. +def graph_to_json(graph: Graph) -> dict[str, Any]: + """Convert a Graph object to JSON format for the agent generator. Args: - graph_id: Graph ID or library agent ID - user_id: User ID + graph: Graph object to convert Returns: - Agent as JSON dict or None if not found + Agent as JSON dict """ - from backend.data.graph import get_graph - - # Try to get the graph (version=None gets the active version) - graph = await get_graph(graph_id, version=None, user_id=user_id) - if not graph: - return None - - # Convert to JSON format nodes = [] for node in graph.nodes: nodes.append( @@ -252,8 +797,41 @@ async def get_agent_as_json( } +async def get_agent_as_json( + agent_id: str, user_id: str | None +) -> dict[str, Any] | None: + """Fetch an agent and convert to JSON format for editing. + + Args: + agent_id: Graph ID or library agent ID + user_id: User ID + + Returns: + Agent as JSON dict or None if not found + """ + graph = await get_graph(agent_id, version=None, user_id=user_id) + + if not graph and user_id: + try: + library_agent = await library_db.get_library_agent(agent_id, user_id) + graph = await get_graph( + library_agent.graph_id, version=None, user_id=user_id + ) + except NotFoundError: + pass + + if not graph: + return None + + return graph_to_json(graph) + + async def generate_agent_patch( - update_request: str, current_agent: dict[str, Any] + update_request: str, + current_agent: dict[str, Any], + library_agents: list[AgentSummary] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Update an existing agent using natural language. @@ -265,13 +843,57 @@ async def generate_agent_patch( Args: update_request: Natural language description of changes current_agent: Current agent JSON + library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams callback) + task_id: Task ID for async processing (enables Redis Streams callback) Returns: - Updated agent JSON, clarifying questions dict, or None on error + Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...}, + {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error Raises: AgentGeneratorNotConfiguredError: If the external service is not configured. """ _check_service_configured() logger.info("Calling external Agent Generator service for generate_agent_patch") - return await generate_agent_patch_external(update_request, current_agent) + return await generate_agent_patch_external( + update_request, + current_agent, + _to_dict_list(library_agents), + operation_id, + task_id, + ) + + +async def customize_template( + template_agent: dict[str, Any], + modification_request: str, + context: str = "", +) -> dict[str, Any] | None: + """Customize a template/marketplace agent using natural language. + + This is used when users want to modify a template or marketplace agent + to fit their specific needs before adding it to their library. + + The external Agent Generator service handles: + - Understanding the modification request + - Applying changes to the template + - Fixing and validating the result + + Args: + template_agent: The template agent JSON to customize + modification_request: Natural language description of customizations + context: Additional context (e.g., answers to previous questions) + + Returns: + Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...}, + error dict {"type": "error", ...}, or None on unexpected error + + Raises: + AgentGeneratorNotConfiguredError: If the external service is not configured. + """ + _check_service_configured() + logger.info("Calling external Agent Generator service for customize_template") + return await customize_template_external( + template_agent, modification_request, context + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/errors.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/errors.py new file mode 100644 index 0000000000..282d8cf9aa --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/errors.py @@ -0,0 +1,95 @@ +"""Error handling utilities for agent generator.""" + +import re + + +def _sanitize_error_details(details: str) -> str: + """Sanitize error details to remove sensitive information. + + Strips common patterns that could expose internal system info: + - File paths (Unix and Windows) + - Database connection strings + - URLs with credentials + - Stack trace internals + + Args: + details: Raw error details string + + Returns: + Sanitized error details safe for user display + """ + sanitized = re.sub( + r"/[a-zA-Z0-9_./\-]+\.(py|js|ts|json|yaml|yml)", "[path]", details + ) + sanitized = re.sub(r"[A-Z]:\\[a-zA-Z0-9_\\.\\-]+", "[path]", sanitized) + sanitized = re.sub( + r"(postgres|mysql|mongodb|redis)://[^\s]+", "[database_url]", sanitized + ) + sanitized = re.sub(r"https?://[^:]+:[^@]+@[^\s]+", "[url]", sanitized) + sanitized = re.sub(r", line \d+", "", sanitized) + sanitized = re.sub(r'File "[^"]+",?', "", sanitized) + + return sanitized.strip() + + +def get_user_message_for_error( + error_type: str, + operation: str = "process the request", + llm_parse_message: str | None = None, + validation_message: str | None = None, + error_details: str | None = None, +) -> str: + """Get a user-friendly error message based on error type. + + This function maps internal error types to user-friendly messages, + providing a consistent experience across different agent operations. + + Args: + error_type: The error type from the external service + (e.g., "llm_parse_error", "timeout", "rate_limit") + operation: Description of what operation failed, used in the default + message (e.g., "analyze the goal", "generate the agent") + llm_parse_message: Custom message for llm_parse_error type + validation_message: Custom message for validation_error type + error_details: Optional additional details about the error + + Returns: + User-friendly error message suitable for display to the user + """ + base_message = "" + + if error_type == "llm_parse_error": + base_message = ( + llm_parse_message + or "The AI had trouble processing this request. Please try again." + ) + elif error_type == "validation_error": + base_message = ( + validation_message + or "The generated agent failed validation. " + "This usually happens when the agent structure doesn't match " + "what the platform expects. Please try simplifying your goal " + "or breaking it into smaller parts." + ) + elif error_type == "patch_error": + base_message = ( + "Failed to apply the changes. The modification couldn't be " + "validated. Please try a different approach or simplify the change." + ) + elif error_type in ("timeout", "llm_timeout"): + base_message = ( + "The request took too long to process. This can happen with " + "complex agents. Please try again or simplify your goal." + ) + elif error_type in ("rate_limit", "llm_rate_limit"): + base_message = "The service is currently busy. Please try again in a moment." + else: + base_message = f"Failed to {operation}. Please try again." + + if error_details: + details = _sanitize_error_details(error_details) + if len(details) > 200: + details = details[:200] + "..." + base_message += f"\n\nTechnical details: {details}" + + return base_message diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py index a4d2f1af15..62411b4e1b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py @@ -14,6 +14,70 @@ from backend.util.settings import Settings logger = logging.getLogger(__name__) + +def _create_error_response( + error_message: str, + error_type: str = "unknown", + details: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Create a standardized error response dict. + + Args: + error_message: Human-readable error message + error_type: Machine-readable error type + details: Optional additional error details + + Returns: + Error dict with type="error" and error details + """ + response: dict[str, Any] = { + "type": "error", + "error": error_message, + "error_type": error_type, + } + if details: + response["details"] = details + return response + + +def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]: + """Classify an HTTP error into error_type and message. + + Args: + e: The HTTP status error + + Returns: + Tuple of (error_type, error_message) + """ + status = e.response.status_code + if status == 429: + return "rate_limit", f"Agent Generator rate limited: {e}" + elif status == 503: + return "service_unavailable", f"Agent Generator unavailable: {e}" + elif status == 504 or status == 408: + return "timeout", f"Agent Generator timed out: {e}" + else: + return "http_error", f"HTTP error calling Agent Generator: {e}" + + +def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]: + """Classify a request error into error_type and message. + + Args: + e: The request error + + Returns: + Tuple of (error_type, error_message) + """ + error_str = str(e).lower() + if "timeout" in error_str or "timed out" in error_str: + return "timeout", f"Agent Generator request timed out: {e}" + elif "connect" in error_str: + return "connection_error", f"Could not connect to Agent Generator: {e}" + else: + return "request_error", f"Request error calling Agent Generator: {e}" + + _client: httpx.AsyncClient | None = None _settings: Settings | None = None @@ -53,13 +117,16 @@ def _get_client() -> httpx.AsyncClient: async def decompose_goal_external( - description: str, context: str = "" + description: str, + context: str = "", + library_agents: list[dict[str, Any]] | None = None, ) -> dict[str, Any] | None: """Call the external service to decompose a goal. Args: description: Natural language goal description context: Additional context (e.g., answers to previous questions) + library_agents: User's library agents available for sub-agent composition Returns: Dict with either: @@ -67,15 +134,17 @@ async def decompose_goal_external( - {"type": "instructions", "steps": [...]} - {"type": "unachievable_goal", ...} - {"type": "vague_goal", ...} - Or None on error + - {"type": "error", "error": "...", "error_type": "..."} on error + Or None on unexpected error """ client = _get_client() - # Build the request payload - payload: dict[str, Any] = {"description": description} if context: - # The external service uses user_instruction for additional context - payload["user_instruction"] = context + description = f"{description}\n\nAdditional context from user:\n{context}" + + payload: dict[str, Any] = {"description": description} + if library_agents: + payload["library_agents"] = library_agents try: response = await client.post("/api/decompose-description", json=payload) @@ -83,8 +152,13 @@ async def decompose_goal_external( data = response.json() if not data.get("success"): - logger.error(f"External service returned error: {data.get('error')}") - return None + error_msg = data.get("error", "Unknown error from Agent Generator") + error_type = data.get("error_type", "unknown") + logger.error( + f"Agent Generator decomposition failed: {error_msg} " + f"(type: {error_type})" + ) + return _create_error_response(error_msg, error_type) # Map the response to the expected format response_type = data.get("type") @@ -106,88 +180,162 @@ async def decompose_goal_external( "type": "vague_goal", "suggested_goal": data.get("suggested_goal"), } + elif response_type == "error": + # Pass through error from the service + return _create_error_response( + data.get("error", "Unknown error"), + data.get("error_type", "unknown"), + ) else: logger.error( f"Unknown response type from external service: {response_type}" ) - return None + return _create_error_response( + f"Unknown response type from Agent Generator: {response_type}", + "invalid_response", + ) except httpx.HTTPStatusError as e: - logger.error(f"HTTP error calling external agent generator: {e}") - return None + error_type, error_msg = _classify_http_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) except httpx.RequestError as e: - logger.error(f"Request error calling external agent generator: {e}") - return None + error_type, error_msg = _classify_request_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) except Exception as e: - logger.error(f"Unexpected error calling external agent generator: {e}") - return None + error_msg = f"Unexpected error calling Agent Generator: {e}" + logger.error(error_msg) + return _create_error_response(error_msg, "unexpected_error") async def generate_agent_external( - instructions: dict[str, Any] + instructions: dict[str, Any], + library_agents: list[dict[str, Any]] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Call the external service to generate an agent from instructions. Args: instructions: Structured instructions from decompose_goal + library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams callback) + task_id: Task ID for async processing (enables Redis Streams callback) Returns: - Agent JSON dict or None on error + Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error """ client = _get_client() + # Build request payload + payload: dict[str, Any] = {"instructions": instructions} + if library_agents: + payload["library_agents"] = library_agents + if operation_id and task_id: + payload["operation_id"] = operation_id + payload["task_id"] = task_id + try: - response = await client.post( - "/api/generate-agent", json={"instructions": instructions} - ) + response = await client.post("/api/generate-agent", json=payload) + + # Handle 202 Accepted for async processing + if response.status_code == 202: + logger.info( + f"Agent Generator accepted async request " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return { + "status": "accepted", + "operation_id": operation_id, + "task_id": task_id, + } + response.raise_for_status() data = response.json() if not data.get("success"): - logger.error(f"External service returned error: {data.get('error')}") - return None + error_msg = data.get("error", "Unknown error from Agent Generator") + error_type = data.get("error_type", "unknown") + logger.error( + f"Agent Generator generation failed: {error_msg} (type: {error_type})" + ) + return _create_error_response(error_msg, error_type) return data.get("agent_json") except httpx.HTTPStatusError as e: - logger.error(f"HTTP error calling external agent generator: {e}") - return None + error_type, error_msg = _classify_http_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) except httpx.RequestError as e: - logger.error(f"Request error calling external agent generator: {e}") - return None + error_type, error_msg = _classify_request_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) except Exception as e: - logger.error(f"Unexpected error calling external agent generator: {e}") - return None + error_msg = f"Unexpected error calling Agent Generator: {e}" + logger.error(error_msg) + return _create_error_response(error_msg, "unexpected_error") async def generate_agent_patch_external( - update_request: str, current_agent: dict[str, Any] + update_request: str, + current_agent: dict[str, Any], + library_agents: list[dict[str, Any]] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Call the external service to generate a patch for an existing agent. Args: update_request: Natural language description of changes current_agent: Current agent JSON + library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams callback) + task_id: Task ID for async processing (enables Redis Streams callback) Returns: - Updated agent JSON, clarifying questions dict, or None on error + Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error """ client = _get_client() + # Build request payload + payload: dict[str, Any] = { + "update_request": update_request, + "current_agent_json": current_agent, + } + if library_agents: + payload["library_agents"] = library_agents + if operation_id and task_id: + payload["operation_id"] = operation_id + payload["task_id"] = task_id + try: - response = await client.post( - "/api/update-agent", - json={ - "update_request": update_request, - "current_agent_json": current_agent, - }, - ) + response = await client.post("/api/update-agent", json=payload) + + # Handle 202 Accepted for async processing + if response.status_code == 202: + logger.info( + f"Agent Generator accepted async update request " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return { + "status": "accepted", + "operation_id": operation_id, + "task_id": task_id, + } + response.raise_for_status() data = response.json() if not data.get("success"): - logger.error(f"External service returned error: {data.get('error')}") - return None + error_msg = data.get("error", "Unknown error from Agent Generator") + error_type = data.get("error_type", "unknown") + logger.error( + f"Agent Generator patch generation failed: {error_msg} " + f"(type: {error_type})" + ) + return _create_error_response(error_msg, error_type) # Check if it's clarifying questions if data.get("type") == "clarifying_questions": @@ -196,18 +344,99 @@ async def generate_agent_patch_external( "questions": data.get("questions", []), } + # Check if it's an error passed through + if data.get("type") == "error": + return _create_error_response( + data.get("error", "Unknown error"), + data.get("error_type", "unknown"), + ) + # Otherwise return the updated agent JSON return data.get("agent_json") except httpx.HTTPStatusError as e: - logger.error(f"HTTP error calling external agent generator: {e}") - return None + error_type, error_msg = _classify_http_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) except httpx.RequestError as e: - logger.error(f"Request error calling external agent generator: {e}") - return None + error_type, error_msg = _classify_request_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) except Exception as e: - logger.error(f"Unexpected error calling external agent generator: {e}") - return None + error_msg = f"Unexpected error calling Agent Generator: {e}" + logger.error(error_msg) + return _create_error_response(error_msg, "unexpected_error") + + +async def customize_template_external( + template_agent: dict[str, Any], + modification_request: str, + context: str = "", +) -> dict[str, Any] | None: + """Call the external service to customize a template/marketplace agent. + + Args: + template_agent: The template agent JSON to customize + modification_request: Natural language description of customizations + context: Additional context (e.g., answers to previous questions) + + Returns: + Customized agent JSON, clarifying questions dict, or error dict on error + """ + client = _get_client() + + request = modification_request + if context: + request = f"{modification_request}\n\nAdditional context from user:\n{context}" + + payload: dict[str, Any] = { + "template_agent_json": template_agent, + "modification_request": request, + } + + try: + response = await client.post("/api/template-modification", json=payload) + response.raise_for_status() + data = response.json() + + if not data.get("success"): + error_msg = data.get("error", "Unknown error from Agent Generator") + error_type = data.get("error_type", "unknown") + logger.error( + f"Agent Generator template customization failed: {error_msg} " + f"(type: {error_type})" + ) + return _create_error_response(error_msg, error_type) + + # Check if it's clarifying questions + if data.get("type") == "clarifying_questions": + return { + "type": "clarifying_questions", + "questions": data.get("questions", []), + } + + # Check if it's an error passed through + if data.get("type") == "error": + return _create_error_response( + data.get("error", "Unknown error"), + data.get("error_type", "unknown"), + ) + + # Otherwise return the customized agent JSON + return data.get("agent_json") + + except httpx.HTTPStatusError as e: + error_type, error_msg = _classify_http_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) + except httpx.RequestError as e: + error_type, error_msg = _classify_request_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) + except Exception as e: + error_msg = f"Unexpected error calling Agent Generator: {e}" + logger.error(error_msg) + return _create_error_response(error_msg, "unexpected_error") async def get_blocks_external() -> list[dict[str, Any]] | None: diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_search.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_search.py index 5fa74ba04e..62d59c470e 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_search.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_search.py @@ -1,6 +1,7 @@ """Shared agent search functionality for find_agent and find_library_agent tools.""" import logging +import re from typing import Literal from backend.api.features.library import db as library_db @@ -19,6 +20,85 @@ logger = logging.getLogger(__name__) SearchSource = Literal["marketplace", "library"] +_UUID_PATTERN = re.compile( + r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$", + re.IGNORECASE, +) + + +def _is_uuid(text: str) -> bool: + """Check if text is a valid UUID v4.""" + return bool(_UUID_PATTERN.match(text.strip())) + + +async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None: + """Fetch a library agent by ID (library agent ID or graph_id). + + Tries multiple lookup strategies: + 1. First by graph_id (AgentGraph primary key) + 2. Then by library agent ID (LibraryAgent primary key) + + Args: + user_id: The user ID + agent_id: The ID to look up (can be graph_id or library agent ID) + + Returns: + AgentInfo if found, None otherwise + """ + try: + agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id) + if agent: + logger.debug(f"Found library agent by graph_id: {agent.name}") + return AgentInfo( + id=agent.id, + name=agent.name, + description=agent.description or "", + source="library", + in_library=True, + creator=agent.creator_name, + status=agent.status.value, + can_access_graph=agent.can_access_graph, + has_external_trigger=agent.has_external_trigger, + new_output=agent.new_output, + graph_id=agent.graph_id, + ) + except DatabaseError: + raise + except Exception as e: + logger.warning( + f"Could not fetch library agent by graph_id {agent_id}: {e}", + exc_info=True, + ) + + try: + agent = await library_db.get_library_agent(agent_id, user_id) + if agent: + logger.debug(f"Found library agent by library_id: {agent.name}") + return AgentInfo( + id=agent.id, + name=agent.name, + description=agent.description or "", + source="library", + in_library=True, + creator=agent.creator_name, + status=agent.status.value, + can_access_graph=agent.can_access_graph, + has_external_trigger=agent.has_external_trigger, + new_output=agent.new_output, + graph_id=agent.graph_id, + ) + except NotFoundError: + logger.debug(f"Library agent not found by library_id: {agent_id}") + except DatabaseError: + raise + except Exception as e: + logger.warning( + f"Could not fetch library agent by library_id {agent_id}: {e}", + exc_info=True, + ) + + return None + async def search_agents( query: str, @@ -69,29 +149,37 @@ async def search_agents( is_featured=False, ) ) - else: # library - logger.info(f"Searching user library for: {query}") - results = await library_db.list_library_agents( - user_id=user_id, # type: ignore[arg-type] - search_term=query, - page_size=10, - ) - for agent in results.agents: - agents.append( - AgentInfo( - id=agent.id, - name=agent.name, - description=agent.description or "", - source="library", - in_library=True, - creator=agent.creator_name, - status=agent.status.value, - can_access_graph=agent.can_access_graph, - has_external_trigger=agent.has_external_trigger, - new_output=agent.new_output, - graph_id=agent.graph_id, - ) + else: + if _is_uuid(query): + logger.info(f"Query looks like UUID, trying direct lookup: {query}") + agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type] + if agent: + agents.append(agent) + logger.info(f"Found agent by direct ID lookup: {agent.name}") + + if not agents: + logger.info(f"Searching user library for: {query}") + results = await library_db.list_library_agents( + user_id=user_id, # type: ignore[arg-type] + search_term=query, + page_size=10, ) + for agent in results.agents: + agents.append( + AgentInfo( + id=agent.id, + name=agent.name, + description=agent.description or "", + source="library", + in_library=True, + creator=agent.creator_name, + status=agent.status.value, + can_access_graph=agent.can_access_graph, + has_external_trigger=agent.has_external_trigger, + new_output=agent.new_output, + graph_id=agent.graph_id, + ) + ) logger.info(f"Found {len(agents)} agents in {source}") except NotFoundError: pass diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py index 6b3784e323..7333851a5b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py @@ -8,13 +8,17 @@ from backend.api.features.chat.model import ChatSession from .agent_generator import ( AgentGeneratorNotConfiguredError, decompose_goal, + enrich_library_agents_from_steps, generate_agent, + get_all_relevant_agents_for_generation, + get_user_message_for_error, save_agent_to_library, ) from .base import BaseTool from .models import ( AgentPreviewResponse, AgentSavedResponse, + AsyncProcessingResponse, ClarificationNeededResponse, ClarifyingQuestion, ErrorResponse, @@ -95,6 +99,10 @@ class CreateAgentTool(BaseTool): save = kwargs.get("save", True) session_id = session.session_id if session else None + # Extract async processing params (passed by long-running tool handler) + operation_id = kwargs.get("_operation_id") + task_id = kwargs.get("_task_id") + if not description: return ErrorResponse( message="Please provide a description of what the agent should do.", @@ -102,9 +110,24 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) - # Step 1: Decompose goal into steps + library_agents = None + if user_id: + try: + library_agents = await get_all_relevant_agents_for_generation( + user_id=user_id, + search_query=description, + include_marketplace=True, + ) + logger.debug( + f"Found {len(library_agents)} relevant agents for sub-agent composition" + ) + except Exception as e: + logger.warning(f"Failed to fetch library agents: {e}") + try: - decomposition_result = await decompose_goal(description, context) + decomposition_result = await decompose_goal( + description, context, library_agents + ) except AgentGeneratorNotConfiguredError: return ErrorResponse( message=( @@ -117,15 +140,31 @@ class CreateAgentTool(BaseTool): if decomposition_result is None: return ErrorResponse( - message="Failed to analyze the goal. The agent generation service may be unavailable or timed out. Please try again.", + message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.", error="decomposition_failed", - details={ - "description": description[:100] - }, # Include context for debugging + details={"description": description[:100]}, + session_id=session_id, + ) + + if decomposition_result.get("type") == "error": + error_msg = decomposition_result.get("error", "Unknown error") + error_type = decomposition_result.get("error_type", "unknown") + user_message = get_user_message_for_error( + error_type, + operation="analyze the goal", + llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.", + ) + return ErrorResponse( + message=user_message, + error=f"decomposition_failed:{error_type}", + details={ + "description": description[:100], + "service_error": error_msg, + "error_type": error_type, + }, session_id=session_id, ) - # Check if LLM returned clarifying questions if decomposition_result.get("type") == "clarifying_questions": questions = decomposition_result.get("questions", []) return ClarificationNeededResponse( @@ -144,7 +183,6 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) - # Check for unachievable/vague goals if decomposition_result.get("type") == "unachievable_goal": suggested = decomposition_result.get("suggested_goal", "") reason = decomposition_result.get("reason", "") @@ -171,9 +209,27 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) - # Step 2: Generate agent JSON (external service handles fixing and validation) + if user_id and library_agents is not None: + try: + library_agents = await enrich_library_agents_from_steps( + user_id=user_id, + decomposition_result=decomposition_result, + existing_agents=library_agents, + include_marketplace=True, + ) + logger.debug( + f"After enrichment: {len(library_agents)} total agents for sub-agent composition" + ) + except Exception as e: + logger.warning(f"Failed to enrich library agents from steps: {e}") + try: - agent_json = await generate_agent(decomposition_result) + agent_json = await generate_agent( + decomposition_result, + library_agents, + operation_id=operation_id, + task_id=task_id, + ) except AgentGeneratorNotConfiguredError: return ErrorResponse( message=( @@ -186,11 +242,47 @@ class CreateAgentTool(BaseTool): if agent_json is None: return ErrorResponse( - message="Failed to generate the agent. The agent generation service may be unavailable or timed out. Please try again.", + message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.", error="generation_failed", + details={"description": description[:100]}, + session_id=session_id, + ) + + if isinstance(agent_json, dict) and agent_json.get("type") == "error": + error_msg = agent_json.get("error", "Unknown error") + error_type = agent_json.get("error_type", "unknown") + user_message = get_user_message_for_error( + error_type, + operation="generate the agent", + llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.", + validation_message=( + "I wasn't able to create a valid agent for this request. " + "The generated workflow had some structural issues. " + "Please try simplifying your goal or breaking it into smaller steps." + ), + error_details=error_msg, + ) + return ErrorResponse( + message=user_message, + error=f"generation_failed:{error_type}", details={ - "description": description[:100] - }, # Include context for debugging + "description": description[:100], + "service_error": error_msg, + "error_type": error_type, + }, + session_id=session_id, + ) + + # Check if Agent Generator accepted for async processing + if agent_json.get("status") == "accepted": + logger.info( + f"Agent generation delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return AsyncProcessingResponse( + message="Agent generation started. You'll be notified when it's complete.", + operation_id=operation_id, + task_id=task_id, session_id=session_id, ) @@ -199,7 +291,6 @@ class CreateAgentTool(BaseTool): node_count = len(agent_json.get("nodes", [])) link_count = len(agent_json.get("links", [])) - # Step 3: Preview or save if not save: return AgentPreviewResponse( message=( @@ -214,7 +305,6 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) - # Save to library if not user_id: return ErrorResponse( message="You must be logged in to save agents.", @@ -232,7 +322,7 @@ class CreateAgentTool(BaseTool): agent_id=created_graph.id, agent_name=created_graph.name, library_agent_id=library_agent.id, - library_agent_link=f"/library/{library_agent.id}", + library_agent_link=f"/library/agents/{library_agent.id}", agent_page_link=f"/build?flowID={created_graph.id}", session_id=session_id, ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py new file mode 100644 index 0000000000..c0568bd936 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py @@ -0,0 +1,337 @@ +"""CustomizeAgentTool - Customizes marketplace/template agents using natural language.""" + +import logging +from typing import Any + +from backend.api.features.chat.model import ChatSession +from backend.api.features.store import db as store_db +from backend.api.features.store.exceptions import AgentNotFoundError + +from .agent_generator import ( + AgentGeneratorNotConfiguredError, + customize_template, + get_user_message_for_error, + graph_to_json, + save_agent_to_library, +) +from .base import BaseTool +from .models import ( + AgentPreviewResponse, + AgentSavedResponse, + ClarificationNeededResponse, + ClarifyingQuestion, + ErrorResponse, + ToolResponseBase, +) + +logger = logging.getLogger(__name__) + + +class CustomizeAgentTool(BaseTool): + """Tool for customizing marketplace/template agents using natural language.""" + + @property + def name(self) -> str: + return "customize_agent" + + @property + def description(self) -> str: + return ( + "Customize a marketplace or template agent using natural language. " + "Takes an existing agent from the marketplace and modifies it based on " + "the user's requirements before adding to their library." + ) + + @property + def requires_auth(self) -> bool: + return True + + @property + def is_long_running(self) -> bool: + return True + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "agent_id": { + "type": "string", + "description": ( + "The marketplace agent ID in format 'creator/slug' " + "(e.g., 'autogpt/newsletter-writer'). " + "Get this from find_agent results." + ), + }, + "modifications": { + "type": "string", + "description": ( + "Natural language description of how to customize the agent. " + "Be specific about what changes you want to make." + ), + }, + "context": { + "type": "string", + "description": ( + "Additional context or answers to previous clarifying questions." + ), + }, + "save": { + "type": "boolean", + "description": ( + "Whether to save the customized agent to the user's library. " + "Default is true. Set to false for preview only." + ), + "default": True, + }, + }, + "required": ["agent_id", "modifications"], + } + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + """Execute the customize_agent tool. + + Flow: + 1. Parse the agent ID to get creator/slug + 2. Fetch the template agent from the marketplace + 3. Call customize_template with the modification request + 4. Preview or save based on the save parameter + """ + agent_id = kwargs.get("agent_id", "").strip() + modifications = kwargs.get("modifications", "").strip() + context = kwargs.get("context", "") + save = kwargs.get("save", True) + session_id = session.session_id if session else None + + if not agent_id: + return ErrorResponse( + message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').", + error="missing_agent_id", + session_id=session_id, + ) + + if not modifications: + return ErrorResponse( + message="Please describe how you want to customize this agent.", + error="missing_modifications", + session_id=session_id, + ) + + # Parse agent_id in format "creator/slug" + parts = [p.strip() for p in agent_id.split("/")] + if len(parts) != 2 or not parts[0] or not parts[1]: + return ErrorResponse( + message=( + f"Invalid agent ID format: '{agent_id}'. " + "Expected format is 'creator/agent-name' " + "(e.g., 'autogpt/newsletter-writer')." + ), + error="invalid_agent_id_format", + session_id=session_id, + ) + + creator_username, agent_slug = parts + + # Fetch the marketplace agent details + try: + agent_details = await store_db.get_store_agent_details( + username=creator_username, agent_name=agent_slug + ) + except AgentNotFoundError: + return ErrorResponse( + message=( + f"Could not find marketplace agent '{agent_id}'. " + "Please check the agent ID and try again." + ), + error="agent_not_found", + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error fetching marketplace agent {agent_id}: {e}") + return ErrorResponse( + message="Failed to fetch the marketplace agent. Please try again.", + error="fetch_error", + session_id=session_id, + ) + + if not agent_details.store_listing_version_id: + return ErrorResponse( + message=( + f"The agent '{agent_id}' does not have an available version. " + "Please try a different agent." + ), + error="no_version_available", + session_id=session_id, + ) + + # Get the full agent graph + try: + graph = await store_db.get_agent(agent_details.store_listing_version_id) + template_agent = graph_to_json(graph) + except Exception as e: + logger.error(f"Error fetching agent graph for {agent_id}: {e}") + return ErrorResponse( + message="Failed to fetch the agent configuration. Please try again.", + error="graph_fetch_error", + session_id=session_id, + ) + + # Call customize_template + try: + result = await customize_template( + template_agent=template_agent, + modification_request=modifications, + context=context, + ) + except AgentGeneratorNotConfiguredError: + return ErrorResponse( + message=( + "Agent customization is not available. " + "The Agent Generator service is not configured." + ), + error="service_not_configured", + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error calling customize_template for {agent_id}: {e}") + return ErrorResponse( + message=( + "Failed to customize the agent due to a service error. " + "Please try again." + ), + error="customization_service_error", + session_id=session_id, + ) + + if result is None: + return ErrorResponse( + message=( + "Failed to customize the agent. " + "The agent generation service may be unavailable or timed out. " + "Please try again." + ), + error="customization_failed", + session_id=session_id, + ) + + # Handle error response + if isinstance(result, dict) and result.get("type") == "error": + error_msg = result.get("error", "Unknown error") + error_type = result.get("error_type", "unknown") + user_message = get_user_message_for_error( + error_type, + operation="customize the agent", + llm_parse_message=( + "The AI had trouble customizing the agent. " + "Please try again or simplify your request." + ), + validation_message=( + "The customized agent failed validation. " + "Please try rephrasing your request." + ), + error_details=error_msg, + ) + return ErrorResponse( + message=user_message, + error=f"customization_failed:{error_type}", + session_id=session_id, + ) + + # Handle clarifying questions + if isinstance(result, dict) and result.get("type") == "clarifying_questions": + questions = result.get("questions") or [] + if not isinstance(questions, list): + logger.error( + f"Unexpected clarifying questions format: {type(questions)}" + ) + questions = [] + return ClarificationNeededResponse( + message=( + "I need some more information to customize this agent. " + "Please answer the following questions:" + ), + questions=[ + ClarifyingQuestion( + question=q.get("question", ""), + keyword=q.get("keyword", ""), + example=q.get("example"), + ) + for q in questions + if isinstance(q, dict) + ], + session_id=session_id, + ) + + # Result should be the customized agent JSON + if not isinstance(result, dict): + logger.error(f"Unexpected customize_template response type: {type(result)}") + return ErrorResponse( + message="Failed to customize the agent due to an unexpected response.", + error="unexpected_response_type", + session_id=session_id, + ) + + customized_agent = result + + agent_name = customized_agent.get( + "name", f"Customized {agent_details.agent_name}" + ) + agent_description = customized_agent.get("description", "") + nodes = customized_agent.get("nodes") + links = customized_agent.get("links") + node_count = len(nodes) if isinstance(nodes, list) else 0 + link_count = len(links) if isinstance(links, list) else 0 + + if not save: + return AgentPreviewResponse( + message=( + f"I've customized the agent '{agent_details.agent_name}'. " + f"The customized agent has {node_count} blocks. " + f"Review it and call customize_agent with save=true to save it." + ), + agent_json=customized_agent, + agent_name=agent_name, + description=agent_description, + node_count=node_count, + link_count=link_count, + session_id=session_id, + ) + + if not user_id: + return ErrorResponse( + message="You must be logged in to save agents.", + error="auth_required", + session_id=session_id, + ) + + # Save to user's library + try: + created_graph, library_agent = await save_agent_to_library( + customized_agent, user_id, is_update=False + ) + + return AgentSavedResponse( + message=( + f"Customized agent '{created_graph.name}' " + f"(based on '{agent_details.agent_name}') " + f"has been saved to your library!" + ), + agent_id=created_graph.id, + agent_name=created_graph.name, + library_agent_id=library_agent.id, + library_agent_link=f"/library/agents/{library_agent.id}", + agent_page_link=f"/build?flowID={created_graph.id}", + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error saving customized agent: {e}") + return ErrorResponse( + message="Failed to save the customized agent. Please try again.", + error="save_failed", + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py index 7c4da8ad43..3ae56407a7 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py @@ -9,12 +9,15 @@ from .agent_generator import ( AgentGeneratorNotConfiguredError, generate_agent_patch, get_agent_as_json, + get_all_relevant_agents_for_generation, + get_user_message_for_error, save_agent_to_library, ) from .base import BaseTool from .models import ( AgentPreviewResponse, AgentSavedResponse, + AsyncProcessingResponse, ClarificationNeededResponse, ClarifyingQuestion, ErrorResponse, @@ -102,6 +105,10 @@ class EditAgentTool(BaseTool): save = kwargs.get("save", True) session_id = session.session_id if session else None + # Extract async processing params (passed by long-running tool handler) + operation_id = kwargs.get("_operation_id") + task_id = kwargs.get("_task_id") + if not agent_id: return ErrorResponse( message="Please provide the agent ID to edit.", @@ -116,7 +123,6 @@ class EditAgentTool(BaseTool): session_id=session_id, ) - # Step 1: Fetch current agent current_agent = await get_agent_as_json(agent_id, user_id) if current_agent is None: @@ -126,14 +132,34 @@ class EditAgentTool(BaseTool): session_id=session_id, ) - # Build the update request with context + library_agents = None + if user_id: + try: + graph_id = current_agent.get("id") + library_agents = await get_all_relevant_agents_for_generation( + user_id=user_id, + search_query=changes, + exclude_graph_id=graph_id, + include_marketplace=True, + ) + logger.debug( + f"Found {len(library_agents)} relevant agents for sub-agent composition" + ) + except Exception as e: + logger.warning(f"Failed to fetch library agents: {e}") + update_request = changes if context: update_request = f"{changes}\n\nAdditional context:\n{context}" - # Step 2: Generate updated agent (external service handles fixing and validation) try: - result = await generate_agent_patch(update_request, current_agent) + result = await generate_agent_patch( + update_request, + current_agent, + library_agents, + operation_id=operation_id, + task_id=task_id, + ) except AgentGeneratorNotConfiguredError: return ErrorResponse( message=( @@ -152,7 +178,42 @@ class EditAgentTool(BaseTool): session_id=session_id, ) - # Check if LLM returned clarifying questions + # Check if Agent Generator accepted for async processing + if result.get("status") == "accepted": + logger.info( + f"Agent edit delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return AsyncProcessingResponse( + message="Agent edit started. You'll be notified when it's complete.", + operation_id=operation_id, + task_id=task_id, + session_id=session_id, + ) + + # Check if the result is an error from the external service + if isinstance(result, dict) and result.get("type") == "error": + error_msg = result.get("error", "Unknown error") + error_type = result.get("error_type", "unknown") + user_message = get_user_message_for_error( + error_type, + operation="generate the changes", + llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.", + validation_message="The generated changes failed validation. Please try rephrasing your request.", + error_details=error_msg, + ) + return ErrorResponse( + message=user_message, + error=f"update_generation_failed:{error_type}", + details={ + "agent_id": agent_id, + "changes": changes[:100], + "service_error": error_msg, + "error_type": error_type, + }, + session_id=session_id, + ) + if result.get("type") == "clarifying_questions": questions = result.get("questions", []) return ClarificationNeededResponse( @@ -171,7 +232,6 @@ class EditAgentTool(BaseTool): session_id=session_id, ) - # Result is the updated agent JSON updated_agent = result agent_name = updated_agent.get("name", "Updated Agent") @@ -179,7 +239,6 @@ class EditAgentTool(BaseTool): node_count = len(updated_agent.get("nodes", [])) link_count = len(updated_agent.get("links", [])) - # Step 3: Preview or save if not save: return AgentPreviewResponse( message=( @@ -195,7 +254,6 @@ class EditAgentTool(BaseTool): session_id=session_id, ) - # Save to library (creates a new version) if not user_id: return ErrorResponse( message="You must be logged in to save agents.", @@ -213,7 +271,7 @@ class EditAgentTool(BaseTool): agent_id=created_graph.id, agent_name=created_graph.name, library_agent_id=library_agent.id, - library_agent_link=f"/library/{library_agent.id}", + library_agent_link=f"/library/agents/{library_agent.id}", agent_page_link=f"/build?flowID={created_graph.id}", session_id=session_id, ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/models.py b/autogpt_platform/backend/backend/api/features/chat/tools/models.py index 8552681d03..69c8c6c684 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/models.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/models.py @@ -28,10 +28,18 @@ class ResponseType(str, Enum): BLOCK_OUTPUT = "block_output" DOC_SEARCH_RESULTS = "doc_search_results" DOC_PAGE = "doc_page" + # Workspace response types + WORKSPACE_FILE_LIST = "workspace_file_list" + WORKSPACE_FILE_CONTENT = "workspace_file_content" + WORKSPACE_FILE_METADATA = "workspace_file_metadata" + WORKSPACE_FILE_WRITTEN = "workspace_file_written" + WORKSPACE_FILE_DELETED = "workspace_file_deleted" # Long-running operation types OPERATION_STARTED = "operation_started" OPERATION_PENDING = "operation_pending" OPERATION_IN_PROGRESS = "operation_in_progress" + # Input validation + INPUT_VALIDATION_ERROR = "input_validation_error" # Base response model @@ -62,6 +70,10 @@ class AgentInfo(BaseModel): has_external_trigger: bool | None = None new_output: bool | None = None graph_id: str | None = None + inputs: dict[str, Any] | None = Field( + default=None, + description="Input schema for the agent, including field names, types, and defaults", + ) class AgentsFoundResponse(ToolResponseBase): @@ -188,6 +200,20 @@ class ErrorResponse(ToolResponseBase): details: dict[str, Any] | None = None +class InputValidationErrorResponse(ToolResponseBase): + """Response when run_agent receives unknown input fields.""" + + type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR + unrecognized_fields: list[str] = Field( + description="List of input field names that were not recognized" + ) + inputs: dict[str, Any] = Field( + description="The agent's valid input schema for reference" + ) + graph_id: str | None = None + graph_version: int | None = None + + # Agent output models class ExecutionOutputInfo(BaseModel): """Summary of a single execution's outputs.""" @@ -346,11 +372,15 @@ class OperationStartedResponse(ToolResponseBase): This is returned immediately to the client while the operation continues to execute. The user can close the tab and check back later. + + The task_id can be used to reconnect to the SSE stream via + GET /chat/tasks/{task_id}/stream?last_idx=0 """ type: ResponseType = ResponseType.OPERATION_STARTED operation_id: str tool_name: str + task_id: str | None = None # For SSE reconnection class OperationPendingResponse(ToolResponseBase): @@ -374,3 +404,20 @@ class OperationInProgressResponse(ToolResponseBase): type: ResponseType = ResponseType.OPERATION_IN_PROGRESS tool_call_id: str + + +class AsyncProcessingResponse(ToolResponseBase): + """Response when an operation has been delegated to async processing. + + This is returned by tools when the external service accepts the request + for async processing (HTTP 202 Accepted). The Redis Streams completion + consumer will handle the result when the external service completes. + + The status field is specifically "accepted" to allow the long-running tool + handler to detect this response and skip LLM continuation. + """ + + type: ResponseType = ResponseType.OPERATION_STARTED + status: str = "accepted" # Must be "accepted" for detection + operation_id: str | None = None + task_id: str | None = None diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py index a7fa65348a..73d4cf81f2 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py @@ -30,6 +30,7 @@ from .models import ( ErrorResponse, ExecutionOptions, ExecutionStartedResponse, + InputValidationErrorResponse, SetupInfo, SetupRequirementsResponse, ToolResponseBase, @@ -273,6 +274,22 @@ class RunAgentTool(BaseTool): input_properties = graph.input_schema.get("properties", {}) required_fields = set(graph.input_schema.get("required", [])) provided_inputs = set(params.inputs.keys()) + valid_fields = set(input_properties.keys()) + + # Check for unknown input fields + unrecognized_fields = provided_inputs - valid_fields + if unrecognized_fields: + return InputValidationErrorResponse( + message=( + f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. " + f"Agent was not executed. Please use the correct field names from the schema." + ), + session_id=session_id, + unrecognized_fields=sorted(unrecognized_fields), + inputs=graph.input_schema, + graph_id=graph.id, + graph_version=graph.version, + ) # If agent has inputs but none were provided AND use_defaults is not set, # always show what's available first so user can decide diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py index 404df2adb6..d5da394fa6 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py @@ -402,3 +402,42 @@ async def test_run_agent_schedule_without_name(setup_test_data): # Should return error about missing schedule_name assert result_data.get("type") == "error" assert "schedule_name" in result_data["message"].lower() + + +@pytest.mark.asyncio(loop_scope="session") +async def test_run_agent_rejects_unknown_input_fields(setup_test_data): + """Test that run_agent returns input_validation_error for unknown input fields.""" + user = setup_test_data["user"] + store_submission = setup_test_data["store_submission"] + + tool = RunAgentTool() + agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}" + session = make_session(user_id=user.id) + + # Execute with unknown input field names + response = await tool.execute( + user_id=user.id, + session_id=str(uuid.uuid4()), + tool_call_id=str(uuid.uuid4()), + username_agent_slug=agent_marketplace_id, + inputs={ + "unknown_field": "some value", + "another_unknown": "another value", + }, + session=session, + ) + + assert response is not None + assert hasattr(response, "output") + assert isinstance(response.output, str) + result_data = orjson.loads(response.output) + + # Should return input_validation_error type with unrecognized fields + assert result_data.get("type") == "input_validation_error" + assert "unrecognized_fields" in result_data + assert set(result_data["unrecognized_fields"]) == { + "another_unknown", + "unknown_field", + } + assert "inputs" in result_data # Contains the valid schema + assert "Agent was not executed" in result_data["message"] diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py index 9d7da6d8f3..5685664db9 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py @@ -1,9 +1,12 @@ """Tool for executing blocks directly.""" import logging +import uuid from collections import defaultdict from typing import Any +from pydantic_core import PydanticUndefined + from backend.api.features.chat.model import ChatSession from backend.api.features.chat.tools.find_block import ( EXCLUDED_BLOCK_IDS, @@ -12,6 +15,7 @@ from backend.api.features.chat.tools.find_block import ( from backend.data.block import get_block from backend.data.execution import ExecutionContext from backend.data.model import CredentialsMetaInput +from backend.data.workspace import get_or_create_workspace from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.util.exceptions import BlockError @@ -77,15 +81,22 @@ class RunBlockTool(BaseTool): self, user_id: str, block: Any, + input_data: dict[str, Any] | None = None, ) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: """ Check if user has required credentials for a block. + Args: + user_id: User ID + block: Block to check credentials for + input_data: Input data for the block (used to determine provider via discriminator) + Returns: tuple[matched_credentials, missing_credentials] """ matched_credentials: dict[str, CredentialsMetaInput] = {} missing_credentials: list[CredentialsMetaInput] = [] + input_data = input_data or {} # Get credential field info from block's input schema credentials_fields_info = block.input_schema.get_credentials_fields_info() @@ -98,14 +109,33 @@ class RunBlockTool(BaseTool): available_creds = await creds_manager.store.get_all_creds(user_id) for field_name, field_info in credentials_fields_info.items(): - # field_info.provider is a frozenset of acceptable providers - # field_info.supported_types is a frozenset of acceptable types + effective_field_info = field_info + if field_info.discriminator and field_info.discriminator_mapping: + # Get discriminator from input, falling back to schema default + discriminator_value = input_data.get(field_info.discriminator) + if discriminator_value is None: + field = block.input_schema.model_fields.get( + field_info.discriminator + ) + if field and field.default is not PydanticUndefined: + discriminator_value = field.default + + if ( + discriminator_value + and discriminator_value in field_info.discriminator_mapping + ): + effective_field_info = field_info.discriminate(discriminator_value) + logger.debug( + f"Discriminated provider for {field_name}: " + f"{discriminator_value} -> {effective_field_info.provider}" + ) + matching_cred = next( ( cred for cred in available_creds - if cred.provider in field_info.provider - and cred.type in field_info.supported_types + if cred.provider in effective_field_info.provider + and cred.type in effective_field_info.supported_types ), None, ) @@ -119,8 +149,8 @@ class RunBlockTool(BaseTool): ) else: # Create a placeholder for the missing credential - provider = next(iter(field_info.provider), "unknown") - cred_type = next(iter(field_info.supported_types), "api_key") + provider = next(iter(effective_field_info.provider), "unknown") + cred_type = next(iter(effective_field_info.supported_types), "api_key") missing_credentials.append( CredentialsMetaInput( id=field_name, @@ -198,10 +228,9 @@ class RunBlockTool(BaseTool): logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}") - # Check credentials creds_manager = IntegrationCredentialsManager() matched_credentials, missing_credentials = await self._check_block_credentials( - user_id, block + user_id, block, input_data ) if missing_credentials: @@ -237,11 +266,48 @@ class RunBlockTool(BaseTool): ) try: - # Fetch actual credentials and prepare kwargs for block execution - # Create execution context with defaults (blocks may require it) + # Get or create user's workspace for CoPilot file operations + workspace = await get_or_create_workspace(user_id) + + # Generate synthetic IDs for CoPilot context + # Each chat session is treated as its own agent with one continuous run + # This means: + # - graph_id (agent) = session (memories scoped to session when limit_to_agent=True) + # - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True) + # - node_exec_id = unique per block execution + synthetic_graph_id = f"copilot-session-{session.session_id}" + synthetic_graph_exec_id = f"copilot-session-{session.session_id}" + synthetic_node_id = f"copilot-node-{block_id}" + synthetic_node_exec_id = ( + f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}" + ) + + # Create unified execution context with all required fields + execution_context = ExecutionContext( + # Execution identity + user_id=user_id, + graph_id=synthetic_graph_id, + graph_exec_id=synthetic_graph_exec_id, + graph_version=1, # Versions are 1-indexed + node_id=synthetic_node_id, + node_exec_id=synthetic_node_exec_id, + # Workspace with session scoping + workspace_id=workspace.id, + session_id=session.session_id, + ) + + # Prepare kwargs for block execution + # Keep individual kwargs for backwards compatibility with existing blocks exec_kwargs: dict[str, Any] = { "user_id": user_id, - "execution_context": ExecutionContext(), + "execution_context": execution_context, + # Legacy: individual kwargs for blocks not yet using execution_context + "workspace_id": workspace.id, + "graph_exec_id": synthetic_graph_exec_id, + "node_exec_id": synthetic_node_exec_id, + "node_id": synthetic_node_id, + "graph_version": 1, # Versions are 1-indexed + "graph_id": synthetic_graph_id, } for field_name, cred_meta in matched_credentials.items(): diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py index a2ac91dc65..bd25594b8a 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py @@ -8,7 +8,12 @@ from backend.api.features.library import model as library_model from backend.api.features.store import db as store_db from backend.data import graph as graph_db from backend.data.graph import GraphModel -from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput +from backend.data.model import ( + CredentialsFieldInfo, + CredentialsMetaInput, + HostScopedCredentials, + OAuth2Credentials, +) from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.util.exceptions import NotFoundError @@ -266,13 +271,21 @@ async def match_user_credentials_to_graph( credential_requirements, _node_fields, ) in aggregated_creds.items(): - # Find first matching credential by provider and type + # Find first matching credential by provider, type, and scopes matching_cred = next( ( cred for cred in available_creds if cred.provider in credential_requirements.provider and cred.type in credential_requirements.supported_types + and ( + cred.type != "oauth2" + or _credential_has_required_scopes(cred, credential_requirements) + ) + and ( + cred.type != "host_scoped" + or _credential_is_for_host(cred, credential_requirements) + ) ), None, ) @@ -296,10 +309,17 @@ async def match_user_credentials_to_graph( f"{credential_field_name} (validation failed: {e})" ) else: + # Build a helpful error message including scope requirements + error_parts = [ + f"provider in {list(credential_requirements.provider)}", + f"type in {list(credential_requirements.supported_types)}", + ] + if credential_requirements.required_scopes: + error_parts.append( + f"scopes including {list(credential_requirements.required_scopes)}" + ) missing_creds.append( - f"{credential_field_name} " - f"(requires provider in {list(credential_requirements.provider)}, " - f"type in {list(credential_requirements.supported_types)})" + f"{credential_field_name} (requires {', '.join(error_parts)})" ) logger.info( @@ -309,6 +329,35 @@ async def match_user_credentials_to_graph( return graph_credentials_inputs, missing_creds +def _credential_has_required_scopes( + credential: OAuth2Credentials, + requirements: CredentialsFieldInfo, +) -> bool: + """Check if an OAuth2 credential has all the scopes required by the input.""" + # If no scopes are required, any credential matches + if not requirements.required_scopes: + return True + + # Check that credential scopes are a superset of required scopes + return set(credential.scopes).issuperset(requirements.required_scopes) + + +def _credential_is_for_host( + credential: HostScopedCredentials, + requirements: CredentialsFieldInfo, +) -> bool: + """Check if a host-scoped credential matches the host required by the input.""" + # We need to know the host to match host-scoped credentials to. + # Graph.aggregate_credentials_inputs() adds the node's set URL value (if any) + # to discriminator_values. No discriminator_values -> no host to match against. + if not requirements.discriminator_values: + return True + + # Check that credential host matches required host. + # Host-scoped credential inputs are grouped by host, so any item from the set works. + return credential.matches_url(list(requirements.discriminator_values)[0]) + + async def check_user_has_required_credentials( user_id: str, required_credentials: list[CredentialsMetaInput], diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py b/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py new file mode 100644 index 0000000000..03532c8fee --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py @@ -0,0 +1,620 @@ +"""CoPilot tools for workspace file operations.""" + +import base64 +import logging +from typing import Any, Optional + +from pydantic import BaseModel + +from backend.api.features.chat.model import ChatSession +from backend.data.workspace import get_or_create_workspace +from backend.util.settings import Config +from backend.util.virus_scanner import scan_content_safe +from backend.util.workspace import WorkspaceManager + +from .base import BaseTool +from .models import ErrorResponse, ResponseType, ToolResponseBase + +logger = logging.getLogger(__name__) + + +class WorkspaceFileInfoData(BaseModel): + """Data model for workspace file information (not a response itself).""" + + file_id: str + name: str + path: str + mime_type: str + size_bytes: int + + +class WorkspaceFileListResponse(ToolResponseBase): + """Response containing list of workspace files.""" + + type: ResponseType = ResponseType.WORKSPACE_FILE_LIST + files: list[WorkspaceFileInfoData] + total_count: int + + +class WorkspaceFileContentResponse(ToolResponseBase): + """Response containing workspace file content (legacy, for small text files).""" + + type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT + file_id: str + name: str + path: str + mime_type: str + content_base64: str + + +class WorkspaceFileMetadataResponse(ToolResponseBase): + """Response containing workspace file metadata and download URL (prevents context bloat).""" + + type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA + file_id: str + name: str + path: str + mime_type: str + size_bytes: int + download_url: str + preview: str | None = None # First 500 chars for text files + + +class WorkspaceWriteResponse(ToolResponseBase): + """Response after writing a file to workspace.""" + + type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN + file_id: str + name: str + path: str + size_bytes: int + + +class WorkspaceDeleteResponse(ToolResponseBase): + """Response after deleting a file from workspace.""" + + type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED + file_id: str + success: bool + + +class ListWorkspaceFilesTool(BaseTool): + """Tool for listing files in user's workspace.""" + + @property + def name(self) -> str: + return "list_workspace_files" + + @property + def description(self) -> str: + return ( + "List files in the user's workspace. " + "Returns file names, paths, sizes, and metadata. " + "Optionally filter by path prefix." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "path_prefix": { + "type": "string", + "description": ( + "Optional path prefix to filter files " + "(e.g., '/documents/' to list only files in documents folder). " + "By default, only files from the current session are listed." + ), + }, + "limit": { + "type": "integer", + "description": "Maximum number of files to return (default 50, max 100)", + "minimum": 1, + "maximum": 100, + }, + "include_all_sessions": { + "type": "boolean", + "description": ( + "If true, list files from all sessions. " + "Default is false (only current session's files)." + ), + }, + }, + "required": [], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + session_id = session.session_id + + if not user_id: + return ErrorResponse( + message="Authentication required", + session_id=session_id, + ) + + path_prefix: Optional[str] = kwargs.get("path_prefix") + limit = min(kwargs.get("limit", 50), 100) + include_all_sessions: bool = kwargs.get("include_all_sessions", False) + + try: + workspace = await get_or_create_workspace(user_id) + # Pass session_id for session-scoped file access + manager = WorkspaceManager(user_id, workspace.id, session_id) + + files = await manager.list_files( + path=path_prefix, + limit=limit, + include_all_sessions=include_all_sessions, + ) + total = await manager.get_file_count( + path=path_prefix, + include_all_sessions=include_all_sessions, + ) + + file_infos = [ + WorkspaceFileInfoData( + file_id=f.id, + name=f.name, + path=f.path, + mime_type=f.mimeType, + size_bytes=f.sizeBytes, + ) + for f in files + ] + + scope_msg = "all sessions" if include_all_sessions else "current session" + return WorkspaceFileListResponse( + files=file_infos, + total_count=total, + message=f"Found {len(files)} files in workspace ({scope_msg})", + session_id=session_id, + ) + + except Exception as e: + logger.error(f"Error listing workspace files: {e}", exc_info=True) + return ErrorResponse( + message=f"Failed to list workspace files: {str(e)}", + error=str(e), + session_id=session_id, + ) + + +class ReadWorkspaceFileTool(BaseTool): + """Tool for reading file content from workspace.""" + + # Size threshold for returning full content vs metadata+URL + # Files larger than this return metadata with download URL to prevent context bloat + MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB + # Preview size for text files + PREVIEW_SIZE = 500 + + @property + def name(self) -> str: + return "read_workspace_file" + + @property + def description(self) -> str: + return ( + "Read a file from the user's workspace. " + "Specify either file_id or path to identify the file. " + "For small text files, returns content directly. " + "For large or binary files, returns metadata and a download URL. " + "Paths are scoped to the current session by default. " + "Use /sessions//... for cross-session access." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "file_id": { + "type": "string", + "description": "The file's unique ID (from list_workspace_files)", + }, + "path": { + "type": "string", + "description": ( + "The virtual file path (e.g., '/documents/report.pdf'). " + "Scoped to current session by default." + ), + }, + "force_download_url": { + "type": "boolean", + "description": ( + "If true, always return metadata+URL instead of inline content. " + "Default is false (auto-selects based on file size/type)." + ), + }, + }, + "required": [], # At least one must be provided + } + + @property + def requires_auth(self) -> bool: + return True + + def _is_text_mime_type(self, mime_type: str) -> bool: + """Check if the MIME type is a text-based type.""" + text_types = [ + "text/", + "application/json", + "application/xml", + "application/javascript", + "application/x-python", + "application/x-sh", + ] + return any(mime_type.startswith(t) for t in text_types) + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + session_id = session.session_id + + if not user_id: + return ErrorResponse( + message="Authentication required", + session_id=session_id, + ) + + file_id: Optional[str] = kwargs.get("file_id") + path: Optional[str] = kwargs.get("path") + force_download_url: bool = kwargs.get("force_download_url", False) + + if not file_id and not path: + return ErrorResponse( + message="Please provide either file_id or path", + session_id=session_id, + ) + + try: + workspace = await get_or_create_workspace(user_id) + # Pass session_id for session-scoped file access + manager = WorkspaceManager(user_id, workspace.id, session_id) + + # Get file info + if file_id: + file_info = await manager.get_file_info(file_id) + if file_info is None: + return ErrorResponse( + message=f"File not found: {file_id}", + session_id=session_id, + ) + target_file_id = file_id + else: + # path is guaranteed to be non-None here due to the check above + assert path is not None + file_info = await manager.get_file_info_by_path(path) + if file_info is None: + return ErrorResponse( + message=f"File not found at path: {path}", + session_id=session_id, + ) + target_file_id = file_info.id + + # Decide whether to return inline content or metadata+URL + is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES + is_text_file = self._is_text_mime_type(file_info.mimeType) + + # Return inline content for small text files (unless force_download_url) + if is_small_file and is_text_file and not force_download_url: + content = await manager.read_file_by_id(target_file_id) + content_b64 = base64.b64encode(content).decode("utf-8") + + return WorkspaceFileContentResponse( + file_id=file_info.id, + name=file_info.name, + path=file_info.path, + mime_type=file_info.mimeType, + content_base64=content_b64, + message=f"Successfully read file: {file_info.name}", + session_id=session_id, + ) + + # Return metadata + workspace:// reference for large or binary files + # This prevents context bloat (100KB file = ~133KB as base64) + # Use workspace:// format so frontend urlTransform can add proxy prefix + download_url = f"workspace://{target_file_id}" + + # Generate preview for text files + preview: str | None = None + if is_text_file: + try: + content = await manager.read_file_by_id(target_file_id) + preview_text = content[: self.PREVIEW_SIZE].decode( + "utf-8", errors="replace" + ) + if len(content) > self.PREVIEW_SIZE: + preview_text += "..." + preview = preview_text + except Exception: + pass # Preview is optional + + return WorkspaceFileMetadataResponse( + file_id=file_info.id, + name=file_info.name, + path=file_info.path, + mime_type=file_info.mimeType, + size_bytes=file_info.sizeBytes, + download_url=download_url, + preview=preview, + message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.", + session_id=session_id, + ) + + except FileNotFoundError as e: + return ErrorResponse( + message=str(e), + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error reading workspace file: {e}", exc_info=True) + return ErrorResponse( + message=f"Failed to read workspace file: {str(e)}", + error=str(e), + session_id=session_id, + ) + + +class WriteWorkspaceFileTool(BaseTool): + """Tool for writing files to workspace.""" + + @property + def name(self) -> str: + return "write_workspace_file" + + @property + def description(self) -> str: + return ( + "Write or create a file in the user's workspace. " + "Provide the content as a base64-encoded string. " + f"Maximum file size is {Config().max_file_size_mb}MB. " + "Files are saved to the current session's folder by default. " + "Use /sessions//... for cross-session access." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "Name for the file (e.g., 'report.pdf')", + }, + "content_base64": { + "type": "string", + "description": "Base64-encoded file content", + }, + "path": { + "type": "string", + "description": ( + "Optional virtual path where to save the file " + "(e.g., '/documents/report.pdf'). " + "Defaults to '/{filename}'. Scoped to current session." + ), + }, + "mime_type": { + "type": "string", + "description": ( + "Optional MIME type of the file. " + "Auto-detected from filename if not provided." + ), + }, + "overwrite": { + "type": "boolean", + "description": "Whether to overwrite if file exists at path (default: false)", + }, + }, + "required": ["filename", "content_base64"], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + session_id = session.session_id + + if not user_id: + return ErrorResponse( + message="Authentication required", + session_id=session_id, + ) + + filename: str = kwargs.get("filename", "") + content_b64: str = kwargs.get("content_base64", "") + path: Optional[str] = kwargs.get("path") + mime_type: Optional[str] = kwargs.get("mime_type") + overwrite: bool = kwargs.get("overwrite", False) + + if not filename: + return ErrorResponse( + message="Please provide a filename", + session_id=session_id, + ) + + if not content_b64: + return ErrorResponse( + message="Please provide content_base64", + session_id=session_id, + ) + + # Decode content + try: + content = base64.b64decode(content_b64) + except Exception: + return ErrorResponse( + message="Invalid base64-encoded content", + session_id=session_id, + ) + + # Check size + max_file_size = Config().max_file_size_mb * 1024 * 1024 + if len(content) > max_file_size: + return ErrorResponse( + message=f"File too large. Maximum size is {Config().max_file_size_mb}MB", + session_id=session_id, + ) + + try: + # Virus scan + await scan_content_safe(content, filename=filename) + + workspace = await get_or_create_workspace(user_id) + # Pass session_id for session-scoped file access + manager = WorkspaceManager(user_id, workspace.id, session_id) + + file_record = await manager.write_file( + content=content, + filename=filename, + path=path, + mime_type=mime_type, + overwrite=overwrite, + ) + + return WorkspaceWriteResponse( + file_id=file_record.id, + name=file_record.name, + path=file_record.path, + size_bytes=file_record.sizeBytes, + message=f"Successfully wrote file: {file_record.name}", + session_id=session_id, + ) + + except ValueError as e: + return ErrorResponse( + message=str(e), + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error writing workspace file: {e}", exc_info=True) + return ErrorResponse( + message=f"Failed to write workspace file: {str(e)}", + error=str(e), + session_id=session_id, + ) + + +class DeleteWorkspaceFileTool(BaseTool): + """Tool for deleting files from workspace.""" + + @property + def name(self) -> str: + return "delete_workspace_file" + + @property + def description(self) -> str: + return ( + "Delete a file from the user's workspace. " + "Specify either file_id or path to identify the file. " + "Paths are scoped to the current session by default. " + "Use /sessions//... for cross-session access." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "file_id": { + "type": "string", + "description": "The file's unique ID (from list_workspace_files)", + }, + "path": { + "type": "string", + "description": ( + "The virtual file path (e.g., '/documents/report.pdf'). " + "Scoped to current session by default." + ), + }, + }, + "required": [], # At least one must be provided + } + + @property + def requires_auth(self) -> bool: + return True + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + session_id = session.session_id + + if not user_id: + return ErrorResponse( + message="Authentication required", + session_id=session_id, + ) + + file_id: Optional[str] = kwargs.get("file_id") + path: Optional[str] = kwargs.get("path") + + if not file_id and not path: + return ErrorResponse( + message="Please provide either file_id or path", + session_id=session_id, + ) + + try: + workspace = await get_or_create_workspace(user_id) + # Pass session_id for session-scoped file access + manager = WorkspaceManager(user_id, workspace.id, session_id) + + # Determine the file_id to delete + target_file_id: str + if file_id: + target_file_id = file_id + else: + # path is guaranteed to be non-None here due to the check above + assert path is not None + file_info = await manager.get_file_info_by_path(path) + if file_info is None: + return ErrorResponse( + message=f"File not found at path: {path}", + session_id=session_id, + ) + target_file_id = file_info.id + + success = await manager.delete_file(target_file_id) + + if not success: + return ErrorResponse( + message=f"File not found: {target_file_id}", + session_id=session_id, + ) + + return WorkspaceDeleteResponse( + file_id=target_file_id, + success=True, + message="File deleted successfully", + session_id=session_id, + ) + + except Exception as e: + logger.error(f"Error deleting workspace file: {e}", exc_info=True) + return ErrorResponse( + message=f"Failed to delete workspace file: {str(e)}", + error=str(e), + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/api/features/library/db.py b/autogpt_platform/backend/backend/api/features/library/db.py index 872fe66b28..394f959953 100644 --- a/autogpt_platform/backend/backend/api/features/library/db.py +++ b/autogpt_platform/backend/backend/api/features/library/db.py @@ -39,6 +39,7 @@ async def list_library_agents( sort_by: library_model.LibraryAgentSort = library_model.LibraryAgentSort.UPDATED_AT, page: int = 1, page_size: int = 50, + include_executions: bool = False, ) -> library_model.LibraryAgentResponse: """ Retrieves a paginated list of LibraryAgent records for a given user. @@ -49,6 +50,9 @@ async def list_library_agents( sort_by: Sorting field (createdAt, updatedAt, isFavorite, isCreatedByUser). page: Current page (1-indexed). page_size: Number of items per page. + include_executions: Whether to include execution data for status calculation. + Defaults to False for performance (UI fetches status separately). + Set to True when accurate status/metrics are needed (e.g., agent generator). Returns: A LibraryAgentResponse containing the list of agents and pagination details. @@ -76,7 +80,6 @@ async def list_library_agents( "isArchived": False, } - # Build search filter if applicable if search_term: where_clause["OR"] = [ { @@ -93,7 +96,6 @@ async def list_library_agents( }, ] - # Determine sorting order_by: prisma.types.LibraryAgentOrderByInput | None = None if sort_by == library_model.LibraryAgentSort.CREATED_AT: @@ -105,7 +107,7 @@ async def list_library_agents( library_agents = await prisma.models.LibraryAgent.prisma().find_many( where=where_clause, include=library_agent_include( - user_id, include_nodes=False, include_executions=False + user_id, include_nodes=False, include_executions=include_executions ), order=order_by, skip=(page - 1) * page_size, diff --git a/autogpt_platform/backend/backend/api/features/library/model.py b/autogpt_platform/backend/backend/api/features/library/model.py index 14d7c7be81..c6bc0e0427 100644 --- a/autogpt_platform/backend/backend/api/features/library/model.py +++ b/autogpt_platform/backend/backend/api/features/library/model.py @@ -9,6 +9,7 @@ import pydantic from backend.data.block import BlockInput from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo from backend.data.model import CredentialsMetaInput, is_credentials_field_name +from backend.util.json import loads as json_loads from backend.util.models import Pagination if TYPE_CHECKING: @@ -16,10 +17,10 @@ if TYPE_CHECKING: class LibraryAgentStatus(str, Enum): - COMPLETED = "COMPLETED" # All runs completed - HEALTHY = "HEALTHY" # Agent is running (not all runs have completed) - WAITING = "WAITING" # Agent is queued or waiting to start - ERROR = "ERROR" # Agent is in an error state + COMPLETED = "COMPLETED" + HEALTHY = "HEALTHY" + WAITING = "WAITING" + ERROR = "ERROR" class MarketplaceListingCreator(pydantic.BaseModel): @@ -39,6 +40,30 @@ class MarketplaceListing(pydantic.BaseModel): creator: MarketplaceListingCreator +class RecentExecution(pydantic.BaseModel): + """Summary of a recent execution for quality assessment. + + Used by the LLM to understand the agent's recent performance with specific examples + rather than just aggregate statistics. + """ + + status: str + correctness_score: float | None = None + activity_summary: str | None = None + + +def _parse_settings(settings: dict | str | None) -> GraphSettings: + """Parse settings from database, handling both dict and string formats.""" + if settings is None: + return GraphSettings() + try: + if isinstance(settings, str): + settings = json_loads(settings) + return GraphSettings.model_validate(settings) + except Exception: + return GraphSettings() + + class LibraryAgent(pydantic.BaseModel): """ Represents an agent in the library, including metadata for display and @@ -48,7 +73,7 @@ class LibraryAgent(pydantic.BaseModel): id: str graph_id: str graph_version: int - owner_user_id: str # ID of user who owns/created this agent graph + owner_user_id: str image_url: str | None @@ -64,7 +89,7 @@ class LibraryAgent(pydantic.BaseModel): description: str instructions: str | None = None - input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend + input_schema: dict[str, Any] output_schema: dict[str, Any] credentials_input_schema: dict[str, Any] | None = pydantic.Field( description="Input schema for credentials required by the agent", @@ -81,25 +106,19 @@ class LibraryAgent(pydantic.BaseModel): ) trigger_setup_info: Optional[GraphTriggerInfo] = None - # Indicates whether there's a new output (based on recent runs) new_output: bool - - # Whether the user can access the underlying graph + execution_count: int = 0 + success_rate: float | None = None + avg_correctness_score: float | None = None + recent_executions: list[RecentExecution] = pydantic.Field( + default_factory=list, + description="List of recent executions with status, score, and summary", + ) can_access_graph: bool - - # Indicates if this agent is the latest version is_latest_version: bool - - # Whether the agent is marked as favorite by the user is_favorite: bool - - # Recommended schedule cron (from marketplace agents) recommended_schedule_cron: str | None = None - - # User-specific settings for this library agent settings: GraphSettings = pydantic.Field(default_factory=GraphSettings) - - # Marketplace listing information if the agent has been published marketplace_listing: Optional["MarketplaceListing"] = None @staticmethod @@ -123,7 +142,6 @@ class LibraryAgent(pydantic.BaseModel): agent_updated_at = agent.AgentGraph.updatedAt lib_agent_updated_at = agent.updatedAt - # Compute updated_at as the latest between library agent and graph updated_at = ( max(agent_updated_at, lib_agent_updated_at) if agent_updated_at @@ -136,7 +154,6 @@ class LibraryAgent(pydantic.BaseModel): creator_name = agent.Creator.name or "Unknown" creator_image_url = agent.Creator.avatarUrl or "" - # Logic to calculate status and new_output week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta( days=7 ) @@ -145,13 +162,55 @@ class LibraryAgent(pydantic.BaseModel): status = status_result.status new_output = status_result.new_output - # Check if user can access the graph - can_access_graph = agent.AgentGraph.userId == agent.userId + execution_count = len(executions) + success_rate: float | None = None + avg_correctness_score: float | None = None + if execution_count > 0: + success_count = sum( + 1 + for e in executions + if e.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED + ) + success_rate = (success_count / execution_count) * 100 - # Hard-coded to True until a method to check is implemented + correctness_scores = [] + for e in executions: + if e.stats and isinstance(e.stats, dict): + score = e.stats.get("correctness_score") + if score is not None and isinstance(score, (int, float)): + correctness_scores.append(float(score)) + if correctness_scores: + avg_correctness_score = sum(correctness_scores) / len( + correctness_scores + ) + + recent_executions: list[RecentExecution] = [] + for e in executions: + exec_score: float | None = None + exec_summary: str | None = None + if e.stats and isinstance(e.stats, dict): + score = e.stats.get("correctness_score") + if score is not None and isinstance(score, (int, float)): + exec_score = float(score) + summary = e.stats.get("activity_status") + if summary is not None and isinstance(summary, str): + exec_summary = summary + exec_status = ( + e.executionStatus.value + if hasattr(e.executionStatus, "value") + else str(e.executionStatus) + ) + recent_executions.append( + RecentExecution( + status=exec_status, + correctness_score=exec_score, + activity_summary=exec_summary, + ) + ) + + can_access_graph = agent.AgentGraph.userId == agent.userId is_latest_version = True - # Build marketplace_listing if available marketplace_listing_data = None if store_listing and store_listing.ActiveVersion and profile: creator_data = MarketplaceListingCreator( @@ -190,11 +249,15 @@ class LibraryAgent(pydantic.BaseModel): has_sensitive_action=graph.has_sensitive_action, trigger_setup_info=graph.trigger_setup_info, new_output=new_output, + execution_count=execution_count, + success_rate=success_rate, + avg_correctness_score=avg_correctness_score, + recent_executions=recent_executions, can_access_graph=can_access_graph, is_latest_version=is_latest_version, is_favorite=agent.isFavorite, recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron, - settings=GraphSettings.model_validate(agent.settings), + settings=_parse_settings(agent.settings), marketplace_listing=marketplace_listing_data, ) @@ -220,18 +283,15 @@ def _calculate_agent_status( if not executions: return AgentStatusResult(status=LibraryAgentStatus.COMPLETED, new_output=False) - # Track how many times each execution status appears status_counts = {status: 0 for status in prisma.enums.AgentExecutionStatus} new_output = False for execution in executions: - # Check if there's a completed run more recent than `recent_threshold` if execution.createdAt >= recent_threshold: if execution.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED: new_output = True status_counts[execution.executionStatus] += 1 - # Determine the final status based on counts if status_counts[prisma.enums.AgentExecutionStatus.FAILED] > 0: return AgentStatusResult(status=LibraryAgentStatus.ERROR, new_output=new_output) elif status_counts[prisma.enums.AgentExecutionStatus.QUEUED] > 0: diff --git a/autogpt_platform/backend/backend/api/features/store/db.py b/autogpt_platform/backend/backend/api/features/store/db.py index 956fdfa7da..850a2bc3e9 100644 --- a/autogpt_platform/backend/backend/api/features/store/db.py +++ b/autogpt_platform/backend/backend/api/features/store/db.py @@ -112,6 +112,7 @@ async def get_store_agents( description=agent["description"], runs=agent["runs"], rating=agent["rating"], + agent_graph_id=agent.get("agentGraphId", ""), ) store_agents.append(store_agent) except Exception as e: @@ -170,6 +171,7 @@ async def get_store_agents( description=agent.description, runs=agent.runs, rating=agent.rating, + agent_graph_id=agent.agentGraphId, ) # Add to the list only if creation was successful store_agents.append(store_agent) diff --git a/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py b/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py index bae5b97cd6..86af457f50 100644 --- a/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py +++ b/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py @@ -454,6 +454,9 @@ async def test_unified_hybrid_search_pagination( cleanup_embeddings: list, ): """Test unified search pagination works correctly.""" + # Use a unique search term to avoid matching other test data + unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}" + # Create multiple items content_ids = [] for i in range(5): @@ -465,14 +468,14 @@ async def test_unified_hybrid_search_pagination( content_type=ContentType.BLOCK, content_id=content_id, embedding=mock_embedding, - searchable_text=f"pagination test item number {i}", + searchable_text=f"{unique_term} item number {i}", metadata={"index": i}, user_id=None, ) # Get first page page1_results, total1 = await unified_hybrid_search( - query="pagination test", + query=unique_term, content_types=[ContentType.BLOCK], page=1, page_size=2, @@ -480,7 +483,7 @@ async def test_unified_hybrid_search_pagination( # Get second page page2_results, total2 = await unified_hybrid_search( - query="pagination test", + query=unique_term, content_types=[ContentType.BLOCK], page=2, page_size=2, diff --git a/autogpt_platform/backend/backend/api/features/store/hybrid_search.py b/autogpt_platform/backend/backend/api/features/store/hybrid_search.py index 8b0884bb24..e1b8f402c8 100644 --- a/autogpt_platform/backend/backend/api/features/store/hybrid_search.py +++ b/autogpt_platform/backend/backend/api/features/store/hybrid_search.py @@ -600,6 +600,7 @@ async def hybrid_search( sa.featured, sa.is_available, sa.updated_at, + sa."agentGraphId", -- Searchable text for BM25 reranking COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text, -- Semantic score @@ -659,6 +660,7 @@ async def hybrid_search( featured, is_available, updated_at, + "agentGraphId", searchable_text, semantic_score, lexical_score, diff --git a/autogpt_platform/backend/backend/api/features/store/model.py b/autogpt_platform/backend/backend/api/features/store/model.py index a3310b96fc..d66b91807d 100644 --- a/autogpt_platform/backend/backend/api/features/store/model.py +++ b/autogpt_platform/backend/backend/api/features/store/model.py @@ -38,6 +38,7 @@ class StoreAgent(pydantic.BaseModel): description: str runs: int rating: float + agent_graph_id: str class StoreAgentsResponse(pydantic.BaseModel): diff --git a/autogpt_platform/backend/backend/api/features/store/model_test.py b/autogpt_platform/backend/backend/api/features/store/model_test.py index fd09a0cf77..c4109f4603 100644 --- a/autogpt_platform/backend/backend/api/features/store/model_test.py +++ b/autogpt_platform/backend/backend/api/features/store/model_test.py @@ -26,11 +26,13 @@ def test_store_agent(): description="Test description", runs=50, rating=4.5, + agent_graph_id="test-graph-id", ) assert agent.slug == "test-agent" assert agent.agent_name == "Test Agent" assert agent.runs == 50 assert agent.rating == 4.5 + assert agent.agent_graph_id == "test-graph-id" def test_store_agents_response(): @@ -46,6 +48,7 @@ def test_store_agents_response(): description="Test description", runs=50, rating=4.5, + agent_graph_id="test-graph-id", ) ], pagination=store_model.Pagination( diff --git a/autogpt_platform/backend/backend/api/features/store/routes_test.py b/autogpt_platform/backend/backend/api/features/store/routes_test.py index 36431c20ec..fcef3f845a 100644 --- a/autogpt_platform/backend/backend/api/features/store/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/store/routes_test.py @@ -82,6 +82,7 @@ def test_get_agents_featured( description="Featured agent description", runs=100, rating=4.5, + agent_graph_id="test-graph-1", ) ], pagination=store_model.Pagination( @@ -127,6 +128,7 @@ def test_get_agents_by_creator( description="Creator agent description", runs=50, rating=4.0, + agent_graph_id="test-graph-2", ) ], pagination=store_model.Pagination( @@ -172,6 +174,7 @@ def test_get_agents_sorted( description="Top agent description", runs=1000, rating=5.0, + agent_graph_id="test-graph-3", ) ], pagination=store_model.Pagination( @@ -217,6 +220,7 @@ def test_get_agents_search( description="Specific search term description", runs=75, rating=4.2, + agent_graph_id="test-graph-search", ) ], pagination=store_model.Pagination( @@ -262,6 +266,7 @@ def test_get_agents_category( description="Category agent description", runs=60, rating=4.1, + agent_graph_id="test-graph-category", ) ], pagination=store_model.Pagination( @@ -306,6 +311,7 @@ def test_get_agents_pagination( description=f"Agent {i} description", runs=i * 10, rating=4.0, + agent_graph_id="test-graph-2", ) for i in range(5) ], diff --git a/autogpt_platform/backend/backend/api/features/store/test_cache_delete.py b/autogpt_platform/backend/backend/api/features/store/test_cache_delete.py index dd9be1f4ab..298c51d47c 100644 --- a/autogpt_platform/backend/backend/api/features/store/test_cache_delete.py +++ b/autogpt_platform/backend/backend/api/features/store/test_cache_delete.py @@ -33,6 +33,7 @@ class TestCacheDeletion: description="Test description", runs=100, rating=4.5, + agent_graph_id="test-graph-id", ) ], pagination=Pagination( diff --git a/autogpt_platform/backend/backend/api/features/workspace/__init__.py b/autogpt_platform/backend/backend/api/features/workspace/__init__.py new file mode 100644 index 0000000000..688ada9937 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/workspace/__init__.py @@ -0,0 +1 @@ +# Workspace API feature module diff --git a/autogpt_platform/backend/backend/api/features/workspace/routes.py b/autogpt_platform/backend/backend/api/features/workspace/routes.py new file mode 100644 index 0000000000..b6d0c84572 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/workspace/routes.py @@ -0,0 +1,122 @@ +""" +Workspace API routes for managing user file storage. +""" + +import logging +import re +from typing import Annotated +from urllib.parse import quote + +import fastapi +from autogpt_libs.auth.dependencies import get_user_id, requires_user +from fastapi.responses import Response + +from backend.data.workspace import get_workspace, get_workspace_file +from backend.util.workspace_storage import get_workspace_storage + + +def _sanitize_filename_for_header(filename: str) -> str: + """ + Sanitize filename for Content-Disposition header to prevent header injection. + + Removes/replaces characters that could break the header or inject new headers. + Uses RFC5987 encoding for non-ASCII characters. + """ + # Remove CR, LF, and null bytes (header injection prevention) + sanitized = re.sub(r"[\r\n\x00]", "", filename) + # Escape quotes + sanitized = sanitized.replace('"', '\\"') + # For non-ASCII, use RFC5987 filename* parameter + # Check if filename has non-ASCII characters + try: + sanitized.encode("ascii") + return f'attachment; filename="{sanitized}"' + except UnicodeEncodeError: + # Use RFC5987 encoding for UTF-8 filenames + encoded = quote(sanitized, safe="") + return f"attachment; filename*=UTF-8''{encoded}" + + +logger = logging.getLogger(__name__) + +router = fastapi.APIRouter( + dependencies=[fastapi.Security(requires_user)], +) + + +def _create_streaming_response(content: bytes, file) -> Response: + """Create a streaming response for file content.""" + return Response( + content=content, + media_type=file.mimeType, + headers={ + "Content-Disposition": _sanitize_filename_for_header(file.name), + "Content-Length": str(len(content)), + }, + ) + + +async def _create_file_download_response(file) -> Response: + """ + Create a download response for a workspace file. + + Handles both local storage (direct streaming) and GCS (signed URL redirect + with fallback to streaming). + """ + storage = await get_workspace_storage() + + # For local storage, stream the file directly + if file.storagePath.startswith("local://"): + content = await storage.retrieve(file.storagePath) + return _create_streaming_response(content, file) + + # For GCS, try to redirect to signed URL, fall back to streaming + try: + url = await storage.get_download_url(file.storagePath, expires_in=300) + # If we got back an API path (fallback), stream directly instead + if url.startswith("/api/"): + content = await storage.retrieve(file.storagePath) + return _create_streaming_response(content, file) + return fastapi.responses.RedirectResponse(url=url, status_code=302) + except Exception as e: + # Log the signed URL failure with context + logger.error( + f"Failed to get signed URL for file {file.id} " + f"(storagePath={file.storagePath}): {e}", + exc_info=True, + ) + # Fall back to streaming directly from GCS + try: + content = await storage.retrieve(file.storagePath) + return _create_streaming_response(content, file) + except Exception as fallback_error: + logger.error( + f"Fallback streaming also failed for file {file.id} " + f"(storagePath={file.storagePath}): {fallback_error}", + exc_info=True, + ) + raise + + +@router.get( + "/files/{file_id}/download", + summary="Download file by ID", +) +async def download_file( + user_id: Annotated[str, fastapi.Security(get_user_id)], + file_id: str, +) -> Response: + """ + Download a file by its ID. + + Returns the file content directly or redirects to a signed URL for GCS. + """ + workspace = await get_workspace(user_id) + if workspace is None: + raise fastapi.HTTPException(status_code=404, detail="Workspace not found") + + file = await get_workspace_file(file_id, workspace.id) + if file is None: + raise fastapi.HTTPException(status_code=404, detail="File not found") + + return await _create_file_download_response(file) diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index e9556e992f..0eef76193e 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -32,6 +32,7 @@ import backend.api.features.postmark.postmark import backend.api.features.store.model import backend.api.features.store.routes import backend.api.features.v1 +import backend.api.features.workspace.routes as workspace_routes import backend.data.block import backend.data.db import backend.data.graph @@ -39,6 +40,10 @@ import backend.data.user import backend.integrations.webhooks.utils import backend.util.service import backend.util.settings +from backend.api.features.chat.completion_consumer import ( + start_completion_consumer, + stop_completion_consumer, +) from backend.blocks.llm import DEFAULT_LLM_MODEL from backend.data.model import Credentials from backend.integrations.providers import ProviderName @@ -52,6 +57,7 @@ from backend.util.exceptions import ( ) from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly from backend.util.service import UnhealthyServiceError +from backend.util.workspace_storage import shutdown_workspace_storage from .external.fastapi_app import external_api from .features.analytics import router as analytics_router @@ -116,14 +122,31 @@ async def lifespan_context(app: fastapi.FastAPI): await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL) await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs() + # Start chat completion consumer for Redis Streams notifications + try: + await start_completion_consumer() + except Exception as e: + logger.warning(f"Could not start chat completion consumer: {e}") + with launch_darkly_context(): yield + # Stop chat completion consumer + try: + await stop_completion_consumer() + except Exception as e: + logger.warning(f"Error stopping chat completion consumer: {e}") + try: await shutdown_cloud_storage_handler() except Exception as e: logger.warning(f"Error shutting down cloud storage handler: {e}") + try: + await shutdown_workspace_storage() + except Exception as e: + logger.warning(f"Error shutting down workspace storage: {e}") + await backend.data.db.disconnect() @@ -315,6 +338,11 @@ app.include_router( tags=["v2", "chat"], prefix="/api/chat", ) +app.include_router( + workspace_routes.router, + tags=["workspace"], + prefix="/api/workspace", +) app.include_router( backend.api.features.oauth.router, tags=["oauth"], diff --git a/autogpt_platform/backend/backend/api/ws_api.py b/autogpt_platform/backend/backend/api/ws_api.py index b71fdb3526..e254d4b4db 100644 --- a/autogpt_platform/backend/backend/api/ws_api.py +++ b/autogpt_platform/backend/backend/api/ws_api.py @@ -66,18 +66,24 @@ async def event_broadcaster(manager: ConnectionManager): execution_bus = AsyncRedisExecutionEventBus() notification_bus = AsyncRedisNotificationEventBus() - async def execution_worker(): - async for event in execution_bus.listen("*"): - await manager.send_execution_update(event) + try: - async def notification_worker(): - async for notification in notification_bus.listen("*"): - await manager.send_notification( - user_id=notification.user_id, - payload=notification.payload, - ) + async def execution_worker(): + async for event in execution_bus.listen("*"): + await manager.send_execution_update(event) - await asyncio.gather(execution_worker(), notification_worker()) + async def notification_worker(): + async for notification in notification_bus.listen("*"): + await manager.send_notification( + user_id=notification.user_id, + payload=notification.payload, + ) + + await asyncio.gather(execution_worker(), notification_worker()) + finally: + # Ensure PubSub connections are closed on any exit to prevent leaks + await execution_bus.close() + await notification_bus.close() async def authenticate_websocket(websocket: WebSocket) -> str: diff --git a/autogpt_platform/backend/backend/blocks/ai_image_customizer.py b/autogpt_platform/backend/backend/blocks/ai_image_customizer.py index 83178e924d..91be33a60e 100644 --- a/autogpt_platform/backend/backend/blocks/ai_image_customizer.py +++ b/autogpt_platform/backend/backend/blocks/ai_image_customizer.py @@ -13,6 +13,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import ( APIKeyCredentials, CredentialsField, @@ -117,11 +118,13 @@ class AIImageCustomizerBlock(Block): "credentials": TEST_CREDENTIALS_INPUT, }, test_output=[ - ("image_url", "https://replicate.delivery/generated-image.jpg"), + # Output will be a workspace ref or data URI depending on context + ("image_url", lambda x: x.startswith(("workspace://", "data:"))), ], test_mock={ + # Use data URI to avoid HTTP requests during tests "run_model": lambda *args, **kwargs: MediaFileType( - "https://replicate.delivery/generated-image.jpg" + "" ), }, test_credentials=TEST_CREDENTIALS, @@ -132,8 +135,7 @@ class AIImageCustomizerBlock(Block): input_data: Input, *, credentials: APIKeyCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: try: @@ -141,10 +143,9 @@ class AIImageCustomizerBlock(Block): processed_images = await asyncio.gather( *( store_media_file( - graph_exec_id=graph_exec_id, file=img, - user_id=user_id, - return_content=True, + execution_context=execution_context, + return_format="for_external_api", # Get content for Replicate API ) for img in input_data.images ) @@ -158,7 +159,14 @@ class AIImageCustomizerBlock(Block): aspect_ratio=input_data.aspect_ratio.value, output_format=input_data.output_format.value, ) - yield "image_url", result + + # Store the generated image to the user's workspace for persistence + stored_url = await store_media_file( + file=result, + execution_context=execution_context, + return_format="for_block_output", + ) + yield "image_url", stored_url except Exception as e: yield "error", str(e) diff --git a/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py b/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py index 8c7b6e6102..e40731cd97 100644 --- a/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py +++ b/autogpt_platform/backend/backend/blocks/ai_image_generator_block.py @@ -6,6 +6,7 @@ from replicate.client import Client as ReplicateClient from replicate.helpers import FileOutput from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput +from backend.data.execution import ExecutionContext from backend.data.model import ( APIKeyCredentials, CredentialsField, @@ -13,6 +14,8 @@ from backend.data.model import ( SchemaField, ) from backend.integrations.providers import ProviderName +from backend.util.file import store_media_file +from backend.util.type import MediaFileType class ImageSize(str, Enum): @@ -165,11 +168,13 @@ class AIImageGeneratorBlock(Block): test_output=[ ( "image_url", - "https://replicate.delivery/generated-image.webp", + # Test output is a data URI since we now store images + lambda x: x.startswith("" }, ) @@ -318,11 +323,24 @@ class AIImageGeneratorBlock(Block): style_text = style_map.get(style, "") return f"{style_text} of" if style_text else "" - async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs): + async def run( + self, + input_data: Input, + *, + credentials: APIKeyCredentials, + execution_context: ExecutionContext, + **kwargs, + ): try: url = await self.generate_image(input_data, credentials) if url: - yield "image_url", url + # Store the generated image to the user's workspace/execution folder + stored_url = await store_media_file( + file=MediaFileType(url), + execution_context=execution_context, + return_format="for_block_output", + ) + yield "image_url", stored_url else: yield "error", "Image generation returned an empty result." except Exception as e: diff --git a/autogpt_platform/backend/backend/blocks/ai_shortform_video_block.py b/autogpt_platform/backend/backend/blocks/ai_shortform_video_block.py index a9e96890d3..eb60843185 100644 --- a/autogpt_platform/backend/backend/blocks/ai_shortform_video_block.py +++ b/autogpt_platform/backend/backend/blocks/ai_shortform_video_block.py @@ -13,6 +13,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import ( APIKeyCredentials, CredentialsField, @@ -21,7 +22,9 @@ from backend.data.model import ( ) from backend.integrations.providers import ProviderName from backend.util.exceptions import BlockExecutionError +from backend.util.file import store_media_file from backend.util.request import Requests +from backend.util.type import MediaFileType TEST_CREDENTIALS = APIKeyCredentials( id="01234567-89ab-cdef-0123-456789abcdef", @@ -271,7 +274,10 @@ class AIShortformVideoCreatorBlock(Block): "voice": Voice.LILY, "video_style": VisualMediaType.STOCK_VIDEOS, }, - test_output=("video_url", "https://example.com/video.mp4"), + test_output=( + "video_url", + lambda x: x.startswith(("workspace://", "data:")), + ), test_mock={ "create_webhook": lambda *args, **kwargs: ( "test_uuid", @@ -280,15 +286,21 @@ class AIShortformVideoCreatorBlock(Block): "create_video": lambda *args, **kwargs: {"pid": "test_pid"}, "check_video_status": lambda *args, **kwargs: { "status": "ready", - "videoUrl": "https://example.com/video.mp4", + "videoUrl": "data:video/mp4;base64,AAAA", }, - "wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4", + # Use data URI to avoid HTTP requests during tests + "wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA", }, test_credentials=TEST_CREDENTIALS, ) async def run( - self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs + self, + input_data: Input, + *, + credentials: APIKeyCredentials, + execution_context: ExecutionContext, + **kwargs, ) -> BlockOutput: # Create a new Webhook.site URL webhook_token, webhook_url = await self.create_webhook() @@ -340,7 +352,13 @@ class AIShortformVideoCreatorBlock(Block): ) video_url = await self.wait_for_video(credentials.api_key, pid) logger.debug(f"Video ready: {video_url}") - yield "video_url", video_url + # Store the generated video to the user's workspace for persistence + stored_url = await store_media_file( + file=MediaFileType(video_url), + execution_context=execution_context, + return_format="for_block_output", + ) + yield "video_url", stored_url class AIAdMakerVideoCreatorBlock(Block): @@ -447,7 +465,10 @@ class AIAdMakerVideoCreatorBlock(Block): "https://cdn.revid.ai/uploads/1747076315114-image.png", ], }, - test_output=("video_url", "https://example.com/ad.mp4"), + test_output=( + "video_url", + lambda x: x.startswith(("workspace://", "data:")), + ), test_mock={ "create_webhook": lambda *args, **kwargs: ( "test_uuid", @@ -456,14 +477,21 @@ class AIAdMakerVideoCreatorBlock(Block): "create_video": lambda *args, **kwargs: {"pid": "test_pid"}, "check_video_status": lambda *args, **kwargs: { "status": "ready", - "videoUrl": "https://example.com/ad.mp4", + "videoUrl": "data:video/mp4;base64,AAAA", }, - "wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4", + "wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA", }, test_credentials=TEST_CREDENTIALS, ) - async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs): + async def run( + self, + input_data: Input, + *, + credentials: APIKeyCredentials, + execution_context: ExecutionContext, + **kwargs, + ): webhook_token, webhook_url = await self.create_webhook() payload = { @@ -531,7 +559,13 @@ class AIAdMakerVideoCreatorBlock(Block): raise RuntimeError("Failed to create video: No project ID returned") video_url = await self.wait_for_video(credentials.api_key, pid) - yield "video_url", video_url + # Store the generated video to the user's workspace for persistence + stored_url = await store_media_file( + file=MediaFileType(video_url), + execution_context=execution_context, + return_format="for_block_output", + ) + yield "video_url", stored_url class AIScreenshotToVideoAdBlock(Block): @@ -626,7 +660,10 @@ class AIScreenshotToVideoAdBlock(Block): "script": "Amazing numbers!", "screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png", }, - test_output=("video_url", "https://example.com/screenshot.mp4"), + test_output=( + "video_url", + lambda x: x.startswith(("workspace://", "data:")), + ), test_mock={ "create_webhook": lambda *args, **kwargs: ( "test_uuid", @@ -635,14 +672,21 @@ class AIScreenshotToVideoAdBlock(Block): "create_video": lambda *args, **kwargs: {"pid": "test_pid"}, "check_video_status": lambda *args, **kwargs: { "status": "ready", - "videoUrl": "https://example.com/screenshot.mp4", + "videoUrl": "data:video/mp4;base64,AAAA", }, - "wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4", + "wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA", }, test_credentials=TEST_CREDENTIALS, ) - async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs): + async def run( + self, + input_data: Input, + *, + credentials: APIKeyCredentials, + execution_context: ExecutionContext, + **kwargs, + ): webhook_token, webhook_url = await self.create_webhook() payload = { @@ -710,4 +754,10 @@ class AIScreenshotToVideoAdBlock(Block): raise RuntimeError("Failed to create video: No project ID returned") video_url = await self.wait_for_video(credentials.api_key, pid) - yield "video_url", video_url + # Store the generated video to the user's workspace for persistence + stored_url = await store_media_file( + file=MediaFileType(video_url), + execution_context=execution_context, + return_format="for_block_output", + ) + yield "video_url", stored_url diff --git a/autogpt_platform/backend/backend/blocks/bannerbear/text_overlay.py b/autogpt_platform/backend/backend/blocks/bannerbear/text_overlay.py index 16d46c0d99..62aaf63d88 100644 --- a/autogpt_platform/backend/backend/blocks/bannerbear/text_overlay.py +++ b/autogpt_platform/backend/backend/blocks/bannerbear/text_overlay.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from pydantic import SecretStr +from backend.data.execution import ExecutionContext from backend.sdk import ( APIKeyCredentials, Block, @@ -17,6 +18,8 @@ from backend.sdk import ( Requests, SchemaField, ) +from backend.util.file import store_media_file +from backend.util.type import MediaFileType from ._config import bannerbear @@ -135,15 +138,17 @@ class BannerbearTextOverlayBlock(Block): }, test_output=[ ("success", True), - ("image_url", "https://cdn.bannerbear.com/test-image.jpg"), + # Output will be a workspace ref or data URI depending on context + ("image_url", lambda x: x.startswith(("workspace://", "data:"))), ("uid", "test-uid-123"), ("status", "completed"), ], test_mock={ + # Use data URI to avoid HTTP requests during tests "_make_api_request": lambda *args, **kwargs: { "uid": "test-uid-123", "status": "completed", - "image_url": "https://cdn.bannerbear.com/test-image.jpg", + "image_url": "", } }, test_credentials=TEST_CREDENTIALS, @@ -177,7 +182,12 @@ class BannerbearTextOverlayBlock(Block): raise Exception(error_msg) async def run( - self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs + self, + input_data: Input, + *, + credentials: APIKeyCredentials, + execution_context: ExecutionContext, + **kwargs, ) -> BlockOutput: # Build the modifications array modifications = [] @@ -234,6 +244,18 @@ class BannerbearTextOverlayBlock(Block): # Synchronous request - image should be ready yield "success", True - yield "image_url", data.get("image_url", "") + + # Store the generated image to workspace for persistence + image_url = data.get("image_url", "") + if image_url: + stored_url = await store_media_file( + file=MediaFileType(image_url), + execution_context=execution_context, + return_format="for_block_output", + ) + yield "image_url", stored_url + else: + yield "image_url", "" + yield "uid", data.get("uid", "") yield "status", data.get("status", "completed") diff --git a/autogpt_platform/backend/backend/blocks/basic.py b/autogpt_platform/backend/backend/blocks/basic.py index a9c77e2b93..95193b3feb 100644 --- a/autogpt_platform/backend/backend/blocks/basic.py +++ b/autogpt_platform/backend/backend/blocks/basic.py @@ -9,6 +9,7 @@ from backend.data.block import ( BlockSchemaOutput, BlockType, ) +from backend.data.execution import ExecutionContext from backend.data.model import SchemaField from backend.util.file import store_media_file from backend.util.type import MediaFileType, convert @@ -17,10 +18,10 @@ from backend.util.type import MediaFileType, convert class FileStoreBlock(Block): class Input(BlockSchemaInput): file_in: MediaFileType = SchemaField( - description="The file to store in the temporary directory, it can be a URL, data URI, or local path." + description="The file to download and store. Can be a URL (https://...), data URI, or local path." ) base_64: bool = SchemaField( - description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).", + description="Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks).", default=False, advanced=True, title="Produce Base64 Output", @@ -28,13 +29,18 @@ class FileStoreBlock(Block): class Output(BlockSchemaOutput): file_out: MediaFileType = SchemaField( - description="The relative path to the stored file in the temporary directory." + description="Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks." ) def __init__(self): super().__init__( id="cbb50872-625b-42f0-8203-a2ae78242d8a", - description="Stores the input file in the temporary directory.", + description=( + "Downloads and stores a file from a URL, data URI, or local path. " + "Use this to fetch images, documents, or other files for processing. " + "In CoPilot: saves to workspace (use list_workspace_files to see it). " + "In graphs: outputs a data URI to pass to other blocks." + ), categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA}, input_schema=FileStoreBlock.Input, output_schema=FileStoreBlock.Output, @@ -45,15 +51,18 @@ class FileStoreBlock(Block): self, input_data: Input, *, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: + # Determine return format based on user preference + # for_external_api: always returns data URI (base64) - honors "Produce Base64 Output" + # for_block_output: smart format - workspace:// in CoPilot, data URI in graphs + return_format = "for_external_api" if input_data.base_64 else "for_block_output" + yield "file_out", await store_media_file( - graph_exec_id=graph_exec_id, file=input_data.file_in, - user_id=user_id, - return_content=input_data.base_64, + execution_context=execution_context, + return_format=return_format, ) diff --git a/autogpt_platform/backend/backend/blocks/discord/bot_blocks.py b/autogpt_platform/backend/backend/blocks/discord/bot_blocks.py index 5ecd730f47..4438af1955 100644 --- a/autogpt_platform/backend/backend/blocks/discord/bot_blocks.py +++ b/autogpt_platform/backend/backend/blocks/discord/bot_blocks.py @@ -15,6 +15,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import APIKeyCredentials, SchemaField from backend.util.file import store_media_file from backend.util.request import Requests @@ -666,8 +667,7 @@ class SendDiscordFileBlock(Block): file: MediaFileType, filename: str, message_content: str, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, ) -> dict: intents = discord.Intents.default() intents.guilds = True @@ -731,10 +731,9 @@ class SendDiscordFileBlock(Block): # Local file path - read from stored media file # This would be a path from a previous block's output stored_file = await store_media_file( - graph_exec_id=graph_exec_id, file=file, - user_id=user_id, - return_content=True, # Get as data URI + execution_context=execution_context, + return_format="for_external_api", # Get content to send to Discord ) # Now process as data URI header, encoded = stored_file.split(",", 1) @@ -781,8 +780,7 @@ class SendDiscordFileBlock(Block): input_data: Input, *, credentials: APIKeyCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: try: @@ -793,8 +791,7 @@ class SendDiscordFileBlock(Block): file=input_data.file, filename=input_data.filename, message_content=input_data.message_content, - graph_exec_id=graph_exec_id, - user_id=user_id, + execution_context=execution_context, ) yield "status", result.get("status", "Unknown error") diff --git a/autogpt_platform/backend/backend/blocks/fal/ai_video_generator.py b/autogpt_platform/backend/backend/blocks/fal/ai_video_generator.py index 2a71548dcc..c2079ef159 100644 --- a/autogpt_platform/backend/backend/blocks/fal/ai_video_generator.py +++ b/autogpt_platform/backend/backend/blocks/fal/ai_video_generator.py @@ -17,8 +17,11 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import SchemaField +from backend.util.file import store_media_file from backend.util.request import ClientResponseError, Requests +from backend.util.type import MediaFileType logger = logging.getLogger(__name__) @@ -64,9 +67,13 @@ class AIVideoGeneratorBlock(Block): "credentials": TEST_CREDENTIALS_INPUT, }, test_credentials=TEST_CREDENTIALS, - test_output=[("video_url", "https://fal.media/files/example/video.mp4")], + test_output=[ + # Output will be a workspace ref or data URI depending on context + ("video_url", lambda x: x.startswith(("workspace://", "data:"))), + ], test_mock={ - "generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4" + # Use data URI to avoid HTTP requests during tests + "generate_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA" }, ) @@ -208,11 +215,22 @@ class AIVideoGeneratorBlock(Block): raise RuntimeError(f"API request failed: {str(e)}") async def run( - self, input_data: Input, *, credentials: FalCredentials, **kwargs + self, + input_data: Input, + *, + credentials: FalCredentials, + execution_context: ExecutionContext, + **kwargs, ) -> BlockOutput: try: video_url = await self.generate_video(input_data, credentials) - yield "video_url", video_url + # Store the generated video to the user's workspace for persistence + stored_url = await store_media_file( + file=MediaFileType(video_url), + execution_context=execution_context, + return_format="for_block_output", + ) + yield "video_url", stored_url except Exception as e: error_message = str(e) yield "error", error_message diff --git a/autogpt_platform/backend/backend/blocks/flux_kontext.py b/autogpt_platform/backend/backend/blocks/flux_kontext.py index dd8375c4ce..d56baa6d92 100644 --- a/autogpt_platform/backend/backend/blocks/flux_kontext.py +++ b/autogpt_platform/backend/backend/blocks/flux_kontext.py @@ -12,6 +12,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import ( APIKeyCredentials, CredentialsField, @@ -121,10 +122,12 @@ class AIImageEditorBlock(Block): "credentials": TEST_CREDENTIALS_INPUT, }, test_output=[ - ("output_image", "https://replicate.com/output/edited-image.png"), + # Output will be a workspace ref or data URI depending on context + ("output_image", lambda x: x.startswith(("workspace://", "data:"))), ], test_mock={ - "run_model": lambda *args, **kwargs: "https://replicate.com/output/edited-image.png", + # Use data URI to avoid HTTP requests during tests + "run_model": lambda *args, **kwargs: "", }, test_credentials=TEST_CREDENTIALS, ) @@ -134,8 +137,7 @@ class AIImageEditorBlock(Block): input_data: Input, *, credentials: APIKeyCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: result = await self.run_model( @@ -144,20 +146,25 @@ class AIImageEditorBlock(Block): prompt=input_data.prompt, input_image_b64=( await store_media_file( - graph_exec_id=graph_exec_id, file=input_data.input_image, - user_id=user_id, - return_content=True, + execution_context=execution_context, + return_format="for_external_api", # Get content for Replicate API ) if input_data.input_image else None ), aspect_ratio=input_data.aspect_ratio.value, seed=input_data.seed, - user_id=user_id, - graph_exec_id=graph_exec_id, + user_id=execution_context.user_id or "", + graph_exec_id=execution_context.graph_exec_id or "", ) - yield "output_image", result + # Store the generated image to the user's workspace for persistence + stored_url = await store_media_file( + file=result, + execution_context=execution_context, + return_format="for_block_output", + ) + yield "output_image", stored_url async def run_model( self, diff --git a/autogpt_platform/backend/backend/blocks/google/gmail.py b/autogpt_platform/backend/backend/blocks/google/gmail.py index d1b3ecd4bf..2040cabe3f 100644 --- a/autogpt_platform/backend/backend/blocks/google/gmail.py +++ b/autogpt_platform/backend/backend/blocks/google/gmail.py @@ -21,6 +21,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import SchemaField from backend.util.file import MediaFileType, get_exec_file_path, store_media_file from backend.util.settings import Settings @@ -95,8 +96,7 @@ def _make_mime_text( async def create_mime_message( input_data, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, ) -> str: """Create a MIME message with attachments and return base64-encoded raw message.""" @@ -117,12 +117,12 @@ async def create_mime_message( if input_data.attachments: for attach in input_data.attachments: local_path = await store_media_file( - user_id=user_id, - graph_exec_id=graph_exec_id, file=attach, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) - abs_path = get_exec_file_path(graph_exec_id, local_path) + assert execution_context.graph_exec_id # Validated by store_media_file + abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path) part = MIMEBase("application", "octet-stream") with open(abs_path, "rb") as f: part.set_payload(f.read()) @@ -582,27 +582,25 @@ class GmailSendBlock(GmailBase): input_data: Input, *, credentials: GoogleCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: service = self._build_service(credentials, **kwargs) result = await self._send_email( service, input_data, - graph_exec_id, - user_id, + execution_context, ) yield "result", result async def _send_email( - self, service, input_data: Input, graph_exec_id: str, user_id: str + self, service, input_data: Input, execution_context: ExecutionContext ) -> dict: if not input_data.to or not input_data.subject or not input_data.body: raise ValueError( "At least one recipient, subject, and body are required for sending an email" ) - raw_message = await create_mime_message(input_data, graph_exec_id, user_id) + raw_message = await create_mime_message(input_data, execution_context) sent_message = await asyncio.to_thread( lambda: service.users() .messages() @@ -692,30 +690,28 @@ class GmailCreateDraftBlock(GmailBase): input_data: Input, *, credentials: GoogleCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: service = self._build_service(credentials, **kwargs) result = await self._create_draft( service, input_data, - graph_exec_id, - user_id, + execution_context, ) yield "result", GmailDraftResult( id=result["id"], message_id=result["message"]["id"], status="draft_created" ) async def _create_draft( - self, service, input_data: Input, graph_exec_id: str, user_id: str + self, service, input_data: Input, execution_context: ExecutionContext ) -> dict: if not input_data.to or not input_data.subject: raise ValueError( "At least one recipient and subject are required for creating a draft" ) - raw_message = await create_mime_message(input_data, graph_exec_id, user_id) + raw_message = await create_mime_message(input_data, execution_context) draft = await asyncio.to_thread( lambda: service.users() .drafts() @@ -1100,7 +1096,7 @@ class GmailGetThreadBlock(GmailBase): async def _build_reply_message( - service, input_data, graph_exec_id: str, user_id: str + service, input_data, execution_context: ExecutionContext ) -> tuple[str, str]: """ Builds a reply MIME message for Gmail threads. @@ -1190,12 +1186,12 @@ async def _build_reply_message( # Handle attachments for attach in input_data.attachments: local_path = await store_media_file( - user_id=user_id, - graph_exec_id=graph_exec_id, file=attach, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) - abs_path = get_exec_file_path(graph_exec_id, local_path) + assert execution_context.graph_exec_id # Validated by store_media_file + abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path) part = MIMEBase("application", "octet-stream") with open(abs_path, "rb") as f: part.set_payload(f.read()) @@ -1311,16 +1307,14 @@ class GmailReplyBlock(GmailBase): input_data: Input, *, credentials: GoogleCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: service = self._build_service(credentials, **kwargs) message = await self._reply( service, input_data, - graph_exec_id, - user_id, + execution_context, ) yield "messageId", message["id"] yield "threadId", message.get("threadId", input_data.threadId) @@ -1343,11 +1337,11 @@ class GmailReplyBlock(GmailBase): yield "email", email async def _reply( - self, service, input_data: Input, graph_exec_id: str, user_id: str + self, service, input_data: Input, execution_context: ExecutionContext ) -> dict: # Build the reply message using the shared helper raw, thread_id = await _build_reply_message( - service, input_data, graph_exec_id, user_id + service, input_data, execution_context ) # Send the message @@ -1441,16 +1435,14 @@ class GmailDraftReplyBlock(GmailBase): input_data: Input, *, credentials: GoogleCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: service = self._build_service(credentials, **kwargs) draft = await self._create_draft_reply( service, input_data, - graph_exec_id, - user_id, + execution_context, ) yield "draftId", draft["id"] yield "messageId", draft["message"]["id"] @@ -1458,11 +1450,11 @@ class GmailDraftReplyBlock(GmailBase): yield "status", "draft_created" async def _create_draft_reply( - self, service, input_data: Input, graph_exec_id: str, user_id: str + self, service, input_data: Input, execution_context: ExecutionContext ) -> dict: # Build the reply message using the shared helper raw, thread_id = await _build_reply_message( - service, input_data, graph_exec_id, user_id + service, input_data, execution_context ) # Create draft with proper thread association @@ -1629,23 +1621,21 @@ class GmailForwardBlock(GmailBase): input_data: Input, *, credentials: GoogleCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: service = self._build_service(credentials, **kwargs) result = await self._forward_message( service, input_data, - graph_exec_id, - user_id, + execution_context, ) yield "messageId", result["id"] yield "threadId", result.get("threadId", "") yield "status", "forwarded" async def _forward_message( - self, service, input_data: Input, graph_exec_id: str, user_id: str + self, service, input_data: Input, execution_context: ExecutionContext ) -> dict: if not input_data.to: raise ValueError("At least one recipient is required for forwarding") @@ -1727,12 +1717,12 @@ To: {original_to} # Add any additional attachments for attach in input_data.additionalAttachments: local_path = await store_media_file( - user_id=user_id, - graph_exec_id=graph_exec_id, file=attach, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) - abs_path = get_exec_file_path(graph_exec_id, local_path) + assert execution_context.graph_exec_id # Validated by store_media_file + abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path) part = MIMEBase("application", "octet-stream") with open(abs_path, "rb") as f: part.set_payload(f.read()) diff --git a/autogpt_platform/backend/backend/blocks/http.py b/autogpt_platform/backend/backend/blocks/http.py index 9b27a3b129..77e7fe243f 100644 --- a/autogpt_platform/backend/backend/blocks/http.py +++ b/autogpt_platform/backend/backend/blocks/http.py @@ -15,6 +15,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import ( CredentialsField, CredentialsMetaInput, @@ -116,10 +117,9 @@ class SendWebRequestBlock(Block): @staticmethod async def _prepare_files( - graph_exec_id: str, + execution_context: ExecutionContext, files_name: str, files: list[MediaFileType], - user_id: str, ) -> list[tuple[str, tuple[str, BytesIO, str]]]: """ Prepare files for the request by storing them and reading their content. @@ -127,11 +127,16 @@ class SendWebRequestBlock(Block): (files_name, (filename, BytesIO, mime_type)) """ files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = [] + graph_exec_id = execution_context.graph_exec_id + if graph_exec_id is None: + raise ValueError("graph_exec_id is required for file operations") for media in files: # Normalise to a list so we can repeat the same key rel_path = await store_media_file( - graph_exec_id, media, user_id, return_content=False + file=media, + execution_context=execution_context, + return_format="for_local_processing", ) abs_path = get_exec_file_path(graph_exec_id, rel_path) async with aiofiles.open(abs_path, "rb") as f: @@ -143,7 +148,7 @@ class SendWebRequestBlock(Block): return files_payload async def run( - self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs + self, input_data: Input, *, execution_context: ExecutionContext, **kwargs ) -> BlockOutput: # ─── Parse/normalise body ──────────────────────────────────── body = input_data.body @@ -174,7 +179,7 @@ class SendWebRequestBlock(Block): files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = [] if use_files: files_payload = await self._prepare_files( - graph_exec_id, input_data.files_name, input_data.files, user_id + execution_context, input_data.files_name, input_data.files ) # Enforce body format rules @@ -238,9 +243,8 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock): self, input_data: Input, *, - graph_exec_id: str, + execution_context: ExecutionContext, credentials: HostScopedCredentials, - user_id: str, **kwargs, ) -> BlockOutput: # Create SendWebRequestBlock.Input from our input (removing credentials field) @@ -271,6 +275,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock): # Use parent class run method async for output_name, output_data in super().run( - base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs + base_input, execution_context=execution_context, **kwargs ): yield output_name, output_data diff --git a/autogpt_platform/backend/backend/blocks/io.py b/autogpt_platform/backend/backend/blocks/io.py index 6f8e62e339..a9c3859490 100644 --- a/autogpt_platform/backend/backend/blocks/io.py +++ b/autogpt_platform/backend/backend/blocks/io.py @@ -12,6 +12,7 @@ from backend.data.block import ( BlockSchemaInput, BlockType, ) +from backend.data.execution import ExecutionContext from backend.data.model import SchemaField from backend.util.file import store_media_file from backend.util.mock import MockObject @@ -462,18 +463,21 @@ class AgentFileInputBlock(AgentInputBlock): self, input_data: Input, *, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: if not input_data.value: return + # Determine return format based on user preference + # for_external_api: always returns data URI (base64) - honors "Produce Base64 Output" + # for_block_output: smart format - workspace:// in CoPilot, data URI in graphs + return_format = "for_external_api" if input_data.base_64 else "for_block_output" + yield "result", await store_media_file( - graph_exec_id=graph_exec_id, file=input_data.value, - user_id=user_id, - return_content=input_data.base_64, + execution_context=execution_context, + return_format=return_format, ) diff --git a/autogpt_platform/backend/backend/blocks/linear/_api.py b/autogpt_platform/backend/backend/blocks/linear/_api.py index 477b8a209c..ea609d515a 100644 --- a/autogpt_platform/backend/backend/blocks/linear/_api.py +++ b/autogpt_platform/backend/backend/blocks/linear/_api.py @@ -162,8 +162,16 @@ class LinearClient: "searchTerm": team_name, } - team_id = await self.query(query, variables) - return team_id["teams"]["nodes"][0]["id"] + result = await self.query(query, variables) + nodes = result["teams"]["nodes"] + + if not nodes: + raise LinearAPIException( + f"Team '{team_name}' not found. Check the team name or key and try again.", + status_code=404, + ) + + return nodes[0]["id"] except LinearAPIException as e: raise e @@ -240,17 +248,44 @@ class LinearClient: except LinearAPIException as e: raise e - async def try_search_issues(self, term: str) -> list[Issue]: + async def try_search_issues( + self, + term: str, + max_results: int = 10, + team_id: str | None = None, + ) -> list[Issue]: try: query = """ - query SearchIssues($term: String!, $includeComments: Boolean!) { - searchIssues(term: $term, includeComments: $includeComments) { + query SearchIssues( + $term: String!, + $first: Int, + $teamId: String + ) { + searchIssues( + term: $term, + first: $first, + teamId: $teamId + ) { nodes { id identifier title description priority + createdAt + state { + id + name + type + } + project { + id + name + } + assignee { + id + name + } } } } @@ -258,7 +293,8 @@ class LinearClient: variables: dict[str, Any] = { "term": term, - "includeComments": True, + "first": max_results, + "teamId": team_id, } issues = await self.query(query, variables) diff --git a/autogpt_platform/backend/backend/blocks/linear/issues.py b/autogpt_platform/backend/backend/blocks/linear/issues.py index baac01214c..165178f8ee 100644 --- a/autogpt_platform/backend/backend/blocks/linear/issues.py +++ b/autogpt_platform/backend/backend/blocks/linear/issues.py @@ -17,7 +17,7 @@ from ._config import ( LinearScope, linear, ) -from .models import CreateIssueResponse, Issue +from .models import CreateIssueResponse, Issue, State class LinearCreateIssueBlock(Block): @@ -135,9 +135,20 @@ class LinearSearchIssuesBlock(Block): description="Linear credentials with read permissions", required_scopes={LinearScope.READ}, ) + max_results: int = SchemaField( + description="Maximum number of results to return", + default=10, + ge=1, + le=100, + ) + team_name: str | None = SchemaField( + description="Optional team name to filter results (e.g., 'Internal', 'Open Source')", + default=None, + ) class Output(BlockSchemaOutput): issues: list[Issue] = SchemaField(description="List of issues") + error: str = SchemaField(description="Error message if the search failed") def __init__(self): super().__init__( @@ -145,8 +156,11 @@ class LinearSearchIssuesBlock(Block): description="Searches for issues on Linear", input_schema=self.Input, output_schema=self.Output, + categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING}, test_input={ "term": "Test issue", + "max_results": 10, + "team_name": None, "credentials": TEST_CREDENTIALS_INPUT_OAUTH, }, test_credentials=TEST_CREDENTIALS_OAUTH, @@ -156,10 +170,14 @@ class LinearSearchIssuesBlock(Block): [ Issue( id="abc123", - identifier="abc123", + identifier="TST-123", title="Test issue", description="Test description", priority=1, + state=State( + id="state1", name="In Progress", type="started" + ), + createdAt="2026-01-15T10:00:00.000Z", ) ], ) @@ -168,10 +186,12 @@ class LinearSearchIssuesBlock(Block): "search_issues": lambda *args, **kwargs: [ Issue( id="abc123", - identifier="abc123", + identifier="TST-123", title="Test issue", description="Test description", priority=1, + state=State(id="state1", name="In Progress", type="started"), + createdAt="2026-01-15T10:00:00.000Z", ) ] }, @@ -181,10 +201,22 @@ class LinearSearchIssuesBlock(Block): async def search_issues( credentials: OAuth2Credentials | APIKeyCredentials, term: str, + max_results: int = 10, + team_name: str | None = None, ) -> list[Issue]: client = LinearClient(credentials=credentials) - response: list[Issue] = await client.try_search_issues(term=term) - return response + + # Resolve team name to ID if provided + # Raises LinearAPIException with descriptive message if team not found + team_id: str | None = None + if team_name: + team_id = await client.try_get_team_by_name(team_name=team_name) + + return await client.try_search_issues( + term=term, + max_results=max_results, + team_id=team_id, + ) async def run( self, @@ -196,7 +228,10 @@ class LinearSearchIssuesBlock(Block): """Execute the issue search""" try: issues = await self.search_issues( - credentials=credentials, term=input_data.term + credentials=credentials, + term=input_data.term, + max_results=input_data.max_results, + team_name=input_data.team_name, ) yield "issues", issues except LinearAPIException as e: diff --git a/autogpt_platform/backend/backend/blocks/linear/models.py b/autogpt_platform/backend/backend/blocks/linear/models.py index bfeaa13656..dd1f603459 100644 --- a/autogpt_platform/backend/backend/blocks/linear/models.py +++ b/autogpt_platform/backend/backend/blocks/linear/models.py @@ -36,12 +36,21 @@ class Project(BaseModel): content: str | None = None +class State(BaseModel): + id: str + name: str + type: str | None = ( + None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled") + ) + + class Issue(BaseModel): id: str identifier: str title: str description: str | None priority: int + state: State | None = None project: Project | None = None createdAt: str | None = None comments: list[Comment] | None = None diff --git a/autogpt_platform/backend/backend/blocks/llm.py b/autogpt_platform/backend/backend/blocks/llm.py index fdcd7f3568..54295da1f1 100644 --- a/autogpt_platform/backend/backend/blocks/llm.py +++ b/autogpt_platform/backend/backend/blocks/llm.py @@ -32,7 +32,7 @@ from backend.data.model import ( from backend.integrations.providers import ProviderName from backend.util import json from backend.util.logging import TruncatedLogger -from backend.util.prompt import compress_prompt, estimate_token_count +from backend.util.prompt import compress_context, estimate_token_count from backend.util.text import TextFormatter logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]") @@ -115,7 +115,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta): CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101" CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001" - CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219" CLAUDE_3_HAIKU = "claude-3-haiku-20240307" # AI/ML API models AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo" @@ -280,9 +279,6 @@ MODEL_METADATA = { LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata( "anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2 ), # claude-haiku-4-5-20251001 - LlmModel.CLAUDE_3_7_SONNET: ModelMetadata( - "anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2 - ), # claude-3-7-sonnet-20250219 LlmModel.CLAUDE_3_HAIKU: ModelMetadata( "anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1 ), # claude-3-haiku-20240307 @@ -638,11 +634,18 @@ async def llm_call( context_window = llm_model.context_window if compress_prompt_to_fit: - prompt = compress_prompt( + result = await compress_context( messages=prompt, target_tokens=llm_model.context_window // 2, - lossy_ok=True, + client=None, # Truncation-only, no LLM summarization + reserve=0, # Caller handles response token budget separately ) + if result.error: + logger.warning( + f"Prompt compression did not meet target: {result.error}. " + f"Proceeding with {result.token_count} tokens." + ) + prompt = result.messages # Calculate available tokens based on context window and input length estimated_input_tokens = estimate_token_count(prompt) diff --git a/autogpt_platform/backend/backend/blocks/media.py b/autogpt_platform/backend/backend/blocks/media.py index c8d4b4768f..a8d145bc64 100644 --- a/autogpt_platform/backend/backend/blocks/media.py +++ b/autogpt_platform/backend/backend/blocks/media.py @@ -1,6 +1,6 @@ import os import tempfile -from typing import Literal, Optional +from typing import Optional from moviepy.audio.io.AudioFileClip import AudioFileClip from moviepy.video.fx.Loop import Loop @@ -13,6 +13,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import SchemaField from backend.util.file import MediaFileType, get_exec_file_path, store_media_file @@ -46,18 +47,19 @@ class MediaDurationBlock(Block): self, input_data: Input, *, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: # 1) Store the input media locally local_media_path = await store_media_file( - graph_exec_id=graph_exec_id, file=input_data.media_in, - user_id=user_id, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", + ) + assert execution_context.graph_exec_id is not None + media_abspath = get_exec_file_path( + execution_context.graph_exec_id, local_media_path ) - media_abspath = get_exec_file_path(graph_exec_id, local_media_path) # 2) Load the clip if input_data.is_video: @@ -88,10 +90,6 @@ class LoopVideoBlock(Block): default=None, ge=1, ) - output_return_type: Literal["file_path", "data_uri"] = SchemaField( - description="How to return the output video. Either a relative path or base64 data URI.", - default="file_path", - ) class Output(BlockSchemaOutput): video_out: str = SchemaField( @@ -111,17 +109,19 @@ class LoopVideoBlock(Block): self, input_data: Input, *, - node_exec_id: str, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: + assert execution_context.graph_exec_id is not None + assert execution_context.node_exec_id is not None + graph_exec_id = execution_context.graph_exec_id + node_exec_id = execution_context.node_exec_id + # 1) Store the input video locally local_video_path = await store_media_file( - graph_exec_id=graph_exec_id, file=input_data.video_in, - user_id=user_id, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) input_abspath = get_exec_file_path(graph_exec_id, local_video_path) @@ -149,12 +149,11 @@ class LoopVideoBlock(Block): looped_clip = looped_clip.with_audio(clip.audio) looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac") - # Return as data URI + # Return output - for_block_output returns workspace:// if available, else data URI video_out = await store_media_file( - graph_exec_id=graph_exec_id, file=output_filename, - user_id=user_id, - return_content=input_data.output_return_type == "data_uri", + execution_context=execution_context, + return_format="for_block_output", ) yield "video_out", video_out @@ -177,10 +176,6 @@ class AddAudioToVideoBlock(Block): description="Volume scale for the newly attached audio track (1.0 = original).", default=1.0, ) - output_return_type: Literal["file_path", "data_uri"] = SchemaField( - description="Return the final output as a relative path or base64 data URI.", - default="file_path", - ) class Output(BlockSchemaOutput): video_out: MediaFileType = SchemaField( @@ -200,23 +195,24 @@ class AddAudioToVideoBlock(Block): self, input_data: Input, *, - node_exec_id: str, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: + assert execution_context.graph_exec_id is not None + assert execution_context.node_exec_id is not None + graph_exec_id = execution_context.graph_exec_id + node_exec_id = execution_context.node_exec_id + # 1) Store the inputs locally local_video_path = await store_media_file( - graph_exec_id=graph_exec_id, file=input_data.video_in, - user_id=user_id, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) local_audio_path = await store_media_file( - graph_exec_id=graph_exec_id, file=input_data.audio_in, - user_id=user_id, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id) @@ -240,12 +236,11 @@ class AddAudioToVideoBlock(Block): output_abspath = os.path.join(abs_temp_dir, output_filename) final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac") - # 5) Return either path or data URI + # 5) Return output - for_block_output returns workspace:// if available, else data URI video_out = await store_media_file( - graph_exec_id=graph_exec_id, file=output_filename, - user_id=user_id, - return_content=input_data.output_return_type == "data_uri", + execution_context=execution_context, + return_format="for_block_output", ) yield "video_out", video_out diff --git a/autogpt_platform/backend/backend/blocks/screenshotone.py b/autogpt_platform/backend/backend/blocks/screenshotone.py index 1f8947376b..ee998f8da2 100644 --- a/autogpt_platform/backend/backend/blocks/screenshotone.py +++ b/autogpt_platform/backend/backend/blocks/screenshotone.py @@ -11,6 +11,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import ( APIKeyCredentials, CredentialsField, @@ -112,8 +113,7 @@ class ScreenshotWebPageBlock(Block): @staticmethod async def take_screenshot( credentials: APIKeyCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, url: str, viewport_width: int, viewport_height: int, @@ -155,12 +155,11 @@ class ScreenshotWebPageBlock(Block): return { "image": await store_media_file( - graph_exec_id=graph_exec_id, file=MediaFileType( f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}" ), - user_id=user_id, - return_content=True, + execution_context=execution_context, + return_format="for_block_output", ) } @@ -169,15 +168,13 @@ class ScreenshotWebPageBlock(Block): input_data: Input, *, credentials: APIKeyCredentials, - graph_exec_id: str, - user_id: str, + execution_context: ExecutionContext, **kwargs, ) -> BlockOutput: try: screenshot_data = await self.take_screenshot( credentials=credentials, - graph_exec_id=graph_exec_id, - user_id=user_id, + execution_context=execution_context, url=input_data.url, viewport_width=input_data.viewport_width, viewport_height=input_data.viewport_height, diff --git a/autogpt_platform/backend/backend/blocks/spreadsheet.py b/autogpt_platform/backend/backend/blocks/spreadsheet.py index 211aac23f4..a13f9e2f6d 100644 --- a/autogpt_platform/backend/backend/blocks/spreadsheet.py +++ b/autogpt_platform/backend/backend/blocks/spreadsheet.py @@ -7,6 +7,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import ContributorDetails, SchemaField from backend.util.file import get_exec_file_path, store_media_file from backend.util.type import MediaFileType @@ -98,7 +99,7 @@ class ReadSpreadsheetBlock(Block): ) async def run( - self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs + self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs ) -> BlockOutput: import csv from io import StringIO @@ -106,14 +107,16 @@ class ReadSpreadsheetBlock(Block): # Determine data source - prefer file_input if provided, otherwise use contents if input_data.file_input: stored_file_path = await store_media_file( - user_id=user_id, - graph_exec_id=graph_exec_id, file=input_data.file_input, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) # Get full file path - file_path = get_exec_file_path(graph_exec_id, stored_file_path) + assert execution_context.graph_exec_id # Validated by store_media_file + file_path = get_exec_file_path( + execution_context.graph_exec_id, stored_file_path + ) if not Path(file_path).exists(): raise ValueError(f"File does not exist: {file_path}") diff --git a/autogpt_platform/backend/backend/blocks/stagehand/blocks.py b/autogpt_platform/backend/backend/blocks/stagehand/blocks.py index be1d736962..91c096ffe4 100644 --- a/autogpt_platform/backend/backend/blocks/stagehand/blocks.py +++ b/autogpt_platform/backend/backend/blocks/stagehand/blocks.py @@ -83,7 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum): GPT41_MINI = "gpt-4.1-mini-2025-04-14" # Anthropic - CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219" + CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" @property def provider_name(self) -> str: @@ -137,7 +137,7 @@ class StagehandObserveBlock(Block): model: StagehandRecommendedLlmModel = SchemaField( title="LLM Model", description="LLM to use for Stagehand (provider is inferred)", - default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET, + default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET, advanced=False, ) model_credentials: AICredentials = AICredentialsField() @@ -182,10 +182,7 @@ class StagehandObserveBlock(Block): **kwargs, ) -> BlockOutput: - logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}") - logger.info( - f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}" - ) + logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}") with disable_signal_handling(): stagehand = Stagehand( @@ -230,7 +227,7 @@ class StagehandActBlock(Block): model: StagehandRecommendedLlmModel = SchemaField( title="LLM Model", description="LLM to use for Stagehand (provider is inferred)", - default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET, + default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET, advanced=False, ) model_credentials: AICredentials = AICredentialsField() @@ -282,10 +279,7 @@ class StagehandActBlock(Block): **kwargs, ) -> BlockOutput: - logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}") - logger.info( - f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}" - ) + logger.debug(f"ACT: Using model provider {model_credentials.provider}") with disable_signal_handling(): stagehand = Stagehand( @@ -330,7 +324,7 @@ class StagehandExtractBlock(Block): model: StagehandRecommendedLlmModel = SchemaField( title="LLM Model", description="LLM to use for Stagehand (provider is inferred)", - default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET, + default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET, advanced=False, ) model_credentials: AICredentials = AICredentialsField() @@ -370,10 +364,7 @@ class StagehandExtractBlock(Block): **kwargs, ) -> BlockOutput: - logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}") - logger.info( - f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}" - ) + logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}") with disable_signal_handling(): stagehand = Stagehand( diff --git a/autogpt_platform/backend/backend/blocks/talking_head.py b/autogpt_platform/backend/backend/blocks/talking_head.py index 7a466bec7e..e01e3d4023 100644 --- a/autogpt_platform/backend/backend/blocks/talking_head.py +++ b/autogpt_platform/backend/backend/blocks/talking_head.py @@ -10,6 +10,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import ( APIKeyCredentials, CredentialsField, @@ -17,7 +18,9 @@ from backend.data.model import ( SchemaField, ) from backend.integrations.providers import ProviderName +from backend.util.file import store_media_file from backend.util.request import Requests +from backend.util.type import MediaFileType TEST_CREDENTIALS = APIKeyCredentials( id="01234567-89ab-cdef-0123-456789abcdef", @@ -102,7 +105,7 @@ class CreateTalkingAvatarVideoBlock(Block): test_output=[ ( "video_url", - "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video", + lambda x: x.startswith(("workspace://", "data:")), ), ], test_mock={ @@ -110,9 +113,10 @@ class CreateTalkingAvatarVideoBlock(Block): "id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx", "status": "created", }, + # Use data URI to avoid HTTP requests during tests "get_clip_status": lambda *args, **kwargs: { "status": "done", - "result_url": "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video", + "result_url": "data:video/mp4;base64,AAAA", }, }, test_credentials=TEST_CREDENTIALS, @@ -138,7 +142,12 @@ class CreateTalkingAvatarVideoBlock(Block): return response.json() async def run( - self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs + self, + input_data: Input, + *, + credentials: APIKeyCredentials, + execution_context: ExecutionContext, + **kwargs, ) -> BlockOutput: # Create the clip payload = { @@ -165,7 +174,14 @@ class CreateTalkingAvatarVideoBlock(Block): for _ in range(input_data.max_polling_attempts): status_response = await self.get_clip_status(credentials.api_key, clip_id) if status_response["status"] == "done": - yield "video_url", status_response["result_url"] + # Store the generated video to the user's workspace for persistence + video_url = status_response["result_url"] + stored_url = await store_media_file( + file=MediaFileType(video_url), + execution_context=execution_context, + return_format="for_block_output", + ) + yield "video_url", stored_url return elif status_response["status"] == "error": raise RuntimeError( diff --git a/autogpt_platform/backend/backend/blocks/test/test_blocks_dos_vulnerability.py b/autogpt_platform/backend/backend/blocks/test/test_blocks_dos_vulnerability.py index 389bb5c636..e2e44b194c 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_blocks_dos_vulnerability.py +++ b/autogpt_platform/backend/backend/blocks/test/test_blocks_dos_vulnerability.py @@ -12,6 +12,7 @@ from backend.blocks.iteration import StepThroughItemsBlock from backend.blocks.llm import AITextSummarizerBlock from backend.blocks.text import ExtractTextInformationBlock from backend.blocks.xml_parser import XMLParserBlock +from backend.data.execution import ExecutionContext from backend.util.file import store_media_file from backend.util.type import MediaFileType @@ -233,9 +234,12 @@ class TestStoreMediaFileSecurity: with pytest.raises(ValueError, match="File too large"): await store_media_file( - graph_exec_id="test", file=MediaFileType(large_data_uri), - user_id="test_user", + execution_context=ExecutionContext( + user_id="test_user", + graph_exec_id="test", + ), + return_format="for_local_processing", ) @patch("backend.util.file.Path") @@ -270,9 +274,12 @@ class TestStoreMediaFileSecurity: # Should raise an error when directory size exceeds limit with pytest.raises(ValueError, match="Disk usage limit exceeded"): await store_media_file( - graph_exec_id="test", file=MediaFileType( "data:text/plain;base64,dGVzdA==" ), # Small test file - user_id="test_user", + execution_context=ExecutionContext( + user_id="test_user", + graph_exec_id="test", + ), + return_format="for_local_processing", ) diff --git a/autogpt_platform/backend/backend/blocks/test/test_http.py b/autogpt_platform/backend/backend/blocks/test/test_http.py index bdc30f3ecf..e01b8e2c5b 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_http.py +++ b/autogpt_platform/backend/backend/blocks/test/test_http.py @@ -11,10 +11,22 @@ from backend.blocks.http import ( HttpMethod, SendAuthenticatedWebRequestBlock, ) +from backend.data.execution import ExecutionContext from backend.data.model import HostScopedCredentials from backend.util.request import Response +def make_test_context( + graph_exec_id: str = "test-exec-id", + user_id: str = "test-user-id", +) -> ExecutionContext: + """Helper to create test ExecutionContext.""" + return ExecutionContext( + user_id=user_id, + graph_exec_id=graph_exec_id, + ) + + class TestHttpBlockWithHostScopedCredentials: """Test suite for HTTP block integration with HostScopedCredentials.""" @@ -105,8 +117,7 @@ class TestHttpBlockWithHostScopedCredentials: async for output_name, output_data in http_block.run( input_data, credentials=exact_match_credentials, - graph_exec_id="test-exec-id", - user_id="test-user-id", + execution_context=make_test_context(), ): result.append((output_name, output_data)) @@ -161,8 +172,7 @@ class TestHttpBlockWithHostScopedCredentials: async for output_name, output_data in http_block.run( input_data, credentials=wildcard_credentials, - graph_exec_id="test-exec-id", - user_id="test-user-id", + execution_context=make_test_context(), ): result.append((output_name, output_data)) @@ -208,8 +218,7 @@ class TestHttpBlockWithHostScopedCredentials: async for output_name, output_data in http_block.run( input_data, credentials=non_matching_credentials, - graph_exec_id="test-exec-id", - user_id="test-user-id", + execution_context=make_test_context(), ): result.append((output_name, output_data)) @@ -258,8 +267,7 @@ class TestHttpBlockWithHostScopedCredentials: async for output_name, output_data in http_block.run( input_data, credentials=exact_match_credentials, - graph_exec_id="test-exec-id", - user_id="test-user-id", + execution_context=make_test_context(), ): result.append((output_name, output_data)) @@ -318,8 +326,7 @@ class TestHttpBlockWithHostScopedCredentials: async for output_name, output_data in http_block.run( input_data, credentials=auto_discovered_creds, # Execution manager found these - graph_exec_id="test-exec-id", - user_id="test-user-id", + execution_context=make_test_context(), ): result.append((output_name, output_data)) @@ -382,8 +389,7 @@ class TestHttpBlockWithHostScopedCredentials: async for output_name, output_data in http_block.run( input_data, credentials=multi_header_creds, - graph_exec_id="test-exec-id", - user_id="test-user-id", + execution_context=make_test_context(), ): result.append((output_name, output_data)) @@ -471,8 +477,7 @@ class TestHttpBlockWithHostScopedCredentials: async for output_name, output_data in http_block.run( input_data, credentials=test_creds, - graph_exec_id="test-exec-id", - user_id="test-user-id", + execution_context=make_test_context(), ): result.append((output_name, output_data)) diff --git a/autogpt_platform/backend/backend/blocks/text.py b/autogpt_platform/backend/backend/blocks/text.py index 5e58e27101..359e22a84f 100644 --- a/autogpt_platform/backend/backend/blocks/text.py +++ b/autogpt_platform/backend/backend/blocks/text.py @@ -11,6 +11,7 @@ from backend.data.block import ( BlockSchemaInput, BlockSchemaOutput, ) +from backend.data.execution import ExecutionContext from backend.data.model import SchemaField from backend.util import json, text from backend.util.file import get_exec_file_path, store_media_file @@ -444,18 +445,21 @@ class FileReadBlock(Block): ) async def run( - self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs + self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs ) -> BlockOutput: # Store the media file properly (handles URLs, data URIs, etc.) stored_file_path = await store_media_file( - user_id=user_id, - graph_exec_id=graph_exec_id, file=input_data.file_input, - return_content=False, + execution_context=execution_context, + return_format="for_local_processing", ) - # Get full file path - file_path = get_exec_file_path(graph_exec_id, stored_file_path) + # Get full file path (graph_exec_id validated by store_media_file above) + if not execution_context.graph_exec_id: + raise ValueError("execution_context.graph_exec_id is required") + file_path = get_exec_file_path( + execution_context.graph_exec_id, stored_file_path + ) if not Path(file_path).exists(): raise ValueError(f"File does not exist: {file_path}") diff --git a/autogpt_platform/backend/backend/data/block.py b/autogpt_platform/backend/backend/data/block.py index 8d9ecfff4c..eb9360b037 100644 --- a/autogpt_platform/backend/backend/data/block.py +++ b/autogpt_platform/backend/backend/data/block.py @@ -873,14 +873,13 @@ def is_block_auth_configured( async def initialize_blocks() -> None: - # First, sync all provider costs to blocks - # Imported here to avoid circular import from backend.sdk.cost_integration import sync_all_provider_costs + from backend.util.retry import func_retry sync_all_provider_costs() - for cls in get_blocks().values(): - block = cls() + @func_retry + async def sync_block_to_db(block: Block) -> None: existing_block = await AgentBlock.prisma().find_first( where={"OR": [{"id": block.id}, {"name": block.name}]} ) @@ -893,7 +892,7 @@ async def initialize_blocks() -> None: outputSchema=json.dumps(block.output_schema.jsonschema()), ) ) - continue + return input_schema = json.dumps(block.input_schema.jsonschema()) output_schema = json.dumps(block.output_schema.jsonschema()) @@ -913,6 +912,25 @@ async def initialize_blocks() -> None: }, ) + failed_blocks: list[str] = [] + for cls in get_blocks().values(): + block = cls() + try: + await sync_block_to_db(block) + except Exception as e: + logger.warning( + f"Failed to sync block {block.name} to database: {e}. " + "Block is still available in memory.", + exc_info=True, + ) + failed_blocks.append(block.name) + + if failed_blocks: + logger.error( + f"Failed to sync {len(failed_blocks)} block(s) to database: " + f"{', '.join(failed_blocks)}. These blocks are still available in memory." + ) + # Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281 def get_block(block_id: str) -> AnyBlockSchema | None: diff --git a/autogpt_platform/backend/backend/data/block_cost_config.py b/autogpt_platform/backend/backend/data/block_cost_config.py index 1b54ae0942..f46cc726f0 100644 --- a/autogpt_platform/backend/backend/data/block_cost_config.py +++ b/autogpt_platform/backend/backend/data/block_cost_config.py @@ -81,7 +81,6 @@ MODEL_COST: dict[LlmModel, int] = { LlmModel.CLAUDE_4_5_HAIKU: 4, LlmModel.CLAUDE_4_5_OPUS: 14, LlmModel.CLAUDE_4_5_SONNET: 9, - LlmModel.CLAUDE_3_7_SONNET: 5, LlmModel.CLAUDE_3_HAIKU: 1, LlmModel.AIML_API_QWEN2_5_72B: 1, LlmModel.AIML_API_LLAMA3_1_70B: 1, diff --git a/autogpt_platform/backend/backend/data/event_bus.py b/autogpt_platform/backend/backend/data/event_bus.py index d8a1c5b729..614fb158b2 100644 --- a/autogpt_platform/backend/backend/data/event_bus.py +++ b/autogpt_platform/backend/backend/data/event_bus.py @@ -133,10 +133,23 @@ class RedisEventBus(BaseRedisEventBus[M], ABC): class AsyncRedisEventBus(BaseRedisEventBus[M], ABC): + def __init__(self): + self._pubsub: AsyncPubSub | None = None + @property async def connection(self) -> redis.AsyncRedis: return await redis.get_redis_async() + async def close(self) -> None: + """Close the PubSub connection if it exists.""" + if self._pubsub is not None: + try: + await self._pubsub.close() + except Exception: + logger.warning("Failed to close PubSub connection", exc_info=True) + finally: + self._pubsub = None + async def publish_event(self, event: M, channel_key: str): """ Publish an event to Redis. Gracefully handles connection failures @@ -157,6 +170,7 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC): await self.connection, channel_key ) assert isinstance(pubsub, AsyncPubSub) + self._pubsub = pubsub if "*" in channel_key: await pubsub.psubscribe(full_channel_name) diff --git a/autogpt_platform/backend/backend/data/execution.py b/autogpt_platform/backend/backend/data/execution.py index 3c1fd25c51..afb8c70538 100644 --- a/autogpt_platform/backend/backend/data/execution.py +++ b/autogpt_platform/backend/backend/data/execution.py @@ -83,12 +83,29 @@ class ExecutionContext(BaseModel): model_config = {"extra": "ignore"} + # Execution identity + user_id: Optional[str] = None + graph_id: Optional[str] = None + graph_exec_id: Optional[str] = None + graph_version: Optional[int] = None + node_id: Optional[str] = None + node_exec_id: Optional[str] = None + + # Safety settings human_in_the_loop_safe_mode: bool = True sensitive_action_safe_mode: bool = False + + # User settings user_timezone: str = "UTC" + + # Execution hierarchy root_execution_id: Optional[str] = None parent_execution_id: Optional[str] = None + # Workspace + workspace_id: Optional[str] = None + session_id: Optional[str] = None + # -------------------------- Models -------------------------- # diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index c1f38f81d5..ee6cd2e4b0 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -1028,6 +1028,39 @@ async def get_graph( return GraphModel.from_db(graph, for_export) +async def get_store_listed_graphs(*graph_ids: str) -> dict[str, GraphModel]: + """Batch-fetch multiple store-listed graphs by their IDs. + + Only returns graphs that have approved store listings (publicly available). + Does not require permission checks since store-listed graphs are public. + + Args: + *graph_ids: Variable number of graph IDs to fetch + + Returns: + Dict mapping graph_id to GraphModel for graphs with approved store listings + """ + if not graph_ids: + return {} + + store_listings = await StoreListingVersion.prisma().find_many( + where={ + "agentGraphId": {"in": list(graph_ids)}, + "submissionStatus": SubmissionStatus.APPROVED, + "isDeleted": False, + }, + include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}}, + distinct=["agentGraphId"], + order={"agentGraphVersion": "desc"}, + ) + + return { + listing.agentGraphId: GraphModel.from_db(listing.AgentGraph) + for listing in store_listings + if listing.AgentGraph + } + + async def get_graph_as_admin( graph_id: str, version: int | None = None, diff --git a/autogpt_platform/backend/backend/data/model.py b/autogpt_platform/backend/backend/data/model.py index 2cc73f6b7b..5a09c591c9 100644 --- a/autogpt_platform/backend/backend/data/model.py +++ b/autogpt_platform/backend/backend/data/model.py @@ -19,7 +19,6 @@ from typing import ( cast, get_args, ) -from urllib.parse import urlparse from uuid import uuid4 from prisma.enums import CreditTransactionType, OnboardingStep @@ -42,6 +41,7 @@ from typing_extensions import TypedDict from backend.integrations.providers import ProviderName from backend.util.json import loads as json_loads +from backend.util.request import parse_url from backend.util.settings import Secrets # Type alias for any provider name (including custom ones) @@ -397,19 +397,25 @@ class HostScopedCredentials(_BaseCredentials): def matches_url(self, url: str) -> bool: """Check if this credential should be applied to the given URL.""" - parsed_url = urlparse(url) - # Extract hostname without port - request_host = parsed_url.hostname + request_host, request_port = _extract_host_from_url(url) + cred_scope_host, cred_scope_port = _extract_host_from_url(self.host) if not request_host: return False - # Simple host matching - exact match or wildcard subdomain match - if self.host == request_host: + # If a port is specified in credential host, the request host port must match + if cred_scope_port is not None and request_port != cred_scope_port: + return False + # Non-standard ports are only allowed if explicitly specified in credential host + elif cred_scope_port is None and request_port not in (80, 443, None): + return False + + # Simple host matching + if cred_scope_host == request_host: return True # Support wildcard matching (e.g., "*.example.com" matches "api.example.com") - if self.host.startswith("*."): - domain = self.host[2:] # Remove "*." + if cred_scope_host.startswith("*."): + domain = cred_scope_host[2:] # Remove "*." return request_host.endswith(f".{domain}") or request_host == domain return False @@ -551,13 +557,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]): ) -def _extract_host_from_url(url: str) -> str: - """Extract host from URL for grouping host-scoped credentials.""" +def _extract_host_from_url(url: str) -> tuple[str, int | None]: + """Extract host and port from URL for grouping host-scoped credentials.""" try: - parsed = urlparse(url) - return parsed.hostname or url + parsed = parse_url(url) + return parsed.hostname or url, parsed.port except Exception: - return "" + return "", None class CredentialsFieldInfo(BaseModel, Generic[CP, CT]): @@ -606,7 +612,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]): providers = frozenset( [cast(CP, "http")] + [ - cast(CP, _extract_host_from_url(str(value))) + cast(CP, parse_url(str(value)).netloc) for value in field.discriminator_values ] ) @@ -666,10 +672,16 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]): if not (self.discriminator and self.discriminator_mapping): return self + try: + provider = self.discriminator_mapping[discriminator_value] + except KeyError: + raise ValueError( + f"Model '{discriminator_value}' is not supported. " + "It may have been deprecated. Please update your agent configuration." + ) + return CredentialsFieldInfo( - credentials_provider=frozenset( - [self.discriminator_mapping[discriminator_value]] - ), + credentials_provider=frozenset([provider]), credentials_types=self.supported_types, credentials_scopes=self.required_scopes, discriminator=self.discriminator, diff --git a/autogpt_platform/backend/backend/data/model_test.py b/autogpt_platform/backend/backend/data/model_test.py index 37ec6be82f..e8e2ddfa35 100644 --- a/autogpt_platform/backend/backend/data/model_test.py +++ b/autogpt_platform/backend/backend/data/model_test.py @@ -79,10 +79,23 @@ class TestHostScopedCredentials: headers={"Authorization": SecretStr("Bearer token")}, ) - assert creds.matches_url("http://localhost:8080/api/v1") + # Non-standard ports require explicit port in credential host + assert not creds.matches_url("http://localhost:8080/api/v1") assert creds.matches_url("https://localhost:443/secure/endpoint") assert creds.matches_url("http://localhost/simple") + def test_matches_url_with_explicit_port(self): + """Test URL matching with explicit port in credential host.""" + creds = HostScopedCredentials( + provider="custom", + host="localhost:8080", + headers={"Authorization": SecretStr("Bearer token")}, + ) + + assert creds.matches_url("http://localhost:8080/api/v1") + assert not creds.matches_url("http://localhost:3000/api/v1") + assert not creds.matches_url("http://localhost/simple") + def test_empty_headers_dict(self): """Test HostScopedCredentials with empty headers.""" creds = HostScopedCredentials( @@ -128,8 +141,20 @@ class TestHostScopedCredentials: ("*.example.com", "https://sub.api.example.com/test", True), ("*.example.com", "https://example.com/test", True), ("*.example.com", "https://example.org/test", False), - ("localhost", "http://localhost:3000/test", True), + # Non-standard ports require explicit port in credential host + ("localhost", "http://localhost:3000/test", False), + ("localhost:3000", "http://localhost:3000/test", True), ("localhost", "http://127.0.0.1:3000/test", False), + # IPv6 addresses (frontend stores with brackets via URL.hostname) + ("[::1]", "http://[::1]/test", True), + ("[::1]", "http://[::1]:80/test", True), + ("[::1]", "https://[::1]:443/test", True), + ("[::1]", "http://[::1]:8080/test", False), # Non-standard port + ("[::1]:8080", "http://[::1]:8080/test", True), + ("[::1]:8080", "http://[::1]:9090/test", False), + ("[2001:db8::1]", "http://[2001:db8::1]/path", True), + ("[2001:db8::1]", "https://[2001:db8::1]:443/path", True), + ("[2001:db8::1]", "http://[2001:db8::ff]/path", False), ], ) def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool): diff --git a/autogpt_platform/backend/backend/data/workspace.py b/autogpt_platform/backend/backend/data/workspace.py new file mode 100644 index 0000000000..f3dba0a294 --- /dev/null +++ b/autogpt_platform/backend/backend/data/workspace.py @@ -0,0 +1,276 @@ +""" +Database CRUD operations for User Workspace. + +This module provides functions for managing user workspaces and workspace files. +""" + +import logging +from datetime import datetime, timezone +from typing import Optional + +from prisma.models import UserWorkspace, UserWorkspaceFile +from prisma.types import UserWorkspaceFileWhereInput + +from backend.util.json import SafeJson + +logger = logging.getLogger(__name__) + + +async def get_or_create_workspace(user_id: str) -> UserWorkspace: + """ + Get user's workspace, creating one if it doesn't exist. + + Uses upsert to handle race conditions when multiple concurrent requests + attempt to create a workspace for the same user. + + Args: + user_id: The user's ID + + Returns: + UserWorkspace instance + """ + workspace = await UserWorkspace.prisma().upsert( + where={"userId": user_id}, + data={ + "create": {"userId": user_id}, + "update": {}, # No updates needed if exists + }, + ) + + return workspace + + +async def get_workspace(user_id: str) -> Optional[UserWorkspace]: + """ + Get user's workspace if it exists. + + Args: + user_id: The user's ID + + Returns: + UserWorkspace instance or None + """ + return await UserWorkspace.prisma().find_unique(where={"userId": user_id}) + + +async def create_workspace_file( + workspace_id: str, + file_id: str, + name: str, + path: str, + storage_path: str, + mime_type: str, + size_bytes: int, + checksum: Optional[str] = None, + metadata: Optional[dict] = None, +) -> UserWorkspaceFile: + """ + Create a new workspace file record. + + Args: + workspace_id: The workspace ID + file_id: The file ID (same as used in storage path for consistency) + name: User-visible filename + path: Virtual path (e.g., "/documents/report.pdf") + storage_path: Actual storage path (GCS or local) + mime_type: MIME type of the file + size_bytes: File size in bytes + checksum: Optional SHA256 checksum + metadata: Optional additional metadata + + Returns: + Created UserWorkspaceFile instance + """ + # Normalize path to start with / + if not path.startswith("/"): + path = f"/{path}" + + file = await UserWorkspaceFile.prisma().create( + data={ + "id": file_id, + "workspaceId": workspace_id, + "name": name, + "path": path, + "storagePath": storage_path, + "mimeType": mime_type, + "sizeBytes": size_bytes, + "checksum": checksum, + "metadata": SafeJson(metadata or {}), + } + ) + + logger.info( + f"Created workspace file {file.id} at path {path} " + f"in workspace {workspace_id}" + ) + return file + + +async def get_workspace_file( + file_id: str, + workspace_id: Optional[str] = None, +) -> Optional[UserWorkspaceFile]: + """ + Get a workspace file by ID. + + Args: + file_id: The file ID + workspace_id: Optional workspace ID for validation + + Returns: + UserWorkspaceFile instance or None + """ + where_clause: dict = {"id": file_id, "isDeleted": False} + if workspace_id: + where_clause["workspaceId"] = workspace_id + + return await UserWorkspaceFile.prisma().find_first(where=where_clause) + + +async def get_workspace_file_by_path( + workspace_id: str, + path: str, +) -> Optional[UserWorkspaceFile]: + """ + Get a workspace file by its virtual path. + + Args: + workspace_id: The workspace ID + path: Virtual path + + Returns: + UserWorkspaceFile instance or None + """ + # Normalize path + if not path.startswith("/"): + path = f"/{path}" + + return await UserWorkspaceFile.prisma().find_first( + where={ + "workspaceId": workspace_id, + "path": path, + "isDeleted": False, + } + ) + + +async def list_workspace_files( + workspace_id: str, + path_prefix: Optional[str] = None, + include_deleted: bool = False, + limit: Optional[int] = None, + offset: int = 0, +) -> list[UserWorkspaceFile]: + """ + List files in a workspace. + + Args: + workspace_id: The workspace ID + path_prefix: Optional path prefix to filter (e.g., "/documents/") + include_deleted: Whether to include soft-deleted files + limit: Maximum number of files to return + offset: Number of files to skip + + Returns: + List of UserWorkspaceFile instances + """ + where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id} + + if not include_deleted: + where_clause["isDeleted"] = False + + if path_prefix: + # Normalize prefix + if not path_prefix.startswith("/"): + path_prefix = f"/{path_prefix}" + where_clause["path"] = {"startswith": path_prefix} + + return await UserWorkspaceFile.prisma().find_many( + where=where_clause, + order={"createdAt": "desc"}, + take=limit, + skip=offset, + ) + + +async def count_workspace_files( + workspace_id: str, + path_prefix: Optional[str] = None, + include_deleted: bool = False, +) -> int: + """ + Count files in a workspace. + + Args: + workspace_id: The workspace ID + path_prefix: Optional path prefix to filter (e.g., "/sessions/abc123/") + include_deleted: Whether to include soft-deleted files + + Returns: + Number of files + """ + where_clause: dict = {"workspaceId": workspace_id} + if not include_deleted: + where_clause["isDeleted"] = False + + if path_prefix: + # Normalize prefix + if not path_prefix.startswith("/"): + path_prefix = f"/{path_prefix}" + where_clause["path"] = {"startswith": path_prefix} + + return await UserWorkspaceFile.prisma().count(where=where_clause) + + +async def soft_delete_workspace_file( + file_id: str, + workspace_id: Optional[str] = None, +) -> Optional[UserWorkspaceFile]: + """ + Soft-delete a workspace file. + + The path is modified to include a deletion timestamp to free up the original + path for new files while preserving the record for potential recovery. + + Args: + file_id: The file ID + workspace_id: Optional workspace ID for validation + + Returns: + Updated UserWorkspaceFile instance or None if not found + """ + # First verify the file exists and belongs to workspace + file = await get_workspace_file(file_id, workspace_id) + if file is None: + return None + + deleted_at = datetime.now(timezone.utc) + # Modify path to free up the unique constraint for new files at original path + # Format: {original_path}__deleted__{timestamp} + deleted_path = f"{file.path}__deleted__{int(deleted_at.timestamp())}" + + updated = await UserWorkspaceFile.prisma().update( + where={"id": file_id}, + data={ + "isDeleted": True, + "deletedAt": deleted_at, + "path": deleted_path, + }, + ) + + logger.info(f"Soft-deleted workspace file {file_id}") + return updated + + +async def get_workspace_total_size(workspace_id: str) -> int: + """ + Get the total size of all files in a workspace. + + Args: + workspace_id: The workspace ID + + Returns: + Total size in bytes + """ + files = await list_workspace_files(workspace_id) + return sum(file.sizeBytes for file in files) diff --git a/autogpt_platform/backend/backend/executor/database.py b/autogpt_platform/backend/backend/executor/database.py index ae7474fc1d..d44439d51c 100644 --- a/autogpt_platform/backend/backend/executor/database.py +++ b/autogpt_platform/backend/backend/executor/database.py @@ -17,6 +17,7 @@ from backend.data.analytics import ( get_accuracy_trends_and_alerts, get_marketplace_graphs_for_monitoring, ) +from backend.data.auth.oauth import cleanup_expired_oauth_tokens from backend.data.credit import UsageTransactionMetadata, get_user_credit_model from backend.data.execution import ( create_graph_execution, @@ -219,6 +220,9 @@ class DatabaseManager(AppService): # Onboarding increment_onboarding_runs = _(increment_onboarding_runs) + # OAuth + cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens) + # Store get_store_agents = _(get_store_agents) get_store_agent_details = _(get_store_agent_details) @@ -349,6 +353,9 @@ class DatabaseManagerAsyncClient(AppServiceClient): # Onboarding increment_onboarding_runs = d.increment_onboarding_runs + # OAuth + cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens + # Store get_store_agents = d.get_store_agents get_store_agent_details = d.get_store_agent_details diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index 39d4f984eb..8362dae828 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -236,7 +236,14 @@ async def execute_node( input_size = len(input_data_str) log_metadata.debug("Executed node with input", input=input_data_str) + # Create node-specific execution context to avoid race conditions + # (multiple nodes can execute concurrently and would otherwise mutate shared state) + execution_context = execution_context.model_copy( + update={"node_id": node_id, "node_exec_id": node_exec_id} + ) + # Inject extra execution arguments for the blocks via kwargs + # Keep individual kwargs for backwards compatibility with existing blocks extra_exec_kwargs: dict = { "graph_id": graph_id, "graph_version": graph_version, diff --git a/autogpt_platform/backend/backend/executor/scheduler.py b/autogpt_platform/backend/backend/executor/scheduler.py index 44b77fc018..cbdc441718 100644 --- a/autogpt_platform/backend/backend/executor/scheduler.py +++ b/autogpt_platform/backend/backend/executor/scheduler.py @@ -24,11 +24,9 @@ from dotenv import load_dotenv from pydantic import BaseModel, Field, ValidationError from sqlalchemy import MetaData, create_engine -from backend.data.auth.oauth import cleanup_expired_oauth_tokens from backend.data.block import BlockInput from backend.data.execution import GraphExecutionWithNodes from backend.data.model import CredentialsMetaInput -from backend.data.onboarding import increment_onboarding_runs from backend.executor import utils as execution_utils from backend.monitoring import ( NotificationJobArgs, @@ -38,7 +36,11 @@ from backend.monitoring import ( report_execution_accuracy_alerts, report_late_executions, ) -from backend.util.clients import get_database_manager_client, get_scheduler_client +from backend.util.clients import ( + get_database_manager_async_client, + get_database_manager_client, + get_scheduler_client, +) from backend.util.cloud_storage import cleanup_expired_files_async from backend.util.exceptions import ( GraphNotFoundError, @@ -148,6 +150,7 @@ def execute_graph(**kwargs): async def _execute_graph(**kwargs): args = GraphExecutionJobArgs(**kwargs) start_time = asyncio.get_event_loop().time() + db = get_database_manager_async_client() try: logger.info(f"Executing recurring job for graph #{args.graph_id}") graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution( @@ -157,7 +160,7 @@ async def _execute_graph(**kwargs): inputs=args.input_data, graph_credentials_inputs=args.input_credentials, ) - await increment_onboarding_runs(args.user_id) + await db.increment_onboarding_runs(args.user_id) elapsed = asyncio.get_event_loop().time() - start_time logger.info( f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} " @@ -246,8 +249,13 @@ def cleanup_expired_files(): def cleanup_oauth_tokens(): """Clean up expired OAuth tokens from the database.""" + # Wait for completion - run_async(cleanup_expired_oauth_tokens()) + async def _cleanup(): + db = get_database_manager_async_client() + return await db.cleanup_expired_oauth_tokens() + + run_async(_cleanup()) def execution_accuracy_alerts(): diff --git a/autogpt_platform/backend/backend/executor/utils.py b/autogpt_platform/backend/backend/executor/utils.py index f35bebb125..fa264c30a7 100644 --- a/autogpt_platform/backend/backend/executor/utils.py +++ b/autogpt_platform/backend/backend/executor/utils.py @@ -892,11 +892,19 @@ async def add_graph_execution( settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id) execution_context = ExecutionContext( + # Execution identity + user_id=user_id, + graph_id=graph_id, + graph_exec_id=graph_exec.id, + graph_version=graph_exec.graph_version, + # Safety settings human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode, sensitive_action_safe_mode=settings.sensitive_action_safe_mode, + # User settings user_timezone=( user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC" ), + # Execution hierarchy root_execution_id=graph_exec.id, ) diff --git a/autogpt_platform/backend/backend/executor/utils_test.py b/autogpt_platform/backend/backend/executor/utils_test.py index 4761a18c63..db33249583 100644 --- a/autogpt_platform/backend/backend/executor/utils_test.py +++ b/autogpt_platform/backend/backend/executor/utils_test.py @@ -348,6 +348,7 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture): mock_graph_exec.id = "execution-id-123" mock_graph_exec.node_executions = [] # Add this to avoid AttributeError mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check + mock_graph_exec.graph_version = graph_version mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock() # Mock the queue and event bus @@ -434,6 +435,9 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture): # Create a second mock execution for the sanity check mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes) mock_graph_exec_2.id = "execution-id-456" + mock_graph_exec_2.node_executions = [] + mock_graph_exec_2.status = ExecutionStatus.QUEUED + mock_graph_exec_2.graph_version = graph_version mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock() # Reset mocks and set up for second call @@ -614,6 +618,7 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture): mock_graph_exec.id = "execution-id-123" mock_graph_exec.node_executions = [] mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check + mock_graph_exec.graph_version = graph_version # Track what's passed to to_graph_execution_entry captured_kwargs = {} diff --git a/autogpt_platform/backend/backend/integrations/webhooks/utils_test.py b/autogpt_platform/backend/backend/integrations/webhooks/utils_test.py new file mode 100644 index 0000000000..bc502a8e44 --- /dev/null +++ b/autogpt_platform/backend/backend/integrations/webhooks/utils_test.py @@ -0,0 +1,39 @@ +from urllib.parse import urlparse + +import fastapi +from fastapi.routing import APIRoute + +from backend.api.features.integrations.router import router as integrations_router +from backend.integrations.providers import ProviderName +from backend.integrations.webhooks import utils as webhooks_utils + + +def test_webhook_ingress_url_matches_route(monkeypatch) -> None: + app = fastapi.FastAPI() + app.include_router(integrations_router, prefix="/api/integrations") + + provider = ProviderName.GITHUB + webhook_id = "webhook_123" + base_url = "https://example.com" + + monkeypatch.setattr(webhooks_utils.app_config, "platform_base_url", base_url) + + route = next( + route + for route in integrations_router.routes + if isinstance(route, APIRoute) + and route.path == "/{provider}/webhooks/{webhook_id}/ingress" + and "POST" in route.methods + ) + expected_path = f"/api/integrations{route.path}".format( + provider=provider.value, + webhook_id=webhook_id, + ) + actual_url = urlparse(webhooks_utils.webhook_ingress_url(provider, webhook_id)) + expected_base = urlparse(base_url) + + assert (actual_url.scheme, actual_url.netloc) == ( + expected_base.scheme, + expected_base.netloc, + ) + assert actual_url.path == expected_path diff --git a/autogpt_platform/backend/backend/util/cloud_storage.py b/autogpt_platform/backend/backend/util/cloud_storage.py index 93fb9039ec..28423d003d 100644 --- a/autogpt_platform/backend/backend/util/cloud_storage.py +++ b/autogpt_platform/backend/backend/util/cloud_storage.py @@ -13,6 +13,7 @@ import aiohttp from gcloud.aio import storage as async_gcs_storage from google.cloud import storage as gcs_storage +from backend.util.gcs_utils import download_with_fresh_session, generate_signed_url from backend.util.settings import Config logger = logging.getLogger(__name__) @@ -251,7 +252,7 @@ class CloudStorageHandler: f"in_task: {current_task is not None}" ) - # Parse bucket and blob name from path + # Parse bucket and blob name from path (path already has gcs:// prefix removed) parts = path.split("/", 1) if len(parts) != 2: raise ValueError(f"Invalid GCS path: {path}") @@ -261,50 +262,19 @@ class CloudStorageHandler: # Authorization check self._validate_file_access(blob_name, user_id, graph_exec_id) - # Use a fresh client for each download to avoid session issues - # This is less efficient but more reliable with the executor's event loop - logger.info("[CloudStorage] Creating fresh GCS client for download") - - # Create a new session specifically for this download - session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(limit=10, force_close=True) + logger.info( + f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}" ) - async_client = None try: - # Create a new GCS client with the fresh session - async_client = async_gcs_storage.Storage(session=session) - - logger.info( - f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}" - ) - - # Download content using the fresh client - content = await async_client.download(bucket_name, blob_name) + content = await download_with_fresh_session(bucket_name, blob_name) logger.info( f"[CloudStorage] GCS download successful - size: {len(content)} bytes" ) - - # Clean up - await async_client.close() - await session.close() - return content - + except FileNotFoundError: + raise except Exception as e: - # Always try to clean up - if async_client is not None: - try: - await async_client.close() - except Exception as cleanup_error: - logger.warning( - f"[CloudStorage] Error closing GCS client: {cleanup_error}" - ) - try: - await session.close() - except Exception as cleanup_error: - logger.warning(f"[CloudStorage] Error closing session: {cleanup_error}") - # Log the specific error for debugging logger.error( f"[CloudStorage] GCS download failed - error: {str(e)}, " @@ -319,10 +289,6 @@ class CloudStorageHandler: f"current_task: {current_task}, " f"bucket: {bucket_name}, blob: redacted for privacy" ) - - # Convert gcloud-aio exceptions to standard ones - if "404" in str(e) or "Not Found" in str(e): - raise FileNotFoundError(f"File not found: gcs://{path}") raise def _validate_file_access( @@ -445,8 +411,7 @@ class CloudStorageHandler: graph_exec_id: str | None = None, ) -> str: """Generate signed URL for GCS with authorization.""" - - # Parse bucket and blob name from path + # Parse bucket and blob name from path (path already has gcs:// prefix removed) parts = path.split("/", 1) if len(parts) != 2: raise ValueError(f"Invalid GCS path: {path}") @@ -456,21 +421,11 @@ class CloudStorageHandler: # Authorization check self._validate_file_access(blob_name, user_id, graph_exec_id) - # Use sync client for signed URLs since gcloud-aio doesn't support them sync_client = self._get_sync_gcs_client() - bucket = sync_client.bucket(bucket_name) - blob = bucket.blob(blob_name) - - # Generate signed URL asynchronously using sync client - url = await asyncio.to_thread( - blob.generate_signed_url, - version="v4", - expiration=datetime.now(timezone.utc) + timedelta(hours=expiration_hours), - method="GET", + return await generate_signed_url( + sync_client, bucket_name, blob_name, expiration_hours * 3600 ) - return url - async def delete_expired_files(self, provider: str = "gcs") -> int: """ Delete files that have passed their expiration time. diff --git a/autogpt_platform/backend/backend/util/file.py b/autogpt_platform/backend/backend/util/file.py index dc8f86ea41..baa9225629 100644 --- a/autogpt_platform/backend/backend/util/file.py +++ b/autogpt_platform/backend/backend/util/file.py @@ -5,13 +5,26 @@ import shutil import tempfile import uuid from pathlib import Path +from typing import TYPE_CHECKING, Literal from urllib.parse import urlparse from backend.util.cloud_storage import get_cloud_storage_handler from backend.util.request import Requests +from backend.util.settings import Config from backend.util.type import MediaFileType from backend.util.virus_scanner import scan_content_safe +if TYPE_CHECKING: + from backend.data.execution import ExecutionContext + +# Return format options for store_media_file +# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc. +# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs +# - "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs +MediaReturnFormat = Literal[ + "for_local_processing", "for_external_api", "for_block_output" +] + TEMP_DIR = Path(tempfile.gettempdir()).resolve() # Maximum filename length (conservative limit for most filesystems) @@ -67,42 +80,56 @@ def clean_exec_files(graph_exec_id: str, file: str = "") -> None: async def store_media_file( - graph_exec_id: str, file: MediaFileType, - user_id: str, - return_content: bool = False, + execution_context: "ExecutionContext", + *, + return_format: MediaReturnFormat, ) -> MediaFileType: """ - Safely handle 'file' (a data URI, a URL, or a local path relative to {temp}/exec_file/{exec_id}), - placing or verifying it under: + Safely handle 'file' (a data URI, a URL, a workspace:// reference, or a local path + relative to {temp}/exec_file/{exec_id}), placing or verifying it under: {tempdir}/exec_file/{exec_id}/... - If 'return_content=True', return a data URI (data:;base64,). - Otherwise, returns the file media path relative to the exec_id folder. + For each MediaFileType input: + - Data URI: decode and store locally + - URL: download and store locally + - workspace:// reference: read from workspace, store locally + - Local path: verify it exists in exec_file directory - For each MediaFileType type: - - Data URI: - -> decode and store in a new random file in that folder - - URL: - -> download and store in that folder - - Local path: - -> interpret as relative to that folder; verify it exists - (no copying, as it's presumably already there). - We realpath-check so no symlink or '..' can escape the folder. + Return format options: + - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc. + - "for_external_api": Returns data URI (base64) - use when sending to external APIs + - "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs - - :param graph_exec_id: The unique ID of the graph execution. - :param file: Data URI, URL, or local (relative) path. - :param return_content: If True, return a data URI of the file content. - If False, return the *relative* path inside the exec_id folder. - :return: The requested result: data URI or relative path of the media. + :param file: Data URI, URL, workspace://, or local (relative) path. + :param execution_context: ExecutionContext with user_id, graph_exec_id, workspace_id. + :param return_format: What to return: "for_local_processing", "for_external_api", or "for_block_output". + :return: The requested result based on return_format. """ + # Extract values from execution_context + graph_exec_id = execution_context.graph_exec_id + user_id = execution_context.user_id + + if not graph_exec_id: + raise ValueError("execution_context.graph_exec_id is required") + if not user_id: + raise ValueError("execution_context.user_id is required") + + # Create workspace_manager if we have workspace_id (with session scoping) + # Import here to avoid circular import (file.py → workspace.py → data → blocks → file.py) + from backend.util.workspace import WorkspaceManager + + workspace_manager: WorkspaceManager | None = None + if execution_context.workspace_id: + workspace_manager = WorkspaceManager( + user_id, execution_context.workspace_id, execution_context.session_id + ) # Build base path base_path = Path(get_exec_file_path(graph_exec_id, "")) base_path.mkdir(parents=True, exist_ok=True) # Security fix: Add disk space limits to prevent DoS - MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB per file + MAX_FILE_SIZE_BYTES = Config().max_file_size_mb * 1024 * 1024 MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory # Check total disk usage in base_path @@ -142,9 +169,57 @@ async def store_media_file( """ return str(absolute_path.relative_to(base)) - # Check if this is a cloud storage path + # Get cloud storage handler for checking cloud paths cloud_storage = await get_cloud_storage_handler() - if cloud_storage.is_cloud_path(file): + + # Track if the input came from workspace (don't re-save it) + is_from_workspace = file.startswith("workspace://") + + # Check if this is a workspace file reference + if is_from_workspace: + if workspace_manager is None: + raise ValueError( + "Workspace file reference requires workspace context. " + "This file type is only available in CoPilot sessions." + ) + + # Parse workspace reference + # workspace://abc123 - by file ID + # workspace:///path/to/file.txt - by virtual path + file_ref = file[12:] # Remove "workspace://" + + if file_ref.startswith("/"): + # Path reference + workspace_content = await workspace_manager.read_file(file_ref) + file_info = await workspace_manager.get_file_info_by_path(file_ref) + filename = sanitize_filename( + file_info.name if file_info else f"{uuid.uuid4()}.bin" + ) + else: + # ID reference + workspace_content = await workspace_manager.read_file_by_id(file_ref) + file_info = await workspace_manager.get_file_info(file_ref) + filename = sanitize_filename( + file_info.name if file_info else f"{uuid.uuid4()}.bin" + ) + + try: + target_path = _ensure_inside_base(base_path / filename, base_path) + except OSError as e: + raise ValueError(f"Invalid file path '{filename}': {e}") from e + + # Check file size limit + if len(workspace_content) > MAX_FILE_SIZE_BYTES: + raise ValueError( + f"File too large: {len(workspace_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes" + ) + + # Virus scan the workspace content before writing locally + await scan_content_safe(workspace_content, filename=filename) + target_path.write_bytes(workspace_content) + + # Check if this is a cloud storage path + elif cloud_storage.is_cloud_path(file): # Download from cloud storage and store locally cloud_content = await cloud_storage.retrieve_file( file, user_id=user_id, graph_exec_id=graph_exec_id @@ -159,9 +234,9 @@ async def store_media_file( raise ValueError(f"Invalid file path '{filename}': {e}") from e # Check file size limit - if len(cloud_content) > MAX_FILE_SIZE: + if len(cloud_content) > MAX_FILE_SIZE_BYTES: raise ValueError( - f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE} bytes" + f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes" ) # Virus scan the cloud content before writing locally @@ -189,9 +264,9 @@ async def store_media_file( content = base64.b64decode(b64_content) # Check file size limit - if len(content) > MAX_FILE_SIZE: + if len(content) > MAX_FILE_SIZE_BYTES: raise ValueError( - f"File too large: {len(content)} bytes > {MAX_FILE_SIZE} bytes" + f"File too large: {len(content)} bytes > {MAX_FILE_SIZE_BYTES} bytes" ) # Virus scan the base64 content before writing @@ -199,23 +274,31 @@ async def store_media_file( target_path.write_bytes(content) elif file.startswith(("http://", "https://")): - # URL + # URL - download first to get Content-Type header + resp = await Requests().get(file) + + # Check file size limit + if len(resp.content) > MAX_FILE_SIZE_BYTES: + raise ValueError( + f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE_BYTES} bytes" + ) + + # Extract filename from URL path parsed_url = urlparse(file) filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}") + + # If filename lacks extension, add one from Content-Type header + if "." not in filename: + content_type = resp.headers.get("Content-Type", "").split(";")[0].strip() + if content_type: + ext = _extension_from_mime(content_type) + filename = f"{filename}{ext}" + try: target_path = _ensure_inside_base(base_path / filename, base_path) except OSError as e: raise ValueError(f"Invalid file path '{filename}': {e}") from e - # Download and save - resp = await Requests().get(file) - - # Check file size limit - if len(resp.content) > MAX_FILE_SIZE: - raise ValueError( - f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE} bytes" - ) - # Virus scan the downloaded content before writing await scan_content_safe(resp.content, filename=filename) target_path.write_bytes(resp.content) @@ -230,12 +313,44 @@ async def store_media_file( if not target_path.is_file(): raise ValueError(f"Local file does not exist: {target_path}") - # Return result - if return_content: - return MediaFileType(_file_to_data_uri(target_path)) - else: + # Return based on requested format + if return_format == "for_local_processing": + # Use when processing files locally with tools like ffmpeg, MoviePy, PIL + # Returns: relative path in exec_file directory (e.g., "image.png") return MediaFileType(_strip_base_prefix(target_path, base_path)) + elif return_format == "for_external_api": + # Use when sending content to external APIs that need base64 + # Returns: data URI (e.g., "...") + return MediaFileType(_file_to_data_uri(target_path)) + + elif return_format == "for_block_output": + # Use when returning output from a block to user/next block + # Returns: workspace:// ref (CoPilot) or data URI (graph execution) + if workspace_manager is None: + # No workspace available (graph execution without CoPilot) + # Fallback to data URI so the content can still be used/displayed + return MediaFileType(_file_to_data_uri(target_path)) + + # Don't re-save if input was already from workspace + if is_from_workspace: + # Return original workspace reference + return MediaFileType(file) + + # Save new content to workspace + content = target_path.read_bytes() + filename = target_path.name + + file_record = await workspace_manager.write_file( + content=content, + filename=filename, + overwrite=True, + ) + return MediaFileType(f"workspace://{file_record.id}") + + else: + raise ValueError(f"Invalid return_format: {return_format}") + def get_dir_size(path: Path) -> int: """Get total size of directory.""" diff --git a/autogpt_platform/backend/backend/util/file_test.py b/autogpt_platform/backend/backend/util/file_test.py index cd4fc69706..9fe672d155 100644 --- a/autogpt_platform/backend/backend/util/file_test.py +++ b/autogpt_platform/backend/backend/util/file_test.py @@ -7,10 +7,22 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from backend.data.execution import ExecutionContext from backend.util.file import store_media_file from backend.util.type import MediaFileType +def make_test_context( + graph_exec_id: str = "test-exec-123", + user_id: str = "test-user-123", +) -> ExecutionContext: + """Helper to create test ExecutionContext.""" + return ExecutionContext( + user_id=user_id, + graph_exec_id=graph_exec_id, + ) + + class TestFileCloudIntegration: """Test cases for cloud storage integration in file utilities.""" @@ -70,10 +82,9 @@ class TestFileCloudIntegration: mock_path_class.side_effect = path_constructor result = await store_media_file( - graph_exec_id, - MediaFileType(cloud_path), - "test-user-123", - return_content=False, + file=MediaFileType(cloud_path), + execution_context=make_test_context(graph_exec_id=graph_exec_id), + return_format="for_local_processing", ) # Verify cloud storage operations @@ -144,10 +155,9 @@ class TestFileCloudIntegration: mock_path_obj.name = "image.png" with patch("backend.util.file.Path", return_value=mock_path_obj): result = await store_media_file( - graph_exec_id, - MediaFileType(cloud_path), - "test-user-123", - return_content=True, + file=MediaFileType(cloud_path), + execution_context=make_test_context(graph_exec_id=graph_exec_id), + return_format="for_external_api", ) # Verify result is a data URI @@ -198,10 +208,9 @@ class TestFileCloudIntegration: mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt") await store_media_file( - graph_exec_id, - MediaFileType(data_uri), - "test-user-123", - return_content=False, + file=MediaFileType(data_uri), + execution_context=make_test_context(graph_exec_id=graph_exec_id), + return_format="for_local_processing", ) # Verify cloud handler was checked but not used for retrieval @@ -234,5 +243,7 @@ class TestFileCloudIntegration: FileNotFoundError, match="File not found in cloud storage" ): await store_media_file( - graph_exec_id, MediaFileType(cloud_path), "test-user-123" + file=MediaFileType(cloud_path), + execution_context=make_test_context(graph_exec_id=graph_exec_id), + return_format="for_local_processing", ) diff --git a/autogpt_platform/backend/backend/util/gcs_utils.py b/autogpt_platform/backend/backend/util/gcs_utils.py new file mode 100644 index 0000000000..3f91f21897 --- /dev/null +++ b/autogpt_platform/backend/backend/util/gcs_utils.py @@ -0,0 +1,108 @@ +""" +Shared GCS utilities for workspace and cloud storage backends. + +This module provides common functionality for working with Google Cloud Storage, +including path parsing, client management, and signed URL generation. +""" + +import asyncio +import logging +from datetime import datetime, timedelta, timezone + +import aiohttp +from gcloud.aio import storage as async_gcs_storage +from google.cloud import storage as gcs_storage + +logger = logging.getLogger(__name__) + + +def parse_gcs_path(path: str) -> tuple[str, str]: + """ + Parse a GCS path in the format 'gcs://bucket/blob' to (bucket, blob). + + Args: + path: GCS path string (e.g., "gcs://my-bucket/path/to/file") + + Returns: + Tuple of (bucket_name, blob_name) + + Raises: + ValueError: If the path format is invalid + """ + if not path.startswith("gcs://"): + raise ValueError(f"Invalid GCS path: {path}") + + path_without_prefix = path[6:] # Remove "gcs://" + parts = path_without_prefix.split("/", 1) + if len(parts) != 2: + raise ValueError(f"Invalid GCS path format: {path}") + + return parts[0], parts[1] + + +async def download_with_fresh_session(bucket: str, blob: str) -> bytes: + """ + Download file content using a fresh session. + + This approach avoids event loop issues that can occur when reusing + sessions across different async contexts (e.g., in executors). + + Args: + bucket: GCS bucket name + blob: Blob path within the bucket + + Returns: + File content as bytes + + Raises: + FileNotFoundError: If the file doesn't exist + """ + session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=10, force_close=True) + ) + client: async_gcs_storage.Storage | None = None + try: + client = async_gcs_storage.Storage(session=session) + content = await client.download(bucket, blob) + return content + except Exception as e: + if "404" in str(e) or "Not Found" in str(e): + raise FileNotFoundError(f"File not found: gcs://{bucket}/{blob}") + raise + finally: + if client: + try: + await client.close() + except Exception: + pass # Best-effort cleanup + await session.close() + + +async def generate_signed_url( + sync_client: gcs_storage.Client, + bucket_name: str, + blob_name: str, + expires_in: int, +) -> str: + """ + Generate a signed URL for temporary access to a GCS file. + + Uses asyncio.to_thread() to run the sync operation without blocking. + + Args: + sync_client: Sync GCS client with service account credentials + bucket_name: GCS bucket name + blob_name: Blob path within the bucket + expires_in: URL expiration time in seconds + + Returns: + Signed URL string + """ + bucket = sync_client.bucket(bucket_name) + blob = bucket.blob(blob_name) + return await asyncio.to_thread( + blob.generate_signed_url, + version="v4", + expiration=datetime.now(timezone.utc) + timedelta(seconds=expires_in), + method="GET", + ) diff --git a/autogpt_platform/backend/backend/util/prompt.py b/autogpt_platform/backend/backend/util/prompt.py index 775d1c932b..5f904bbc8a 100644 --- a/autogpt_platform/backend/backend/util/prompt.py +++ b/autogpt_platform/backend/backend/util/prompt.py @@ -1,10 +1,19 @@ +from __future__ import annotations + +import logging from copy import deepcopy -from typing import Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any from tiktoken import encoding_for_model from backend.util import json +if TYPE_CHECKING: + from openai import AsyncOpenAI + +logger = logging.getLogger(__name__) + # ---------------------------------------------------------------------------# # CONSTANTS # # ---------------------------------------------------------------------------# @@ -100,9 +109,17 @@ def _is_objective_message(msg: dict) -> bool: def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None: """ Carefully truncate tool message content while preserving tool structure. - Only truncates tool_result content, leaves tool_use intact. + Handles both Anthropic-style (list content) and OpenAI-style (string content) tool messages. """ content = msg.get("content") + + # OpenAI-style tool message: role="tool" with string content + if msg.get("role") == "tool" and isinstance(content, str): + if _tok_len(content, enc) > max_tokens: + msg["content"] = _truncate_middle_tokens(content, enc, max_tokens) + return + + # Anthropic-style: list content with tool_result items if not isinstance(content, list): return @@ -140,141 +157,6 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str: # ---------------------------------------------------------------------------# -def compress_prompt( - messages: list[dict], - target_tokens: int, - *, - model: str = "gpt-4o", - reserve: int = 2_048, - start_cap: int = 8_192, - floor_cap: int = 128, - lossy_ok: bool = True, -) -> list[dict]: - """ - Shrink *messages* so that:: - - token_count(prompt) + reserve ≤ target_tokens - - Strategy - -------- - 1. **Token-aware truncation** – progressively halve a per-message cap - (`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the - *content* of every message except the first and last. Tool shells - are included: we keep the envelope but shorten huge payloads. - 2. **Middle-out deletion** – if still over the limit, delete whole - messages working outward from the centre, **skipping** any message - that contains ``tool_calls`` or has ``role == "tool"``. - 3. **Last-chance trim** – if still too big, truncate the *first* and - *last* message bodies down to `floor_cap` tokens. - 4. If the prompt is *still* too large: - • raise ``ValueError`` when ``lossy_ok == False`` (default) - • return the partially-trimmed prompt when ``lossy_ok == True`` - - Parameters - ---------- - messages Complete chat history (will be deep-copied). - model Model name; passed to tiktoken to pick the right - tokenizer (gpt-4o → 'o200k_base', others fallback). - target_tokens Hard ceiling for prompt size **excluding** the model's - forthcoming answer. - reserve How many tokens you want to leave available for that - answer (`max_tokens` in your subsequent completion call). - start_cap Initial per-message truncation ceiling (tokens). - floor_cap Lowest cap we'll accept before moving to deletions. - lossy_ok If *True* return best-effort prompt instead of raising - after all trim passes have been exhausted. - - Returns - ------- - list[dict] – A *new* messages list that abides by the rules above. - """ - enc = encoding_for_model(model) # best-match tokenizer - msgs = deepcopy(messages) # never mutate caller - - def total_tokens() -> int: - """Current size of *msgs* in tokens.""" - return sum(_msg_tokens(m, enc) for m in msgs) - - original_token_count = total_tokens() - - if original_token_count + reserve <= target_tokens: - return msgs - - # ---- STEP 0 : normalise content -------------------------------------- - # Convert non-string payloads to strings so token counting is coherent. - for i, m in enumerate(msgs): - if not isinstance(m.get("content"), str) and m.get("content") is not None: - if _is_tool_message(m): - continue - - # Keep first and last messages intact (unless they're tool messages) - if i == 0 or i == len(msgs) - 1: - continue - - # Reasonable 20k-char ceiling prevents pathological blobs - content_str = json.dumps(m["content"], separators=(",", ":")) - if len(content_str) > 20_000: - content_str = _truncate_middle_tokens(content_str, enc, 20_000) - m["content"] = content_str - - # ---- STEP 1 : token-aware truncation --------------------------------- - cap = start_cap - while total_tokens() + reserve > target_tokens and cap >= floor_cap: - for m in msgs[1:-1]: # keep first & last intact - if _is_tool_message(m): - # For tool messages, only truncate tool result content, preserve structure - _truncate_tool_message_content(m, enc, cap) - continue - - if _is_objective_message(m): - # Never truncate objective messages - they contain the core task - continue - - content = m.get("content") or "" - if _tok_len(content, enc) > cap: - m["content"] = _truncate_middle_tokens(content, enc, cap) - cap //= 2 # tighten the screw - - # ---- STEP 2 : middle-out deletion ----------------------------------- - while total_tokens() + reserve > target_tokens and len(msgs) > 2: - # Identify all deletable messages (not first/last, not tool messages, not objective messages) - deletable_indices = [] - for i in range(1, len(msgs) - 1): # Skip first and last - if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]): - deletable_indices.append(i) - - if not deletable_indices: - break # nothing more we can drop - - # Delete from center outward - find the index closest to center - centre = len(msgs) // 2 - to_delete = min(deletable_indices, key=lambda i: abs(i - centre)) - del msgs[to_delete] - - # ---- STEP 3 : final safety-net trim on first & last ------------------ - cap = start_cap - while total_tokens() + reserve > target_tokens and cap >= floor_cap: - for idx in (0, -1): # first and last - if _is_tool_message(msgs[idx]): - # For tool messages at first/last position, truncate tool result content only - _truncate_tool_message_content(msgs[idx], enc, cap) - continue - - text = msgs[idx].get("content") or "" - if _tok_len(text, enc) > cap: - msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap) - cap //= 2 # tighten the screw - - # ---- STEP 4 : success or fail-gracefully ----------------------------- - if total_tokens() + reserve > target_tokens and not lossy_ok: - raise ValueError( - "compress_prompt: prompt still exceeds budget " - f"({total_tokens() + reserve} > {target_tokens})." - ) - - return msgs - - def estimate_token_count( messages: list[dict], *, @@ -293,7 +175,8 @@ def estimate_token_count( ------- int – Token count. """ - enc = encoding_for_model(model) # best-match tokenizer + token_model = _normalize_model_for_tokenizer(model) + enc = encoding_for_model(token_model) return sum(_msg_tokens(m, enc) for m in messages) @@ -315,6 +198,543 @@ def estimate_token_count_str( ------- int – Token count. """ - enc = encoding_for_model(model) # best-match tokenizer + token_model = _normalize_model_for_tokenizer(model) + enc = encoding_for_model(token_model) text = json.dumps(text) if not isinstance(text, str) else text return _tok_len(text, enc) + + +# ---------------------------------------------------------------------------# +# UNIFIED CONTEXT COMPRESSION # +# ---------------------------------------------------------------------------# + +# Default thresholds +DEFAULT_TOKEN_THRESHOLD = 120_000 +DEFAULT_KEEP_RECENT = 15 + + +@dataclass +class CompressResult: + """Result of context compression.""" + + messages: list[dict] + token_count: int + was_compacted: bool + error: str | None = None + original_token_count: int = 0 + messages_summarized: int = 0 + messages_dropped: int = 0 + + +def _normalize_model_for_tokenizer(model: str) -> str: + """Normalize model name for tiktoken tokenizer selection.""" + if "/" in model: + model = model.split("/")[-1] + if "claude" in model.lower() or not any( + known in model.lower() for known in ["gpt", "o1", "chatgpt", "text-"] + ): + return "gpt-4o" + return model + + +def _extract_tool_call_ids_from_message(msg: dict) -> set[str]: + """ + Extract tool_call IDs from an assistant message. + + Supports both formats: + - OpenAI: {"role": "assistant", "tool_calls": [{"id": "..."}]} + - Anthropic: {"role": "assistant", "content": [{"type": "tool_use", "id": "..."}]} + + Returns: + Set of tool_call IDs found in the message. + """ + ids: set[str] = set() + if msg.get("role") != "assistant": + return ids + + # OpenAI format: tool_calls array + if msg.get("tool_calls"): + for tc in msg["tool_calls"]: + tc_id = tc.get("id") + if tc_id: + ids.add(tc_id) + + # Anthropic format: content list with tool_use blocks + content = msg.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_use": + tc_id = block.get("id") + if tc_id: + ids.add(tc_id) + + return ids + + +def _extract_tool_response_ids_from_message(msg: dict) -> set[str]: + """ + Extract tool_call IDs that this message is responding to. + + Supports both formats: + - OpenAI: {"role": "tool", "tool_call_id": "..."} + - Anthropic: {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "..."}]} + + Returns: + Set of tool_call IDs this message responds to. + """ + ids: set[str] = set() + + # OpenAI format: role=tool with tool_call_id + if msg.get("role") == "tool": + tc_id = msg.get("tool_call_id") + if tc_id: + ids.add(tc_id) + + # Anthropic format: content list with tool_result blocks + content = msg.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_result": + tc_id = block.get("tool_use_id") + if tc_id: + ids.add(tc_id) + + return ids + + +def _is_tool_response_message(msg: dict) -> bool: + """Check if message is a tool response (OpenAI or Anthropic format).""" + # OpenAI format + if msg.get("role") == "tool": + return True + # Anthropic format + content = msg.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_result": + return True + return False + + +def _remove_orphan_tool_responses( + messages: list[dict], orphan_ids: set[str] +) -> list[dict]: + """ + Remove tool response messages/blocks that reference orphan tool_call IDs. + + Supports both OpenAI and Anthropic formats. + For Anthropic messages with mixed valid/orphan tool_result blocks, + filters out only the orphan blocks instead of dropping the entire message. + """ + result = [] + for msg in messages: + # OpenAI format: role=tool - drop entire message if orphan + if msg.get("role") == "tool": + tc_id = msg.get("tool_call_id") + if tc_id and tc_id in orphan_ids: + continue + result.append(msg) + continue + + # Anthropic format: content list may have mixed tool_result blocks + content = msg.get("content") + if isinstance(content, list): + has_tool_results = any( + isinstance(b, dict) and b.get("type") == "tool_result" for b in content + ) + if has_tool_results: + # Filter out orphan tool_result blocks, keep valid ones + filtered_content = [ + block + for block in content + if not ( + isinstance(block, dict) + and block.get("type") == "tool_result" + and block.get("tool_use_id") in orphan_ids + ) + ] + # Only keep message if it has remaining content + if filtered_content: + msg = msg.copy() + msg["content"] = filtered_content + result.append(msg) + continue + + result.append(msg) + return result + + +def _ensure_tool_pairs_intact( + recent_messages: list[dict], + all_messages: list[dict], + start_index: int, +) -> list[dict]: + """ + Ensure tool_call/tool_response pairs stay together after slicing. + + When slicing messages for context compaction, a naive slice can separate + an assistant message containing tool_calls from its corresponding tool + response messages. This causes API validation errors (e.g., Anthropic's + "unexpected tool_use_id found in tool_result blocks"). + + This function checks for orphan tool responses in the slice and extends + backwards to include their corresponding assistant messages. + + Supports both formats: + - OpenAI: tool_calls array + role="tool" responses + - Anthropic: tool_use blocks + tool_result blocks + + Args: + recent_messages: The sliced messages to validate + all_messages: The complete message list (for looking up missing assistants) + start_index: The index in all_messages where recent_messages begins + + Returns: + A potentially extended list of messages with tool pairs intact + """ + if not recent_messages: + return recent_messages + + # Collect all tool_call_ids from assistant messages in the slice + available_tool_call_ids: set[str] = set() + for msg in recent_messages: + available_tool_call_ids |= _extract_tool_call_ids_from_message(msg) + + # Find orphan tool responses (responses whose tool_call_id is missing) + orphan_tool_call_ids: set[str] = set() + for msg in recent_messages: + response_ids = _extract_tool_response_ids_from_message(msg) + for tc_id in response_ids: + if tc_id not in available_tool_call_ids: + orphan_tool_call_ids.add(tc_id) + + if not orphan_tool_call_ids: + # No orphans, slice is valid + return recent_messages + + # Find the assistant messages that contain the orphan tool_call_ids + # Search backwards from start_index in all_messages + messages_to_prepend: list[dict] = [] + for i in range(start_index - 1, -1, -1): + msg = all_messages[i] + msg_tool_ids = _extract_tool_call_ids_from_message(msg) + if msg_tool_ids & orphan_tool_call_ids: + # This assistant message has tool_calls we need + # Also collect its contiguous tool responses that follow it + assistant_and_responses: list[dict] = [msg] + + # Scan forward from this assistant to collect tool responses + for j in range(i + 1, start_index): + following_msg = all_messages[j] + following_response_ids = _extract_tool_response_ids_from_message( + following_msg + ) + if following_response_ids and following_response_ids & msg_tool_ids: + assistant_and_responses.append(following_msg) + elif not _is_tool_response_message(following_msg): + # Stop at first non-tool-response message + break + + # Prepend the assistant and its tool responses (maintain order) + messages_to_prepend = assistant_and_responses + messages_to_prepend + # Mark these as found + orphan_tool_call_ids -= msg_tool_ids + # Also add this assistant's tool_call_ids to available set + available_tool_call_ids |= msg_tool_ids + + if not orphan_tool_call_ids: + # Found all missing assistants + break + + if orphan_tool_call_ids: + # Some tool_call_ids couldn't be resolved - remove those tool responses + # This shouldn't happen in normal operation but handles edge cases + logger.warning( + f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. " + "Removing orphan tool responses." + ) + recent_messages = _remove_orphan_tool_responses( + recent_messages, orphan_tool_call_ids + ) + + if messages_to_prepend: + logger.info( + f"Extended recent messages by {len(messages_to_prepend)} to preserve " + f"tool_call/tool_response pairs" + ) + return messages_to_prepend + recent_messages + + return recent_messages + + +async def _summarize_messages_llm( + messages: list[dict], + client: AsyncOpenAI, + model: str, + timeout: float = 30.0, +) -> str: + """Summarize messages using an LLM.""" + conversation = [] + for msg in messages: + role = msg.get("role", "") + content = msg.get("content", "") + if content and role in ("user", "assistant", "tool"): + conversation.append(f"{role.upper()}: {content}") + + conversation_text = "\n\n".join(conversation) + + if not conversation_text: + return "No conversation history available." + + # Limit to ~100k chars for safety + MAX_CHARS = 100_000 + if len(conversation_text) > MAX_CHARS: + conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]" + + response = await client.with_options(timeout=timeout).chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": ( + "Create a detailed summary of the conversation so far. " + "This summary will be used as context when continuing the conversation.\n\n" + "Before writing the summary, analyze each message chronologically to identify:\n" + "- User requests and their explicit goals\n" + "- Your approach and key decisions made\n" + "- Technical specifics (file names, tool outputs, function signatures)\n" + "- Errors encountered and resolutions applied\n\n" + "You MUST include ALL of the following sections:\n\n" + "## 1. Primary Request and Intent\n" + "The user's explicit goals and what they are trying to accomplish.\n\n" + "## 2. Key Technical Concepts\n" + "Technologies, frameworks, tools, and patterns being used or discussed.\n\n" + "## 3. Files and Resources Involved\n" + "Specific files examined or modified, with relevant snippets and identifiers.\n\n" + "## 4. Errors and Fixes\n" + "Problems encountered, error messages, and their resolutions. " + "Include any user feedback on fixes.\n\n" + "## 5. Problem Solving\n" + "Issues that have been resolved and how they were addressed.\n\n" + "## 6. All User Messages\n" + "A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n" + "## 7. Pending Tasks\n" + "Work items the user explicitly requested that have not yet been completed.\n\n" + "## 8. Current Work\n" + "Precise description of what was being worked on most recently, including relevant context.\n\n" + "## 9. Next Steps\n" + "What should happen next, aligned with the user's most recent requests. " + "Include verbatim quotes of recent instructions if relevant." + ), + }, + {"role": "user", "content": f"Summarize:\n\n{conversation_text}"}, + ], + max_tokens=1500, + temperature=0.3, + ) + + return response.choices[0].message.content or "No summary available." + + +async def compress_context( + messages: list[dict], + target_tokens: int = DEFAULT_TOKEN_THRESHOLD, + *, + model: str = "gpt-4o", + client: AsyncOpenAI | None = None, + keep_recent: int = DEFAULT_KEEP_RECENT, + reserve: int = 2_048, + start_cap: int = 8_192, + floor_cap: int = 128, +) -> CompressResult: + """ + Unified context compression that combines summarization and truncation strategies. + + Strategy (in order): + 1. **LLM summarization** – If client provided, summarize old messages into a + single context message while keeping recent messages intact. This is the + primary strategy for chat service. + 2. **Content truncation** – Progressively halve a per-message cap and truncate + bloated message content (tool outputs, large pastes). Preserves all messages + but shortens their content. Primary strategy when client=None (LLM blocks). + 3. **Middle-out deletion** – Delete whole messages one at a time from the center + outward, skipping tool messages and objective messages. + 4. **First/last trim** – Truncate first and last message content as last resort. + + Parameters + ---------- + messages Complete chat history (will be deep-copied). + target_tokens Hard ceiling for prompt size. + model Model name for tokenization and summarization. + client AsyncOpenAI client. If provided, enables LLM summarization + as the first strategy. If None, skips to truncation strategies. + keep_recent Number of recent messages to preserve during summarization. + reserve Tokens to reserve for model response. + start_cap Initial per-message truncation ceiling (tokens). + floor_cap Lowest cap before moving to deletions. + + Returns + ------- + CompressResult with compressed messages and metadata. + """ + # Guard clause for empty messages + if not messages: + return CompressResult( + messages=[], + token_count=0, + was_compacted=False, + original_token_count=0, + ) + + token_model = _normalize_model_for_tokenizer(model) + enc = encoding_for_model(token_model) + msgs = deepcopy(messages) + + def total_tokens() -> int: + return sum(_msg_tokens(m, enc) for m in msgs) + + original_count = total_tokens() + + # Already under limit + if original_count + reserve <= target_tokens: + return CompressResult( + messages=msgs, + token_count=original_count, + was_compacted=False, + original_token_count=original_count, + ) + + messages_summarized = 0 + messages_dropped = 0 + + # ---- STEP 1: LLM summarization (if client provided) ------------------- + # This is the primary compression strategy for chat service. + # Summarize old messages while keeping recent ones intact. + if client is not None: + has_system = len(msgs) > 0 and msgs[0].get("role") == "system" + system_msg = msgs[0] if has_system else None + + # Calculate old vs recent messages + if has_system: + if len(msgs) > keep_recent + 1: + old_msgs = msgs[1:-keep_recent] + recent_msgs = msgs[-keep_recent:] + else: + old_msgs = [] + recent_msgs = msgs[1:] if len(msgs) > 1 else [] + else: + if len(msgs) > keep_recent: + old_msgs = msgs[:-keep_recent] + recent_msgs = msgs[-keep_recent:] + else: + old_msgs = [] + recent_msgs = msgs + + # Ensure tool pairs stay intact + slice_start = max(0, len(msgs) - keep_recent) + recent_msgs = _ensure_tool_pairs_intact(recent_msgs, msgs, slice_start) + + if old_msgs: + try: + summary_text = await _summarize_messages_llm(old_msgs, client, model) + summary_msg = { + "role": "assistant", + "content": f"[Previous conversation summary — for context only]: {summary_text}", + } + messages_summarized = len(old_msgs) + + if has_system: + msgs = [system_msg, summary_msg] + recent_msgs + else: + msgs = [summary_msg] + recent_msgs + + logger.info( + f"Context summarized: {original_count} -> {total_tokens()} tokens, " + f"summarized {messages_summarized} messages" + ) + except Exception as e: + logger.warning(f"Summarization failed, continuing with truncation: {e}") + # Fall through to content truncation + + # ---- STEP 2: Normalize content ---------------------------------------- + # Convert non-string payloads to strings so token counting is coherent. + # Always run this before truncation to ensure consistent token counting. + for i, m in enumerate(msgs): + if not isinstance(m.get("content"), str) and m.get("content") is not None: + if _is_tool_message(m): + continue + if i == 0 or i == len(msgs) - 1: + continue + content_str = json.dumps(m["content"], separators=(",", ":")) + if len(content_str) > 20_000: + content_str = _truncate_middle_tokens(content_str, enc, 20_000) + m["content"] = content_str + + # ---- STEP 3: Token-aware content truncation --------------------------- + # Progressively halve per-message cap and truncate bloated content. + # This preserves all messages but shortens their content. + cap = start_cap + while total_tokens() + reserve > target_tokens and cap >= floor_cap: + for m in msgs[1:-1]: + if _is_tool_message(m): + _truncate_tool_message_content(m, enc, cap) + continue + if _is_objective_message(m): + continue + content = m.get("content") or "" + if _tok_len(content, enc) > cap: + m["content"] = _truncate_middle_tokens(content, enc, cap) + cap //= 2 + + # ---- STEP 4: Middle-out deletion -------------------------------------- + # Delete messages one at a time from the center outward. + # This is more granular than dropping all old messages at once. + while total_tokens() + reserve > target_tokens and len(msgs) > 2: + deletable: list[int] = [] + for i in range(1, len(msgs) - 1): + msg = msgs[i] + if ( + msg is not None + and not _is_tool_message(msg) + and not _is_objective_message(msg) + ): + deletable.append(i) + if not deletable: + break + centre = len(msgs) // 2 + to_delete = min(deletable, key=lambda i: abs(i - centre)) + del msgs[to_delete] + messages_dropped += 1 + + # ---- STEP 5: Final trim on first/last --------------------------------- + cap = start_cap + while total_tokens() + reserve > target_tokens and cap >= floor_cap: + for idx in (0, -1): + msg = msgs[idx] + if msg is None: + continue + if _is_tool_message(msg): + _truncate_tool_message_content(msg, enc, cap) + continue + text = msg.get("content") or "" + if _tok_len(text, enc) > cap: + msg["content"] = _truncate_middle_tokens(text, enc, cap) + cap //= 2 + + # Filter out any None values that may have been introduced + final_msgs: list[dict] = [m for m in msgs if m is not None] + final_count = sum(_msg_tokens(m, enc) for m in final_msgs) + error = None + if final_count + reserve > target_tokens: + error = f"Could not compress below target ({final_count + reserve} > {target_tokens})" + logger.warning(error) + + return CompressResult( + messages=final_msgs, + token_count=final_count, + was_compacted=True, + error=error, + original_token_count=original_count, + messages_summarized=messages_summarized, + messages_dropped=messages_dropped, + ) diff --git a/autogpt_platform/backend/backend/util/prompt_test.py b/autogpt_platform/backend/backend/util/prompt_test.py index af6b230f8f..2d4bf090b3 100644 --- a/autogpt_platform/backend/backend/util/prompt_test.py +++ b/autogpt_platform/backend/backend/util/prompt_test.py @@ -1,10 +1,21 @@ """Tests for prompt utility functions, especially tool call token counting.""" +from unittest.mock import AsyncMock, MagicMock + import pytest from tiktoken import encoding_for_model from backend.util import json -from backend.util.prompt import _msg_tokens, estimate_token_count +from backend.util.prompt import ( + CompressResult, + _ensure_tool_pairs_intact, + _msg_tokens, + _normalize_model_for_tokenizer, + _truncate_middle_tokens, + _truncate_tool_message_content, + compress_context, + estimate_token_count, +) class TestMsgTokens: @@ -276,3 +287,690 @@ class TestEstimateTokenCount: assert total_tokens == expected_total assert total_tokens > 20 # Should be substantial + + +class TestNormalizeModelForTokenizer: + """Test model name normalization for tiktoken.""" + + def test_openai_models_unchanged(self): + """Test that OpenAI models are returned as-is.""" + assert _normalize_model_for_tokenizer("gpt-4o") == "gpt-4o" + assert _normalize_model_for_tokenizer("gpt-4") == "gpt-4" + assert _normalize_model_for_tokenizer("gpt-3.5-turbo") == "gpt-3.5-turbo" + + def test_claude_models_normalized(self): + """Test that Claude models are normalized to gpt-4o.""" + assert _normalize_model_for_tokenizer("claude-3-opus") == "gpt-4o" + assert _normalize_model_for_tokenizer("claude-3-sonnet") == "gpt-4o" + assert _normalize_model_for_tokenizer("anthropic/claude-3-haiku") == "gpt-4o" + + def test_openrouter_paths_extracted(self): + """Test that OpenRouter model paths are handled.""" + assert _normalize_model_for_tokenizer("openai/gpt-4o") == "gpt-4o" + assert _normalize_model_for_tokenizer("anthropic/claude-3-opus") == "gpt-4o" + + def test_unknown_models_default_to_gpt4o(self): + """Test that unknown models default to gpt-4o.""" + assert _normalize_model_for_tokenizer("some-random-model") == "gpt-4o" + assert _normalize_model_for_tokenizer("llama-3-70b") == "gpt-4o" + + +class TestTruncateToolMessageContent: + """Test tool message content truncation.""" + + @pytest.fixture + def enc(self): + return encoding_for_model("gpt-4o") + + def test_truncate_openai_tool_message(self, enc): + """Test truncation of OpenAI-style tool message with string content.""" + long_content = "x" * 10000 + msg = {"role": "tool", "tool_call_id": "call_123", "content": long_content} + + _truncate_tool_message_content(msg, enc, max_tokens=100) + + # Content should be truncated + assert len(msg["content"]) < len(long_content) + assert "…" in msg["content"] # Has ellipsis marker + + def test_truncate_anthropic_tool_result(self, enc): + """Test truncation of Anthropic-style tool_result.""" + long_content = "y" * 10000 + msg = { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_123", + "content": long_content, + } + ], + } + + _truncate_tool_message_content(msg, enc, max_tokens=100) + + # Content should be truncated + result_content = msg["content"][0]["content"] + assert len(result_content) < len(long_content) + assert "…" in result_content + + def test_preserve_tool_use_blocks(self, enc): + """Test that tool_use blocks are not truncated.""" + msg = { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_123", + "name": "some_function", + "input": {"key": "value" * 1000}, # Large input + } + ], + } + + original = json.dumps(msg["content"][0]["input"]) + _truncate_tool_message_content(msg, enc, max_tokens=10) + + # tool_use should be unchanged + assert json.dumps(msg["content"][0]["input"]) == original + + def test_no_truncation_when_under_limit(self, enc): + """Test that short content is not modified.""" + msg = {"role": "tool", "tool_call_id": "call_123", "content": "Short content"} + + original = msg["content"] + _truncate_tool_message_content(msg, enc, max_tokens=1000) + + assert msg["content"] == original + + +class TestTruncateMiddleTokens: + """Test middle truncation of text.""" + + @pytest.fixture + def enc(self): + return encoding_for_model("gpt-4o") + + def test_truncates_long_text(self, enc): + """Test that long text is truncated with ellipsis in middle.""" + long_text = "word " * 1000 + result = _truncate_middle_tokens(long_text, enc, max_tok=50) + + assert len(enc.encode(result)) <= 52 # Allow some slack for ellipsis + assert "…" in result + assert result.startswith("word") # Head preserved + assert result.endswith("word ") # Tail preserved + + def test_preserves_short_text(self, enc): + """Test that short text is not modified.""" + short_text = "Hello world" + result = _truncate_middle_tokens(short_text, enc, max_tok=100) + + assert result == short_text + + +class TestEnsureToolPairsIntact: + """Test tool call/response pair preservation for both OpenAI and Anthropic formats.""" + + # ---- OpenAI Format Tests ---- + + def test_openai_adds_missing_tool_call(self): + """Test that orphaned OpenAI tool_response gets its tool_call prepended.""" + all_msgs = [ + {"role": "system", "content": "You are helpful."}, + { + "role": "assistant", + "tool_calls": [ + {"id": "call_1", "type": "function", "function": {"name": "f1"}} + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result"}, + {"role": "user", "content": "Thanks!"}, + ] + # Recent messages start at index 2 (the tool response) + recent = [all_msgs[2], all_msgs[3]] + start_index = 2 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the tool_call message + assert len(result) == 3 + assert result[0]["role"] == "assistant" + assert "tool_calls" in result[0] + + def test_openai_keeps_complete_pairs(self): + """Test that complete OpenAI pairs are unchanged.""" + all_msgs = [ + {"role": "system", "content": "System"}, + { + "role": "assistant", + "tool_calls": [ + {"id": "call_1", "type": "function", "function": {"name": "f1"}} + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result"}, + ] + recent = all_msgs[1:] # Include both tool_call and response + start_index = 1 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + assert len(result) == 2 # No messages added + + def test_openai_multiple_tool_calls(self): + """Test multiple OpenAI tool calls in one assistant message.""" + all_msgs = [ + {"role": "system", "content": "System"}, + { + "role": "assistant", + "tool_calls": [ + {"id": "call_1", "type": "function", "function": {"name": "f1"}}, + {"id": "call_2", "type": "function", "function": {"name": "f2"}}, + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result1"}, + {"role": "tool", "tool_call_id": "call_2", "content": "result2"}, + {"role": "user", "content": "Thanks!"}, + ] + # Recent messages start at index 2 (first tool response) + recent = [all_msgs[2], all_msgs[3], all_msgs[4]] + start_index = 2 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the assistant message with both tool_calls + assert len(result) == 4 + assert result[0]["role"] == "assistant" + assert len(result[0]["tool_calls"]) == 2 + + # ---- Anthropic Format Tests ---- + + def test_anthropic_adds_missing_tool_use(self): + """Test that orphaned Anthropic tool_result gets its tool_use prepended.""" + all_msgs = [ + {"role": "system", "content": "You are helpful."}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_123", + "name": "get_weather", + "input": {"location": "SF"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_123", + "content": "22°C and sunny", + } + ], + }, + {"role": "user", "content": "Thanks!"}, + ] + # Recent messages start at index 2 (the tool_result) + recent = [all_msgs[2], all_msgs[3]] + start_index = 2 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the tool_use message + assert len(result) == 3 + assert result[0]["role"] == "assistant" + assert result[0]["content"][0]["type"] == "tool_use" + + def test_anthropic_keeps_complete_pairs(self): + """Test that complete Anthropic pairs are unchanged.""" + all_msgs = [ + {"role": "system", "content": "System"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_456", + "name": "calculator", + "input": {"expr": "2+2"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_456", + "content": "4", + } + ], + }, + ] + recent = all_msgs[1:] # Include both tool_use and result + start_index = 1 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + assert len(result) == 2 # No messages added + + def test_anthropic_multiple_tool_uses(self): + """Test multiple Anthropic tool_use blocks in one message.""" + all_msgs = [ + {"role": "system", "content": "System"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me check both..."}, + { + "type": "tool_use", + "id": "toolu_1", + "name": "get_weather", + "input": {"city": "NYC"}, + }, + { + "type": "tool_use", + "id": "toolu_2", + "name": "get_weather", + "input": {"city": "LA"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_1", + "content": "Cold", + }, + { + "type": "tool_result", + "tool_use_id": "toolu_2", + "content": "Warm", + }, + ], + }, + {"role": "user", "content": "Thanks!"}, + ] + # Recent messages start at index 2 (tool_result) + recent = [all_msgs[2], all_msgs[3]] + start_index = 2 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the assistant message with both tool_uses + assert len(result) == 3 + assert result[0]["role"] == "assistant" + tool_use_count = sum( + 1 for b in result[0]["content"] if b.get("type") == "tool_use" + ) + assert tool_use_count == 2 + + # ---- Mixed/Edge Case Tests ---- + + def test_anthropic_with_type_message_field(self): + """Test Anthropic format with 'type': 'message' field (smart_decision_maker style).""" + all_msgs = [ + {"role": "system", "content": "You are helpful."}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_abc", + "name": "search", + "input": {"q": "test"}, + } + ], + }, + { + "role": "user", + "type": "message", # Extra field from smart_decision_maker + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_abc", + "content": "Found results", + } + ], + }, + {"role": "user", "content": "Thanks!"}, + ] + # Recent messages start at index 2 (the tool_result with 'type': 'message') + recent = [all_msgs[2], all_msgs[3]] + start_index = 2 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the tool_use message + assert len(result) == 3 + assert result[0]["role"] == "assistant" + assert result[0]["content"][0]["type"] == "tool_use" + + def test_handles_no_tool_messages(self): + """Test messages without tool calls.""" + all_msgs = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + recent = all_msgs + start_index = 0 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + assert result == all_msgs + + def test_handles_empty_messages(self): + """Test empty message list.""" + result = _ensure_tool_pairs_intact([], [], 0) + assert result == [] + + def test_mixed_text_and_tool_content(self): + """Test Anthropic message with mixed text and tool_use content.""" + all_msgs = [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I'll help you with that."}, + { + "type": "tool_use", + "id": "toolu_mixed", + "name": "search", + "input": {"q": "test"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_mixed", + "content": "Found results", + } + ], + }, + {"role": "assistant", "content": "Here are the results..."}, + ] + # Start from tool_result + recent = [all_msgs[1], all_msgs[2]] + start_index = 1 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the assistant message with tool_use + assert len(result) == 3 + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][1]["type"] == "tool_use" + + +class TestCompressContext: + """Test the async compress_context function.""" + + @pytest.mark.asyncio + async def test_no_compression_needed(self): + """Test messages under limit return without compression.""" + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello!"}, + ] + + result = await compress_context(messages, target_tokens=100000) + + assert isinstance(result, CompressResult) + assert result.was_compacted is False + assert len(result.messages) == 2 + assert result.error is None + + @pytest.mark.asyncio + async def test_truncation_without_client(self): + """Test that truncation works without LLM client.""" + long_content = "x" * 50000 + messages = [ + {"role": "system", "content": "System"}, + {"role": "user", "content": long_content}, + {"role": "assistant", "content": "Response"}, + ] + + result = await compress_context( + messages, target_tokens=1000, client=None, reserve=100 + ) + + assert result.was_compacted is True + # Should have truncated without summarization + assert result.messages_summarized == 0 + + @pytest.mark.asyncio + async def test_with_mocked_llm_client(self): + """Test summarization with mocked LLM client.""" + # Create many messages to trigger summarization + messages = [{"role": "system", "content": "System prompt"}] + for i in range(30): + messages.append({"role": "user", "content": f"User message {i} " * 100}) + messages.append( + {"role": "assistant", "content": f"Assistant response {i} " * 100} + ) + + # Mock the AsyncOpenAI client + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Summary of conversation" + mock_client.with_options.return_value.chat.completions.create = AsyncMock( + return_value=mock_response + ) + + result = await compress_context( + messages, + target_tokens=5000, + client=mock_client, + keep_recent=5, + reserve=500, + ) + + assert result.was_compacted is True + # Should have attempted summarization + assert mock_client.with_options.called or result.messages_summarized > 0 + + @pytest.mark.asyncio + async def test_preserves_tool_pairs(self): + """Test that tool call/response pairs stay together.""" + messages = [ + {"role": "system", "content": "System"}, + {"role": "user", "content": "Do something"}, + { + "role": "assistant", + "tool_calls": [ + {"id": "call_1", "type": "function", "function": {"name": "func"}} + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "Result " * 1000}, + {"role": "assistant", "content": "Done!"}, + ] + + result = await compress_context( + messages, target_tokens=500, client=None, reserve=50 + ) + + # Check that if tool response exists, its call exists too + tool_call_ids = set() + tool_response_ids = set() + for msg in result.messages: + if "tool_calls" in msg: + for tc in msg["tool_calls"]: + tool_call_ids.add(tc["id"]) + if msg.get("role") == "tool": + tool_response_ids.add(msg.get("tool_call_id")) + + # All tool responses should have their calls + assert tool_response_ids <= tool_call_ids + + @pytest.mark.asyncio + async def test_returns_error_when_cannot_compress(self): + """Test that error is returned when compression fails.""" + # Single huge message that can't be compressed enough + messages = [ + {"role": "user", "content": "x" * 100000}, + ] + + result = await compress_context( + messages, target_tokens=100, client=None, reserve=50 + ) + + # Should have an error since we can't get below 100 tokens + assert result.error is not None + assert result.was_compacted is True + + @pytest.mark.asyncio + async def test_empty_messages(self): + """Test that empty messages list returns early without error.""" + result = await compress_context([], target_tokens=1000) + + assert result.messages == [] + assert result.token_count == 0 + assert result.was_compacted is False + assert result.error is None + + +class TestRemoveOrphanToolResponses: + """Test _remove_orphan_tool_responses helper function.""" + + def test_removes_openai_orphan(self): + """Test removal of orphan OpenAI tool response.""" + from backend.util.prompt import _remove_orphan_tool_responses + + messages = [ + {"role": "tool", "tool_call_id": "call_orphan", "content": "result"}, + {"role": "user", "content": "Hello"}, + ] + orphan_ids = {"call_orphan"} + + result = _remove_orphan_tool_responses(messages, orphan_ids) + + assert len(result) == 1 + assert result[0]["role"] == "user" + + def test_keeps_valid_openai_tool(self): + """Test that valid OpenAI tool responses are kept.""" + from backend.util.prompt import _remove_orphan_tool_responses + + messages = [ + {"role": "tool", "tool_call_id": "call_valid", "content": "result"}, + ] + orphan_ids = {"call_other"} + + result = _remove_orphan_tool_responses(messages, orphan_ids) + + assert len(result) == 1 + assert result[0]["tool_call_id"] == "call_valid" + + def test_filters_anthropic_mixed_blocks(self): + """Test filtering individual orphan blocks from Anthropic message with mixed valid/orphan.""" + from backend.util.prompt import _remove_orphan_tool_responses + + messages = [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_valid", + "content": "valid result", + }, + { + "type": "tool_result", + "tool_use_id": "toolu_orphan", + "content": "orphan result", + }, + ], + }, + ] + orphan_ids = {"toolu_orphan"} + + result = _remove_orphan_tool_responses(messages, orphan_ids) + + assert len(result) == 1 + # Should only have the valid tool_result, orphan filtered out + assert len(result[0]["content"]) == 1 + assert result[0]["content"][0]["tool_use_id"] == "toolu_valid" + + def test_removes_anthropic_all_orphan(self): + """Test removal of Anthropic message when all tool_results are orphans.""" + from backend.util.prompt import _remove_orphan_tool_responses + + messages = [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_orphan1", + "content": "result1", + }, + { + "type": "tool_result", + "tool_use_id": "toolu_orphan2", + "content": "result2", + }, + ], + }, + ] + orphan_ids = {"toolu_orphan1", "toolu_orphan2"} + + result = _remove_orphan_tool_responses(messages, orphan_ids) + + # Message should be completely removed since no content left + assert len(result) == 0 + + def test_preserves_non_tool_messages(self): + """Test that non-tool messages are preserved.""" + from backend.util.prompt import _remove_orphan_tool_responses + + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + orphan_ids = {"some_id"} + + result = _remove_orphan_tool_responses(messages, orphan_ids) + + assert result == messages + + +class TestCompressResultDataclass: + """Test CompressResult dataclass.""" + + def test_default_values(self): + """Test default values are set correctly.""" + result = CompressResult( + messages=[{"role": "user", "content": "test"}], + token_count=10, + was_compacted=False, + ) + + assert result.error is None + assert result.original_token_count == 0 # Defaults to 0, not None + assert result.messages_summarized == 0 + assert result.messages_dropped == 0 + + def test_all_fields(self): + """Test all fields can be set.""" + result = CompressResult( + messages=[{"role": "user", "content": "test"}], + token_count=100, + was_compacted=True, + error="Some error", + original_token_count=500, + messages_summarized=10, + messages_dropped=5, + ) + + assert result.token_count == 100 + assert result.was_compacted is True + assert result.error == "Some error" + assert result.original_token_count == 500 + assert result.messages_summarized == 10 + assert result.messages_dropped == 5 diff --git a/autogpt_platform/backend/backend/util/request.py b/autogpt_platform/backend/backend/util/request.py index 9744372b15..95e5ee32f7 100644 --- a/autogpt_platform/backend/backend/util/request.py +++ b/autogpt_platform/backend/backend/util/request.py @@ -157,12 +157,7 @@ async def validate_url( is_trusted: Boolean indicating if the hostname is in trusted_origins ip_addresses: List of IP addresses for the host; empty if the host is trusted """ - # Canonicalize URL - url = url.strip("/ ").replace("\\", "/") - parsed = urlparse(url) - if not parsed.scheme: - url = f"http://{url}" - parsed = urlparse(url) + parsed = parse_url(url) # Check scheme if parsed.scheme not in ALLOWED_SCHEMES: @@ -220,6 +215,17 @@ async def validate_url( ) +def parse_url(url: str) -> URL: + """Canonicalizes and parses a URL string.""" + url = url.strip("/ ").replace("\\", "/") + + # Ensure scheme is present for proper parsing + if not re.match(r"[a-z0-9+.\-]+://", url): + url = f"http://{url}" + + return urlparse(url) + + def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL: """ Pins a URL to a specific IP address to prevent DNS rebinding attacks. diff --git a/autogpt_platform/backend/backend/util/settings.py b/autogpt_platform/backend/backend/util/settings.py index a42a4d29b4..aa28a4c9ac 100644 --- a/autogpt_platform/backend/backend/util/settings.py +++ b/autogpt_platform/backend/backend/util/settings.py @@ -263,6 +263,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings): description="The name of the Google Cloud Storage bucket for media files", ) + workspace_storage_dir: str = Field( + default="", + description="Local directory for workspace file storage when GCS is not configured. " + "If empty, defaults to {app_data}/workspaces. Used for self-hosted deployments.", + ) + reddit_user_agent: str = Field( default="web:AutoGPT:v0.6.0 (by /u/autogpt)", description="The user agent for the Reddit API", @@ -389,6 +395,13 @@ class Config(UpdateTrackingModel["Config"], BaseSettings): description="Maximum file size in MB for file uploads (1-1024 MB)", ) + max_file_size_mb: int = Field( + default=100, + ge=1, + le=1024, + description="Maximum file size in MB for workspace files (1-1024 MB)", + ) + # AutoMod configuration automod_enabled: bool = Field( default=False, diff --git a/autogpt_platform/backend/backend/util/test.py b/autogpt_platform/backend/backend/util/test.py index 0a539644ee..23d7c24147 100644 --- a/autogpt_platform/backend/backend/util/test.py +++ b/autogpt_platform/backend/backend/util/test.py @@ -140,14 +140,29 @@ async def execute_block_test(block: Block): setattr(block, mock_name, mock_obj) # Populate credentials argument(s) + # Generate IDs for execution context + graph_id = str(uuid.uuid4()) + node_id = str(uuid.uuid4()) + graph_exec_id = str(uuid.uuid4()) + node_exec_id = str(uuid.uuid4()) + user_id = str(uuid.uuid4()) + graph_version = 1 # Default version for tests + extra_exec_kwargs: dict = { - "graph_id": str(uuid.uuid4()), - "node_id": str(uuid.uuid4()), - "graph_exec_id": str(uuid.uuid4()), - "node_exec_id": str(uuid.uuid4()), - "user_id": str(uuid.uuid4()), - "graph_version": 1, # Default version for tests - "execution_context": ExecutionContext(), + "graph_id": graph_id, + "node_id": node_id, + "graph_exec_id": graph_exec_id, + "node_exec_id": node_exec_id, + "user_id": user_id, + "graph_version": graph_version, + "execution_context": ExecutionContext( + user_id=user_id, + graph_id=graph_id, + graph_exec_id=graph_exec_id, + graph_version=graph_version, + node_id=node_id, + node_exec_id=node_exec_id, + ), } input_model = cast(type[BlockSchema], block.input_schema) diff --git a/autogpt_platform/backend/backend/util/workspace.py b/autogpt_platform/backend/backend/util/workspace.py new file mode 100644 index 0000000000..a2f1a61b9e --- /dev/null +++ b/autogpt_platform/backend/backend/util/workspace.py @@ -0,0 +1,419 @@ +""" +WorkspaceManager for managing user workspace file operations. + +This module provides a high-level interface for workspace file operations, +combining the storage backend and database layer. +""" + +import logging +import mimetypes +import uuid +from typing import Optional + +from prisma.errors import UniqueViolationError +from prisma.models import UserWorkspaceFile + +from backend.data.workspace import ( + count_workspace_files, + create_workspace_file, + get_workspace_file, + get_workspace_file_by_path, + list_workspace_files, + soft_delete_workspace_file, +) +from backend.util.settings import Config +from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage + +logger = logging.getLogger(__name__) + + +class WorkspaceManager: + """ + Manages workspace file operations. + + Combines storage backend operations with database record management. + Supports session-scoped file segmentation where files are stored in + session-specific virtual paths: /sessions/{session_id}/{filename} + """ + + def __init__( + self, user_id: str, workspace_id: str, session_id: Optional[str] = None + ): + """ + Initialize WorkspaceManager. + + Args: + user_id: The user's ID + workspace_id: The workspace ID + session_id: Optional session ID for session-scoped file access + """ + self.user_id = user_id + self.workspace_id = workspace_id + self.session_id = session_id + # Session path prefix for file isolation + self.session_path = f"/sessions/{session_id}" if session_id else "" + + def _resolve_path(self, path: str) -> str: + """ + Resolve a path, defaulting to session folder if session_id is set. + + Cross-session access is allowed by explicitly using /sessions/other-session-id/... + + Args: + path: Virtual path (e.g., "/file.txt" or "/sessions/abc123/file.txt") + + Returns: + Resolved path with session prefix if applicable + """ + # If path explicitly references a session folder, use it as-is + if path.startswith("/sessions/"): + return path + + # If we have a session context, prepend session path + if self.session_path: + # Normalize the path + if not path.startswith("/"): + path = f"/{path}" + return f"{self.session_path}{path}" + + # No session context, use path as-is + return path if path.startswith("/") else f"/{path}" + + def _get_effective_path( + self, path: Optional[str], include_all_sessions: bool + ) -> Optional[str]: + """ + Get effective path for list/count operations based on session context. + + Args: + path: Optional path prefix to filter + include_all_sessions: If True, don't apply session scoping + + Returns: + Effective path prefix for database query + """ + if include_all_sessions: + # Normalize path to ensure leading slash (stored paths are normalized) + if path is not None and not path.startswith("/"): + return f"/{path}" + return path + elif path is not None: + # Resolve the provided path with session scoping + return self._resolve_path(path) + elif self.session_path: + # Default to session folder with trailing slash to prevent prefix collisions + # e.g., "/sessions/abc" should not match "/sessions/abc123" + return self.session_path.rstrip("/") + "/" + else: + # No session context, use path as-is + return path + + async def read_file(self, path: str) -> bytes: + """ + Read file from workspace by virtual path. + + When session_id is set, paths are resolved relative to the session folder + unless they explicitly reference /sessions/... + + Args: + path: Virtual path (e.g., "/documents/report.pdf") + + Returns: + File content as bytes + + Raises: + FileNotFoundError: If file doesn't exist + """ + resolved_path = self._resolve_path(path) + file = await get_workspace_file_by_path(self.workspace_id, resolved_path) + if file is None: + raise FileNotFoundError(f"File not found at path: {resolved_path}") + + storage = await get_workspace_storage() + return await storage.retrieve(file.storagePath) + + async def read_file_by_id(self, file_id: str) -> bytes: + """ + Read file from workspace by file ID. + + Args: + file_id: The file's ID + + Returns: + File content as bytes + + Raises: + FileNotFoundError: If file doesn't exist + """ + file = await get_workspace_file(file_id, self.workspace_id) + if file is None: + raise FileNotFoundError(f"File not found: {file_id}") + + storage = await get_workspace_storage() + return await storage.retrieve(file.storagePath) + + async def write_file( + self, + content: bytes, + filename: str, + path: Optional[str] = None, + mime_type: Optional[str] = None, + overwrite: bool = False, + ) -> UserWorkspaceFile: + """ + Write file to workspace. + + When session_id is set, files are written to /sessions/{session_id}/... + by default. Use explicit /sessions/... paths for cross-session access. + + Args: + content: File content as bytes + filename: Filename for the file + path: Virtual path (defaults to "/{filename}", session-scoped if session_id set) + mime_type: MIME type (auto-detected if not provided) + overwrite: Whether to overwrite existing file at path + + Returns: + Created UserWorkspaceFile instance + + Raises: + ValueError: If file exceeds size limit or path already exists + """ + # Enforce file size limit + max_file_size = Config().max_file_size_mb * 1024 * 1024 + if len(content) > max_file_size: + raise ValueError( + f"File too large: {len(content)} bytes exceeds " + f"{Config().max_file_size_mb}MB limit" + ) + + # Determine path with session scoping + if path is None: + path = f"/{filename}" + elif not path.startswith("/"): + path = f"/{path}" + + # Resolve path with session prefix + path = self._resolve_path(path) + + # Check if file exists at path (only error for non-overwrite case) + # For overwrite=True, we let the write proceed and handle via UniqueViolationError + # This ensures the new file is written to storage BEFORE the old one is deleted, + # preventing data loss if the new write fails + if not overwrite: + existing = await get_workspace_file_by_path(self.workspace_id, path) + if existing is not None: + raise ValueError(f"File already exists at path: {path}") + + # Auto-detect MIME type if not provided + if mime_type is None: + mime_type, _ = mimetypes.guess_type(filename) + mime_type = mime_type or "application/octet-stream" + + # Compute checksum + checksum = compute_file_checksum(content) + + # Generate unique file ID for storage + file_id = str(uuid.uuid4()) + + # Store file in storage backend + storage = await get_workspace_storage() + storage_path = await storage.store( + workspace_id=self.workspace_id, + file_id=file_id, + filename=filename, + content=content, + ) + + # Create database record - handle race condition where another request + # created a file at the same path between our check and create + try: + file = await create_workspace_file( + workspace_id=self.workspace_id, + file_id=file_id, + name=filename, + path=path, + storage_path=storage_path, + mime_type=mime_type, + size_bytes=len(content), + checksum=checksum, + ) + except UniqueViolationError: + # Race condition: another request created a file at this path + if overwrite: + # Re-fetch and delete the conflicting file, then retry + existing = await get_workspace_file_by_path(self.workspace_id, path) + if existing: + await self.delete_file(existing.id) + # Retry the create - if this also fails, clean up storage file + try: + file = await create_workspace_file( + workspace_id=self.workspace_id, + file_id=file_id, + name=filename, + path=path, + storage_path=storage_path, + mime_type=mime_type, + size_bytes=len(content), + checksum=checksum, + ) + except Exception: + # Clean up orphaned storage file on retry failure + try: + await storage.delete(storage_path) + except Exception as e: + logger.warning(f"Failed to clean up orphaned storage file: {e}") + raise + else: + # Clean up the orphaned storage file before raising + try: + await storage.delete(storage_path) + except Exception as e: + logger.warning(f"Failed to clean up orphaned storage file: {e}") + raise ValueError(f"File already exists at path: {path}") + except Exception: + # Any other database error (connection, validation, etc.) - clean up storage + try: + await storage.delete(storage_path) + except Exception as e: + logger.warning(f"Failed to clean up orphaned storage file: {e}") + raise + + logger.info( + f"Wrote file {file.id} ({filename}) to workspace {self.workspace_id} " + f"at path {path}, size={len(content)} bytes" + ) + + return file + + async def list_files( + self, + path: Optional[str] = None, + limit: Optional[int] = None, + offset: int = 0, + include_all_sessions: bool = False, + ) -> list[UserWorkspaceFile]: + """ + List files in workspace. + + When session_id is set and include_all_sessions is False (default), + only files in the current session's folder are listed. + + Args: + path: Optional path prefix to filter (e.g., "/documents/") + limit: Maximum number of files to return + offset: Number of files to skip + include_all_sessions: If True, list files from all sessions. + If False (default), only list current session's files. + + Returns: + List of UserWorkspaceFile instances + """ + effective_path = self._get_effective_path(path, include_all_sessions) + + return await list_workspace_files( + workspace_id=self.workspace_id, + path_prefix=effective_path, + limit=limit, + offset=offset, + ) + + async def delete_file(self, file_id: str) -> bool: + """ + Delete a file (soft-delete). + + Args: + file_id: The file's ID + + Returns: + True if deleted, False if not found + """ + file = await get_workspace_file(file_id, self.workspace_id) + if file is None: + return False + + # Delete from storage + storage = await get_workspace_storage() + try: + await storage.delete(file.storagePath) + except Exception as e: + logger.warning(f"Failed to delete file from storage: {e}") + # Continue with database soft-delete even if storage delete fails + + # Soft-delete database record + result = await soft_delete_workspace_file(file_id, self.workspace_id) + return result is not None + + async def get_download_url(self, file_id: str, expires_in: int = 3600) -> str: + """ + Get download URL for a file. + + Args: + file_id: The file's ID + expires_in: URL expiration in seconds (default 1 hour) + + Returns: + Download URL (signed URL for GCS, API endpoint for local) + + Raises: + FileNotFoundError: If file doesn't exist + """ + file = await get_workspace_file(file_id, self.workspace_id) + if file is None: + raise FileNotFoundError(f"File not found: {file_id}") + + storage = await get_workspace_storage() + return await storage.get_download_url(file.storagePath, expires_in) + + async def get_file_info(self, file_id: str) -> Optional[UserWorkspaceFile]: + """ + Get file metadata. + + Args: + file_id: The file's ID + + Returns: + UserWorkspaceFile instance or None + """ + return await get_workspace_file(file_id, self.workspace_id) + + async def get_file_info_by_path(self, path: str) -> Optional[UserWorkspaceFile]: + """ + Get file metadata by path. + + When session_id is set, paths are resolved relative to the session folder + unless they explicitly reference /sessions/... + + Args: + path: Virtual path + + Returns: + UserWorkspaceFile instance or None + """ + resolved_path = self._resolve_path(path) + return await get_workspace_file_by_path(self.workspace_id, resolved_path) + + async def get_file_count( + self, + path: Optional[str] = None, + include_all_sessions: bool = False, + ) -> int: + """ + Get number of files in workspace. + + When session_id is set and include_all_sessions is False (default), + only counts files in the current session's folder. + + Args: + path: Optional path prefix to filter (e.g., "/documents/") + include_all_sessions: If True, count all files in workspace. + If False (default), only count current session's files. + + Returns: + Number of files + """ + effective_path = self._get_effective_path(path, include_all_sessions) + + return await count_workspace_files( + self.workspace_id, path_prefix=effective_path + ) diff --git a/autogpt_platform/backend/backend/util/workspace_storage.py b/autogpt_platform/backend/backend/util/workspace_storage.py new file mode 100644 index 0000000000..2f4c8ae2b5 --- /dev/null +++ b/autogpt_platform/backend/backend/util/workspace_storage.py @@ -0,0 +1,398 @@ +""" +Workspace storage backend abstraction for supporting both cloud and local deployments. + +This module provides a unified interface for storing workspace files, with implementations +for Google Cloud Storage (cloud deployments) and local filesystem (self-hosted deployments). +""" + +import asyncio +import hashlib +import logging +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional + +import aiofiles +import aiohttp +from gcloud.aio import storage as async_gcs_storage +from google.cloud import storage as gcs_storage + +from backend.util.data import get_data_path +from backend.util.gcs_utils import ( + download_with_fresh_session, + generate_signed_url, + parse_gcs_path, +) +from backend.util.settings import Config + +logger = logging.getLogger(__name__) + + +class WorkspaceStorageBackend(ABC): + """Abstract interface for workspace file storage.""" + + @abstractmethod + async def store( + self, + workspace_id: str, + file_id: str, + filename: str, + content: bytes, + ) -> str: + """ + Store file content, return storage path. + + Args: + workspace_id: The workspace ID + file_id: Unique file ID for storage + filename: Original filename + content: File content as bytes + + Returns: + Storage path string (cloud path or local path) + """ + pass + + @abstractmethod + async def retrieve(self, storage_path: str) -> bytes: + """ + Retrieve file content from storage. + + Args: + storage_path: The storage path returned from store() + + Returns: + File content as bytes + """ + pass + + @abstractmethod + async def delete(self, storage_path: str) -> None: + """ + Delete file from storage. + + Args: + storage_path: The storage path to delete + """ + pass + + @abstractmethod + async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str: + """ + Get URL for downloading the file. + + Args: + storage_path: The storage path + expires_in: URL expiration time in seconds (default 1 hour) + + Returns: + Download URL (signed URL for GCS, direct API path for local) + """ + pass + + +class GCSWorkspaceStorage(WorkspaceStorageBackend): + """Google Cloud Storage implementation for workspace storage.""" + + def __init__(self, bucket_name: str): + self.bucket_name = bucket_name + self._async_client: Optional[async_gcs_storage.Storage] = None + self._sync_client: Optional[gcs_storage.Client] = None + self._session: Optional[aiohttp.ClientSession] = None + + async def _get_async_client(self) -> async_gcs_storage.Storage: + """Get or create async GCS client.""" + if self._async_client is None: + self._session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=100, force_close=False) + ) + self._async_client = async_gcs_storage.Storage(session=self._session) + return self._async_client + + def _get_sync_client(self) -> gcs_storage.Client: + """Get or create sync GCS client (for signed URLs).""" + if self._sync_client is None: + self._sync_client = gcs_storage.Client() + return self._sync_client + + async def close(self) -> None: + """Close all client connections.""" + if self._async_client is not None: + try: + await self._async_client.close() + except Exception as e: + logger.warning(f"Error closing GCS client: {e}") + self._async_client = None + + if self._session is not None: + try: + await self._session.close() + except Exception as e: + logger.warning(f"Error closing session: {e}") + self._session = None + + def _build_blob_name(self, workspace_id: str, file_id: str, filename: str) -> str: + """Build the blob path for workspace files.""" + return f"workspaces/{workspace_id}/{file_id}/{filename}" + + async def store( + self, + workspace_id: str, + file_id: str, + filename: str, + content: bytes, + ) -> str: + """Store file in GCS.""" + client = await self._get_async_client() + blob_name = self._build_blob_name(workspace_id, file_id, filename) + + # Upload with metadata + upload_time = datetime.now(timezone.utc) + await client.upload( + self.bucket_name, + blob_name, + content, + metadata={ + "uploaded_at": upload_time.isoformat(), + "workspace_id": workspace_id, + "file_id": file_id, + }, + ) + + return f"gcs://{self.bucket_name}/{blob_name}" + + async def retrieve(self, storage_path: str) -> bytes: + """Retrieve file from GCS.""" + bucket_name, blob_name = parse_gcs_path(storage_path) + return await download_with_fresh_session(bucket_name, blob_name) + + async def delete(self, storage_path: str) -> None: + """Delete file from GCS.""" + bucket_name, blob_name = parse_gcs_path(storage_path) + client = await self._get_async_client() + + try: + await client.delete(bucket_name, blob_name) + except Exception as e: + if "404" not in str(e) and "Not Found" not in str(e): + raise + # File already deleted, that's fine + + async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str: + """ + Generate download URL for GCS file. + + Attempts to generate a signed URL if running with service account credentials. + Falls back to an API proxy endpoint if signed URL generation fails + (e.g., when running locally with user OAuth credentials). + """ + bucket_name, blob_name = parse_gcs_path(storage_path) + + # Extract file_id from blob_name for fallback: workspaces/{workspace_id}/{file_id}/{filename} + blob_parts = blob_name.split("/") + file_id = blob_parts[2] if len(blob_parts) >= 3 else None + + # Try to generate signed URL (requires service account credentials) + try: + sync_client = self._get_sync_client() + return await generate_signed_url( + sync_client, bucket_name, blob_name, expires_in + ) + except AttributeError as e: + # Signed URL generation requires service account with private key. + # When running with user OAuth credentials, fall back to API proxy. + if "private key" in str(e) and file_id: + logger.debug( + "Cannot generate signed URL (no service account credentials), " + "falling back to API proxy endpoint" + ) + return f"/api/workspace/files/{file_id}/download" + raise + + +class LocalWorkspaceStorage(WorkspaceStorageBackend): + """Local filesystem implementation for workspace storage (self-hosted deployments).""" + + def __init__(self, base_dir: Optional[str] = None): + """ + Initialize local storage backend. + + Args: + base_dir: Base directory for workspace storage. + If None, defaults to {app_data}/workspaces + """ + if base_dir: + self.base_dir = Path(base_dir) + else: + self.base_dir = Path(get_data_path()) / "workspaces" + + # Ensure base directory exists + self.base_dir.mkdir(parents=True, exist_ok=True) + + def _build_file_path(self, workspace_id: str, file_id: str, filename: str) -> Path: + """Build the local file path with path traversal protection.""" + # Import here to avoid circular import + # (file.py imports workspace.py which imports workspace_storage.py) + from backend.util.file import sanitize_filename + + # Sanitize filename to prevent path traversal (removes / and \ among others) + safe_filename = sanitize_filename(filename) + file_path = (self.base_dir / workspace_id / file_id / safe_filename).resolve() + + # Verify the resolved path is still under base_dir + if not file_path.is_relative_to(self.base_dir.resolve()): + raise ValueError("Invalid filename: path traversal detected") + + return file_path + + def _parse_storage_path(self, storage_path: str) -> Path: + """Parse local storage path to filesystem path.""" + if storage_path.startswith("local://"): + relative_path = storage_path[8:] # Remove "local://" + else: + relative_path = storage_path + + full_path = (self.base_dir / relative_path).resolve() + + # Security check: ensure path is under base_dir + # Use is_relative_to() for robust path containment check + # (handles case-insensitive filesystems and edge cases) + if not full_path.is_relative_to(self.base_dir.resolve()): + raise ValueError("Invalid storage path: path traversal detected") + + return full_path + + async def store( + self, + workspace_id: str, + file_id: str, + filename: str, + content: bytes, + ) -> str: + """Store file locally.""" + file_path = self._build_file_path(workspace_id, file_id, filename) + + # Create parent directories + file_path.parent.mkdir(parents=True, exist_ok=True) + + # Write file asynchronously + async with aiofiles.open(file_path, "wb") as f: + await f.write(content) + + # Return relative path as storage path + relative_path = file_path.relative_to(self.base_dir) + return f"local://{relative_path}" + + async def retrieve(self, storage_path: str) -> bytes: + """Retrieve file from local storage.""" + file_path = self._parse_storage_path(storage_path) + + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {storage_path}") + + async with aiofiles.open(file_path, "rb") as f: + return await f.read() + + async def delete(self, storage_path: str) -> None: + """Delete file from local storage.""" + file_path = self._parse_storage_path(storage_path) + + if file_path.exists(): + # Remove file + file_path.unlink() + + # Clean up empty parent directories + parent = file_path.parent + while parent != self.base_dir: + try: + if parent.exists() and not any(parent.iterdir()): + parent.rmdir() + else: + break + except OSError: + break + parent = parent.parent + + async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str: + """ + Get download URL for local file. + + For local storage, this returns an API endpoint path. + The actual serving is handled by the API layer. + """ + # Parse the storage path to get the components + if storage_path.startswith("local://"): + relative_path = storage_path[8:] + else: + relative_path = storage_path + + # Return the API endpoint for downloading + # The file_id is extracted from the path: {workspace_id}/{file_id}/{filename} + parts = relative_path.split("/") + if len(parts) >= 2: + file_id = parts[1] # Second component is file_id + return f"/api/workspace/files/{file_id}/download" + else: + raise ValueError(f"Invalid storage path format: {storage_path}") + + +# Global storage backend instance +_workspace_storage: Optional[WorkspaceStorageBackend] = None +_storage_lock = asyncio.Lock() + + +async def get_workspace_storage() -> WorkspaceStorageBackend: + """ + Get the workspace storage backend instance. + + Uses GCS if media_gcs_bucket_name is configured, otherwise uses local storage. + """ + global _workspace_storage + + if _workspace_storage is None: + async with _storage_lock: + if _workspace_storage is None: + config = Config() + + if config.media_gcs_bucket_name: + logger.info( + f"Using GCS workspace storage: {config.media_gcs_bucket_name}" + ) + _workspace_storage = GCSWorkspaceStorage( + config.media_gcs_bucket_name + ) + else: + storage_dir = ( + config.workspace_storage_dir + if config.workspace_storage_dir + else None + ) + logger.info( + f"Using local workspace storage: {storage_dir or 'default'}" + ) + _workspace_storage = LocalWorkspaceStorage(storage_dir) + + return _workspace_storage + + +async def shutdown_workspace_storage() -> None: + """ + Properly shutdown the global workspace storage backend. + + Closes aiohttp sessions and other resources for GCS backend. + Should be called during application shutdown. + """ + global _workspace_storage + + if _workspace_storage is not None: + async with _storage_lock: + if _workspace_storage is not None: + if isinstance(_workspace_storage, GCSWorkspaceStorage): + await _workspace_storage.close() + _workspace_storage = None + + +def compute_file_checksum(content: bytes) -> str: + """Compute SHA256 checksum of file content.""" + return hashlib.sha256(content).hexdigest() diff --git a/autogpt_platform/backend/migrations/20260126120000_migrate_claude_3_7_to_4_5_sonnet/migration.sql b/autogpt_platform/backend/migrations/20260126120000_migrate_claude_3_7_to_4_5_sonnet/migration.sql new file mode 100644 index 0000000000..5746c80820 --- /dev/null +++ b/autogpt_platform/backend/migrations/20260126120000_migrate_claude_3_7_to_4_5_sonnet/migration.sql @@ -0,0 +1,22 @@ +-- Migrate Claude 3.7 Sonnet to Claude 4.5 Sonnet +-- This updates all AgentNode blocks that use the deprecated Claude 3.7 Sonnet model +-- Anthropic is retiring claude-3-7-sonnet-20250219 on February 19, 2026 + +-- Update AgentNode constant inputs +UPDATE "AgentNode" +SET "constantInput" = JSONB_SET( + "constantInput"::jsonb, + '{model}', + '"claude-sonnet-4-5-20250929"'::jsonb + ) +WHERE "constantInput"::jsonb->>'model' = 'claude-3-7-sonnet-20250219'; + +-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput) +UPDATE "AgentNodeExecutionInputOutput" +SET "data" = JSONB_SET( + "data"::jsonb, + '{model}', + '"claude-sonnet-4-5-20250929"'::jsonb + ) +WHERE "agentPresetId" IS NOT NULL + AND "data"::jsonb->>'model' = 'claude-3-7-sonnet-20250219'; diff --git a/autogpt_platform/backend/migrations/20260127230419_add_user_workspace/migration.sql b/autogpt_platform/backend/migrations/20260127230419_add_user_workspace/migration.sql new file mode 100644 index 0000000000..bb63dccb33 --- /dev/null +++ b/autogpt_platform/backend/migrations/20260127230419_add_user_workspace/migration.sql @@ -0,0 +1,52 @@ +-- CreateEnum +CREATE TYPE "WorkspaceFileSource" AS ENUM ('UPLOAD', 'EXECUTION', 'COPILOT', 'IMPORT'); + +-- CreateTable +CREATE TABLE "UserWorkspace" ( + "id" TEXT NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + "userId" TEXT NOT NULL, + + CONSTRAINT "UserWorkspace_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "UserWorkspaceFile" ( + "id" TEXT NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + "workspaceId" TEXT NOT NULL, + "name" TEXT NOT NULL, + "path" TEXT NOT NULL, + "storagePath" TEXT NOT NULL, + "mimeType" TEXT NOT NULL, + "sizeBytes" BIGINT NOT NULL, + "checksum" TEXT, + "isDeleted" BOOLEAN NOT NULL DEFAULT false, + "deletedAt" TIMESTAMP(3), + "source" "WorkspaceFileSource" NOT NULL DEFAULT 'UPLOAD', + "sourceExecId" TEXT, + "sourceSessionId" TEXT, + "metadata" JSONB NOT NULL DEFAULT '{}', + + CONSTRAINT "UserWorkspaceFile_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "UserWorkspace_userId_key" ON "UserWorkspace"("userId"); + +-- CreateIndex +CREATE INDEX "UserWorkspace_userId_idx" ON "UserWorkspace"("userId"); + +-- CreateIndex +CREATE INDEX "UserWorkspaceFile_workspaceId_isDeleted_idx" ON "UserWorkspaceFile"("workspaceId", "isDeleted"); + +-- CreateIndex +CREATE UNIQUE INDEX "UserWorkspaceFile_workspaceId_path_key" ON "UserWorkspaceFile"("workspaceId", "path"); + +-- AddForeignKey +ALTER TABLE "UserWorkspace" ADD CONSTRAINT "UserWorkspace_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "UserWorkspaceFile" ADD CONSTRAINT "UserWorkspaceFile_workspaceId_fkey" FOREIGN KEY ("workspaceId") REFERENCES "UserWorkspace"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/autogpt_platform/backend/migrations/20260129011611_remove_workspace_file_source/migration.sql b/autogpt_platform/backend/migrations/20260129011611_remove_workspace_file_source/migration.sql new file mode 100644 index 0000000000..2709bc8484 --- /dev/null +++ b/autogpt_platform/backend/migrations/20260129011611_remove_workspace_file_source/migration.sql @@ -0,0 +1,16 @@ +/* + Warnings: + + - You are about to drop the column `source` on the `UserWorkspaceFile` table. All the data in the column will be lost. + - You are about to drop the column `sourceExecId` on the `UserWorkspaceFile` table. All the data in the column will be lost. + - You are about to drop the column `sourceSessionId` on the `UserWorkspaceFile` table. All the data in the column will be lost. + +*/ + +-- AlterTable +ALTER TABLE "UserWorkspaceFile" DROP COLUMN "source", +DROP COLUMN "sourceExecId", +DROP COLUMN "sourceSessionId"; + +-- DropEnum +DROP TYPE "WorkspaceFileSource"; diff --git a/autogpt_platform/backend/schema.prisma b/autogpt_platform/backend/schema.prisma index 2c52528e3f..2da898a7ce 100644 --- a/autogpt_platform/backend/schema.prisma +++ b/autogpt_platform/backend/schema.prisma @@ -63,6 +63,7 @@ model User { IntegrationWebhooks IntegrationWebhook[] NotificationBatches UserNotificationBatch[] PendingHumanReviews PendingHumanReview[] + Workspace UserWorkspace? // OAuth Provider relations OAuthApplications OAuthApplication[] @@ -137,6 +138,53 @@ model CoPilotUnderstanding { @@index([userId]) } +//////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////// +//////////////// USER WORKSPACE TABLES ///////////////// +//////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////// + +// User's persistent file storage workspace +model UserWorkspace { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + userId String @unique + User User @relation(fields: [userId], references: [id], onDelete: Cascade) + + Files UserWorkspaceFile[] + + @@index([userId]) +} + +// Individual files in a user's workspace +model UserWorkspaceFile { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + workspaceId String + Workspace UserWorkspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade) + + // File metadata + name String // User-visible filename + path String // Virtual path (e.g., "/documents/report.pdf") + storagePath String // Actual GCS or local storage path + mimeType String + sizeBytes BigInt + checksum String? // SHA256 for integrity + + // File state + isDeleted Boolean @default(false) + deletedAt DateTime? + + metadata Json @default("{}") + + @@unique([workspaceId, path]) + @@index([workspaceId, isDeleted]) +} + model BuilderSearchHistory { id String @id @default(uuid()) createdAt DateTime @default(now()) diff --git a/autogpt_platform/backend/snapshots/agts_by_creator b/autogpt_platform/backend/snapshots/agts_by_creator index 4d6dd12920..3f2e128a0d 100644 --- a/autogpt_platform/backend/snapshots/agts_by_creator +++ b/autogpt_platform/backend/snapshots/agts_by_creator @@ -9,7 +9,8 @@ "sub_heading": "Creator agent subheading", "description": "Creator agent description", "runs": 50, - "rating": 4.0 + "rating": 4.0, + "agent_graph_id": "test-graph-2" } ], "pagination": { diff --git a/autogpt_platform/backend/snapshots/agts_category b/autogpt_platform/backend/snapshots/agts_category index f65925ead3..4d0531763c 100644 --- a/autogpt_platform/backend/snapshots/agts_category +++ b/autogpt_platform/backend/snapshots/agts_category @@ -9,7 +9,8 @@ "sub_heading": "Category agent subheading", "description": "Category agent description", "runs": 60, - "rating": 4.1 + "rating": 4.1, + "agent_graph_id": "test-graph-category" } ], "pagination": { diff --git a/autogpt_platform/backend/snapshots/agts_pagination b/autogpt_platform/backend/snapshots/agts_pagination index 82e7f5f9bf..7b946157fb 100644 --- a/autogpt_platform/backend/snapshots/agts_pagination +++ b/autogpt_platform/backend/snapshots/agts_pagination @@ -9,7 +9,8 @@ "sub_heading": "Agent 0 subheading", "description": "Agent 0 description", "runs": 0, - "rating": 4.0 + "rating": 4.0, + "agent_graph_id": "test-graph-2" }, { "slug": "agent-1", @@ -20,7 +21,8 @@ "sub_heading": "Agent 1 subheading", "description": "Agent 1 description", "runs": 10, - "rating": 4.0 + "rating": 4.0, + "agent_graph_id": "test-graph-2" }, { "slug": "agent-2", @@ -31,7 +33,8 @@ "sub_heading": "Agent 2 subheading", "description": "Agent 2 description", "runs": 20, - "rating": 4.0 + "rating": 4.0, + "agent_graph_id": "test-graph-2" }, { "slug": "agent-3", @@ -42,7 +45,8 @@ "sub_heading": "Agent 3 subheading", "description": "Agent 3 description", "runs": 30, - "rating": 4.0 + "rating": 4.0, + "agent_graph_id": "test-graph-2" }, { "slug": "agent-4", @@ -53,7 +57,8 @@ "sub_heading": "Agent 4 subheading", "description": "Agent 4 description", "runs": 40, - "rating": 4.0 + "rating": 4.0, + "agent_graph_id": "test-graph-2" } ], "pagination": { diff --git a/autogpt_platform/backend/snapshots/agts_search b/autogpt_platform/backend/snapshots/agts_search index ca3f504584..ae9cc116bc 100644 --- a/autogpt_platform/backend/snapshots/agts_search +++ b/autogpt_platform/backend/snapshots/agts_search @@ -9,7 +9,8 @@ "sub_heading": "Search agent subheading", "description": "Specific search term description", "runs": 75, - "rating": 4.2 + "rating": 4.2, + "agent_graph_id": "test-graph-search" } ], "pagination": { diff --git a/autogpt_platform/backend/snapshots/agts_sorted b/autogpt_platform/backend/snapshots/agts_sorted index cddead76a5..b182256b2c 100644 --- a/autogpt_platform/backend/snapshots/agts_sorted +++ b/autogpt_platform/backend/snapshots/agts_sorted @@ -9,7 +9,8 @@ "sub_heading": "Top agent subheading", "description": "Top agent description", "runs": 1000, - "rating": 5.0 + "rating": 5.0, + "agent_graph_id": "test-graph-3" } ], "pagination": { diff --git a/autogpt_platform/backend/snapshots/feat_agts b/autogpt_platform/backend/snapshots/feat_agts index d57996a768..4f85786434 100644 --- a/autogpt_platform/backend/snapshots/feat_agts +++ b/autogpt_platform/backend/snapshots/feat_agts @@ -9,7 +9,8 @@ "sub_heading": "Featured agent subheading", "description": "Featured agent description", "runs": 100, - "rating": 4.5 + "rating": 4.5, + "agent_graph_id": "test-graph-1" } ], "pagination": { diff --git a/autogpt_platform/backend/snapshots/lib_agts_search b/autogpt_platform/backend/snapshots/lib_agts_search index 67c307b09e..3ce8402b63 100644 --- a/autogpt_platform/backend/snapshots/lib_agts_search +++ b/autogpt_platform/backend/snapshots/lib_agts_search @@ -31,6 +31,10 @@ "has_sensitive_action": false, "trigger_setup_info": null, "new_output": false, + "execution_count": 0, + "success_rate": null, + "avg_correctness_score": null, + "recent_executions": [], "can_access_graph": true, "is_latest_version": true, "is_favorite": false, @@ -72,6 +76,10 @@ "has_sensitive_action": false, "trigger_setup_info": null, "new_output": false, + "execution_count": 0, + "success_rate": null, + "avg_correctness_score": null, + "recent_executions": [], "can_access_graph": false, "is_latest_version": true, "is_favorite": false, diff --git a/autogpt_platform/backend/test/agent_generator/test_core_integration.py b/autogpt_platform/backend/test/agent_generator/test_core_integration.py index bdcc24ba79..528763e751 100644 --- a/autogpt_platform/backend/test/agent_generator/test_core_integration.py +++ b/autogpt_platform/backend/test/agent_generator/test_core_integration.py @@ -57,7 +57,8 @@ class TestDecomposeGoal: result = await core.decompose_goal("Build a chatbot") - mock_external.assert_called_once_with("Build a chatbot", "") + # library_agents defaults to None + mock_external.assert_called_once_with("Build a chatbot", "", None) assert result == expected_result @pytest.mark.asyncio @@ -74,7 +75,8 @@ class TestDecomposeGoal: await core.decompose_goal("Build a chatbot", "Use Python") - mock_external.assert_called_once_with("Build a chatbot", "Use Python") + # library_agents defaults to None + mock_external.assert_called_once_with("Build a chatbot", "Use Python", None) @pytest.mark.asyncio async def test_returns_none_on_service_failure(self): @@ -109,8 +111,7 @@ class TestGenerateAgent: instructions = {"type": "instructions", "steps": ["Step 1"]} result = await core.generate_agent(instructions) - mock_external.assert_called_once_with(instructions) - # Result should have id, version, is_active added if not present + mock_external.assert_called_once_with(instructions, None, None, None) assert result is not None assert result["name"] == "Test Agent" assert "id" in result @@ -174,7 +175,9 @@ class TestGenerateAgentPatch: current_agent = {"nodes": [], "links": []} result = await core.generate_agent_patch("Add a node", current_agent) - mock_external.assert_called_once_with("Add a node", current_agent) + mock_external.assert_called_once_with( + "Add a node", current_agent, None, None, None + ) assert result == expected_result @pytest.mark.asyncio diff --git a/autogpt_platform/backend/test/agent_generator/test_library_agents.py b/autogpt_platform/backend/test/agent_generator/test_library_agents.py new file mode 100644 index 0000000000..8387339582 --- /dev/null +++ b/autogpt_platform/backend/test/agent_generator/test_library_agents.py @@ -0,0 +1,857 @@ +""" +Tests for library agent fetching functionality in agent generator. + +This test suite verifies the search-based library agent fetching, +including the combination of library and marketplace agents. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from backend.api.features.chat.tools.agent_generator import core + + +class TestGetLibraryAgentsForGeneration: + """Test get_library_agents_for_generation function.""" + + @pytest.mark.asyncio + async def test_fetches_agents_with_search_term(self): + """Test that search_term is passed to the library db.""" + # Create a mock agent with proper attribute values + mock_agent = MagicMock() + mock_agent.graph_id = "agent-123" + mock_agent.graph_version = 1 + mock_agent.name = "Email Agent" + mock_agent.description = "Sends emails" + mock_agent.input_schema = {"properties": {}} + mock_agent.output_schema = {"properties": {}} + mock_agent.recent_executions = [] + + mock_response = MagicMock() + mock_response.agents = [mock_agent] + + with patch.object( + core.library_db, + "list_library_agents", + new_callable=AsyncMock, + return_value=mock_response, + ) as mock_list: + result = await core.get_library_agents_for_generation( + user_id="user-123", + search_query="send email", + ) + + mock_list.assert_called_once_with( + user_id="user-123", + search_term="send email", + page=1, + page_size=15, + include_executions=True, + ) + + # Verify result format + assert len(result) == 1 + assert result[0]["graph_id"] == "agent-123" + assert result[0]["name"] == "Email Agent" + + @pytest.mark.asyncio + async def test_excludes_specified_graph_id(self): + """Test that agents with excluded graph_id are filtered out.""" + mock_response = MagicMock() + mock_response.agents = [ + MagicMock( + graph_id="agent-123", + graph_version=1, + name="Agent 1", + description="First agent", + input_schema={}, + output_schema={}, + recent_executions=[], + ), + MagicMock( + graph_id="agent-456", + graph_version=1, + name="Agent 2", + description="Second agent", + input_schema={}, + output_schema={}, + recent_executions=[], + ), + ] + + with patch.object( + core.library_db, + "list_library_agents", + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await core.get_library_agents_for_generation( + user_id="user-123", + exclude_graph_id="agent-123", + ) + + # Verify the excluded agent is not in results + assert len(result) == 1 + assert result[0]["graph_id"] == "agent-456" + + @pytest.mark.asyncio + async def test_respects_max_results(self): + """Test that max_results parameter limits the page_size.""" + mock_response = MagicMock() + mock_response.agents = [] + + with patch.object( + core.library_db, + "list_library_agents", + new_callable=AsyncMock, + return_value=mock_response, + ) as mock_list: + await core.get_library_agents_for_generation( + user_id="user-123", + max_results=5, + ) + + mock_list.assert_called_once_with( + user_id="user-123", + search_term=None, + page=1, + page_size=5, + include_executions=True, + ) + + +class TestSearchMarketplaceAgentsForGeneration: + """Test search_marketplace_agents_for_generation function.""" + + @pytest.mark.asyncio + async def test_searches_marketplace_with_query(self): + """Test that marketplace is searched with the query.""" + mock_response = MagicMock() + mock_response.agents = [ + MagicMock( + agent_name="Public Agent", + description="A public agent", + sub_heading="Does something useful", + creator="creator-1", + agent_graph_id="graph-123", + ) + ] + + mock_graph = MagicMock() + mock_graph.id = "graph-123" + mock_graph.version = 1 + mock_graph.input_schema = {"type": "object"} + mock_graph.output_schema = {"type": "object"} + + with ( + patch( + "backend.api.features.store.db.get_store_agents", + new_callable=AsyncMock, + return_value=mock_response, + ) as mock_search, + patch( + "backend.api.features.chat.tools.agent_generator.core.get_store_listed_graphs", + new_callable=AsyncMock, + return_value={"graph-123": mock_graph}, + ), + ): + result = await core.search_marketplace_agents_for_generation( + search_query="automation", + max_results=10, + ) + + mock_search.assert_called_once_with( + search_query="automation", + page=1, + page_size=10, + ) + + assert len(result) == 1 + assert result[0]["name"] == "Public Agent" + assert result[0]["graph_id"] == "graph-123" + + @pytest.mark.asyncio + async def test_handles_marketplace_error_gracefully(self): + """Test that marketplace errors don't crash the function.""" + with patch( + "backend.api.features.store.db.get_store_agents", + new_callable=AsyncMock, + side_effect=Exception("Marketplace unavailable"), + ): + result = await core.search_marketplace_agents_for_generation( + search_query="test" + ) + + # Should return empty list, not raise exception + assert result == [] + + +class TestGetAllRelevantAgentsForGeneration: + """Test get_all_relevant_agents_for_generation function.""" + + @pytest.mark.asyncio + async def test_combines_library_and_marketplace_agents(self): + """Test that agents from both sources are combined.""" + library_agents = [ + { + "graph_id": "lib-123", + "graph_version": 1, + "name": "Library Agent", + "description": "From library", + "input_schema": {}, + "output_schema": {}, + } + ] + + marketplace_agents = [ + { + "graph_id": "market-456", + "graph_version": 1, + "name": "Market Agent", + "description": "From marketplace", + "input_schema": {}, + "output_schema": {}, + } + ] + + with patch.object( + core, + "get_library_agents_for_generation", + new_callable=AsyncMock, + return_value=library_agents, + ): + with patch.object( + core, + "search_marketplace_agents_for_generation", + new_callable=AsyncMock, + return_value=marketplace_agents, + ): + result = await core.get_all_relevant_agents_for_generation( + user_id="user-123", + search_query="test query", + include_marketplace=True, + ) + + # Library agents should come first + assert len(result) == 2 + assert result[0]["name"] == "Library Agent" + assert result[1]["name"] == "Market Agent" + + @pytest.mark.asyncio + async def test_deduplicates_by_graph_id(self): + """Test that marketplace agents with same graph_id as library are excluded.""" + library_agents = [ + { + "graph_id": "shared-123", + "graph_version": 1, + "name": "Shared Agent", + "description": "From library", + "input_schema": {}, + "output_schema": {}, + } + ] + + marketplace_agents = [ + { + "graph_id": "shared-123", # Same graph_id, should be deduplicated + "graph_version": 1, + "name": "Shared Agent", + "description": "From marketplace", + "input_schema": {}, + "output_schema": {}, + }, + { + "graph_id": "unique-456", + "graph_version": 1, + "name": "Unique Agent", + "description": "Only in marketplace", + "input_schema": {}, + "output_schema": {}, + }, + ] + + with patch.object( + core, + "get_library_agents_for_generation", + new_callable=AsyncMock, + return_value=library_agents, + ): + with patch.object( + core, + "search_marketplace_agents_for_generation", + new_callable=AsyncMock, + return_value=marketplace_agents, + ): + result = await core.get_all_relevant_agents_for_generation( + user_id="user-123", + search_query="test", + include_marketplace=True, + ) + + # Shared Agent from marketplace should be excluded by graph_id + assert len(result) == 2 + names = [a["name"] for a in result] + assert "Shared Agent" in names + assert "Unique Agent" in names + + @pytest.mark.asyncio + async def test_skips_marketplace_when_disabled(self): + """Test that marketplace is not searched when include_marketplace=False.""" + library_agents = [ + { + "graph_id": "lib-123", + "graph_version": 1, + "name": "Library Agent", + "description": "From library", + "input_schema": {}, + "output_schema": {}, + } + ] + + with patch.object( + core, + "get_library_agents_for_generation", + new_callable=AsyncMock, + return_value=library_agents, + ): + with patch.object( + core, + "search_marketplace_agents_for_generation", + new_callable=AsyncMock, + ) as mock_marketplace: + result = await core.get_all_relevant_agents_for_generation( + user_id="user-123", + search_query="test", + include_marketplace=False, + ) + + # Marketplace should not be called + mock_marketplace.assert_not_called() + assert len(result) == 1 + + @pytest.mark.asyncio + async def test_skips_marketplace_when_no_search_query(self): + """Test that marketplace is not searched without a search query.""" + library_agents = [ + { + "graph_id": "lib-123", + "graph_version": 1, + "name": "Library Agent", + "description": "From library", + "input_schema": {}, + "output_schema": {}, + } + ] + + with patch.object( + core, + "get_library_agents_for_generation", + new_callable=AsyncMock, + return_value=library_agents, + ): + with patch.object( + core, + "search_marketplace_agents_for_generation", + new_callable=AsyncMock, + ) as mock_marketplace: + result = await core.get_all_relevant_agents_for_generation( + user_id="user-123", + search_query=None, # No search query + include_marketplace=True, + ) + + # Marketplace should not be called without search query + mock_marketplace.assert_not_called() + assert len(result) == 1 + + +class TestExtractSearchTermsFromSteps: + """Test extract_search_terms_from_steps function.""" + + def test_extracts_terms_from_instructions_type(self): + """Test extraction from valid instructions decomposition result.""" + decomposition_result = { + "type": "instructions", + "steps": [ + { + "description": "Send an email notification", + "block_name": "GmailSendBlock", + }, + {"description": "Fetch weather data", "action": "Get weather API"}, + ], + } + + result = core.extract_search_terms_from_steps(decomposition_result) + + assert "Send an email notification" in result + assert "GmailSendBlock" in result + assert "Fetch weather data" in result + assert "Get weather API" in result + + def test_returns_empty_for_non_instructions_type(self): + """Test that non-instructions types return empty list.""" + decomposition_result = { + "type": "clarifying_questions", + "questions": [{"question": "What email?"}], + } + + result = core.extract_search_terms_from_steps(decomposition_result) + + assert result == [] + + def test_deduplicates_terms_case_insensitively(self): + """Test that duplicate terms are removed (case-insensitive).""" + decomposition_result = { + "type": "instructions", + "steps": [ + {"description": "Send Email", "name": "send email"}, + {"description": "Other task"}, + ], + } + + result = core.extract_search_terms_from_steps(decomposition_result) + + # Should only have one "send email" variant + email_terms = [t for t in result if "email" in t.lower()] + assert len(email_terms) == 1 + + def test_filters_short_terms(self): + """Test that terms with 3 or fewer characters are filtered out.""" + decomposition_result = { + "type": "instructions", + "steps": [ + {"description": "ab", "action": "xyz"}, # Both too short + {"description": "Valid term here"}, + ], + } + + result = core.extract_search_terms_from_steps(decomposition_result) + + assert "ab" not in result + assert "xyz" not in result + assert "Valid term here" in result + + def test_handles_empty_steps(self): + """Test handling of empty steps list.""" + decomposition_result = { + "type": "instructions", + "steps": [], + } + + result = core.extract_search_terms_from_steps(decomposition_result) + + assert result == [] + + +class TestEnrichLibraryAgentsFromSteps: + """Test enrich_library_agents_from_steps function.""" + + @pytest.mark.asyncio + async def test_enriches_with_additional_agents(self): + """Test that additional agents are found based on steps.""" + existing_agents = [ + { + "graph_id": "existing-123", + "graph_version": 1, + "name": "Existing Agent", + "description": "Already fetched", + "input_schema": {}, + "output_schema": {}, + } + ] + + additional_agents = [ + { + "graph_id": "new-456", + "graph_version": 1, + "name": "Email Agent", + "description": "For sending emails", + "input_schema": {}, + "output_schema": {}, + } + ] + + decomposition_result = { + "type": "instructions", + "steps": [ + {"description": "Send email notification"}, + ], + } + + with patch.object( + core, + "get_all_relevant_agents_for_generation", + new_callable=AsyncMock, + return_value=additional_agents, + ): + result = await core.enrich_library_agents_from_steps( + user_id="user-123", + decomposition_result=decomposition_result, + existing_agents=existing_agents, + ) + + # Should have both existing and new agents + assert len(result) == 2 + names = [a["name"] for a in result] + assert "Existing Agent" in names + assert "Email Agent" in names + + @pytest.mark.asyncio + async def test_deduplicates_by_graph_id(self): + """Test that agents with same graph_id are not duplicated.""" + existing_agents = [ + { + "graph_id": "agent-123", + "graph_version": 1, + "name": "Existing Agent", + "description": "Already fetched", + "input_schema": {}, + "output_schema": {}, + } + ] + + # Additional search returns same agent + additional_agents = [ + { + "graph_id": "agent-123", # Same ID + "graph_version": 1, + "name": "Existing Agent Copy", + "description": "Same agent different name", + "input_schema": {}, + "output_schema": {}, + } + ] + + decomposition_result = { + "type": "instructions", + "steps": [{"description": "Some action"}], + } + + with patch.object( + core, + "get_all_relevant_agents_for_generation", + new_callable=AsyncMock, + return_value=additional_agents, + ): + result = await core.enrich_library_agents_from_steps( + user_id="user-123", + decomposition_result=decomposition_result, + existing_agents=existing_agents, + ) + + # Should not duplicate + assert len(result) == 1 + + @pytest.mark.asyncio + async def test_deduplicates_by_name(self): + """Test that agents with same name are not duplicated.""" + existing_agents = [ + { + "graph_id": "agent-123", + "graph_version": 1, + "name": "Email Agent", + "description": "Already fetched", + "input_schema": {}, + "output_schema": {}, + } + ] + + # Additional search returns agent with same name but different ID + additional_agents = [ + { + "graph_id": "agent-456", # Different ID + "graph_version": 1, + "name": "Email Agent", # Same name + "description": "Different agent same name", + "input_schema": {}, + "output_schema": {}, + } + ] + + decomposition_result = { + "type": "instructions", + "steps": [{"description": "Send email"}], + } + + with patch.object( + core, + "get_all_relevant_agents_for_generation", + new_callable=AsyncMock, + return_value=additional_agents, + ): + result = await core.enrich_library_agents_from_steps( + user_id="user-123", + decomposition_result=decomposition_result, + existing_agents=existing_agents, + ) + + # Should not duplicate by name + assert len(result) == 1 + assert result[0].get("graph_id") == "agent-123" # Original kept + + @pytest.mark.asyncio + async def test_returns_existing_when_no_steps(self): + """Test that existing agents are returned when no search terms extracted.""" + existing_agents = [ + { + "graph_id": "existing-123", + "graph_version": 1, + "name": "Existing Agent", + "description": "Already fetched", + "input_schema": {}, + "output_schema": {}, + } + ] + + decomposition_result = { + "type": "clarifying_questions", # Not instructions type + "questions": [], + } + + result = await core.enrich_library_agents_from_steps( + user_id="user-123", + decomposition_result=decomposition_result, + existing_agents=existing_agents, + ) + + # Should return existing unchanged + assert result == existing_agents + + @pytest.mark.asyncio + async def test_limits_search_terms_to_three(self): + """Test that only first 3 search terms are used.""" + existing_agents = [] + + decomposition_result = { + "type": "instructions", + "steps": [ + {"description": "First action"}, + {"description": "Second action"}, + {"description": "Third action"}, + {"description": "Fourth action"}, + {"description": "Fifth action"}, + ], + } + + call_count = 0 + + async def mock_get_agents(*args, **kwargs): + nonlocal call_count + call_count += 1 + return [] + + with patch.object( + core, + "get_all_relevant_agents_for_generation", + side_effect=mock_get_agents, + ): + await core.enrich_library_agents_from_steps( + user_id="user-123", + decomposition_result=decomposition_result, + existing_agents=existing_agents, + ) + + # Should only make 3 calls (limited to first 3 terms) + assert call_count == 3 + + +class TestExtractUuidsFromText: + """Test extract_uuids_from_text function.""" + + def test_extracts_single_uuid(self): + """Test extraction of a single UUID from text.""" + text = "Use my agent 46631191-e8a8-486f-ad90-84f89738321d for this task" + result = core.extract_uuids_from_text(text) + assert len(result) == 1 + assert "46631191-e8a8-486f-ad90-84f89738321d" in result + + def test_extracts_multiple_uuids(self): + """Test extraction of multiple UUIDs from text.""" + text = ( + "Combine agents 11111111-1111-4111-8111-111111111111 " + "and 22222222-2222-4222-9222-222222222222" + ) + result = core.extract_uuids_from_text(text) + assert len(result) == 2 + assert "11111111-1111-4111-8111-111111111111" in result + assert "22222222-2222-4222-9222-222222222222" in result + + def test_deduplicates_uuids(self): + """Test that duplicate UUIDs are deduplicated.""" + text = ( + "Use 46631191-e8a8-486f-ad90-84f89738321d twice: " + "46631191-e8a8-486f-ad90-84f89738321d" + ) + result = core.extract_uuids_from_text(text) + assert len(result) == 1 + + def test_normalizes_to_lowercase(self): + """Test that UUIDs are normalized to lowercase.""" + text = "Use 46631191-E8A8-486F-AD90-84F89738321D" + result = core.extract_uuids_from_text(text) + assert result[0] == "46631191-e8a8-486f-ad90-84f89738321d" + + def test_returns_empty_for_no_uuids(self): + """Test that empty list is returned when no UUIDs found.""" + text = "Create an email agent that sends notifications" + result = core.extract_uuids_from_text(text) + assert result == [] + + def test_ignores_invalid_uuids(self): + """Test that invalid UUID-like strings are ignored.""" + text = "Not a valid UUID: 12345678-1234-1234-1234-123456789abc" + result = core.extract_uuids_from_text(text) + # UUID v4 requires specific patterns (4 in third group, 8/9/a/b in fourth) + assert len(result) == 0 + + +class TestGetLibraryAgentById: + """Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id).""" + + @pytest.mark.asyncio + async def test_returns_agent_when_found_by_graph_id(self): + """Test that agent is returned when found by graph_id.""" + mock_agent = MagicMock() + mock_agent.graph_id = "agent-123" + mock_agent.graph_version = 1 + mock_agent.name = "Test Agent" + mock_agent.description = "Test description" + mock_agent.input_schema = {"properties": {}} + mock_agent.output_schema = {"properties": {}} + + with patch.object( + core.library_db, + "get_library_agent_by_graph_id", + new_callable=AsyncMock, + return_value=mock_agent, + ): + result = await core.get_library_agent_by_id("user-123", "agent-123") + + assert result is not None + assert result["graph_id"] == "agent-123" + assert result["name"] == "Test Agent" + + @pytest.mark.asyncio + async def test_falls_back_to_library_agent_id(self): + """Test that lookup falls back to library agent ID when graph_id not found.""" + mock_agent = MagicMock() + mock_agent.graph_id = "graph-456" # Different from the lookup ID + mock_agent.graph_version = 1 + mock_agent.name = "Library Agent" + mock_agent.description = "Found by library ID" + mock_agent.input_schema = {"properties": {}} + mock_agent.output_schema = {"properties": {}} + + with ( + patch.object( + core.library_db, + "get_library_agent_by_graph_id", + new_callable=AsyncMock, + return_value=None, # Not found by graph_id + ), + patch.object( + core.library_db, + "get_library_agent", + new_callable=AsyncMock, + return_value=mock_agent, # Found by library ID + ), + ): + result = await core.get_library_agent_by_id("user-123", "library-id-123") + + assert result is not None + assert result["graph_id"] == "graph-456" + assert result["name"] == "Library Agent" + + @pytest.mark.asyncio + async def test_returns_none_when_not_found_by_either_method(self): + """Test that None is returned when agent not found by either method.""" + with ( + patch.object( + core.library_db, + "get_library_agent_by_graph_id", + new_callable=AsyncMock, + return_value=None, + ), + patch.object( + core.library_db, + "get_library_agent", + new_callable=AsyncMock, + side_effect=core.NotFoundError("Not found"), + ), + ): + result = await core.get_library_agent_by_id("user-123", "nonexistent") + + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_on_exception(self): + """Test that None is returned when exception occurs in both lookups.""" + with ( + patch.object( + core.library_db, + "get_library_agent_by_graph_id", + new_callable=AsyncMock, + side_effect=Exception("Database error"), + ), + patch.object( + core.library_db, + "get_library_agent", + new_callable=AsyncMock, + side_effect=Exception("Database error"), + ), + ): + result = await core.get_library_agent_by_id("user-123", "agent-123") + + assert result is None + + @pytest.mark.asyncio + async def test_alias_works(self): + """Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id.""" + assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id + + +class TestGetAllRelevantAgentsWithUuids: + """Test UUID extraction in get_all_relevant_agents_for_generation.""" + + @pytest.mark.asyncio + async def test_fetches_explicitly_mentioned_agents(self): + """Test that agents mentioned by UUID are fetched directly.""" + mock_agent = MagicMock() + mock_agent.graph_id = "46631191-e8a8-486f-ad90-84f89738321d" + mock_agent.graph_version = 1 + mock_agent.name = "Mentioned Agent" + mock_agent.description = "Explicitly mentioned" + mock_agent.input_schema = {} + mock_agent.output_schema = {} + + mock_response = MagicMock() + mock_response.agents = [] + + with ( + patch.object( + core.library_db, + "get_library_agent_by_graph_id", + new_callable=AsyncMock, + return_value=mock_agent, + ), + patch.object( + core.library_db, + "list_library_agents", + new_callable=AsyncMock, + return_value=mock_response, + ), + ): + result = await core.get_all_relevant_agents_for_generation( + user_id="user-123", + search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d", + include_marketplace=False, + ) + + assert len(result) == 1 + assert result[0].get("graph_id") == "46631191-e8a8-486f-ad90-84f89738321d" + assert result[0].get("name") == "Mentioned Agent" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/autogpt_platform/backend/test/agent_generator/test_service.py b/autogpt_platform/backend/test/agent_generator/test_service.py index 81ff794532..cc37c428c0 100644 --- a/autogpt_platform/backend/test/agent_generator/test_service.py +++ b/autogpt_platform/backend/test/agent_generator/test_service.py @@ -102,7 +102,7 @@ class TestDecomposeGoalExternal: @pytest.mark.asyncio async def test_decompose_goal_with_context(self): - """Test decomposition with additional context.""" + """Test decomposition with additional context enriched into description.""" mock_response = MagicMock() mock_response.json.return_value = { "success": True, @@ -119,9 +119,12 @@ class TestDecomposeGoalExternal: "Build a chatbot", context="Use Python" ) + expected_description = ( + "Build a chatbot\n\nAdditional context from user:\nUse Python" + ) mock_client.post.assert_called_once_with( "/api/decompose-description", - json={"description": "Build a chatbot", "user_instruction": "Use Python"}, + json={"description": expected_description}, ) @pytest.mark.asyncio @@ -151,15 +154,20 @@ class TestDecomposeGoalExternal: @pytest.mark.asyncio async def test_decompose_goal_handles_http_error(self): """Test decomposition handles HTTP errors gracefully.""" + mock_response = MagicMock() + mock_response.status_code = 500 mock_client = AsyncMock() mock_client.post.side_effect = httpx.HTTPStatusError( - "Server error", request=MagicMock(), response=MagicMock() + "Server error", request=MagicMock(), response=mock_response ) with patch.object(service, "_get_client", return_value=mock_client): result = await service.decompose_goal_external("Build a chatbot") - assert result is None + assert result is not None + assert result.get("type") == "error" + assert result.get("error_type") == "http_error" + assert "Server error" in result.get("error", "") @pytest.mark.asyncio async def test_decompose_goal_handles_request_error(self): @@ -170,7 +178,10 @@ class TestDecomposeGoalExternal: with patch.object(service, "_get_client", return_value=mock_client): result = await service.decompose_goal_external("Build a chatbot") - assert result is None + assert result is not None + assert result.get("type") == "error" + assert result.get("error_type") == "connection_error" + assert "Connection failed" in result.get("error", "") @pytest.mark.asyncio async def test_decompose_goal_handles_service_error(self): @@ -179,6 +190,7 @@ class TestDecomposeGoalExternal: mock_response.json.return_value = { "success": False, "error": "Internal error", + "error_type": "internal_error", } mock_response.raise_for_status = MagicMock() @@ -188,7 +200,10 @@ class TestDecomposeGoalExternal: with patch.object(service, "_get_client", return_value=mock_client): result = await service.decompose_goal_external("Build a chatbot") - assert result is None + assert result is not None + assert result.get("type") == "error" + assert result.get("error") == "Internal error" + assert result.get("error_type") == "internal_error" class TestGenerateAgentExternal: @@ -236,7 +251,10 @@ class TestGenerateAgentExternal: with patch.object(service, "_get_client", return_value=mock_client): result = await service.generate_agent_external({"steps": []}) - assert result is None + assert result is not None + assert result.get("type") == "error" + assert result.get("error_type") == "connection_error" + assert "Connection failed" in result.get("error", "") class TestGenerateAgentPatchExternal: @@ -418,5 +436,139 @@ class TestGetBlocksExternal: assert result is None +class TestLibraryAgentsPassthrough: + """Test that library_agents are passed correctly in all requests.""" + + def setup_method(self): + """Reset client singleton before each test.""" + service._settings = None + service._client = None + + @pytest.mark.asyncio + async def test_decompose_goal_passes_library_agents(self): + """Test that library_agents are included in decompose goal payload.""" + library_agents = [ + { + "graph_id": "agent-123", + "graph_version": 1, + "name": "Email Sender", + "description": "Sends emails", + "input_schema": {"properties": {"to": {"type": "string"}}}, + "output_schema": {"properties": {"sent": {"type": "boolean"}}}, + }, + ] + + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "type": "instructions", + "steps": ["Step 1"], + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + with patch.object(service, "_get_client", return_value=mock_client): + await service.decompose_goal_external( + "Send an email", + library_agents=library_agents, + ) + + # Verify library_agents was passed in the payload + call_args = mock_client.post.call_args + assert call_args[1]["json"]["library_agents"] == library_agents + + @pytest.mark.asyncio + async def test_generate_agent_passes_library_agents(self): + """Test that library_agents are included in generate agent payload.""" + library_agents = [ + { + "graph_id": "agent-456", + "graph_version": 2, + "name": "Data Fetcher", + "description": "Fetches data from API", + "input_schema": {"properties": {"url": {"type": "string"}}}, + "output_schema": {"properties": {"data": {"type": "object"}}}, + }, + ] + + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "agent_json": {"name": "Test Agent", "nodes": []}, + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + with patch.object(service, "_get_client", return_value=mock_client): + await service.generate_agent_external( + {"steps": ["Step 1"]}, + library_agents=library_agents, + ) + + # Verify library_agents was passed in the payload + call_args = mock_client.post.call_args + assert call_args[1]["json"]["library_agents"] == library_agents + + @pytest.mark.asyncio + async def test_generate_agent_patch_passes_library_agents(self): + """Test that library_agents are included in patch generation payload.""" + library_agents = [ + { + "graph_id": "agent-789", + "graph_version": 1, + "name": "Slack Notifier", + "description": "Sends Slack messages", + "input_schema": {"properties": {"message": {"type": "string"}}}, + "output_schema": {"properties": {"success": {"type": "boolean"}}}, + }, + ] + + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "agent_json": {"name": "Updated Agent", "nodes": []}, + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + with patch.object(service, "_get_client", return_value=mock_client): + await service.generate_agent_patch_external( + "Add error handling", + {"name": "Original Agent", "nodes": []}, + library_agents=library_agents, + ) + + # Verify library_agents was passed in the payload + call_args = mock_client.post.call_args + assert call_args[1]["json"]["library_agents"] == library_agents + + @pytest.mark.asyncio + async def test_decompose_goal_without_library_agents(self): + """Test that decompose goal works without library_agents.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "type": "instructions", + "steps": ["Step 1"], + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + with patch.object(service, "_get_client", return_value=mock_client): + await service.decompose_goal_external("Build a workflow") + + # Verify library_agents was NOT passed when not provided + call_args = mock_client.post.call_args + assert "library_agents" not in call_args[1]["json"] + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/autogpt_platform/backend/test/e2e_test_data.py b/autogpt_platform/backend/test/e2e_test_data.py index d7576cdad3..7288197a90 100644 --- a/autogpt_platform/backend/test/e2e_test_data.py +++ b/autogpt_platform/backend/test/e2e_test_data.py @@ -43,19 +43,24 @@ faker = Faker() # Constants for data generation limits (reduced for E2E tests) NUM_USERS = 15 NUM_AGENT_BLOCKS = 30 -MIN_GRAPHS_PER_USER = 15 -MAX_GRAPHS_PER_USER = 15 +MIN_GRAPHS_PER_USER = 25 +MAX_GRAPHS_PER_USER = 25 MIN_NODES_PER_GRAPH = 3 MAX_NODES_PER_GRAPH = 6 MIN_PRESETS_PER_USER = 2 MAX_PRESETS_PER_USER = 3 -MIN_AGENTS_PER_USER = 15 -MAX_AGENTS_PER_USER = 15 +MIN_AGENTS_PER_USER = 25 +MAX_AGENTS_PER_USER = 25 MIN_EXECUTIONS_PER_GRAPH = 2 MAX_EXECUTIONS_PER_GRAPH = 8 MIN_REVIEWS_PER_VERSION = 2 MAX_REVIEWS_PER_VERSION = 5 +# Guaranteed minimums for marketplace tests (deterministic) +GUARANTEED_FEATURED_AGENTS = 8 +GUARANTEED_FEATURED_CREATORS = 5 +GUARANTEED_TOP_AGENTS = 10 + def get_image(): """Generate a consistent image URL using picsum.photos service.""" @@ -385,7 +390,7 @@ class TestDataCreator: library_agents = [] for user in self.users: - num_agents = 10 # Create exactly 10 agents per user + num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER) # Get available graphs for this user user_graphs = [ @@ -507,14 +512,17 @@ class TestDataCreator: existing_profiles, min(num_creators, len(existing_profiles)) ) - # Mark about 50% of creators as featured (more for testing) - num_featured = max(2, int(num_creators * 0.5)) + # Guarantee at least GUARANTEED_FEATURED_CREATORS featured creators + num_featured = max(GUARANTEED_FEATURED_CREATORS, int(num_creators * 0.5)) num_featured = min( num_featured, len(selected_profiles) ) # Don't exceed available profiles featured_profile_ids = set( random.sample([p.id for p in selected_profiles], num_featured) ) + print( + f"🎯 Creating {num_featured} featured creators (min: {GUARANTEED_FEATURED_CREATORS})" + ) for profile in selected_profiles: try: @@ -545,21 +553,25 @@ class TestDataCreator: return profiles async def create_test_store_submissions(self) -> List[Dict[str, Any]]: - """Create test store submissions using the API function.""" + """Create test store submissions using the API function. + + DETERMINISTIC: Guarantees minimum featured agents for E2E tests. + """ print("Creating test store submissions...") submissions = [] approved_submissions = [] + featured_count = 0 + submission_counter = 0 - # Create a special test submission for test123@gmail.com + # Create a special test submission for test123@gmail.com (ALWAYS approved + featured) test_user = next( (user for user in self.users if user["email"] == "test123@gmail.com"), None ) - if test_user: - # Special test data for consistent testing + if test_user and self.agent_graphs: test_submission_data = { "user_id": test_user["id"], - "agent_id": self.agent_graphs[0]["id"], # Use first available graph + "agent_id": self.agent_graphs[0]["id"], "agent_version": 1, "slug": "test-agent-submission", "name": "Test Agent Submission", @@ -580,37 +592,24 @@ class TestDataCreator: submissions.append(test_submission.model_dump()) print("✅ Created special test store submission for test123@gmail.com") - # Randomly approve, reject, or leave pending the test submission + # ALWAYS approve and feature the test submission if test_submission.store_listing_version_id: - random_value = random.random() - if random_value < 0.4: # 40% chance to approve - approved_submission = await review_store_submission( - store_listing_version_id=test_submission.store_listing_version_id, - is_approved=True, - external_comments="Test submission approved", - internal_comments="Auto-approved test submission", - reviewer_id=test_user["id"], - ) - approved_submissions.append(approved_submission.model_dump()) - print("✅ Approved test store submission") + approved_submission = await review_store_submission( + store_listing_version_id=test_submission.store_listing_version_id, + is_approved=True, + external_comments="Test submission approved", + internal_comments="Auto-approved test submission", + reviewer_id=test_user["id"], + ) + approved_submissions.append(approved_submission.model_dump()) + print("✅ Approved test store submission") - # Mark approved submission as featured - await prisma.storelistingversion.update( - where={"id": test_submission.store_listing_version_id}, - data={"isFeatured": True}, - ) - print("🌟 Marked test agent as FEATURED") - elif random_value < 0.7: # 30% chance to reject (40% to 70%) - await review_store_submission( - store_listing_version_id=test_submission.store_listing_version_id, - is_approved=False, - external_comments="Test submission rejected - needs improvements", - internal_comments="Auto-rejected test submission for E2E testing", - reviewer_id=test_user["id"], - ) - print("❌ Rejected test store submission") - else: # 30% chance to leave pending (70% to 100%) - print("⏳ Left test submission pending for review") + await prisma.storelistingversion.update( + where={"id": test_submission.store_listing_version_id}, + data={"isFeatured": True}, + ) + featured_count += 1 + print("🌟 Marked test agent as FEATURED") except Exception as e: print(f"Error creating test store submission: {e}") @@ -620,7 +619,6 @@ class TestDataCreator: # Create regular submissions for all users for user in self.users: - # Get available graphs for this specific user user_graphs = [ g for g in self.agent_graphs if g.get("userId") == user["id"] ] @@ -631,18 +629,17 @@ class TestDataCreator: ) continue - # Create exactly 4 store submissions per user for submission_index in range(4): graph = random.choice(user_graphs) + submission_counter += 1 try: print( - f"Creating store submission for user {user['id']} with graph {graph['id']} (owner: {graph.get('userId')})" + f"Creating store submission for user {user['id']} with graph {graph['id']}" ) - # Use the API function to create store submission with correct parameters submission = await create_store_submission( - user_id=user["id"], # Must match graph's userId + user_id=user["id"], agent_id=graph["id"], agent_version=graph.get("version", 1), slug=faker.slug(), @@ -651,22 +648,24 @@ class TestDataCreator: video_url=get_video_url() if random.random() < 0.3 else None, image_urls=[get_image() for _ in range(3)], description=faker.text(), - categories=[ - get_category() - ], # Single category from predefined list + categories=[get_category()], changes_summary="Initial E2E test submission", ) submissions.append(submission.model_dump()) print(f"✅ Created store submission: {submission.name}") - # Randomly approve, reject, or leave pending the submission if submission.store_listing_version_id: - random_value = random.random() - if random_value < 0.4: # 40% chance to approve - try: - # Pick a random user as the reviewer (admin) - reviewer_id = random.choice(self.users)["id"] + # DETERMINISTIC: First N submissions are always approved + # First GUARANTEED_FEATURED_AGENTS of those are always featured + should_approve = ( + submission_counter <= GUARANTEED_TOP_AGENTS + or random.random() < 0.4 + ) + should_feature = featured_count < GUARANTEED_FEATURED_AGENTS + if should_approve: + try: + reviewer_id = random.choice(self.users)["id"] approved_submission = await review_store_submission( store_listing_version_id=submission.store_listing_version_id, is_approved=True, @@ -681,16 +680,7 @@ class TestDataCreator: f"✅ Approved store submission: {submission.name}" ) - # Mark some agents as featured during creation (30% chance) - # More likely for creators and first submissions - is_creator = user["id"] in [ - p.get("userId") for p in self.profiles - ] - feature_chance = ( - 0.5 if is_creator else 0.2 - ) # 50% for creators, 20% for others - - if random.random() < feature_chance: + if should_feature: try: await prisma.storelistingversion.update( where={ @@ -698,8 +688,25 @@ class TestDataCreator: }, data={"isFeatured": True}, ) + featured_count += 1 print( - f"🌟 Marked agent as FEATURED: {submission.name}" + f"🌟 Marked agent as FEATURED ({featured_count}/{GUARANTEED_FEATURED_AGENTS}): {submission.name}" + ) + except Exception as e: + print( + f"Warning: Could not mark submission as featured: {e}" + ) + elif random.random() < 0.2: + try: + await prisma.storelistingversion.update( + where={ + "id": submission.store_listing_version_id + }, + data={"isFeatured": True}, + ) + featured_count += 1 + print( + f"🌟 Marked agent as FEATURED (bonus): {submission.name}" ) except Exception as e: print( @@ -710,11 +717,9 @@ class TestDataCreator: print( f"Warning: Could not approve submission {submission.name}: {e}" ) - elif random_value < 0.7: # 30% chance to reject (40% to 70%) + elif random.random() < 0.5: try: - # Pick a random user as the reviewer (admin) reviewer_id = random.choice(self.users)["id"] - await review_store_submission( store_listing_version_id=submission.store_listing_version_id, is_approved=False, @@ -729,7 +734,7 @@ class TestDataCreator: print( f"Warning: Could not reject submission {submission.name}: {e}" ) - else: # 30% chance to leave pending (70% to 100%) + else: print( f"⏳ Left submission pending for review: {submission.name}" ) @@ -743,9 +748,13 @@ class TestDataCreator: traceback.print_exc() continue + print("\n📊 Store Submissions Summary:") + print(f" Created: {len(submissions)}") + print(f" Approved: {len(approved_submissions)}") print( - f"Created {len(submissions)} store submissions, approved {len(approved_submissions)}" + f" Featured: {featured_count} (guaranteed min: {GUARANTEED_FEATURED_AGENTS})" ) + self.store_submissions = submissions return submissions @@ -825,12 +834,15 @@ class TestDataCreator: print(f"✅ Agent blocks available: {len(self.agent_blocks)}") print(f"✅ Agent graphs created: {len(self.agent_graphs)}") print(f"✅ Library agents created: {len(self.library_agents)}") - print(f"✅ Creator profiles updated: {len(self.profiles)} (some featured)") - print( - f"✅ Store submissions created: {len(self.store_submissions)} (some marked as featured during creation)" - ) + print(f"✅ Creator profiles updated: {len(self.profiles)}") + print(f"✅ Store submissions created: {len(self.store_submissions)}") print(f"✅ API keys created: {len(self.api_keys)}") print(f"✅ Presets created: {len(self.presets)}") + print("\n🎯 Deterministic Guarantees:") + print(f" • Featured agents: >= {GUARANTEED_FEATURED_AGENTS}") + print(f" • Featured creators: >= {GUARANTEED_FEATURED_CREATORS}") + print(f" • Top agents (approved): >= {GUARANTEED_TOP_AGENTS}") + print(f" • Library agents per user: >= {MIN_AGENTS_PER_USER}") print("\n🚀 Your E2E test database is ready to use!") diff --git a/autogpt_platform/frontend/.env.default b/autogpt_platform/frontend/.env.default index af250fb8bf..7a9d81e39e 100644 --- a/autogpt_platform/frontend/.env.default +++ b/autogpt_platform/frontend/.env.default @@ -34,3 +34,6 @@ NEXT_PUBLIC_PREVIEW_STEALING_DEV= # PostHog Analytics NEXT_PUBLIC_POSTHOG_KEY= NEXT_PUBLIC_POSTHOG_HOST=https://eu.i.posthog.com + +# OpenAI (for voice transcription) +OPENAI_API_KEY= diff --git a/autogpt_platform/frontend/CLAUDE.md b/autogpt_platform/frontend/CLAUDE.md new file mode 100644 index 0000000000..b58f1ad6aa --- /dev/null +++ b/autogpt_platform/frontend/CLAUDE.md @@ -0,0 +1,76 @@ +# CLAUDE.md - Frontend + +This file provides guidance to Claude Code when working with the frontend. + +## Essential Commands + +```bash +# Install dependencies +pnpm i + +# Generate API client from OpenAPI spec +pnpm generate:api + +# Start development server +pnpm dev + +# Run E2E tests +pnpm test + +# Run Storybook for component development +pnpm storybook + +# Build production +pnpm build + +# Format and lint +pnpm format + +# Type checking +pnpm types +``` + +### Code Style + +- Fully capitalize acronyms in symbols, e.g. `graphID`, `useBackendAPI` +- Use function declarations (not arrow functions) for components/handlers + +## Architecture + +- **Framework**: Next.js 15 App Router (client-first approach) +- **Data Fetching**: Type-safe generated API hooks via Orval + React Query +- **State Management**: React Query for server state, co-located UI state in components/hooks +- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks) +- **Workflow Builder**: Visual graph editor using @xyflow/react +- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling +- **Icons**: Phosphor Icons only +- **Feature Flags**: LaunchDarkly integration +- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions +- **Testing**: Playwright for E2E, Storybook for component development + +## Environment Configuration + +`.env.default` (defaults) → `.env` (user overrides) + +## Feature Development + +See @CONTRIBUTING.md for complete patterns. Quick reference: + +1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx` + - Extract component logic into custom hooks grouped by concern, not by component. Each hook should represent a cohesive domain of functionality (e.g., useSearch, useFilters, usePagination) rather than bundling all state into one useComponentState hook. + - Put each hook in its own `.ts` file + - Put sub-components in local `components/` folder + - Component props should be `type Props = { ... }` (not exported) unless it needs to be used outside the component +2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts` + - Use design system components from `src/components/` (atoms, molecules, organisms) + - Never use `src/components/__legacy__/*` +3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/` + - Regenerate with `pnpm generate:api` + - Pattern: `use{Method}{Version}{OperationName}` +4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only +5. **Testing**: Add Storybook stories for new components, Playwright for E2E +6. **Code conventions**: + - Use function declarations (not arrow functions) for components/handlers + - Do not use `useCallback` or `useMemo` unless asked to optimise a given function + - Do not type hook returns, let Typescript infer as much as possible + - Never type with `any` unless a variable/attribute can ACTUALLY be of any type diff --git a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx index 70d9783ccd..246fe52826 100644 --- a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx +++ b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx @@ -1,10 +1,9 @@ "use client"; +import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding"; +import { getOnboardingStatus, resolveResponse } from "@/app/api/helpers"; import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; import { useRouter } from "next/navigation"; import { useEffect } from "react"; -import { resolveResponse, getOnboardingStatus } from "@/app/api/helpers"; -import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding"; -import { getHomepageRoute } from "@/lib/constants"; export default function OnboardingPage() { const router = useRouter(); @@ -13,12 +12,10 @@ export default function OnboardingPage() { async function redirectToStep() { try { // Check if onboarding is enabled (also gets chat flag for redirect) - const { shouldShowOnboarding, isChatEnabled } = - await getOnboardingStatus(); - const homepageRoute = getHomepageRoute(isChatEnabled); + const { shouldShowOnboarding } = await getOnboardingStatus(); if (!shouldShowOnboarding) { - router.replace(homepageRoute); + router.replace("/"); return; } @@ -26,7 +23,7 @@ export default function OnboardingPage() { // Handle completed onboarding if (onboarding.completedSteps.includes("GET_RESULTS")) { - router.replace(homepageRoute); + router.replace("/"); return; } diff --git a/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts b/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts index 15be137f63..e7e2997d0d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts +++ b/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts @@ -1,9 +1,8 @@ -import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase"; -import { getHomepageRoute } from "@/lib/constants"; -import BackendAPI from "@/lib/autogpt-server-api"; -import { NextResponse } from "next/server"; -import { revalidatePath } from "next/cache"; import { getOnboardingStatus } from "@/app/api/helpers"; +import BackendAPI from "@/lib/autogpt-server-api"; +import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase"; +import { revalidatePath } from "next/cache"; +import { NextResponse } from "next/server"; // Handle the callback to complete the user session login export async function GET(request: Request) { @@ -27,13 +26,12 @@ export async function GET(request: Request) { await api.createUser(); // Get onboarding status from backend (includes chat flag evaluated for this user) - const { shouldShowOnboarding, isChatEnabled } = - await getOnboardingStatus(); + const { shouldShowOnboarding } = await getOnboardingStatus(); if (shouldShowOnboarding) { next = "/onboarding"; revalidatePath("/onboarding", "layout"); } else { - next = getHomepageRoute(isChatEnabled); + next = "/"; revalidatePath(next, "layout"); } } catch (createUserError) { diff --git a/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts index 41d05a9afb..fd67519957 100644 --- a/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts +++ b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts @@ -1,6 +1,17 @@ import { OAuthPopupResultMessage } from "./types"; import { NextResponse } from "next/server"; +/** + * Safely encode a value as JSON for embedding in a script tag. + * Escapes characters that could break out of the script context to prevent XSS. + */ +function safeJsonStringify(value: unknown): string { + return JSON.stringify(value) + .replace(//g, "\\u003e") + .replace(/&/g, "\\u0026"); +} + // This route is intended to be used as the callback for integration OAuth flows, // controlled by the CredentialsInput component. The CredentialsInput opens the login // page in a pop-up window, which then redirects to this route to close the loop. @@ -23,12 +34,13 @@ export async function GET(request: Request) { console.debug("Sending message to opener:", message); // Return a response with the message as JSON and a script to close the window + // Use safeJsonStringify to prevent XSS by escaping <, >, and & characters return new NextResponse( ` diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode.tsx index 94e917a4ac..834603cc4a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode.tsx @@ -857,7 +857,7 @@ export const CustomNode = React.memo( })(); const hasAdvancedFields = - data.inputSchema && + data.inputSchema?.properties && Object.entries(data.inputSchema.properties).some(([key, value]) => { return ( value.advanced === true && !data.inputSchema.required?.includes(key) diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/components/SessionsList/useSessionsPagination.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/components/SessionsList/useSessionsPagination.ts index 11ddd937af..61e3e6f37f 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/components/SessionsList/useSessionsPagination.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/components/SessionsList/useSessionsPagination.ts @@ -73,9 +73,9 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) { }; const reset = () => { + // Only reset the offset - keep existing sessions visible during refetch + // The effect will replace sessions when new data arrives at offset 0 setOffset(0); - setAccumulatedSessions([]); - setTotalCount(null); }; return { diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts index 74fd663ab2..913c4d7ded 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts @@ -11,7 +11,6 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { useQueryClient } from "@tanstack/react-query"; import { usePathname, useSearchParams } from "next/navigation"; -import { useRef } from "react"; import { useCopilotStore } from "../../copilot-page-store"; import { useCopilotSessionId } from "../../useCopilotSessionId"; import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer"; @@ -70,41 +69,16 @@ export function useCopilotShell() { }); const stopStream = useChatStore((s) => s.stopStream); - const onStreamComplete = useChatStore((s) => s.onStreamComplete); - const isStreaming = useCopilotStore((s) => s.isStreaming); const isCreatingSession = useCopilotStore((s) => s.isCreatingSession); - const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession); - const openInterruptModal = useCopilotStore((s) => s.openInterruptModal); - const pendingActionRef = useRef<(() => void) | null>(null); - - async function stopCurrentStream() { - if (!currentSessionId) return; - - setIsSwitchingSession(true); - await new Promise((resolve) => { - const unsubscribe = onStreamComplete((completedId) => { - if (completedId === currentSessionId) { - clearTimeout(timeout); - unsubscribe(); - resolve(); - } - }); - const timeout = setTimeout(() => { - unsubscribe(); - resolve(); - }, 3000); - stopStream(currentSessionId); - }); - - queryClient.invalidateQueries({ - queryKey: getGetV2GetSessionQueryKey(currentSessionId), - }); - setIsSwitchingSession(false); - } - - function selectSession(sessionId: string) { + function handleSessionClick(sessionId: string) { if (sessionId === currentSessionId) return; + + // Stop current stream - SSE reconnection allows resuming later + if (currentSessionId) { + stopStream(currentSessionId); + } + if (recentlyCreatedSessionsRef.current.has(sessionId)) { queryClient.invalidateQueries({ queryKey: getGetV2GetSessionQueryKey(sessionId), @@ -114,7 +88,12 @@ export function useCopilotShell() { if (isMobile) handleCloseDrawer(); } - function startNewChat() { + function handleNewChatClick() { + // Stop current stream - SSE reconnection allows resuming later + if (currentSessionId) { + stopStream(currentSessionId); + } + resetPagination(); queryClient.invalidateQueries({ queryKey: getGetV2ListSessionsQueryKey(), @@ -123,32 +102,6 @@ export function useCopilotShell() { if (isMobile) handleCloseDrawer(); } - function handleSessionClick(sessionId: string) { - if (sessionId === currentSessionId) return; - - if (isStreaming) { - pendingActionRef.current = async () => { - await stopCurrentStream(); - selectSession(sessionId); - }; - openInterruptModal(pendingActionRef.current); - } else { - selectSession(sessionId); - } - } - - function handleNewChatClick() { - if (isStreaming) { - pendingActionRef.current = async () => { - await stopCurrentStream(); - startNewChat(); - }; - openInterruptModal(pendingActionRef.current); - } else { - startNewChat(); - } - } - return { isMobile, isDrawerOpen, diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts index 692a5741f4..c6e479f896 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts @@ -26,8 +26,20 @@ export function buildCopilotChatUrl(prompt: string): string { export function getQuickActions(): string[] { return [ - "Show me what I can automate", - "Design a custom workflow", - "Help me with content creation", + "I don't know where to start, just ask me stuff", + "I do the same thing every week and it's killing me", + "Help me find where I'm wasting my time", ]; } + +export function getInputPlaceholder(width?: number) { + if (!width) return "What's your role and what eats up most of your day?"; + + if (width < 500) { + return "I'm a chef and I hate..."; + } + if (width <= 1080) { + return "What's your role and what eats up most of your day?"; + } + return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'"; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/layout.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/layout.tsx index 89cf72e2ba..876e5accfb 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/layout.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/layout.tsx @@ -1,6 +1,13 @@ -import type { ReactNode } from "react"; +"use client"; +import { FeatureFlagPage } from "@/services/feature-flags/FeatureFlagPage"; +import { Flag } from "@/services/feature-flags/use-get-flag"; +import { type ReactNode } from "react"; import { CopilotShell } from "./components/CopilotShell/CopilotShell"; export default function CopilotLayout({ children }: { children: ReactNode }) { - return {children}; + return ( + + {children} + + ); } diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx index 104b238895..542173a99c 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx @@ -6,7 +6,9 @@ import { Text } from "@/components/atoms/Text/Text"; import { Chat } from "@/components/contextual/Chat/Chat"; import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput"; import { Dialog } from "@/components/molecules/Dialog/Dialog"; +import { useEffect, useState } from "react"; import { useCopilotStore } from "./copilot-page-store"; +import { getInputPlaceholder } from "./helpers"; import { useCopilotPage } from "./useCopilotPage"; export default function CopilotPage() { @@ -14,14 +16,25 @@ export default function CopilotPage() { const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen); const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt); const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt); - const { - greetingName, - quickActions, - isLoading, - hasSession, - initialPrompt, - isReady, - } = state; + + const [inputPlaceholder, setInputPlaceholder] = useState( + getInputPlaceholder(), + ); + + useEffect(() => { + const handleResize = () => { + setInputPlaceholder(getInputPlaceholder(window.innerWidth)); + }; + + handleResize(); + + window.addEventListener("resize", handleResize); + return () => window.removeEventListener("resize", handleResize); + }, []); + + const { greetingName, quickActions, isLoading, hasSession, initialPrompt } = + state; + const { handleQuickAction, startChatWithPrompt, @@ -29,8 +42,6 @@ export default function CopilotPage() { handleStreamingChange, } = handlers; - if (!isReady) return null; - if (hasSession) { return (
@@ -81,7 +92,7 @@ export default function CopilotPage() { } return ( -
+
{isLoading ? (
@@ -98,25 +109,25 @@ export default function CopilotPage() {
) : ( <> -
+
Hey, {greetingName} - What do you want to automate? + Tell me about your work — I'll find what to automate.
-
+
{quickActions.map((action) => ( diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts index e4713cd24a..9d99f8e7bd 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts @@ -3,18 +3,11 @@ import { postV2CreateSession, } from "@/app/api/__generated__/endpoints/chat/chat"; import { useToast } from "@/components/molecules/Toast/use-toast"; -import { getHomepageRoute } from "@/lib/constants"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { useOnboarding } from "@/providers/onboarding/onboarding-provider"; -import { - Flag, - type FlagValues, - useGetFlag, -} from "@/services/feature-flags/use-get-flag"; import { SessionKey, sessionStorage } from "@/services/storage/session-storage"; import * as Sentry from "@sentry/nextjs"; import { useQueryClient } from "@tanstack/react-query"; -import { useFlags } from "launchdarkly-react-client-sdk"; import { useRouter } from "next/navigation"; import { useEffect } from "react"; import { useCopilotStore } from "./copilot-page-store"; @@ -33,22 +26,6 @@ export function useCopilotPage() { const isCreating = useCopilotStore((s) => s.isCreatingSession); const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession); - // Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus - useEffect(() => { - if (isLoggedIn) { - completeStep("VISIT_COPILOT"); - } - }, [completeStep, isLoggedIn]); - - const isChatEnabled = useGetFlag(Flag.CHAT); - const flags = useFlags(); - const homepageRoute = getHomepageRoute(isChatEnabled); - const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true"; - const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID; - const isLaunchDarklyConfigured = envEnabled && Boolean(clientId); - const isFlagReady = - !isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined; - const greetingName = getGreetingName(user); const quickActions = getQuickActions(); @@ -58,11 +35,8 @@ export function useCopilotPage() { : undefined; useEffect(() => { - if (!isFlagReady) return; - if (isChatEnabled === false) { - router.replace(homepageRoute); - } - }, [homepageRoute, isChatEnabled, isFlagReady, router]); + if (isLoggedIn) completeStep("VISIT_COPILOT"); + }, [completeStep, isLoggedIn]); async function startChatWithPrompt(prompt: string) { if (!prompt?.trim()) return; @@ -116,7 +90,6 @@ export function useCopilotPage() { isLoading: isUserLoading, hasSession, initialPrompt, - isReady: isFlagReady && isChatEnabled !== false && isLoggedIn, }, handlers: { handleQuickAction, diff --git a/autogpt_platform/frontend/src/app/(platform)/error/page.tsx b/autogpt_platform/frontend/src/app/(platform)/error/page.tsx index b26ca4559b..3cf68178ad 100644 --- a/autogpt_platform/frontend/src/app/(platform)/error/page.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/error/page.tsx @@ -1,8 +1,6 @@ "use client"; import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard"; -import { getHomepageRoute } from "@/lib/constants"; -import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; import { useSearchParams } from "next/navigation"; import { Suspense } from "react"; import { getErrorDetails } from "./helpers"; @@ -11,8 +9,6 @@ function ErrorPageContent() { const searchParams = useSearchParams(); const errorMessage = searchParams.get("message"); const errorDetails = getErrorDetails(errorMessage); - const isChatEnabled = useGetFlag(Flag.CHAT); - const homepageRoute = getHomepageRoute(isChatEnabled); function handleRetry() { // Auth-related errors should redirect to login @@ -30,7 +26,7 @@ function ErrorPageContent() { }, 2000); } else { // For server/network errors, go to home - window.location.href = homepageRoute; + window.location.href = "/"; } } diff --git a/autogpt_platform/frontend/src/app/(platform)/login/actions.ts b/autogpt_platform/frontend/src/app/(platform)/login/actions.ts index 447a25a41d..c4867dd123 100644 --- a/autogpt_platform/frontend/src/app/(platform)/login/actions.ts +++ b/autogpt_platform/frontend/src/app/(platform)/login/actions.ts @@ -1,6 +1,5 @@ "use server"; -import { getHomepageRoute } from "@/lib/constants"; import BackendAPI from "@/lib/autogpt-server-api"; import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase"; import { loginFormSchema } from "@/types/auth"; @@ -38,10 +37,8 @@ export async function login(email: string, password: string) { await api.createUser(); // Get onboarding status from backend (includes chat flag evaluated for this user) - const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus(); - const next = shouldShowOnboarding - ? "/onboarding" - : getHomepageRoute(isChatEnabled); + const { shouldShowOnboarding } = await getOnboardingStatus(); + const next = shouldShowOnboarding ? "/onboarding" : "/"; return { success: true, diff --git a/autogpt_platform/frontend/src/app/(platform)/login/useLoginPage.ts b/autogpt_platform/frontend/src/app/(platform)/login/useLoginPage.ts index e64cc1858d..9b81965c31 100644 --- a/autogpt_platform/frontend/src/app/(platform)/login/useLoginPage.ts +++ b/autogpt_platform/frontend/src/app/(platform)/login/useLoginPage.ts @@ -1,8 +1,6 @@ import { useToast } from "@/components/molecules/Toast/use-toast"; -import { getHomepageRoute } from "@/lib/constants"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { environment } from "@/services/environment"; -import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; import { loginFormSchema, LoginProvider } from "@/types/auth"; import { zodResolver } from "@hookform/resolvers/zod"; import { useRouter, useSearchParams } from "next/navigation"; @@ -22,17 +20,15 @@ export function useLoginPage() { const [isGoogleLoading, setIsGoogleLoading] = useState(false); const [showNotAllowedModal, setShowNotAllowedModal] = useState(false); const isCloudEnv = environment.isCloud(); - const isChatEnabled = useGetFlag(Flag.CHAT); - const homepageRoute = getHomepageRoute(isChatEnabled); // Get redirect destination from 'next' query parameter const nextUrl = searchParams.get("next"); useEffect(() => { if (isLoggedIn && !isLoggingIn) { - router.push(nextUrl || homepageRoute); + router.push(nextUrl || "/"); } - }, [homepageRoute, isLoggedIn, isLoggingIn, nextUrl, router]); + }, [isLoggedIn, isLoggingIn, nextUrl, router]); const form = useForm>({ resolver: zodResolver(loginFormSchema), @@ -98,7 +94,7 @@ export function useLoginPage() { } // Prefer URL's next parameter, then use backend-determined route - router.replace(nextUrl || result.next || homepageRoute); + router.replace(nextUrl || result.next || "/"); } catch (error) { toast({ title: diff --git a/autogpt_platform/frontend/src/app/(platform)/signup/actions.ts b/autogpt_platform/frontend/src/app/(platform)/signup/actions.ts index 0fbba54b8e..204482dbe9 100644 --- a/autogpt_platform/frontend/src/app/(platform)/signup/actions.ts +++ b/autogpt_platform/frontend/src/app/(platform)/signup/actions.ts @@ -1,6 +1,5 @@ "use server"; -import { getHomepageRoute } from "@/lib/constants"; import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase"; import { signupFormSchema } from "@/types/auth"; import * as Sentry from "@sentry/nextjs"; @@ -59,10 +58,8 @@ export async function signup( } // Get onboarding status from backend (includes chat flag evaluated for this user) - const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus(); - const next = shouldShowOnboarding - ? "/onboarding" - : getHomepageRoute(isChatEnabled); + const { shouldShowOnboarding } = await getOnboardingStatus(); + const next = shouldShowOnboarding ? "/onboarding" : "/"; return { success: true, next }; } catch (err) { diff --git a/autogpt_platform/frontend/src/app/(platform)/signup/useSignupPage.ts b/autogpt_platform/frontend/src/app/(platform)/signup/useSignupPage.ts index 5fa8c2c159..fd78b48735 100644 --- a/autogpt_platform/frontend/src/app/(platform)/signup/useSignupPage.ts +++ b/autogpt_platform/frontend/src/app/(platform)/signup/useSignupPage.ts @@ -1,8 +1,6 @@ import { useToast } from "@/components/molecules/Toast/use-toast"; -import { getHomepageRoute } from "@/lib/constants"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { environment } from "@/services/environment"; -import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; import { LoginProvider, signupFormSchema } from "@/types/auth"; import { zodResolver } from "@hookform/resolvers/zod"; import { useRouter, useSearchParams } from "next/navigation"; @@ -22,17 +20,15 @@ export function useSignupPage() { const [isGoogleLoading, setIsGoogleLoading] = useState(false); const [showNotAllowedModal, setShowNotAllowedModal] = useState(false); const isCloudEnv = environment.isCloud(); - const isChatEnabled = useGetFlag(Flag.CHAT); - const homepageRoute = getHomepageRoute(isChatEnabled); // Get redirect destination from 'next' query parameter const nextUrl = searchParams.get("next"); useEffect(() => { if (isLoggedIn && !isSigningUp) { - router.push(nextUrl || homepageRoute); + router.push(nextUrl || "/"); } - }, [homepageRoute, isLoggedIn, isSigningUp, nextUrl, router]); + }, [isLoggedIn, isSigningUp, nextUrl, router]); const form = useForm>({ resolver: zodResolver(signupFormSchema), @@ -133,7 +129,7 @@ export function useSignupPage() { } // Prefer the URL's next parameter, then result.next (for onboarding), then default - const redirectTo = nextUrl || result.next || homepageRoute; + const redirectTo = nextUrl || result.next || "/"; router.replace(redirectTo); } catch (error) { setIsLoading(false); diff --git a/autogpt_platform/frontend/src/app/api/chat/tasks/[taskId]/stream/route.ts b/autogpt_platform/frontend/src/app/api/chat/tasks/[taskId]/stream/route.ts new file mode 100644 index 0000000000..336786bfdb --- /dev/null +++ b/autogpt_platform/frontend/src/app/api/chat/tasks/[taskId]/stream/route.ts @@ -0,0 +1,81 @@ +import { environment } from "@/services/environment"; +import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers"; +import { NextRequest } from "next/server"; + +/** + * SSE Proxy for task stream reconnection. + * + * This endpoint allows clients to reconnect to an ongoing or recently completed + * background task's stream. It replays missed messages from Redis Streams and + * subscribes to live updates if the task is still running. + * + * Client contract: + * 1. When receiving an operation_started event, store the task_id + * 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx} + * 3. Messages are replayed from the last_message_id position + * 4. Stream ends when "finish" event is received + */ +export async function GET( + request: NextRequest, + { params }: { params: Promise<{ taskId: string }> }, +) { + const { taskId } = await params; + const searchParams = request.nextUrl.searchParams; + const lastMessageId = searchParams.get("last_message_id") || "0-0"; + + try { + // Get auth token from server-side session + const token = await getServerAuthToken(); + + // Build backend URL + const backendUrl = environment.getAGPTServerBaseUrl(); + const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl); + streamUrl.searchParams.set("last_message_id", lastMessageId); + + // Forward request to backend with auth header + const headers: Record = { + Accept: "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }; + + if (token) { + headers["Authorization"] = `Bearer ${token}`; + } + + const response = await fetch(streamUrl.toString(), { + method: "GET", + headers, + }); + + if (!response.ok) { + const error = await response.text(); + return new Response(error, { + status: response.status, + headers: { "Content-Type": "application/json" }, + }); + } + + // Return the SSE stream directly + return new Response(response.body, { + headers: { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache, no-transform", + Connection: "keep-alive", + "X-Accel-Buffering": "no", + }, + }); + } catch (error) { + console.error("Task stream proxy error:", error); + return new Response( + JSON.stringify({ + error: "Failed to connect to task stream", + detail: error instanceof Error ? error.message : String(error), + }), + { + status: 500, + headers: { "Content-Type": "application/json" }, + }, + ); + } +} diff --git a/autogpt_platform/frontend/src/app/api/helpers.ts b/autogpt_platform/frontend/src/app/api/helpers.ts index c2104d231a..226f5fa786 100644 --- a/autogpt_platform/frontend/src/app/api/helpers.ts +++ b/autogpt_platform/frontend/src/app/api/helpers.ts @@ -181,6 +181,5 @@ export async function getOnboardingStatus() { const isCompleted = onboarding.completedSteps.includes("CONGRATS"); return { shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted, - isChatEnabled: status.is_chat_enabled, }; } diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 2a9db1990d..5ed449829d 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -917,6 +917,28 @@ "security": [{ "HTTPBearerJWT": [] }] } }, + "/api/chat/config/ttl": { + "get": { + "tags": ["v2", "chat", "chat"], + "summary": "Get Ttl Config", + "description": "Get the stream TTL configuration.\n\nReturns the Time-To-Live settings for chat streams, which determines\nhow long clients can reconnect to an active stream.\n\nReturns:\n dict: TTL configuration with seconds and milliseconds values.", + "operationId": "getV2GetTtlConfig", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "additionalProperties": true, + "type": "object", + "title": "Response Getv2Getttlconfig" + } + } + } + } + } + } + }, "/api/chat/health": { "get": { "tags": ["v2", "chat", "chat"], @@ -939,6 +961,63 @@ } } }, + "/api/chat/operations/{operation_id}/complete": { + "post": { + "tags": ["v2", "chat", "chat"], + "summary": "Complete Operation", + "description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.", + "operationId": "postV2CompleteOperation", + "parameters": [ + { + "name": "operation_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Operation Id" } + }, + { + "name": "x-api-key", + "in": "header", + "required": false, + "schema": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "X-Api-Key" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OperationCompleteRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "object", + "additionalProperties": true, + "title": "Response Postv2Completeoperation" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, "/api/chat/sessions": { "get": { "tags": ["v2", "chat", "chat"], @@ -1022,7 +1101,7 @@ "get": { "tags": ["v2", "chat", "chat"], "summary": "Get Session", - "description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.", + "description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.", "operationId": "getV2GetSession", "security": [{ "HTTPBearerJWT": [] }], "parameters": [ @@ -1157,7 +1236,7 @@ "post": { "tags": ["v2", "chat", "chat"], "summary": "Stream Chat Post", - "description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.", + "description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.", "operationId": "postV2StreamChatPost", "security": [{ "HTTPBearerJWT": [] }], "parameters": [ @@ -1195,6 +1274,94 @@ } } }, + "/api/chat/tasks/{task_id}": { + "get": { + "tags": ["v2", "chat", "chat"], + "summary": "Get Task Status", + "description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.", + "operationId": "getV2GetTaskStatus", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "task_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Task Id" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "object", + "additionalProperties": true, + "title": "Response Getv2Gettaskstatus" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/chat/tasks/{task_id}/stream": { + "get": { + "tags": ["v2", "chat", "chat"], + "summary": "Stream Task", + "description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.", + "operationId": "getV2StreamTask", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "task_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Task Id" } + }, + { + "name": "last_message_id", + "in": "query", + "required": false, + "schema": { + "type": "string", + "description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.", + "default": "0-0", + "title": "Last Message Id" + }, + "description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay." + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { "application/json": { "schema": {} } } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, "/api/credits": { "get": { "tags": ["v1", "credits"], @@ -5912,6 +6079,40 @@ } } }, + "/api/workspace/files/{file_id}/download": { + "get": { + "tags": ["workspace"], + "summary": "Download file by ID", + "description": "Download a file by its ID.\n\nReturns the file content directly or redirects to a signed URL for GCS.", + "operationId": "getWorkspaceDownload file by id", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "file_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "File Id" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { "application/json": { "schema": {} } } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, "/health": { "get": { "tags": ["health"], @@ -6134,6 +6335,18 @@ "title": "AccuracyTrendsResponse", "description": "Response model for accuracy trends and alerts." }, + "ActiveStreamInfo": { + "properties": { + "task_id": { "type": "string", "title": "Task Id" }, + "last_message_id": { "type": "string", "title": "Last Message Id" }, + "operation_id": { "type": "string", "title": "Operation Id" }, + "tool_name": { "type": "string", "title": "Tool Name" } + }, + "type": "object", + "required": ["task_id", "last_message_id", "operation_id", "tool_name"], + "title": "ActiveStreamInfo", + "description": "Information about an active stream for reconnection." + }, "AddUserCreditsResponse": { "properties": { "new_balance": { "type": "integer", "title": "New Balance" }, @@ -7947,6 +8160,25 @@ ] }, "new_output": { "type": "boolean", "title": "New Output" }, + "execution_count": { + "type": "integer", + "title": "Execution Count", + "default": 0 + }, + "success_rate": { + "anyOf": [{ "type": "number" }, { "type": "null" }], + "title": "Success Rate" + }, + "avg_correctness_score": { + "anyOf": [{ "type": "number" }, { "type": "null" }], + "title": "Avg Correctness Score" + }, + "recent_executions": { + "items": { "$ref": "#/components/schemas/RecentExecution" }, + "type": "array", + "title": "Recent Executions", + "description": "List of recent executions with status, score, and summary" + }, "can_access_graph": { "type": "boolean", "title": "Can Access Graph" @@ -8770,6 +9002,27 @@ ], "title": "OnboardingStep" }, + "OperationCompleteRequest": { + "properties": { + "success": { "type": "boolean", "title": "Success" }, + "result": { + "anyOf": [ + { "additionalProperties": true, "type": "object" }, + { "type": "string" }, + { "type": "null" } + ], + "title": "Result" + }, + "error": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Error" + } + }, + "type": "object", + "required": ["success"], + "title": "OperationCompleteRequest", + "description": "Request model for external completion webhook." + }, "Pagination": { "properties": { "total_items": { @@ -9340,6 +9593,23 @@ "required": ["providers", "pagination"], "title": "ProviderResponse" }, + "RecentExecution": { + "properties": { + "status": { "type": "string", "title": "Status" }, + "correctness_score": { + "anyOf": [{ "type": "number" }, { "type": "null" }], + "title": "Correctness Score" + }, + "activity_summary": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Activity Summary" + } + }, + "type": "object", + "required": ["status"], + "title": "RecentExecution", + "description": "Summary of a recent execution for quality assessment.\n\nUsed by the LLM to understand the agent's recent performance with specific examples\nrather than just aggregate statistics." + }, "RefundRequest": { "properties": { "id": { "type": "string", "title": "Id" }, @@ -9608,6 +9878,12 @@ "items": { "additionalProperties": true, "type": "object" }, "type": "array", "title": "Messages" + }, + "active_stream": { + "anyOf": [ + { "$ref": "#/components/schemas/ActiveStreamInfo" }, + { "type": "null" } + ] } }, "type": "object", @@ -9763,7 +10039,8 @@ "sub_heading": { "type": "string", "title": "Sub Heading" }, "description": { "type": "string", "title": "Description" }, "runs": { "type": "integer", "title": "Runs" }, - "rating": { "type": "number", "title": "Rating" } + "rating": { "type": "number", "title": "Rating" }, + "agent_graph_id": { "type": "string", "title": "Agent Graph Id" } }, "type": "object", "required": [ @@ -9775,7 +10052,8 @@ "sub_heading", "description", "runs", - "rating" + "rating", + "agent_graph_id" ], "title": "StoreAgent" }, diff --git a/autogpt_platform/frontend/src/app/api/proxy/[...path]/route.ts b/autogpt_platform/frontend/src/app/api/proxy/[...path]/route.ts index 293c406373..442bd77e0f 100644 --- a/autogpt_platform/frontend/src/app/api/proxy/[...path]/route.ts +++ b/autogpt_platform/frontend/src/app/api/proxy/[...path]/route.ts @@ -1,5 +1,6 @@ import { ApiError, + getServerAuthToken, makeAuthenticatedFileUpload, makeAuthenticatedRequest, } from "@/lib/autogpt-server-api/helpers"; @@ -15,6 +16,69 @@ function buildBackendUrl(path: string[], queryString: string): string { return `${environment.getAGPTServerBaseUrl()}/${backendPath}${queryString}`; } +/** + * Check if this is a workspace file download request that needs binary response handling. + */ +function isWorkspaceDownloadRequest(path: string[]): boolean { + // Match pattern: api/workspace/files/{id}/download (5 segments) + return ( + path.length == 5 && + path[0] === "api" && + path[1] === "workspace" && + path[2] === "files" && + path[path.length - 1] === "download" + ); +} + +/** + * Handle workspace file download requests with proper binary response streaming. + */ +async function handleWorkspaceDownload( + req: NextRequest, + backendUrl: string, +): Promise { + const token = await getServerAuthToken(); + + const headers: Record = {}; + if (token && token !== "no-token-found") { + headers["Authorization"] = `Bearer ${token}`; + } + + const response = await fetch(backendUrl, { + method: "GET", + headers, + redirect: "follow", // Follow redirects to signed URLs + }); + + if (!response.ok) { + return NextResponse.json( + { error: `Failed to download file: ${response.statusText}` }, + { status: response.status }, + ); + } + + // Get the content type from the backend response + const contentType = + response.headers.get("Content-Type") || "application/octet-stream"; + const contentDisposition = response.headers.get("Content-Disposition"); + + // Stream the response body + const responseHeaders: Record = { + "Content-Type": contentType, + }; + + if (contentDisposition) { + responseHeaders["Content-Disposition"] = contentDisposition; + } + + // Return the binary content + const arrayBuffer = await response.arrayBuffer(); + return new NextResponse(arrayBuffer, { + status: 200, + headers: responseHeaders, + }); +} + async function handleJsonRequest( req: NextRequest, method: string, @@ -180,6 +244,11 @@ async function handler( }; try { + // Handle workspace file downloads separately (binary response) + if (method === "GET" && isWorkspaceDownloadRequest(path)) { + return await handleWorkspaceDownload(req, backendUrl); + } + if (method === "GET" || method === "DELETE") { responseBody = await handleGetDeleteRequest(method, backendUrl, req); } else if (contentType?.includes("application/json")) { diff --git a/autogpt_platform/frontend/src/app/api/transcribe/route.ts b/autogpt_platform/frontend/src/app/api/transcribe/route.ts new file mode 100644 index 0000000000..10c182cdfa --- /dev/null +++ b/autogpt_platform/frontend/src/app/api/transcribe/route.ts @@ -0,0 +1,77 @@ +import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers"; +import { NextRequest, NextResponse } from "next/server"; + +const WHISPER_API_URL = "https://api.openai.com/v1/audio/transcriptions"; +const MAX_FILE_SIZE = 25 * 1024 * 1024; // 25MB - Whisper's limit + +function getExtensionFromMimeType(mimeType: string): string { + const subtype = mimeType.split("/")[1]?.split(";")[0]; + return subtype || "webm"; +} + +export async function POST(request: NextRequest) { + const token = await getServerAuthToken(); + + if (!token || token === "no-token-found") { + return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); + } + + const apiKey = process.env.OPENAI_API_KEY; + + if (!apiKey) { + return NextResponse.json( + { error: "OpenAI API key not configured" }, + { status: 401 }, + ); + } + + try { + const formData = await request.formData(); + const audioFile = formData.get("audio"); + + if (!audioFile || !(audioFile instanceof Blob)) { + return NextResponse.json( + { error: "No audio file provided" }, + { status: 400 }, + ); + } + + if (audioFile.size > MAX_FILE_SIZE) { + return NextResponse.json( + { error: "File too large. Maximum size is 25MB." }, + { status: 413 }, + ); + } + + const ext = getExtensionFromMimeType(audioFile.type); + const whisperFormData = new FormData(); + whisperFormData.append("file", audioFile, `recording.${ext}`); + whisperFormData.append("model", "whisper-1"); + + const response = await fetch(WHISPER_API_URL, { + method: "POST", + headers: { + Authorization: `Bearer ${apiKey}`, + }, + body: whisperFormData, + }); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + console.error("Whisper API error:", errorData); + return NextResponse.json( + { error: errorData.error?.message || "Transcription failed" }, + { status: response.status }, + ); + } + + const result = await response.json(); + return NextResponse.json({ text: result.text }); + } catch (error) { + console.error("Transcription error:", error); + return NextResponse.json( + { error: "Failed to process audio" }, + { status: 500 }, + ); + } +} diff --git a/autogpt_platform/frontend/src/app/page.tsx b/autogpt_platform/frontend/src/app/page.tsx index dbfab49469..ce67760eda 100644 --- a/autogpt_platform/frontend/src/app/page.tsx +++ b/autogpt_platform/frontend/src/app/page.tsx @@ -1,27 +1,15 @@ "use client"; -import { getHomepageRoute } from "@/lib/constants"; -import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; +import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; import { useRouter } from "next/navigation"; import { useEffect } from "react"; export default function Page() { - const isChatEnabled = useGetFlag(Flag.CHAT); const router = useRouter(); - const homepageRoute = getHomepageRoute(isChatEnabled); - const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true"; - const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID; - const isLaunchDarklyConfigured = envEnabled && Boolean(clientId); - const isFlagReady = - !isLaunchDarklyConfigured || typeof isChatEnabled === "boolean"; - useEffect( - function redirectToHomepage() { - if (!isFlagReady) return; - router.replace(homepageRoute); - }, - [homepageRoute, isFlagReady, router], - ); + useEffect(() => { + router.replace("/copilot"); + }, [router]); - return null; + return ; } diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx index ada8c26231..da454150bf 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx @@ -1,7 +1,6 @@ "use client"; import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId"; -import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store"; import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; import { Text } from "@/components/atoms/Text/Text"; import { cn } from "@/lib/utils"; @@ -25,8 +24,8 @@ export function Chat({ }: ChatProps) { const { urlSessionId } = useCopilotSessionId(); const hasHandledNotFoundRef = useRef(false); - const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession); const { + session, messages, isLoading, isCreating, @@ -38,6 +37,18 @@ export function Chat({ startPollingForOperation, } = useChat({ urlSessionId }); + // Extract active stream info for reconnection + const activeStream = ( + session as { + active_stream?: { + task_id: string; + last_message_id: string; + operation_id: string; + tool_name: string; + }; + } + )?.active_stream; + useEffect(() => { if (!onSessionNotFound) return; if (!urlSessionId) return; @@ -53,8 +64,7 @@ export function Chat({ isCreating, ]); - const shouldShowLoader = - (showLoader && (isLoading || isCreating)) || isSwitchingSession; + const shouldShowLoader = showLoader && (isLoading || isCreating); return (
@@ -66,21 +76,19 @@ export function Chat({
- {isSwitchingSession - ? "Switching chat..." - : "Loading your chat..."} + Loading your chat...
)} {/* Error State */} - {error && !isLoading && !isSwitchingSession && ( + {error && !isLoading && ( )} {/* Session Content */} - {sessionId && !isLoading && !error && !isSwitchingSession && ( + {sessionId && !isLoading && !error && ( )} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/SSE_RECONNECTION.md b/autogpt_platform/frontend/src/components/contextual/Chat/SSE_RECONNECTION.md new file mode 100644 index 0000000000..9e78679f4e --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/SSE_RECONNECTION.md @@ -0,0 +1,159 @@ +# SSE Reconnection Contract for Long-Running Operations + +This document describes the client-side contract for handling SSE (Server-Sent Events) disconnections and reconnecting to long-running background tasks. + +## Overview + +When a user triggers a long-running operation (like agent generation), the backend: + +1. Spawns a background task that survives SSE disconnections +2. Returns an `operation_started` response with a `task_id` +3. Stores stream messages in Redis Streams for replay + +Clients can reconnect to the task stream at any time to receive missed messages. + +## Client-Side Flow + +### 1. Receiving Operation Started + +When you receive an `operation_started` tool response: + +```typescript +// The response includes a task_id for reconnection +{ + type: "operation_started", + tool_name: "generate_agent", + operation_id: "uuid-...", + task_id: "task-uuid-...", // <-- Store this for reconnection + message: "Operation started. You can close this tab." +} +``` + +### 2. Storing Task Info + +Use the chat store to track the active task: + +```typescript +import { useChatStore } from "./chat-store"; + +// When operation_started is received: +useChatStore.getState().setActiveTask(sessionId, { + taskId: response.task_id, + operationId: response.operation_id, + toolName: response.tool_name, + lastMessageId: "0", +}); +``` + +### 3. Reconnecting to a Task + +To reconnect (e.g., after page refresh or tab reopen): + +```typescript +const { reconnectToTask, getActiveTask } = useChatStore.getState(); + +// Check if there's an active task for this session +const activeTask = getActiveTask(sessionId); + +if (activeTask) { + // Reconnect to the task stream + await reconnectToTask( + sessionId, + activeTask.taskId, + activeTask.lastMessageId, // Resume from last position + (chunk) => { + // Handle incoming chunks + console.log("Received chunk:", chunk); + }, + ); +} +``` + +### 4. Tracking Message Position + +To enable precise replay, update the last message ID as chunks arrive: + +```typescript +const { updateTaskLastMessageId } = useChatStore.getState(); + +function handleChunk(chunk: StreamChunk) { + // If chunk has an index/id, track it + if (chunk.idx !== undefined) { + updateTaskLastMessageId(sessionId, String(chunk.idx)); + } +} +``` + +## API Endpoints + +### Task Stream Reconnection + +``` +GET /api/chat/tasks/{taskId}/stream?last_message_id={idx} +``` + +- `taskId`: The task ID from `operation_started` +- `last_message_id`: Last received message index (default: "0" for full replay) + +Returns: SSE stream of missed messages + live updates + +## Chunk Types + +The reconnected stream follows the same Vercel AI SDK protocol: + +| Type | Description | +| ----------------------- | ----------------------- | +| `start` | Message lifecycle start | +| `text-delta` | Streaming text content | +| `text-end` | Text block completed | +| `tool-output-available` | Tool result available | +| `finish` | Stream completed | +| `error` | Error occurred | + +## Error Handling + +If reconnection fails: + +1. Check if task still exists (may have expired - default TTL: 1 hour) +2. Fall back to polling the session for final state +3. Show appropriate UI message to user + +## Persistence Considerations + +For robust reconnection across browser restarts: + +```typescript +// Store in localStorage/sessionStorage +const ACTIVE_TASKS_KEY = "chat_active_tasks"; + +function persistActiveTask(sessionId: string, task: ActiveTaskInfo) { + const tasks = JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}"); + tasks[sessionId] = task; + localStorage.setItem(ACTIVE_TASKS_KEY, JSON.stringify(tasks)); +} + +function loadPersistedTasks(): Record { + return JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}"); +} +``` + +## Backend Configuration + +The following backend settings affect reconnection behavior: + +| Setting | Default | Description | +| ------------------- | ------- | ---------------------------------- | +| `stream_ttl` | 3600s | How long streams are kept in Redis | +| `stream_max_length` | 1000 | Max messages per stream | + +## Testing + +To test reconnection locally: + +1. Start a long-running operation (e.g., agent generation) +2. Note the `task_id` from the `operation_started` response +3. Close the browser tab +4. Reopen and call `reconnectToTask` with the saved `task_id` +5. Verify that missed messages are replayed + +See the main README for full local development setup. diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-constants.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-constants.ts new file mode 100644 index 0000000000..8802de2155 --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-constants.ts @@ -0,0 +1,16 @@ +/** + * Constants for the chat system. + * + * Centralizes magic strings and values used across chat components. + */ + +// LocalStorage keys +export const STORAGE_KEY_ACTIVE_TASKS = "chat_active_tasks"; + +// Redis Stream IDs +export const INITIAL_MESSAGE_ID = "0"; +export const INITIAL_STREAM_ID = "0-0"; + +// TTL values (in milliseconds) +export const COMPLETED_STREAM_TTL_MS = 5 * 60 * 1000; // 5 minutes +export const ACTIVE_TASK_TTL_MS = 60 * 60 * 1000; // 1 hour diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts index 8229630e5d..3083f65d2c 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts @@ -1,6 +1,12 @@ "use client"; import { create } from "zustand"; +import { + ACTIVE_TASK_TTL_MS, + COMPLETED_STREAM_TTL_MS, + INITIAL_STREAM_ID, + STORAGE_KEY_ACTIVE_TASKS, +} from "./chat-constants"; import type { ActiveStream, StreamChunk, @@ -8,15 +14,59 @@ import type { StreamResult, StreamStatus, } from "./chat-types"; -import { executeStream } from "./stream-executor"; +import { executeStream, executeTaskReconnect } from "./stream-executor"; -const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes +export interface ActiveTaskInfo { + taskId: string; + sessionId: string; + operationId: string; + toolName: string; + lastMessageId: string; + startedAt: number; +} + +/** Load active tasks from localStorage */ +function loadPersistedTasks(): Map { + if (typeof window === "undefined") return new Map(); + try { + const stored = localStorage.getItem(STORAGE_KEY_ACTIVE_TASKS); + if (!stored) return new Map(); + const parsed = JSON.parse(stored) as Record; + const now = Date.now(); + const tasks = new Map(); + // Filter out expired tasks + for (const [sessionId, task] of Object.entries(parsed)) { + if (now - task.startedAt < ACTIVE_TASK_TTL_MS) { + tasks.set(sessionId, task); + } + } + return tasks; + } catch { + return new Map(); + } +} + +/** Save active tasks to localStorage */ +function persistTasks(tasks: Map): void { + if (typeof window === "undefined") return; + try { + const obj: Record = {}; + for (const [sessionId, task] of tasks) { + obj[sessionId] = task; + } + localStorage.setItem(STORAGE_KEY_ACTIVE_TASKS, JSON.stringify(obj)); + } catch { + // Ignore storage errors + } +} interface ChatStoreState { activeStreams: Map; completedStreams: Map; activeSessions: Set; streamCompleteCallbacks: Set; + /** Active tasks for SSE reconnection - keyed by sessionId */ + activeTasks: Map; } interface ChatStoreActions { @@ -41,6 +91,24 @@ interface ChatStoreActions { unregisterActiveSession: (sessionId: string) => void; isSessionActive: (sessionId: string) => boolean; onStreamComplete: (callback: StreamCompleteCallback) => () => void; + /** Track active task for SSE reconnection */ + setActiveTask: ( + sessionId: string, + taskInfo: Omit, + ) => void; + /** Get active task for a session */ + getActiveTask: (sessionId: string) => ActiveTaskInfo | undefined; + /** Clear active task when operation completes */ + clearActiveTask: (sessionId: string) => void; + /** Reconnect to an existing task stream */ + reconnectToTask: ( + sessionId: string, + taskId: string, + lastMessageId?: string, + onChunk?: (chunk: StreamChunk) => void, + ) => Promise; + /** Update last message ID for a task (for tracking replay position) */ + updateTaskLastMessageId: (sessionId: string, lastMessageId: string) => void; } type ChatStore = ChatStoreState & ChatStoreActions; @@ -64,18 +132,126 @@ function cleanupExpiredStreams( const now = Date.now(); const cleaned = new Map(completedStreams); for (const [sessionId, result] of cleaned) { - if (now - result.completedAt > COMPLETED_STREAM_TTL) { + if (now - result.completedAt > COMPLETED_STREAM_TTL_MS) { cleaned.delete(sessionId); } } return cleaned; } +/** + * Finalize a stream by moving it from activeStreams to completedStreams. + * Also handles cleanup and notifications. + */ +function finalizeStream( + sessionId: string, + stream: ActiveStream, + onChunk: ((chunk: StreamChunk) => void) | undefined, + get: () => ChatStoreState & ChatStoreActions, + set: (state: Partial) => void, +): void { + if (onChunk) stream.onChunkCallbacks.delete(onChunk); + + if (stream.status !== "streaming") { + const currentState = get(); + const finalActiveStreams = new Map(currentState.activeStreams); + let finalCompletedStreams = new Map(currentState.completedStreams); + + const storedStream = finalActiveStreams.get(sessionId); + if (storedStream === stream) { + const result: StreamResult = { + sessionId, + status: stream.status, + chunks: stream.chunks, + completedAt: Date.now(), + error: stream.error, + }; + finalCompletedStreams.set(sessionId, result); + finalActiveStreams.delete(sessionId); + finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams); + set({ + activeStreams: finalActiveStreams, + completedStreams: finalCompletedStreams, + }); + + if (stream.status === "completed" || stream.status === "error") { + notifyStreamComplete(currentState.streamCompleteCallbacks, sessionId); + } + } + } +} + +/** + * Clean up an existing stream for a session and move it to completed streams. + * Returns updated maps for both active and completed streams. + */ +function cleanupExistingStream( + sessionId: string, + activeStreams: Map, + completedStreams: Map, + callbacks: Set, +): { + activeStreams: Map; + completedStreams: Map; +} { + const newActiveStreams = new Map(activeStreams); + let newCompletedStreams = new Map(completedStreams); + + const existingStream = newActiveStreams.get(sessionId); + if (existingStream) { + existingStream.abortController.abort(); + const normalizedStatus = + existingStream.status === "streaming" + ? "completed" + : existingStream.status; + const result: StreamResult = { + sessionId, + status: normalizedStatus, + chunks: existingStream.chunks, + completedAt: Date.now(), + error: existingStream.error, + }; + newCompletedStreams.set(sessionId, result); + newActiveStreams.delete(sessionId); + newCompletedStreams = cleanupExpiredStreams(newCompletedStreams); + if (normalizedStatus === "completed" || normalizedStatus === "error") { + notifyStreamComplete(callbacks, sessionId); + } + } + + return { + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + }; +} + +/** + * Create a new active stream with initial state. + */ +function createActiveStream( + sessionId: string, + onChunk?: (chunk: StreamChunk) => void, +): ActiveStream { + const abortController = new AbortController(); + const initialCallbacks = new Set<(chunk: StreamChunk) => void>(); + if (onChunk) initialCallbacks.add(onChunk); + + return { + sessionId, + abortController, + status: "streaming", + startedAt: Date.now(), + chunks: [], + onChunkCallbacks: initialCallbacks, + }; +} + export const useChatStore = create((set, get) => ({ activeStreams: new Map(), completedStreams: new Map(), activeSessions: new Set(), streamCompleteCallbacks: new Set(), + activeTasks: loadPersistedTasks(), startStream: async function startStream( sessionId, @@ -85,45 +261,21 @@ export const useChatStore = create((set, get) => ({ onChunk, ) { const state = get(); - const newActiveStreams = new Map(state.activeStreams); - let newCompletedStreams = new Map(state.completedStreams); const callbacks = state.streamCompleteCallbacks; - const existingStream = newActiveStreams.get(sessionId); - if (existingStream) { - existingStream.abortController.abort(); - const normalizedStatus = - existingStream.status === "streaming" - ? "completed" - : existingStream.status; - const result: StreamResult = { - sessionId, - status: normalizedStatus, - chunks: existingStream.chunks, - completedAt: Date.now(), - error: existingStream.error, - }; - newCompletedStreams.set(sessionId, result); - newActiveStreams.delete(sessionId); - newCompletedStreams = cleanupExpiredStreams(newCompletedStreams); - if (normalizedStatus === "completed" || normalizedStatus === "error") { - notifyStreamComplete(callbacks, sessionId); - } - } - - const abortController = new AbortController(); - const initialCallbacks = new Set<(chunk: StreamChunk) => void>(); - if (onChunk) initialCallbacks.add(onChunk); - - const stream: ActiveStream = { + // Clean up any existing stream for this session + const { + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + } = cleanupExistingStream( sessionId, - abortController, - status: "streaming", - startedAt: Date.now(), - chunks: [], - onChunkCallbacks: initialCallbacks, - }; + state.activeStreams, + state.completedStreams, + callbacks, + ); + // Create new stream + const stream = createActiveStream(sessionId, onChunk); newActiveStreams.set(sessionId, stream); set({ activeStreams: newActiveStreams, @@ -133,36 +285,7 @@ export const useChatStore = create((set, get) => ({ try { await executeStream(stream, message, isUserMessage, context); } finally { - if (onChunk) stream.onChunkCallbacks.delete(onChunk); - if (stream.status !== "streaming") { - const currentState = get(); - const finalActiveStreams = new Map(currentState.activeStreams); - let finalCompletedStreams = new Map(currentState.completedStreams); - - const storedStream = finalActiveStreams.get(sessionId); - if (storedStream === stream) { - const result: StreamResult = { - sessionId, - status: stream.status, - chunks: stream.chunks, - completedAt: Date.now(), - error: stream.error, - }; - finalCompletedStreams.set(sessionId, result); - finalActiveStreams.delete(sessionId); - finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams); - set({ - activeStreams: finalActiveStreams, - completedStreams: finalCompletedStreams, - }); - if (stream.status === "completed" || stream.status === "error") { - notifyStreamComplete( - currentState.streamCompleteCallbacks, - sessionId, - ); - } - } - } + finalizeStream(sessionId, stream, onChunk, get, set); } }, @@ -286,4 +409,93 @@ export const useChatStore = create((set, get) => ({ set({ streamCompleteCallbacks: cleanedCallbacks }); }; }, + + setActiveTask: function setActiveTask(sessionId, taskInfo) { + const state = get(); + const newActiveTasks = new Map(state.activeTasks); + newActiveTasks.set(sessionId, { + ...taskInfo, + sessionId, + startedAt: Date.now(), + }); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + }, + + getActiveTask: function getActiveTask(sessionId) { + return get().activeTasks.get(sessionId); + }, + + clearActiveTask: function clearActiveTask(sessionId) { + const state = get(); + if (!state.activeTasks.has(sessionId)) return; + + const newActiveTasks = new Map(state.activeTasks); + newActiveTasks.delete(sessionId); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + }, + + reconnectToTask: async function reconnectToTask( + sessionId, + taskId, + lastMessageId = INITIAL_STREAM_ID, + onChunk, + ) { + const state = get(); + const callbacks = state.streamCompleteCallbacks; + + // Clean up any existing stream for this session + const { + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + } = cleanupExistingStream( + sessionId, + state.activeStreams, + state.completedStreams, + callbacks, + ); + + // Create new stream for reconnection + const stream = createActiveStream(sessionId, onChunk); + newActiveStreams.set(sessionId, stream); + set({ + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + }); + + try { + await executeTaskReconnect(stream, taskId, lastMessageId); + } finally { + finalizeStream(sessionId, stream, onChunk, get, set); + + // Clear active task on completion + if (stream.status === "completed" || stream.status === "error") { + const taskState = get(); + if (taskState.activeTasks.has(sessionId)) { + const newActiveTasks = new Map(taskState.activeTasks); + newActiveTasks.delete(sessionId); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + } + } + } + }, + + updateTaskLastMessageId: function updateTaskLastMessageId( + sessionId, + lastMessageId, + ) { + const state = get(); + const task = state.activeTasks.get(sessionId); + if (!task) return; + + const newActiveTasks = new Map(state.activeTasks); + newActiveTasks.set(sessionId, { + ...task, + lastMessageId, + }); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + }, })); diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts index 8c8aa7b704..34813e17fe 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts @@ -4,6 +4,7 @@ export type StreamStatus = "idle" | "streaming" | "completed" | "error"; export interface StreamChunk { type: + | "stream_start" | "text_chunk" | "text_ended" | "tool_call" @@ -15,6 +16,7 @@ export interface StreamChunk { | "error" | "usage" | "stream_end"; + taskId?: string; timestamp?: string; content?: string; message?: string; @@ -41,7 +43,7 @@ export interface StreamChunk { } export type VercelStreamChunk = - | { type: "start"; messageId: string } + | { type: "start"; messageId: string; taskId?: string } | { type: "finish" } | { type: "text-start"; id: string } | { type: "text-delta"; id: string; delta: string } @@ -92,3 +94,70 @@ export interface StreamResult { } export type StreamCompleteCallback = (sessionId: string) => void; + +// Type guards for message types + +/** + * Check if a message has a toolId property. + */ +export function hasToolId( + msg: T, +): msg is T & { toolId: string } { + return ( + "toolId" in msg && + typeof (msg as Record).toolId === "string" + ); +} + +/** + * Check if a message has an operationId property. + */ +export function hasOperationId( + msg: T, +): msg is T & { operationId: string } { + return ( + "operationId" in msg && + typeof (msg as Record).operationId === "string" + ); +} + +/** + * Check if a message has a toolCallId property. + */ +export function hasToolCallId( + msg: T, +): msg is T & { toolCallId: string } { + return ( + "toolCallId" in msg && + typeof (msg as Record).toolCallId === "string" + ); +} + +/** + * Check if a message is an operation message type. + */ +export function isOperationMessage( + msg: T, +): msg is T & { + type: "operation_started" | "operation_pending" | "operation_in_progress"; +} { + return ( + msg.type === "operation_started" || + msg.type === "operation_pending" || + msg.type === "operation_in_progress" + ); +} + +/** + * Get the tool ID from a message if available. + * Checks toolId, operationId, and toolCallId properties. + */ +export function getToolIdFromMessage( + msg: T, +): string | undefined { + const record = msg as Record; + if (typeof record.toolId === "string") return record.toolId; + if (typeof record.operationId === "string") return record.operationId; + if (typeof record.toolCallId === "string") return record.toolCallId; + return undefined; +} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx index dec221338a..fbf2d5d143 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx @@ -2,7 +2,6 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessi import { Button } from "@/components/atoms/Button/Button"; import { Text } from "@/components/atoms/Text/Text"; import { Dialog } from "@/components/molecules/Dialog/Dialog"; -import { useBreakpoint } from "@/lib/hooks/useBreakpoint"; import { cn } from "@/lib/utils"; import { GlobeHemisphereEastIcon } from "@phosphor-icons/react"; import { useEffect } from "react"; @@ -17,6 +16,13 @@ export interface ChatContainerProps { className?: string; onStreamingChange?: (isStreaming: boolean) => void; onOperationStarted?: () => void; + /** Active stream info from the server for reconnection */ + activeStream?: { + taskId: string; + lastMessageId: string; + operationId: string; + toolName: string; + }; } export function ChatContainer({ @@ -26,6 +32,7 @@ export function ChatContainer({ className, onStreamingChange, onOperationStarted, + activeStream, }: ChatContainerProps) { const { messages, @@ -41,16 +48,13 @@ export function ChatContainer({ initialMessages, initialPrompt, onOperationStarted, + activeStream, }); useEffect(() => { onStreamingChange?.(isStreaming); }, [isStreaming, onStreamingChange]); - const breakpoint = useBreakpoint(); - const isMobile = - breakpoint === "base" || breakpoint === "sm" || breakpoint === "md"; - return (
diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts index 82e9b05e88..af3b3329b7 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts @@ -2,6 +2,7 @@ import { toast } from "sonner"; import type { StreamChunk } from "../../chat-types"; import type { HandlerDependencies } from "./handlers"; import { + getErrorDisplayMessage, handleError, handleLoginNeeded, handleStreamEnd, @@ -24,16 +25,22 @@ export function createStreamEventDispatcher( chunk.type === "need_login" || chunk.type === "error" ) { - if (!deps.hasResponseRef.current) { - console.info("[ChatStream] First response chunk:", { - type: chunk.type, - sessionId: deps.sessionId, - }); - } deps.hasResponseRef.current = true; } switch (chunk.type) { + case "stream_start": + // Store task ID for SSE reconnection + if (chunk.taskId && deps.onActiveTaskStarted) { + deps.onActiveTaskStarted({ + taskId: chunk.taskId, + operationId: chunk.taskId, + toolName: "chat", + toolCallId: "chat_stream", + }); + } + break; + case "text_chunk": handleTextChunk(chunk, deps); break; @@ -56,11 +63,7 @@ export function createStreamEventDispatcher( break; case "stream_end": - console.info("[ChatStream] Stream ended:", { - sessionId: deps.sessionId, - hasResponse: deps.hasResponseRef.current, - chunkCount: deps.streamingChunksRef.current.length, - }); + // Note: "finish" type from backend gets normalized to "stream_end" by normalizeStreamChunk handleStreamEnd(chunk, deps); break; @@ -70,7 +73,7 @@ export function createStreamEventDispatcher( // Show toast at dispatcher level to avoid circular dependencies if (!isRegionBlocked) { toast.error("Chat Error", { - description: chunk.message || chunk.content || "An error occurred", + description: getErrorDisplayMessage(chunk), }); } break; diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts index f3cac01f96..5aec5b9818 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts @@ -18,11 +18,19 @@ export interface HandlerDependencies { setStreamingChunks: Dispatch>; streamingChunksRef: MutableRefObject; hasResponseRef: MutableRefObject; + textFinalizedRef: MutableRefObject; + streamEndedRef: MutableRefObject; setMessages: Dispatch>; setIsStreamingInitiated: Dispatch>; setIsRegionBlockedModalOpen: Dispatch>; sessionId: string; onOperationStarted?: () => void; + onActiveTaskStarted?: (taskInfo: { + taskId: string; + operationId: string; + toolName: string; + toolCallId: string; + }) => void; } export function isRegionBlockedError(chunk: StreamChunk): boolean { @@ -32,6 +40,25 @@ export function isRegionBlockedError(chunk: StreamChunk): boolean { return message.toLowerCase().includes("not available in your region"); } +export function getUserFriendlyErrorMessage( + code: string | undefined, +): string | undefined { + switch (code) { + case "TASK_EXPIRED": + return "This operation has expired. Please try again."; + case "TASK_NOT_FOUND": + return "Could not find the requested operation."; + case "ACCESS_DENIED": + return "You do not have access to this operation."; + case "QUEUE_OVERFLOW": + return "Connection was interrupted. Please refresh to continue."; + case "MODEL_NOT_AVAILABLE_REGION": + return "This model is not available in your region."; + default: + return undefined; + } +} + export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) { if (!chunk.content) return; deps.setHasTextChunks(true); @@ -46,10 +73,15 @@ export function handleTextEnded( _chunk: StreamChunk, deps: HandlerDependencies, ) { + if (deps.textFinalizedRef.current) { + return; + } + const completedText = deps.streamingChunksRef.current.join(""); if (completedText.trim()) { + deps.textFinalizedRef.current = true; + deps.setMessages((prev) => { - // Check if this exact message already exists to prevent duplicates const exists = prev.some( (msg) => msg.type === "message" && @@ -76,9 +108,14 @@ export function handleToolCallStart( chunk: StreamChunk, deps: HandlerDependencies, ) { + // Use deterministic fallback instead of Date.now() to ensure same ID on replay + const toolId = + chunk.tool_id || + `tool-${deps.sessionId}-${chunk.idx ?? "unknown"}-${chunk.tool_name || "unknown"}`; + const toolCallMessage: Extract = { type: "tool_call", - toolId: chunk.tool_id || `tool-${Date.now()}-${chunk.idx || 0}`, + toolId, toolName: chunk.tool_name || "Executing", arguments: chunk.arguments || {}, timestamp: new Date(), @@ -111,6 +148,29 @@ export function handleToolCallStart( deps.setMessages(updateToolCallMessages); } +const TOOL_RESPONSE_TYPES = new Set([ + "tool_response", + "operation_started", + "operation_pending", + "operation_in_progress", + "execution_started", + "agent_carousel", + "clarification_needed", +]); + +function hasResponseForTool( + messages: ChatMessageData[], + toolId: string, +): boolean { + return messages.some((msg) => { + if (!TOOL_RESPONSE_TYPES.has(msg.type)) return false; + const msgToolId = + (msg as { toolId?: string }).toolId || + (msg as { toolCallId?: string }).toolCallId; + return msgToolId === toolId; + }); +} + export function handleToolResponse( chunk: StreamChunk, deps: HandlerDependencies, @@ -152,31 +212,49 @@ export function handleToolResponse( ) { const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name); if (inputsMessage) { - deps.setMessages((prev) => [...prev, inputsMessage]); + deps.setMessages((prev) => { + // Check for duplicate inputs_needed message + const exists = prev.some((msg) => msg.type === "inputs_needed"); + if (exists) return prev; + return [...prev, inputsMessage]; + }); } const credentialsMessage = extractCredentialsNeeded( parsedResult, chunk.tool_name, ); if (credentialsMessage) { - deps.setMessages((prev) => [...prev, credentialsMessage]); + deps.setMessages((prev) => { + // Check for duplicate credentials_needed message + const exists = prev.some((msg) => msg.type === "credentials_needed"); + if (exists) return prev; + return [...prev, credentialsMessage]; + }); } } return; } - // Trigger polling when operation_started is received if (responseMessage.type === "operation_started") { deps.onOperationStarted?.(); + const taskId = (responseMessage as { taskId?: string }).taskId; + if (taskId && deps.onActiveTaskStarted) { + deps.onActiveTaskStarted({ + taskId, + operationId: + (responseMessage as { operationId?: string }).operationId || "", + toolName: (responseMessage as { toolName?: string }).toolName || "", + toolCallId: (responseMessage as { toolId?: string }).toolId || "", + }); + } } deps.setMessages((prev) => { const toolCallIndex = prev.findIndex( (msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id, ); - const hasResponse = prev.some( - (msg) => msg.type === "tool_response" && msg.toolId === chunk.tool_id, - ); - if (hasResponse) return prev; + if (hasResponseForTool(prev, chunk.tool_id!)) { + return prev; + } if (toolCallIndex !== -1) { const newMessages = [...prev]; newMessages.splice(toolCallIndex + 1, 0, responseMessage); @@ -198,28 +276,48 @@ export function handleLoginNeeded( agentInfo: chunk.agent_info, timestamp: new Date(), }; - deps.setMessages((prev) => [...prev, loginNeededMessage]); + deps.setMessages((prev) => { + // Check for duplicate login_needed message + const exists = prev.some((msg) => msg.type === "login_needed"); + if (exists) return prev; + return [...prev, loginNeededMessage]; + }); } export function handleStreamEnd( _chunk: StreamChunk, deps: HandlerDependencies, ) { + if (deps.streamEndedRef.current) { + return; + } + deps.streamEndedRef.current = true; + const completedContent = deps.streamingChunksRef.current.join(""); if (!completedContent.trim() && !deps.hasResponseRef.current) { - deps.setMessages((prev) => [ - ...prev, - { - type: "message", - role: "assistant", - content: "No response received. Please try again.", - timestamp: new Date(), - }, - ]); - } - if (completedContent.trim()) { deps.setMessages((prev) => { - // Check if this exact message already exists to prevent duplicates + const exists = prev.some( + (msg) => + msg.type === "message" && + msg.role === "assistant" && + msg.content === "No response received. Please try again.", + ); + if (exists) return prev; + return [ + ...prev, + { + type: "message", + role: "assistant", + content: "No response received. Please try again.", + timestamp: new Date(), + }, + ]; + }); + } + if (completedContent.trim() && !deps.textFinalizedRef.current) { + deps.textFinalizedRef.current = true; + + deps.setMessages((prev) => { const exists = prev.some( (msg) => msg.type === "message" && @@ -244,8 +342,6 @@ export function handleStreamEnd( } export function handleError(chunk: StreamChunk, deps: HandlerDependencies) { - const errorMessage = chunk.message || chunk.content || "An error occurred"; - console.error("Stream error:", errorMessage); if (isRegionBlockedError(chunk)) { deps.setIsRegionBlockedModalOpen(true); } @@ -253,4 +349,14 @@ export function handleError(chunk: StreamChunk, deps: HandlerDependencies) { deps.setHasTextChunks(false); deps.setStreamingChunks([]); deps.streamingChunksRef.current = []; + deps.textFinalizedRef.current = false; + deps.streamEndedRef.current = true; +} + +export function getErrorDisplayMessage(chunk: StreamChunk): string { + const friendlyMessage = getUserFriendlyErrorMessage(chunk.code); + if (friendlyMessage) { + return friendlyMessage; + } + return chunk.message || chunk.content || "An error occurred"; } diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts index e744c9bc34..f1e94cea17 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts @@ -349,6 +349,7 @@ export function parseToolResponse( toolName: (parsedResult.tool_name as string) || toolName, toolId, operationId: (parsedResult.operation_id as string) || "", + taskId: (parsedResult.task_id as string) || undefined, // For SSE reconnection message: (parsedResult.message as string) || "Operation started. You can close this tab.", diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts index 46f384d055..248383df42 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts @@ -1,10 +1,17 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; import { useEffect, useMemo, useRef, useState } from "react"; +import { INITIAL_STREAM_ID } from "../../chat-constants"; import { useChatStore } from "../../chat-store"; import { toast } from "sonner"; import { useChatStream } from "../../useChatStream"; import { usePageContext } from "../../usePageContext"; import type { ChatMessageData } from "../ChatMessage/useChatMessage"; +import { + getToolIdFromMessage, + hasToolId, + isOperationMessage, + type StreamChunk, +} from "../../chat-types"; import { createStreamEventDispatcher } from "./createStreamEventDispatcher"; import { createUserMessage, @@ -14,6 +21,13 @@ import { processInitialMessages, } from "./helpers"; +const TOOL_RESULT_TYPES = new Set([ + "tool_response", + "agent_carousel", + "execution_started", + "clarification_needed", +]); + // Helper to generate deduplication key for a message function getMessageKey(msg: ChatMessageData): string { if (msg.type === "message") { @@ -23,14 +37,18 @@ function getMessageKey(msg: ChatMessageData): string { return `msg:${msg.role}:${msg.content}`; } else if (msg.type === "tool_call") { return `toolcall:${msg.toolId}`; - } else if (msg.type === "tool_response") { - return `toolresponse:${(msg as any).toolId}`; - } else if ( - msg.type === "operation_started" || - msg.type === "operation_pending" || - msg.type === "operation_in_progress" - ) { - return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`; + } else if (TOOL_RESULT_TYPES.has(msg.type)) { + // Unified key for all tool result types - same toolId with different types + // (tool_response vs agent_carousel) should deduplicate to the same key + const toolId = getToolIdFromMessage(msg); + // If no toolId, fall back to content-based key to avoid empty key collisions + if (!toolId) { + return `toolresult:content:${JSON.stringify(msg).slice(0, 200)}`; + } + return `toolresult:${toolId}`; + } else if (isOperationMessage(msg)) { + const toolId = getToolIdFromMessage(msg) || ""; + return `op:${toolId}:${msg.toolName}`; } else { return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`; } @@ -41,6 +59,13 @@ interface Args { initialMessages: SessionDetailResponse["messages"]; initialPrompt?: string; onOperationStarted?: () => void; + /** Active stream info from the server for reconnection */ + activeStream?: { + taskId: string; + lastMessageId: string; + operationId: string; + toolName: string; + }; } export function useChatContainer({ @@ -48,6 +73,7 @@ export function useChatContainer({ initialMessages, initialPrompt, onOperationStarted, + activeStream, }: Args) { const [messages, setMessages] = useState([]); const [streamingChunks, setStreamingChunks] = useState([]); @@ -57,6 +83,8 @@ export function useChatContainer({ useState(false); const hasResponseRef = useRef(false); const streamingChunksRef = useRef([]); + const textFinalizedRef = useRef(false); + const streamEndedRef = useRef(false); const previousSessionIdRef = useRef(null); const { error, @@ -65,44 +93,182 @@ export function useChatContainer({ } = useChatStream(); const activeStreams = useChatStore((s) => s.activeStreams); const subscribeToStream = useChatStore((s) => s.subscribeToStream); + const setActiveTask = useChatStore((s) => s.setActiveTask); + const getActiveTask = useChatStore((s) => s.getActiveTask); + const reconnectToTask = useChatStore((s) => s.reconnectToTask); const isStreaming = isStreamingInitiated || hasTextChunks; + // Track whether we've already connected to this activeStream to avoid duplicate connections + const connectedActiveStreamRef = useRef(null); + // Track if component is mounted to prevent state updates after unmount + const isMountedRef = useRef(true); + // Track current dispatcher to prevent multiple dispatchers from adding messages + const currentDispatcherIdRef = useRef(0); + + // Set mounted flag - reset on every mount, cleanup on unmount + useEffect(function trackMountedState() { + isMountedRef.current = true; + return function cleanup() { + isMountedRef.current = false; + }; + }, []); + + // Callback to store active task info for SSE reconnection + function handleActiveTaskStarted(taskInfo: { + taskId: string; + operationId: string; + toolName: string; + toolCallId: string; + }) { + if (!sessionId) return; + setActiveTask(sessionId, { + taskId: taskInfo.taskId, + operationId: taskInfo.operationId, + toolName: taskInfo.toolName, + lastMessageId: INITIAL_STREAM_ID, + }); + } + + // Create dispatcher for stream events - stable reference for current sessionId + // Each dispatcher gets a unique ID to prevent stale dispatchers from updating state + function createDispatcher() { + if (!sessionId) return () => {}; + // Increment dispatcher ID - only the most recent dispatcher should update state + const dispatcherId = ++currentDispatcherIdRef.current; + + const baseDispatcher = createStreamEventDispatcher({ + setHasTextChunks, + setStreamingChunks, + streamingChunksRef, + hasResponseRef, + textFinalizedRef, + streamEndedRef, + setMessages, + setIsRegionBlockedModalOpen, + sessionId, + setIsStreamingInitiated, + onOperationStarted, + onActiveTaskStarted: handleActiveTaskStarted, + }); + + // Wrap dispatcher to check if it's still the current one + return function guardedDispatcher(chunk: StreamChunk) { + // Skip if component unmounted or this is a stale dispatcher + if (!isMountedRef.current) { + return; + } + if (dispatcherId !== currentDispatcherIdRef.current) { + return; + } + baseDispatcher(chunk); + }; + } useEffect( function handleSessionChange() { - if (sessionId === previousSessionIdRef.current) return; + const isSessionChange = sessionId !== previousSessionIdRef.current; - const prevSession = previousSessionIdRef.current; - if (prevSession) { - stopStreaming(prevSession); + // Handle session change - reset state + if (isSessionChange) { + const prevSession = previousSessionIdRef.current; + if (prevSession) { + stopStreaming(prevSession); + } + previousSessionIdRef.current = sessionId; + connectedActiveStreamRef.current = null; + setMessages([]); + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + setIsStreamingInitiated(false); + hasResponseRef.current = false; + textFinalizedRef.current = false; + streamEndedRef.current = false; } - previousSessionIdRef.current = sessionId; - setMessages([]); - setStreamingChunks([]); - streamingChunksRef.current = []; - setHasTextChunks(false); - setIsStreamingInitiated(false); - hasResponseRef.current = false; if (!sessionId) return; - const activeStream = activeStreams.get(sessionId); - if (!activeStream || activeStream.status !== "streaming") return; + // Priority 1: Check if server told us there's an active stream (most authoritative) + if (activeStream) { + const streamKey = `${sessionId}:${activeStream.taskId}`; - const dispatcher = createStreamEventDispatcher({ - setHasTextChunks, - setStreamingChunks, - streamingChunksRef, - hasResponseRef, - setMessages, - setIsRegionBlockedModalOpen, - sessionId, - setIsStreamingInitiated, - onOperationStarted, - }); + if (connectedActiveStreamRef.current === streamKey) { + return; + } + + // Skip if there's already an active stream for this session in the store + const existingStream = activeStreams.get(sessionId); + if (existingStream && existingStream.status === "streaming") { + connectedActiveStreamRef.current = streamKey; + return; + } + + connectedActiveStreamRef.current = streamKey; + + // Clear all state before reconnection to prevent duplicates + // Server's initialMessages is authoritative; local state will be rebuilt from SSE replay + setMessages([]); + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + textFinalizedRef.current = false; + streamEndedRef.current = false; + hasResponseRef.current = false; + + setIsStreamingInitiated(true); + setActiveTask(sessionId, { + taskId: activeStream.taskId, + operationId: activeStream.operationId, + toolName: activeStream.toolName, + lastMessageId: activeStream.lastMessageId, + }); + reconnectToTask( + sessionId, + activeStream.taskId, + activeStream.lastMessageId, + createDispatcher(), + ); + // Don't return cleanup here - the guarded dispatcher handles stale events + // and the stream will complete naturally. Cleanup would prematurely stop + // the stream when effect re-runs due to activeStreams changing. + return; + } + + // Only check localStorage/in-memory on session change + if (!isSessionChange) return; + + // Priority 2: Check localStorage for active task + const activeTask = getActiveTask(sessionId); + if (activeTask) { + // Clear all state before reconnection to prevent duplicates + // Server's initialMessages is authoritative; local state will be rebuilt from SSE replay + setMessages([]); + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + textFinalizedRef.current = false; + streamEndedRef.current = false; + hasResponseRef.current = false; + + setIsStreamingInitiated(true); + reconnectToTask( + sessionId, + activeTask.taskId, + activeTask.lastMessageId, + createDispatcher(), + ); + // Don't return cleanup here - the guarded dispatcher handles stale events + return; + } + + // Priority 3: Check for an in-memory active stream (same-tab scenario) + const inMemoryStream = activeStreams.get(sessionId); + if (!inMemoryStream || inMemoryStream.status !== "streaming") { + return; + } setIsStreamingInitiated(true); const skipReplay = initialMessages.length > 0; - return subscribeToStream(sessionId, dispatcher, skipReplay); + return subscribeToStream(sessionId, createDispatcher(), skipReplay); }, [ sessionId, @@ -110,6 +276,10 @@ export function useChatContainer({ activeStreams, subscribeToStream, onOperationStarted, + getActiveTask, + reconnectToTask, + activeStream, + setActiveTask, ], ); @@ -124,7 +294,7 @@ export function useChatContainer({ msg.type === "agent_carousel" || msg.type === "execution_started" ) { - const toolId = (msg as any).toolId; + const toolId = hasToolId(msg) ? msg.toolId : undefined; if (toolId) { ids.add(toolId); } @@ -141,12 +311,8 @@ export function useChatContainer({ setMessages((prev) => { const filtered = prev.filter((msg) => { - if ( - msg.type === "operation_started" || - msg.type === "operation_pending" || - msg.type === "operation_in_progress" - ) { - const toolId = (msg as any).toolId || (msg as any).toolCallId; + if (isOperationMessage(msg)) { + const toolId = getToolIdFromMessage(msg); if (toolId && completedToolIds.has(toolId)) { return false; // Remove - operation completed } @@ -174,12 +340,8 @@ export function useChatContainer({ // Filter local messages: remove duplicates and completed operation messages const newLocalMessages = messages.filter((msg) => { // Remove operation messages for completed tools - if ( - msg.type === "operation_started" || - msg.type === "operation_pending" || - msg.type === "operation_in_progress" - ) { - const toolId = (msg as any).toolId || (msg as any).toolCallId; + if (isOperationMessage(msg)) { + const toolId = getToolIdFromMessage(msg); if (toolId && completedToolIds.has(toolId)) { return false; } @@ -190,7 +352,70 @@ export function useChatContainer({ }); // Server messages first (correct order), then new local messages - return [...processedInitial, ...newLocalMessages]; + const combined = [...processedInitial, ...newLocalMessages]; + + // Post-processing: Remove duplicate assistant messages that can occur during + // race conditions (e.g., rapid screen switching during SSE reconnection). + // Two assistant messages are considered duplicates if: + // - They are both text messages with role "assistant" + // - One message's content starts with the other's content (partial vs complete) + // - Or they have very similar content (>80% overlap at the start) + const deduplicated: ChatMessageData[] = []; + for (let i = 0; i < combined.length; i++) { + const current = combined[i]; + + // Check if this is an assistant text message + if (current.type !== "message" || current.role !== "assistant") { + deduplicated.push(current); + continue; + } + + // Look for duplicate assistant messages in the rest of the array + let dominated = false; + for (let j = 0; j < combined.length; j++) { + if (i === j) continue; + const other = combined[j]; + if (other.type !== "message" || other.role !== "assistant") continue; + + const currentContent = current.content || ""; + const otherContent = other.content || ""; + + // Skip empty messages + if (!currentContent.trim() || !otherContent.trim()) continue; + + // Check if current is a prefix of other (current is incomplete version) + if ( + otherContent.length > currentContent.length && + otherContent.startsWith(currentContent.slice(0, 100)) + ) { + // Current is a shorter/incomplete version of other - skip it + dominated = true; + break; + } + + // Check if messages are nearly identical (within a small difference) + // This catches cases where content differs only slightly + const minLen = Math.min(currentContent.length, otherContent.length); + const compareLen = Math.min(minLen, 200); // Compare first 200 chars + if ( + compareLen > 50 && + currentContent.slice(0, compareLen) === + otherContent.slice(0, compareLen) + ) { + // Same prefix - keep the longer one + if (otherContent.length > currentContent.length) { + dominated = true; + break; + } + } + } + + if (!dominated) { + deduplicated.push(current); + } + } + + return deduplicated; }, [initialMessages, messages, completedToolIds]); async function sendMessage( @@ -198,10 +423,8 @@ export function useChatContainer({ isUserMessage: boolean = true, context?: { url: string; content: string }, ) { - if (!sessionId) { - console.error("[useChatContainer] Cannot send message: no session ID"); - return; - } + if (!sessionId) return; + setIsRegionBlockedModalOpen(false); if (isUserMessage) { const userMessage = createUserMessage(content); @@ -214,31 +437,19 @@ export function useChatContainer({ setHasTextChunks(false); setIsStreamingInitiated(true); hasResponseRef.current = false; - - const dispatcher = createStreamEventDispatcher({ - setHasTextChunks, - setStreamingChunks, - streamingChunksRef, - hasResponseRef, - setMessages, - setIsRegionBlockedModalOpen, - sessionId, - setIsStreamingInitiated, - onOperationStarted, - }); + textFinalizedRef.current = false; + streamEndedRef.current = false; try { await sendStreamMessage( sessionId, content, - dispatcher, + createDispatcher(), isUserMessage, context, ); } catch (err) { - console.error("[useChatContainer] Failed to send message:", err); setIsStreamingInitiated(false); - if (err instanceof Error && err.name === "AbortError") return; const errorMessage = diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx index c45e8dc250..bac004f6ed 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx @@ -1,7 +1,14 @@ import { Button } from "@/components/atoms/Button/Button"; import { cn } from "@/lib/utils"; -import { ArrowUpIcon, StopIcon } from "@phosphor-icons/react"; +import { + ArrowUpIcon, + CircleNotchIcon, + MicrophoneIcon, + StopIcon, +} from "@phosphor-icons/react"; +import { RecordingIndicator } from "./components/RecordingIndicator"; import { useChatInput } from "./useChatInput"; +import { useVoiceRecording } from "./useVoiceRecording"; export interface Props { onSend: (message: string) => void; @@ -21,13 +28,37 @@ export function ChatInput({ className, }: Props) { const inputId = "chat-input"; - const { value, handleKeyDown, handleSubmit, handleChange, hasMultipleLines } = - useChatInput({ - onSend, - disabled: disabled || isStreaming, - maxRows: 4, - inputId, - }); + const { + value, + setValue, + handleKeyDown: baseHandleKeyDown, + handleSubmit, + handleChange, + hasMultipleLines, + } = useChatInput({ + onSend, + disabled: disabled || isStreaming, + maxRows: 4, + inputId, + }); + + const { + isRecording, + isTranscribing, + elapsedTime, + toggleRecording, + handleKeyDown, + showMicButton, + isInputDisabled, + audioStream, + } = useVoiceRecording({ + setValue, + disabled: disabled || isStreaming, + isStreaming, + value, + baseHandleKeyDown, + inputId, + }); return (
@@ -35,59 +66,110 @@ export function ChatInput({
+ {!value && !isRecording && ( + + )}