mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-30 09:28:19 -05:00
Compare commits
45 Commits
feat/text-
...
feat/sub-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ad8fde75d | ||
|
|
aef705007b | ||
|
|
be7e1ad9b6 | ||
|
|
ce050abff9 | ||
|
|
79eb2889ab | ||
|
|
5bc5e02dcb | ||
|
|
f83366d08d | ||
|
|
350ad3591b | ||
|
|
de0ec3d388 | ||
|
|
7cb1e588b0 | ||
|
|
16ae8ddbe0 | ||
|
|
4b04ae2147 | ||
|
|
de71d6134a | ||
|
|
e6eb8a3f57 | ||
|
|
582c6cad36 | ||
|
|
0d1d275e8d | ||
|
|
dc92a7b520 | ||
|
|
d4047b5439 | ||
|
|
f00678fd1c | ||
|
|
aa175e0f4e | ||
|
|
9a8838c69a | ||
|
|
41beae1122 | ||
|
|
e810f7b0d7 | ||
|
|
9c3822fffe | ||
|
|
c039a2e3ad | ||
|
|
3b822cdaf7 | ||
|
|
a3fe1ede55 | ||
|
|
552d069a9d | ||
|
|
b2eb4831bd | ||
|
|
4cd5da678d | ||
|
|
b94c83aacc | ||
|
|
7668c17d9c | ||
|
|
e0dfae5732 | ||
|
|
7df867d645 | ||
|
|
d855f79874 | ||
|
|
dac99694fe | ||
|
|
0953983944 | ||
|
|
0058cd3ba6 | ||
|
|
ea035224bc | ||
|
|
62813a1ea6 | ||
|
|
67405f7eb9 | ||
|
|
171ff6e776 | ||
|
|
349b1f9c79 | ||
|
|
277b0537e9 | ||
|
|
071b3bb5cd |
@@ -29,8 +29,7 @@
|
|||||||
"postCreateCmd": [
|
"postCreateCmd": [
|
||||||
"cd autogpt_platform/autogpt_libs && poetry install",
|
"cd autogpt_platform/autogpt_libs && poetry install",
|
||||||
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
||||||
"cd autogpt_platform/frontend && pnpm install",
|
"cd autogpt_platform/frontend && pnpm install"
|
||||||
"cd docs && pip install -r requirements.txt"
|
|
||||||
],
|
],
|
||||||
"terminalCommand": "code .",
|
"terminalCommand": "code .",
|
||||||
"deleteBranchWithWorktree": false
|
"deleteBranchWithWorktree": false
|
||||||
|
|||||||
6
.github/copilot-instructions.md
vendored
6
.github/copilot-instructions.md
vendored
@@ -160,7 +160,7 @@ pnpm storybook # Start component development server
|
|||||||
|
|
||||||
**Backend Entry Points:**
|
**Backend Entry Points:**
|
||||||
|
|
||||||
- `backend/backend/server/server.py` - FastAPI application setup
|
- `backend/backend/api/rest_api.py` - FastAPI application setup
|
||||||
- `backend/backend/data/` - Database models and user management
|
- `backend/backend/data/` - Database models and user management
|
||||||
- `backend/blocks/` - Agent execution blocks and logic
|
- `backend/blocks/` - Agent execution blocks and logic
|
||||||
|
|
||||||
@@ -219,7 +219,7 @@ Agents are built using a visual block-based system where each block performs a s
|
|||||||
|
|
||||||
### API Development
|
### API Development
|
||||||
|
|
||||||
1. Update routes in `/backend/backend/server/routers/`
|
1. Update routes in `/backend/backend/api/features/`
|
||||||
2. Add/update Pydantic models in same directory
|
2. Add/update Pydantic models in same directory
|
||||||
3. Write tests alongside route files
|
3. Write tests alongside route files
|
||||||
4. For `data/*.py` changes, validate user ID checks
|
4. For `data/*.py` changes, validate user ID checks
|
||||||
@@ -285,7 +285,7 @@ Agents are built using a visual block-based system where each block performs a s
|
|||||||
|
|
||||||
### Security Guidelines
|
### Security Guidelines
|
||||||
|
|
||||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
**Cache Protection Middleware** (`/backend/backend/api/middleware/security.py`):
|
||||||
|
|
||||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -178,4 +178,5 @@ autogpt_platform/backend/settings.py
|
|||||||
*.ign.*
|
*.ign.*
|
||||||
.test-contents
|
.test-contents
|
||||||
.claude/settings.local.json
|
.claude/settings.local.json
|
||||||
|
CLAUDE.local.md
|
||||||
/autogpt_platform/backend/logs
|
/autogpt_platform/backend/logs
|
||||||
|
|||||||
24
AGENTS.md
24
AGENTS.md
@@ -16,7 +16,6 @@ See `docs/content/platform/getting-started.md` for setup instructions.
|
|||||||
- Format Python code with `poetry run format`.
|
- Format Python code with `poetry run format`.
|
||||||
- Format frontend code using `pnpm format`.
|
- Format frontend code using `pnpm format`.
|
||||||
|
|
||||||
|
|
||||||
## Frontend guidelines:
|
## Frontend guidelines:
|
||||||
|
|
||||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||||
@@ -33,14 +32,17 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
|||||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||||
|
|
||||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
||||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
||||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
||||||
- Use function declarations for components, arrow functions only for callbacks
|
- Use function declarations for components, arrow functions only for callbacks
|
||||||
- No barrel files or `index.ts` re-exports
|
- No barrel files or `index.ts` re-exports
|
||||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
|
||||||
- Avoid comments at all times unless the code is very complex
|
- Avoid comments at all times unless the code is very complex
|
||||||
|
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
|
||||||
|
- Do not type hook returns, let Typescript infer as much as possible
|
||||||
|
- Never type with `any`, if not types available use `unknown`
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
@@ -49,22 +51,8 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
|||||||
|
|
||||||
Always run the relevant linters and tests before committing.
|
Always run the relevant linters and tests before committing.
|
||||||
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
||||||
Types:
|
Types: - feat - fix - refactor - ci - dx (developer experience)
|
||||||
- feat
|
Scopes: - platform - platform/library - platform/marketplace - backend - backend/executor - frontend - frontend/library - frontend/marketplace - blocks
|
||||||
- fix
|
|
||||||
- refactor
|
|
||||||
- ci
|
|
||||||
- dx (developer experience)
|
|
||||||
Scopes:
|
|
||||||
- platform
|
|
||||||
- platform/library
|
|
||||||
- platform/marketplace
|
|
||||||
- backend
|
|
||||||
- backend/executor
|
|
||||||
- frontend
|
|
||||||
- frontend/library
|
|
||||||
- frontend/marketplace
|
|
||||||
- blocks
|
|
||||||
|
|
||||||
## Pull requests
|
## Pull requests
|
||||||
|
|
||||||
|
|||||||
@@ -6,152 +6,30 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
|||||||
|
|
||||||
AutoGPT Platform is a monorepo containing:
|
AutoGPT Platform is a monorepo containing:
|
||||||
|
|
||||||
- **Backend** (`/backend`): Python FastAPI server with async support
|
- **Backend** (`backend`): Python FastAPI server with async support
|
||||||
- **Frontend** (`/frontend`): Next.js React application
|
- **Frontend** (`frontend`): Next.js React application
|
||||||
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
|
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||||
|
|
||||||
## Essential Commands
|
## Component Documentation
|
||||||
|
|
||||||
### Backend Development
|
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
|
||||||
|
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
|
||||||
|
|
||||||
```bash
|
## Key Concepts
|
||||||
# Install dependencies
|
|
||||||
cd backend && poetry install
|
|
||||||
|
|
||||||
# Run database migrations
|
|
||||||
poetry run prisma migrate dev
|
|
||||||
|
|
||||||
# Start all services (database, redis, rabbitmq, clamav)
|
|
||||||
docker compose up -d
|
|
||||||
|
|
||||||
# Run the backend server
|
|
||||||
poetry run serve
|
|
||||||
|
|
||||||
# Run tests
|
|
||||||
poetry run test
|
|
||||||
|
|
||||||
# Run specific test
|
|
||||||
poetry run pytest path/to/test_file.py::test_function_name
|
|
||||||
|
|
||||||
# Run block tests (tests that validate all blocks work correctly)
|
|
||||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
|
||||||
|
|
||||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
|
||||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
|
||||||
|
|
||||||
# Lint and format
|
|
||||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
|
||||||
poetry run format # Black + isort
|
|
||||||
poetry run lint # ruff
|
|
||||||
```
|
|
||||||
|
|
||||||
More details can be found in TESTING.md
|
|
||||||
|
|
||||||
#### Creating/Updating Snapshots
|
|
||||||
|
|
||||||
When you first write a test or when the expected output changes:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
poetry run pytest path/to/test.py --snapshot-update
|
|
||||||
```
|
|
||||||
|
|
||||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
|
||||||
|
|
||||||
### Frontend Development
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Install dependencies
|
|
||||||
cd frontend && pnpm i
|
|
||||||
|
|
||||||
# Generate API client from OpenAPI spec
|
|
||||||
pnpm generate:api
|
|
||||||
|
|
||||||
# Start development server
|
|
||||||
pnpm dev
|
|
||||||
|
|
||||||
# Run E2E tests
|
|
||||||
pnpm test
|
|
||||||
|
|
||||||
# Run Storybook for component development
|
|
||||||
pnpm storybook
|
|
||||||
|
|
||||||
# Build production
|
|
||||||
pnpm build
|
|
||||||
|
|
||||||
# Format and lint
|
|
||||||
pnpm format
|
|
||||||
|
|
||||||
# Type checking
|
|
||||||
pnpm types
|
|
||||||
```
|
|
||||||
|
|
||||||
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
|
|
||||||
|
|
||||||
**Key Frontend Conventions:**
|
|
||||||
|
|
||||||
- Separate render logic from data/behavior in components
|
|
||||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
|
||||||
- Use function declarations (not arrow functions) for components/handlers
|
|
||||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
|
||||||
- Only use Phosphor Icons
|
|
||||||
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
|
|
||||||
|
|
||||||
## Architecture Overview
|
|
||||||
|
|
||||||
### Backend Architecture
|
|
||||||
|
|
||||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
|
||||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
|
||||||
- **Queue System**: RabbitMQ for async task processing
|
|
||||||
- **Execution Engine**: Separate executor service processes agent workflows
|
|
||||||
- **Authentication**: JWT-based with Supabase integration
|
|
||||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
|
||||||
|
|
||||||
### Frontend Architecture
|
|
||||||
|
|
||||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
|
||||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
|
||||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
|
||||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
|
||||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
|
||||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
|
||||||
- **Icons**: Phosphor Icons only
|
|
||||||
- **Feature Flags**: LaunchDarkly integration
|
|
||||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
|
||||||
- **Testing**: Playwright for E2E, Storybook for component development
|
|
||||||
|
|
||||||
### Key Concepts
|
|
||||||
|
|
||||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||||
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
|
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||||
3. **Integrations**: OAuth and API connections stored per user
|
3. **Integrations**: OAuth and API connections stored per user
|
||||||
4. **Store**: Marketplace for sharing agent templates
|
4. **Store**: Marketplace for sharing agent templates
|
||||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||||
|
|
||||||
### Testing Approach
|
|
||||||
|
|
||||||
- Backend uses pytest with snapshot testing for API responses
|
|
||||||
- Test files are colocated with source files (`*_test.py`)
|
|
||||||
- Frontend uses Playwright for E2E tests
|
|
||||||
- Component testing via Storybook
|
|
||||||
|
|
||||||
### Database Schema
|
|
||||||
|
|
||||||
Key models (defined in `/backend/schema.prisma`):
|
|
||||||
|
|
||||||
- `User`: Authentication and profile data
|
|
||||||
- `AgentGraph`: Workflow definitions with version control
|
|
||||||
- `AgentGraphExecution`: Execution history and results
|
|
||||||
- `AgentNode`: Individual nodes in a workflow
|
|
||||||
- `StoreListing`: Marketplace listings for sharing agents
|
|
||||||
|
|
||||||
### Environment Configuration
|
### Environment Configuration
|
||||||
|
|
||||||
#### Configuration Files
|
#### Configuration Files
|
||||||
|
|
||||||
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
|
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||||
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
|
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||||
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
|
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||||
|
|
||||||
#### Docker Environment Loading Order
|
#### Docker Environment Loading Order
|
||||||
|
|
||||||
@@ -167,83 +45,12 @@ Key models (defined in `/backend/schema.prisma`):
|
|||||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||||
|
|
||||||
### Common Development Tasks
|
|
||||||
|
|
||||||
**Adding a new block:**
|
|
||||||
|
|
||||||
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
|
||||||
|
|
||||||
- Provider configuration with `ProviderBuilder`
|
|
||||||
- Block schema definition
|
|
||||||
- Authentication (API keys, OAuth, webhooks)
|
|
||||||
- Testing and validation
|
|
||||||
- File organization
|
|
||||||
|
|
||||||
Quick steps:
|
|
||||||
|
|
||||||
1. Create new file in `/backend/backend/blocks/`
|
|
||||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
|
||||||
3. Inherit from `Block` base class
|
|
||||||
4. Define input/output schemas using `BlockSchema`
|
|
||||||
5. Implement async `run` method
|
|
||||||
6. Generate unique block ID using `uuid.uuid4()`
|
|
||||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
|
||||||
|
|
||||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
|
||||||
ex: do the inputs and outputs tie well together?
|
|
||||||
|
|
||||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
|
||||||
|
|
||||||
**Modifying the API:**
|
|
||||||
|
|
||||||
1. Update route in `/backend/backend/server/routers/`
|
|
||||||
2. Add/update Pydantic models in same directory
|
|
||||||
3. Write tests alongside the route file
|
|
||||||
4. Run `poetry run test` to verify
|
|
||||||
|
|
||||||
### Frontend guidelines:
|
|
||||||
|
|
||||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
|
||||||
|
|
||||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
|
||||||
- Add `usePageName.ts` hook for logic
|
|
||||||
- Put sub-components in local `components/` folder
|
|
||||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
|
||||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
|
||||||
- Never use `src/components/__legacy__/*`
|
|
||||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
|
||||||
- Regenerate with `pnpm generate:api`
|
|
||||||
- Pattern: `use{Method}{Version}{OperationName}`
|
|
||||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
|
||||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
|
||||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
|
||||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
|
||||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
|
||||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
|
||||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
|
||||||
- Use function declarations for components, arrow functions only for callbacks
|
|
||||||
- No barrel files or `index.ts` re-exports
|
|
||||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
|
||||||
- Avoid comments at all times unless the code is very complex
|
|
||||||
|
|
||||||
### Security Implementation
|
|
||||||
|
|
||||||
**Cache Protection Middleware:**
|
|
||||||
|
|
||||||
- Located in `/backend/backend/server/middleware/security.py`
|
|
||||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
|
||||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
|
||||||
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
|
|
||||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
|
||||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
|
||||||
- Applied to both main API server and external API applications
|
|
||||||
|
|
||||||
### Creating Pull Requests
|
### Creating Pull Requests
|
||||||
|
|
||||||
- Create the PR aginst the `dev` branch of the repository.
|
- Create the PR against the `dev` branch of the repository.
|
||||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
|
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||||
- Use conventional commit messages (see below)/
|
- Use conventional commit messages (see below)
|
||||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
|
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||||
- Run the github pre-commit hooks to ensure code quality.
|
- Run the github pre-commit hooks to ensure code quality.
|
||||||
|
|
||||||
### Reviewing/Revising Pull Requests
|
### Reviewing/Revising Pull Requests
|
||||||
|
|||||||
170
autogpt_platform/backend/CLAUDE.md
Normal file
170
autogpt_platform/backend/CLAUDE.md
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
# CLAUDE.md - Backend
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code when working with the backend.
|
||||||
|
|
||||||
|
## Essential Commands
|
||||||
|
|
||||||
|
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dependencies
|
||||||
|
poetry install
|
||||||
|
|
||||||
|
# Run database migrations
|
||||||
|
poetry run prisma migrate dev
|
||||||
|
|
||||||
|
# Start all services (database, redis, rabbitmq, clamav)
|
||||||
|
docker compose up -d
|
||||||
|
|
||||||
|
# Run the backend as a whole
|
||||||
|
poetry run app
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
poetry run test
|
||||||
|
|
||||||
|
# Run specific test
|
||||||
|
poetry run pytest path/to/test_file.py::test_function_name
|
||||||
|
|
||||||
|
# Run block tests (tests that validate all blocks work correctly)
|
||||||
|
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||||
|
|
||||||
|
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||||
|
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||||
|
|
||||||
|
# Lint and format
|
||||||
|
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||||
|
poetry run format # Black + isort
|
||||||
|
poetry run lint # ruff
|
||||||
|
```
|
||||||
|
|
||||||
|
More details can be found in @TESTING.md
|
||||||
|
|
||||||
|
### Creating/Updating Snapshots
|
||||||
|
|
||||||
|
When you first write a test or when the expected output changes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
poetry run pytest path/to/test.py --snapshot-update
|
||||||
|
```
|
||||||
|
|
||||||
|
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||||
|
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||||
|
- **Queue System**: RabbitMQ for async task processing
|
||||||
|
- **Execution Engine**: Separate executor service processes agent workflows
|
||||||
|
- **Authentication**: JWT-based with Supabase integration
|
||||||
|
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||||
|
|
||||||
|
## Testing Approach
|
||||||
|
|
||||||
|
- Uses pytest with snapshot testing for API responses
|
||||||
|
- Test files are colocated with source files (`*_test.py`)
|
||||||
|
|
||||||
|
## Database Schema
|
||||||
|
|
||||||
|
Key models (defined in `schema.prisma`):
|
||||||
|
|
||||||
|
- `User`: Authentication and profile data
|
||||||
|
- `AgentGraph`: Workflow definitions with version control
|
||||||
|
- `AgentGraphExecution`: Execution history and results
|
||||||
|
- `AgentNode`: Individual nodes in a workflow
|
||||||
|
- `StoreListing`: Marketplace listings for sharing agents
|
||||||
|
|
||||||
|
## Environment Configuration
|
||||||
|
|
||||||
|
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||||
|
|
||||||
|
## Common Development Tasks
|
||||||
|
|
||||||
|
### Adding a new block
|
||||||
|
|
||||||
|
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||||
|
|
||||||
|
- Provider configuration with `ProviderBuilder`
|
||||||
|
- Block schema definition
|
||||||
|
- Authentication (API keys, OAuth, webhooks)
|
||||||
|
- Testing and validation
|
||||||
|
- File organization
|
||||||
|
|
||||||
|
Quick steps:
|
||||||
|
|
||||||
|
1. Create new file in `backend/blocks/`
|
||||||
|
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||||
|
3. Inherit from `Block` base class
|
||||||
|
4. Define input/output schemas using `BlockSchema`
|
||||||
|
5. Implement async `run` method
|
||||||
|
6. Generate unique block ID using `uuid.uuid4()`
|
||||||
|
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||||
|
|
||||||
|
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||||
|
ex: do the inputs and outputs tie well together?
|
||||||
|
|
||||||
|
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||||
|
|
||||||
|
#### Handling files in blocks with `store_media_file()`
|
||||||
|
|
||||||
|
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||||
|
|
||||||
|
| Format | Use When | Returns |
|
||||||
|
|--------|----------|---------|
|
||||||
|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||||
|
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||||
|
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||||
|
|
||||||
|
**Examples:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# INPUT: Need to process file locally with ffmpeg
|
||||||
|
local_path = await store_media_file(
|
||||||
|
file=input_data.video,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||||
|
|
||||||
|
# INPUT: Need to send to external API like Replicate
|
||||||
|
image_b64 = await store_media_file(
|
||||||
|
file=input_data.image,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_external_api",
|
||||||
|
)
|
||||||
|
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||||
|
|
||||||
|
# OUTPUT: Returning result from block
|
||||||
|
result_url = await store_media_file(
|
||||||
|
file=generated_image_url,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", result_url
|
||||||
|
# In CoPilot: result_url = "workspace://abc123"
|
||||||
|
# In graphs: result_url = "data:image/png;base64,..."
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key points:**
|
||||||
|
|
||||||
|
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||||
|
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||||
|
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||||
|
|
||||||
|
### Modifying the API
|
||||||
|
|
||||||
|
1. Update route in `backend/api/features/`
|
||||||
|
2. Add/update Pydantic models in same directory
|
||||||
|
3. Write tests alongside the route file
|
||||||
|
4. Run `poetry run test` to verify
|
||||||
|
|
||||||
|
## Security Implementation
|
||||||
|
|
||||||
|
### Cache Protection Middleware
|
||||||
|
|
||||||
|
- Located in `backend/api/middleware/security.py`
|
||||||
|
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||||
|
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||||
|
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||||
|
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||||
|
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||||
|
- Applied to both main API server and external API applications
|
||||||
@@ -138,7 +138,7 @@ If the test doesn't need the `user_id` specifically, mocking is not necessary as
|
|||||||
|
|
||||||
#### Using Global Auth Fixtures
|
#### Using Global Auth Fixtures
|
||||||
|
|
||||||
Two global auth fixtures are provided by `backend/server/conftest.py`:
|
Two global auth fixtures are provided by `backend/api/conftest.py`:
|
||||||
|
|
||||||
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
|
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
|
||||||
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
|
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ router = fastapi.APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Taken from backend/server/v2/store/db.py
|
# Taken from backend/api/features/store/db.py
|
||||||
def sanitize_query(query: str | None) -> str | None:
|
def sanitize_query(query: str | None) -> str | None:
|
||||||
if query is None:
|
if query is None:
|
||||||
return query
|
return query
|
||||||
|
|||||||
@@ -33,9 +33,15 @@ class ChatConfig(BaseSettings):
|
|||||||
|
|
||||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||||
max_agent_runs: int = Field(default=3, description="Maximum number of agent runs")
|
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||||
max_agent_schedules: int = Field(
|
max_agent_schedules: int = Field(
|
||||||
default=3, description="Maximum number of agent schedules"
|
default=30, description="Maximum number of agent schedules"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Long-running operation configuration
|
||||||
|
long_running_operation_ttl: int = Field(
|
||||||
|
default=600,
|
||||||
|
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Langfuse Prompt Management Configuration
|
# Langfuse Prompt Management Configuration
|
||||||
|
|||||||
@@ -247,3 +247,45 @@ async def get_chat_session_message_count(session_id: str) -> int:
|
|||||||
"""Get the number of messages in a chat session."""
|
"""Get the number of messages in a chat session."""
|
||||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
async def update_tool_message_content(
|
||||||
|
session_id: str,
|
||||||
|
tool_call_id: str,
|
||||||
|
new_content: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Update the content of a tool message in chat history.
|
||||||
|
|
||||||
|
Used by background tasks to update pending operation messages with final results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The chat session ID.
|
||||||
|
tool_call_id: The tool call ID to find the message.
|
||||||
|
new_content: The new content to set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a message was updated, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await PrismaChatMessage.prisma().update_many(
|
||||||
|
where={
|
||||||
|
"sessionId": session_id,
|
||||||
|
"toolCallId": tool_call_id,
|
||||||
|
},
|
||||||
|
data={
|
||||||
|
"content": new_content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if result == 0:
|
||||||
|
logger.warning(
|
||||||
|
f"No message found to update for session {session_id}, "
|
||||||
|
f"tool_call_id {tool_call_id}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to update tool message for session {session_id}, "
|
||||||
|
f"tool_call_id {tool_call_id}: {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|||||||
@@ -295,6 +295,21 @@ async def cache_chat_session(session: ChatSession) -> None:
|
|||||||
await _cache_session(session)
|
await _cache_session(session)
|
||||||
|
|
||||||
|
|
||||||
|
async def invalidate_session_cache(session_id: str) -> None:
|
||||||
|
"""Invalidate a chat session from Redis cache.
|
||||||
|
|
||||||
|
Used by background tasks to ensure fresh data is loaded on next access.
|
||||||
|
This is best-effort - Redis failures are logged but don't fail the operation.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis_key = _get_session_cache_key(session_id)
|
||||||
|
async_redis = await get_redis_async()
|
||||||
|
await async_redis.delete(redis_key)
|
||||||
|
except Exception as e:
|
||||||
|
# Best-effort: log but don't fail - cache will expire naturally
|
||||||
|
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||||
"""Get a chat session from the database."""
|
"""Get a chat session from the database."""
|
||||||
prisma_session = await chat_db.get_chat_session(session_id)
|
prisma_session = await chat_db.get_chat_session(session_id)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from openai import (
|
|||||||
)
|
)
|
||||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||||
|
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.data.understanding import (
|
from backend.data.understanding import (
|
||||||
format_understanding_for_prompt,
|
format_understanding_for_prompt,
|
||||||
get_business_understanding,
|
get_business_understanding,
|
||||||
@@ -24,6 +25,7 @@ from backend.data.understanding import (
|
|||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
from . import db as chat_db
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import (
|
from .model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@@ -31,6 +33,7 @@ from .model import (
|
|||||||
Usage,
|
Usage,
|
||||||
cache_chat_session,
|
cache_chat_session,
|
||||||
get_chat_session,
|
get_chat_session,
|
||||||
|
invalidate_session_cache,
|
||||||
update_session_title,
|
update_session_title,
|
||||||
upsert_chat_session,
|
upsert_chat_session,
|
||||||
)
|
)
|
||||||
@@ -48,8 +51,13 @@ from .response_model import (
|
|||||||
StreamToolOutputAvailable,
|
StreamToolOutputAvailable,
|
||||||
StreamUsage,
|
StreamUsage,
|
||||||
)
|
)
|
||||||
from .tools import execute_tool, tools
|
from .tools import execute_tool, get_tool, tools
|
||||||
from .tools.models import ErrorResponse
|
from .tools.models import (
|
||||||
|
ErrorResponse,
|
||||||
|
OperationInProgressResponse,
|
||||||
|
OperationPendingResponse,
|
||||||
|
OperationStartedResponse,
|
||||||
|
)
|
||||||
from .tracking import track_user_message
|
from .tracking import track_user_message
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -61,11 +69,126 @@ client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|||||||
|
|
||||||
langfuse = get_client()
|
langfuse = get_client()
|
||||||
|
|
||||||
|
# Redis key prefix for tracking running long-running operations
|
||||||
|
# Used for idempotency across Kubernetes pods - prevents duplicate executions on browser refresh
|
||||||
|
RUNNING_OPERATION_PREFIX = "chat:running_operation:"
|
||||||
|
|
||||||
class LangfuseNotConfiguredError(Exception):
|
# Default system prompt used when Langfuse is not configured
|
||||||
"""Raised when Langfuse is required but not configured."""
|
# This is a snapshot of the "CoPilot Prompt" from Langfuse (version 11)
|
||||||
|
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
|
||||||
|
|
||||||
pass
|
Here is everything you know about the current user from previous interactions:
|
||||||
|
|
||||||
|
<users_information>
|
||||||
|
{users_information}
|
||||||
|
</users_information>
|
||||||
|
|
||||||
|
## YOUR CORE MANDATE
|
||||||
|
|
||||||
|
You are action-oriented. Your success is measured by:
|
||||||
|
- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"?
|
||||||
|
- **Demonstrable Proof**: Show working automations, not descriptions of what's possible
|
||||||
|
- **Time Saved**: Focus on tangible efficiency gains
|
||||||
|
- **Quality Output**: Deliver results that meet or exceed expectations
|
||||||
|
|
||||||
|
## YOUR WORKFLOW
|
||||||
|
|
||||||
|
Adapt flexibly to the conversation context. Not every interaction requires all stages:
|
||||||
|
|
||||||
|
1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations.
|
||||||
|
|
||||||
|
2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task.
|
||||||
|
|
||||||
|
3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.).
|
||||||
|
|
||||||
|
4. **Discover or Create Agents**:
|
||||||
|
- **Always check the user's library first** with `find_library_agent` (these may be customized to their needs)
|
||||||
|
- Search the marketplace with `find_agent` for pre-built automations
|
||||||
|
- Find reusable components with `find_block`
|
||||||
|
- Create custom solutions with `create_agent` if nothing suitable exists
|
||||||
|
- Modify existing library agents with `edit_agent`
|
||||||
|
|
||||||
|
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
|
||||||
|
|
||||||
|
6. **Show Results**: Display outputs using `agent_output`.
|
||||||
|
|
||||||
|
## AVAILABLE TOOLS
|
||||||
|
|
||||||
|
**Understanding & Discovery:**
|
||||||
|
- `add_understanding`: Create a memory about the user's business or use cases for future sessions
|
||||||
|
- `search_docs`: Search platform documentation for specific technical information
|
||||||
|
- `get_doc_page`: Retrieve full text of a specific documentation page
|
||||||
|
|
||||||
|
**Agent Discovery:**
|
||||||
|
- `find_library_agent`: Search the user's existing agents (CHECK HERE FIRST—these may be customized)
|
||||||
|
- `find_agent`: Search the marketplace for pre-built automations
|
||||||
|
- `find_block`: Find pre-written code units that perform specific tasks (agents are built from blocks)
|
||||||
|
|
||||||
|
**Agent Creation & Editing:**
|
||||||
|
- `create_agent`: Create a new automation agent
|
||||||
|
- `edit_agent`: Modify an agent in the user's library
|
||||||
|
|
||||||
|
**Execution & Output:**
|
||||||
|
- `run_agent`: Run an agent now, schedule it, or set up a webhook trigger
|
||||||
|
- `run_block`: Test or run a specific block independently
|
||||||
|
- `agent_output`: View results from previous agent runs
|
||||||
|
|
||||||
|
## BEHAVIORAL GUIDELINES
|
||||||
|
|
||||||
|
**Be Concise:**
|
||||||
|
- Target 2-5 short lines maximum
|
||||||
|
- Make every word count—no repetition or filler
|
||||||
|
- Use lightweight structure for scannability (bullets, numbered lists, short prompts)
|
||||||
|
- Avoid jargon (blocks, slugs, cron) unless the user asks
|
||||||
|
|
||||||
|
**Be Proactive:**
|
||||||
|
- Suggest next steps before being asked
|
||||||
|
- Anticipate needs based on conversation context and user information
|
||||||
|
- Look for opportunities to expand scope when relevant
|
||||||
|
- Reveal capabilities through action, not explanation
|
||||||
|
|
||||||
|
**Use Tools Effectively:**
|
||||||
|
- Select the right tool for each task
|
||||||
|
- **Always check `find_library_agent` before searching the marketplace**
|
||||||
|
- Use `add_understanding` to capture valuable business context
|
||||||
|
- When tool calls fail, try alternative approaches
|
||||||
|
|
||||||
|
## CRITICAL REMINDER
|
||||||
|
|
||||||
|
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
|
||||||
|
|
||||||
|
# Module-level set to hold strong references to background tasks.
|
||||||
|
# This prevents asyncio from garbage collecting tasks before they complete.
|
||||||
|
# Tasks are automatically removed on completion via done_callback.
|
||||||
|
_background_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
|
||||||
|
async def _mark_operation_started(tool_call_id: str) -> bool:
|
||||||
|
"""Mark a long-running operation as started (Redis-based).
|
||||||
|
|
||||||
|
Returns True if successfully marked (operation was not already running),
|
||||||
|
False if operation was already running (lost race condition).
|
||||||
|
Raises exception if Redis is unavailable (fail-closed).
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
key = f"{RUNNING_OPERATION_PREFIX}{tool_call_id}"
|
||||||
|
# SETNX with TTL - atomic "set if not exists"
|
||||||
|
result = await redis.set(key, "1", ex=config.long_running_operation_ttl, nx=True)
|
||||||
|
return result is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def _mark_operation_completed(tool_call_id: str) -> None:
|
||||||
|
"""Mark a long-running operation as completed (remove Redis key).
|
||||||
|
|
||||||
|
This is best-effort - if Redis fails, the TTL will eventually clean up.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
key = f"{RUNNING_OPERATION_PREFIX}{tool_call_id}"
|
||||||
|
await redis.delete(key)
|
||||||
|
except Exception as e:
|
||||||
|
# Non-critical: TTL will clean up eventually
|
||||||
|
logger.warning(f"Failed to delete running operation key {tool_call_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _is_langfuse_configured() -> bool:
|
def _is_langfuse_configured() -> bool:
|
||||||
@@ -75,6 +198,30 @@ def _is_langfuse_configured() -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_system_prompt_template(context: str) -> str:
|
||||||
|
"""Get the system prompt, trying Langfuse first with fallback to default.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: The user context/information to compile into the prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The compiled system prompt string.
|
||||||
|
"""
|
||||||
|
if _is_langfuse_configured():
|
||||||
|
try:
|
||||||
|
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||||
|
# Use asyncio.to_thread to avoid blocking the event loop
|
||||||
|
prompt = await asyncio.to_thread(
|
||||||
|
langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0
|
||||||
|
)
|
||||||
|
return prompt.compile(users_information=context)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch prompt from Langfuse, using default: {e}")
|
||||||
|
|
||||||
|
# Fallback to default prompt
|
||||||
|
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
||||||
|
|
||||||
|
|
||||||
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||||
"""Build the full system prompt including business understanding if available.
|
"""Build the full system prompt including business understanding if available.
|
||||||
|
|
||||||
@@ -83,12 +230,8 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
|||||||
If "default" and this is the user's first session, will use "onboarding" instead.
|
If "default" and this is the user's first session, will use "onboarding" instead.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (compiled prompt string, Langfuse prompt object for tracing)
|
Tuple of (compiled prompt string, business understanding object)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
|
||||||
prompt = langfuse.get_prompt(config.langfuse_prompt_name, cache_ttl_seconds=0)
|
|
||||||
|
|
||||||
# If user is authenticated, try to fetch their business understanding
|
# If user is authenticated, try to fetch their business understanding
|
||||||
understanding = None
|
understanding = None
|
||||||
if user_id:
|
if user_id:
|
||||||
@@ -97,12 +240,13 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to fetch business understanding: {e}")
|
logger.warning(f"Failed to fetch business understanding: {e}")
|
||||||
understanding = None
|
understanding = None
|
||||||
|
|
||||||
if understanding:
|
if understanding:
|
||||||
context = format_understanding_for_prompt(understanding)
|
context = format_understanding_for_prompt(understanding)
|
||||||
else:
|
else:
|
||||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||||
|
|
||||||
compiled = prompt.compile(users_information=context)
|
compiled = await _get_system_prompt_template(context)
|
||||||
return compiled, understanding
|
return compiled, understanding
|
||||||
|
|
||||||
|
|
||||||
@@ -210,16 +354,6 @@ async def stream_chat_completion(
|
|||||||
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if Langfuse is configured - required for chat functionality
|
|
||||||
if not _is_langfuse_configured():
|
|
||||||
logger.error("Chat request failed: Langfuse is not configured")
|
|
||||||
yield StreamError(
|
|
||||||
errorText="Chat service is not available. Langfuse must be configured "
|
|
||||||
"with LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables."
|
|
||||||
)
|
|
||||||
yield StreamFinish()
|
|
||||||
return
|
|
||||||
|
|
||||||
# Only fetch from Redis if session not provided (initial call)
|
# Only fetch from Redis if session not provided (initial call)
|
||||||
if session is None:
|
if session is None:
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
@@ -315,6 +449,7 @@ async def stream_chat_completion(
|
|||||||
has_yielded_end = False
|
has_yielded_end = False
|
||||||
has_yielded_error = False
|
has_yielded_error = False
|
||||||
has_done_tool_call = False
|
has_done_tool_call = False
|
||||||
|
has_long_running_tool_call = False # Track if we had a long-running tool call
|
||||||
has_received_text = False
|
has_received_text = False
|
||||||
text_streaming_ended = False
|
text_streaming_ended = False
|
||||||
tool_response_messages: list[ChatMessage] = []
|
tool_response_messages: list[ChatMessage] = []
|
||||||
@@ -336,7 +471,6 @@ async def stream_chat_completion(
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
text_block_id=text_block_id,
|
text_block_id=text_block_id,
|
||||||
):
|
):
|
||||||
|
|
||||||
if isinstance(chunk, StreamTextStart):
|
if isinstance(chunk, StreamTextStart):
|
||||||
# Emit text-start before first text delta
|
# Emit text-start before first text delta
|
||||||
if not has_received_text:
|
if not has_received_text:
|
||||||
@@ -394,13 +528,34 @@ async def stream_chat_completion(
|
|||||||
if isinstance(chunk.output, str)
|
if isinstance(chunk.output, str)
|
||||||
else orjson.dumps(chunk.output).decode("utf-8")
|
else orjson.dumps(chunk.output).decode("utf-8")
|
||||||
)
|
)
|
||||||
tool_response_messages.append(
|
# Skip saving long-running operation responses - messages already saved in _yield_tool_call
|
||||||
ChatMessage(
|
# Use JSON parsing instead of substring matching to avoid false positives
|
||||||
role="tool",
|
is_long_running_response = False
|
||||||
content=result_content,
|
try:
|
||||||
tool_call_id=chunk.toolCallId,
|
parsed = orjson.loads(result_content)
|
||||||
|
if isinstance(parsed, dict) and parsed.get("type") in (
|
||||||
|
"operation_started",
|
||||||
|
"operation_in_progress",
|
||||||
|
):
|
||||||
|
is_long_running_response = True
|
||||||
|
except (orjson.JSONDecodeError, TypeError):
|
||||||
|
pass # Not JSON or not a dict - treat as regular response
|
||||||
|
if is_long_running_response:
|
||||||
|
# Remove from accumulated_tool_calls since assistant message was already saved
|
||||||
|
accumulated_tool_calls[:] = [
|
||||||
|
tc
|
||||||
|
for tc in accumulated_tool_calls
|
||||||
|
if tc["id"] != chunk.toolCallId
|
||||||
|
]
|
||||||
|
has_long_running_tool_call = True
|
||||||
|
else:
|
||||||
|
tool_response_messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="tool",
|
||||||
|
content=result_content,
|
||||||
|
tool_call_id=chunk.toolCallId,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
has_done_tool_call = True
|
has_done_tool_call = True
|
||||||
# Track if any tool execution failed
|
# Track if any tool execution failed
|
||||||
if not chunk.success:
|
if not chunk.success:
|
||||||
@@ -576,7 +731,14 @@ async def stream_chat_completion(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Extended session messages, new message_count={len(session.messages)}"
|
f"Extended session messages, new message_count={len(session.messages)}"
|
||||||
)
|
)
|
||||||
if messages_to_save or has_appended_streaming_message:
|
# Save if there are regular (non-long-running) tool responses or streaming message.
|
||||||
|
# Long-running tools save their own state, but we still need to save regular tools
|
||||||
|
# that may be in the same response.
|
||||||
|
has_regular_tool_responses = len(tool_response_messages) > 0
|
||||||
|
if has_regular_tool_responses or (
|
||||||
|
not has_long_running_tool_call
|
||||||
|
and (messages_to_save or has_appended_streaming_message)
|
||||||
|
):
|
||||||
await upsert_chat_session(session)
|
await upsert_chat_session(session)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -585,7 +747,9 @@ async def stream_chat_completion(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# If we did a tool call, stream the chat completion again to get the next response
|
# If we did a tool call, stream the chat completion again to get the next response
|
||||||
if has_done_tool_call:
|
# Skip only if ALL tools were long-running (they handle their own completion)
|
||||||
|
has_regular_tools = len(tool_response_messages) > 0
|
||||||
|
if has_done_tool_call and (has_regular_tools or not has_long_running_tool_call):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Tool call executed, streaming chat completion again to get assistant response"
|
"Tool call executed, streaming chat completion again to get assistant response"
|
||||||
)
|
)
|
||||||
@@ -725,6 +889,114 @@ async def _summarize_messages(
|
|||||||
return summary or "No summary available."
|
return summary or "No summary available."
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_tool_pairs_intact(
|
||||||
|
recent_messages: list[dict],
|
||||||
|
all_messages: list[dict],
|
||||||
|
start_index: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Ensure tool_call/tool_response pairs stay together after slicing.
|
||||||
|
|
||||||
|
When slicing messages for context compaction, a naive slice can separate
|
||||||
|
an assistant message containing tool_calls from its corresponding tool
|
||||||
|
response messages. This causes API validation errors (e.g., Anthropic's
|
||||||
|
"unexpected tool_use_id found in tool_result blocks").
|
||||||
|
|
||||||
|
This function checks for orphan tool responses in the slice and extends
|
||||||
|
backwards to include their corresponding assistant messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recent_messages: The sliced messages to validate
|
||||||
|
all_messages: The complete message list (for looking up missing assistants)
|
||||||
|
start_index: The index in all_messages where recent_messages begins
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A potentially extended list of messages with tool pairs intact
|
||||||
|
"""
|
||||||
|
if not recent_messages:
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
# Collect all tool_call_ids from assistant messages in the slice
|
||||||
|
available_tool_call_ids: set[str] = set()
|
||||||
|
for msg in recent_messages:
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||||
|
for tc in msg["tool_calls"]:
|
||||||
|
tc_id = tc.get("id")
|
||||||
|
if tc_id:
|
||||||
|
available_tool_call_ids.add(tc_id)
|
||||||
|
|
||||||
|
# Find orphan tool responses (tool messages whose tool_call_id is missing)
|
||||||
|
orphan_tool_call_ids: set[str] = set()
|
||||||
|
for msg in recent_messages:
|
||||||
|
if msg.get("role") == "tool":
|
||||||
|
tc_id = msg.get("tool_call_id")
|
||||||
|
if tc_id and tc_id not in available_tool_call_ids:
|
||||||
|
orphan_tool_call_ids.add(tc_id)
|
||||||
|
|
||||||
|
if not orphan_tool_call_ids:
|
||||||
|
# No orphans, slice is valid
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
# Find the assistant messages that contain the orphan tool_call_ids
|
||||||
|
# Search backwards from start_index in all_messages
|
||||||
|
messages_to_prepend: list[dict] = []
|
||||||
|
for i in range(start_index - 1, -1, -1):
|
||||||
|
msg = all_messages[i]
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||||
|
msg_tool_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")}
|
||||||
|
if msg_tool_ids & orphan_tool_call_ids:
|
||||||
|
# This assistant message has tool_calls we need
|
||||||
|
# Also collect its contiguous tool responses that follow it
|
||||||
|
assistant_and_responses: list[dict] = [msg]
|
||||||
|
|
||||||
|
# Scan forward from this assistant to collect tool responses
|
||||||
|
for j in range(i + 1, start_index):
|
||||||
|
following_msg = all_messages[j]
|
||||||
|
if following_msg.get("role") == "tool":
|
||||||
|
tool_id = following_msg.get("tool_call_id")
|
||||||
|
if tool_id and tool_id in msg_tool_ids:
|
||||||
|
assistant_and_responses.append(following_msg)
|
||||||
|
else:
|
||||||
|
# Stop at first non-tool message
|
||||||
|
break
|
||||||
|
|
||||||
|
# Prepend the assistant and its tool responses (maintain order)
|
||||||
|
messages_to_prepend = assistant_and_responses + messages_to_prepend
|
||||||
|
# Mark these as found
|
||||||
|
orphan_tool_call_ids -= msg_tool_ids
|
||||||
|
# Also add this assistant's tool_call_ids to available set
|
||||||
|
available_tool_call_ids |= msg_tool_ids
|
||||||
|
|
||||||
|
if not orphan_tool_call_ids:
|
||||||
|
# Found all missing assistants
|
||||||
|
break
|
||||||
|
|
||||||
|
if orphan_tool_call_ids:
|
||||||
|
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
||||||
|
# This shouldn't happen in normal operation but handles edge cases
|
||||||
|
logger.warning(
|
||||||
|
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
||||||
|
"Removing orphan tool responses."
|
||||||
|
)
|
||||||
|
recent_messages = [
|
||||||
|
msg
|
||||||
|
for msg in recent_messages
|
||||||
|
if not (
|
||||||
|
msg.get("role") == "tool"
|
||||||
|
and msg.get("tool_call_id") in orphan_tool_call_ids
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if messages_to_prepend:
|
||||||
|
logger.info(
|
||||||
|
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
||||||
|
f"tool_call/tool_response pairs"
|
||||||
|
)
|
||||||
|
return messages_to_prepend + recent_messages
|
||||||
|
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
|
||||||
async def _stream_chat_chunks(
|
async def _stream_chat_chunks(
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
tools: list[ChatCompletionToolParam],
|
tools: list[ChatCompletionToolParam],
|
||||||
@@ -816,7 +1088,15 @@ async def _stream_chat_chunks(
|
|||||||
# Always attempt mitigation when over limit, even with few messages
|
# Always attempt mitigation when over limit, even with few messages
|
||||||
if messages:
|
if messages:
|
||||||
# Split messages based on whether system prompt exists
|
# Split messages based on whether system prompt exists
|
||||||
recent_messages = messages[-KEEP_RECENT:]
|
# Calculate start index for the slice
|
||||||
|
slice_start = max(0, len(messages_dict) - KEEP_RECENT)
|
||||||
|
recent_messages = messages_dict[-KEEP_RECENT:]
|
||||||
|
|
||||||
|
# Ensure tool_call/tool_response pairs stay together
|
||||||
|
# This prevents API errors from orphan tool responses
|
||||||
|
recent_messages = _ensure_tool_pairs_intact(
|
||||||
|
recent_messages, messages_dict, slice_start
|
||||||
|
)
|
||||||
|
|
||||||
if has_system_prompt:
|
if has_system_prompt:
|
||||||
# Keep system prompt separate, summarize everything between system and recent
|
# Keep system prompt separate, summarize everything between system and recent
|
||||||
@@ -903,6 +1183,13 @@ async def _stream_chat_chunks(
|
|||||||
if len(recent_messages) >= keep_count
|
if len(recent_messages) >= keep_count
|
||||||
else recent_messages
|
else recent_messages
|
||||||
)
|
)
|
||||||
|
# Ensure tool pairs stay intact in the reduced slice
|
||||||
|
reduced_slice_start = max(
|
||||||
|
0, len(recent_messages) - keep_count
|
||||||
|
)
|
||||||
|
reduced_recent = _ensure_tool_pairs_intact(
|
||||||
|
reduced_recent, recent_messages, reduced_slice_start
|
||||||
|
)
|
||||||
if has_system_prompt:
|
if has_system_prompt:
|
||||||
messages = [
|
messages = [
|
||||||
system_msg,
|
system_msg,
|
||||||
@@ -961,7 +1248,10 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
# Create a base list excluding system prompt to avoid duplication
|
# Create a base list excluding system prompt to avoid duplication
|
||||||
# This is the pool of messages we'll slice from in the loop
|
# This is the pool of messages we'll slice from in the loop
|
||||||
base_msgs = messages[1:] if has_system_prompt else messages
|
# Use messages_dict for type consistency with _ensure_tool_pairs_intact
|
||||||
|
base_msgs = (
|
||||||
|
messages_dict[1:] if has_system_prompt else messages_dict
|
||||||
|
)
|
||||||
|
|
||||||
# Try progressively smaller keep counts
|
# Try progressively smaller keep counts
|
||||||
new_token_count = token_count # Initialize with current count
|
new_token_count = token_count # Initialize with current count
|
||||||
@@ -984,6 +1274,12 @@ async def _stream_chat_chunks(
|
|||||||
# Slice from base_msgs to get recent messages (without system prompt)
|
# Slice from base_msgs to get recent messages (without system prompt)
|
||||||
recent_messages = base_msgs[-keep_count:]
|
recent_messages = base_msgs[-keep_count:]
|
||||||
|
|
||||||
|
# Ensure tool pairs stay intact in the reduced slice
|
||||||
|
reduced_slice_start = max(0, len(base_msgs) - keep_count)
|
||||||
|
recent_messages = _ensure_tool_pairs_intact(
|
||||||
|
recent_messages, base_msgs, reduced_slice_start
|
||||||
|
)
|
||||||
|
|
||||||
if has_system_prompt:
|
if has_system_prompt:
|
||||||
messages = [system_msg] + recent_messages
|
messages = [system_msg] + recent_messages
|
||||||
else:
|
else:
|
||||||
@@ -1260,17 +1556,19 @@ async def _yield_tool_call(
|
|||||||
"""
|
"""
|
||||||
Yield a tool call and its execution result.
|
Yield a tool call and its execution result.
|
||||||
|
|
||||||
For long-running tools, yields heartbeat events every 15 seconds to keep
|
For tools marked with `is_long_running=True` (like agent generation), spawns a
|
||||||
the SSE connection alive through proxies and load balancers.
|
background task so the operation survives SSE disconnections. For other tools,
|
||||||
|
yields heartbeat events every 15 seconds to keep the SSE connection alive.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
orjson.JSONDecodeError: If tool call arguments cannot be parsed as JSON
|
orjson.JSONDecodeError: If tool call arguments cannot be parsed as JSON
|
||||||
KeyError: If expected tool call fields are missing
|
KeyError: If expected tool call fields are missing
|
||||||
TypeError: If tool call structure is invalid
|
TypeError: If tool call structure is invalid
|
||||||
"""
|
"""
|
||||||
|
import uuid as uuid_module
|
||||||
|
|
||||||
tool_name = tool_calls[yield_idx]["function"]["name"]
|
tool_name = tool_calls[yield_idx]["function"]["name"]
|
||||||
tool_call_id = tool_calls[yield_idx]["id"]
|
tool_call_id = tool_calls[yield_idx]["id"]
|
||||||
logger.info(f"Yielding tool call: {tool_calls[yield_idx]}")
|
|
||||||
|
|
||||||
# Parse tool call arguments - handle empty arguments gracefully
|
# Parse tool call arguments - handle empty arguments gracefully
|
||||||
raw_arguments = tool_calls[yield_idx]["function"]["arguments"]
|
raw_arguments = tool_calls[yield_idx]["function"]["arguments"]
|
||||||
@@ -1285,7 +1583,151 @@ async def _yield_tool_call(
|
|||||||
input=arguments,
|
input=arguments,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run tool execution in background task with heartbeats to keep connection alive
|
# Check if this tool is long-running (survives SSE disconnection)
|
||||||
|
tool = get_tool(tool_name)
|
||||||
|
if tool and tool.is_long_running:
|
||||||
|
# Atomic check-and-set: returns False if operation already running (lost race)
|
||||||
|
if not await _mark_operation_started(tool_call_id):
|
||||||
|
logger.info(
|
||||||
|
f"Tool call {tool_call_id} already in progress, returning status"
|
||||||
|
)
|
||||||
|
# Build dynamic message based on tool name
|
||||||
|
if tool_name == "create_agent":
|
||||||
|
in_progress_msg = "Agent creation already in progress. Please wait..."
|
||||||
|
elif tool_name == "edit_agent":
|
||||||
|
in_progress_msg = "Agent edit already in progress. Please wait..."
|
||||||
|
else:
|
||||||
|
in_progress_msg = f"{tool_name} already in progress. Please wait..."
|
||||||
|
|
||||||
|
yield StreamToolOutputAvailable(
|
||||||
|
toolCallId=tool_call_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
output=OperationInProgressResponse(
|
||||||
|
message=in_progress_msg,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
).model_dump_json(),
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Generate operation ID
|
||||||
|
operation_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
|
# Build a user-friendly message based on tool and arguments
|
||||||
|
if tool_name == "create_agent":
|
||||||
|
agent_desc = arguments.get("description", "")
|
||||||
|
# Truncate long descriptions for the message
|
||||||
|
desc_preview = (
|
||||||
|
(agent_desc[:100] + "...") if len(agent_desc) > 100 else agent_desc
|
||||||
|
)
|
||||||
|
pending_msg = (
|
||||||
|
f"Creating your agent: {desc_preview}"
|
||||||
|
if desc_preview
|
||||||
|
else "Creating agent... This may take a few minutes."
|
||||||
|
)
|
||||||
|
started_msg = (
|
||||||
|
"Agent creation started. You can close this tab - "
|
||||||
|
"check your library in a few minutes."
|
||||||
|
)
|
||||||
|
elif tool_name == "edit_agent":
|
||||||
|
changes = arguments.get("changes", "")
|
||||||
|
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
|
||||||
|
pending_msg = (
|
||||||
|
f"Editing agent: {changes_preview}"
|
||||||
|
if changes_preview
|
||||||
|
else "Editing agent... This may take a few minutes."
|
||||||
|
)
|
||||||
|
started_msg = (
|
||||||
|
"Agent edit started. You can close this tab - "
|
||||||
|
"check your library in a few minutes."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pending_msg = f"Running {tool_name}... This may take a few minutes."
|
||||||
|
started_msg = (
|
||||||
|
f"{tool_name} started. You can close this tab - "
|
||||||
|
"check back in a few minutes."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track appended messages for rollback on failure
|
||||||
|
assistant_message: ChatMessage | None = None
|
||||||
|
pending_message: ChatMessage | None = None
|
||||||
|
|
||||||
|
# Wrap session save and task creation in try-except to release lock on failure
|
||||||
|
try:
|
||||||
|
# Save assistant message with tool_call FIRST (required by LLM)
|
||||||
|
assistant_message = ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[tool_calls[yield_idx]],
|
||||||
|
)
|
||||||
|
session.messages.append(assistant_message)
|
||||||
|
|
||||||
|
# Then save pending tool result
|
||||||
|
pending_message = ChatMessage(
|
||||||
|
role="tool",
|
||||||
|
content=OperationPendingResponse(
|
||||||
|
message=pending_msg,
|
||||||
|
operation_id=operation_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
).model_dump_json(),
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
session.messages.append(pending_message)
|
||||||
|
await upsert_chat_session(session)
|
||||||
|
logger.info(
|
||||||
|
f"Saved pending operation {operation_id} for tool {tool_name} "
|
||||||
|
f"in session {session.session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store task reference in module-level set to prevent GC before completion
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_execute_long_running_tool(
|
||||||
|
tool_name=tool_name,
|
||||||
|
parameters=arguments,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
operation_id=operation_id,
|
||||||
|
session_id=session.session_id,
|
||||||
|
user_id=session.user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
except Exception as e:
|
||||||
|
# Roll back appended messages to prevent data corruption on subsequent saves
|
||||||
|
if (
|
||||||
|
pending_message
|
||||||
|
and session.messages
|
||||||
|
and session.messages[-1] == pending_message
|
||||||
|
):
|
||||||
|
session.messages.pop()
|
||||||
|
if (
|
||||||
|
assistant_message
|
||||||
|
and session.messages
|
||||||
|
and session.messages[-1] == assistant_message
|
||||||
|
):
|
||||||
|
session.messages.pop()
|
||||||
|
|
||||||
|
# Release the Redis lock since the background task won't be spawned
|
||||||
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
logger.error(
|
||||||
|
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Return immediately - don't wait for completion
|
||||||
|
yield StreamToolOutputAvailable(
|
||||||
|
toolCallId=tool_call_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
output=OperationStartedResponse(
|
||||||
|
message=started_msg,
|
||||||
|
operation_id=operation_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
).model_dump_json(),
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Normal flow: Run tool execution in background task with heartbeats
|
||||||
tool_task = asyncio.create_task(
|
tool_task = asyncio.create_task(
|
||||||
execute_tool(
|
execute_tool(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
@@ -1335,3 +1777,195 @@ async def _yield_tool_call(
|
|||||||
)
|
)
|
||||||
|
|
||||||
yield tool_execution_response
|
yield tool_execution_response
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_long_running_tool(
|
||||||
|
tool_name: str,
|
||||||
|
parameters: dict[str, Any],
|
||||||
|
tool_call_id: str,
|
||||||
|
operation_id: str,
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a long-running tool in background and update chat history with result.
|
||||||
|
|
||||||
|
This function runs independently of the SSE connection, so the operation
|
||||||
|
survives if the user closes their browser tab.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load fresh session (not stale reference)
|
||||||
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
if not session:
|
||||||
|
logger.error(f"Session {session_id} not found for background tool")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Execute the actual tool
|
||||||
|
result = await execute_tool(
|
||||||
|
tool_name=tool_name,
|
||||||
|
parameters=parameters,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the pending message with result
|
||||||
|
await _update_pending_operation(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
result=(
|
||||||
|
result.output
|
||||||
|
if isinstance(result.output, str)
|
||||||
|
else orjson.dumps(result.output).decode("utf-8")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Background tool {tool_name} completed for session {session_id}")
|
||||||
|
|
||||||
|
# Generate LLM continuation so user sees response when they poll/refresh
|
||||||
|
await _generate_llm_continuation(session_id=session_id, user_id=user_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message=f"Tool {tool_name} failed: {str(e)}",
|
||||||
|
)
|
||||||
|
await _update_pending_operation(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
result=error_response.model_dump_json(),
|
||||||
|
)
|
||||||
|
# Generate LLM continuation so user sees explanation even for errors
|
||||||
|
try:
|
||||||
|
await _generate_llm_continuation(session_id=session_id, user_id=user_id)
|
||||||
|
except Exception as llm_err:
|
||||||
|
logger.warning(f"Failed to generate LLM continuation for error: {llm_err}")
|
||||||
|
finally:
|
||||||
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_pending_operation(
|
||||||
|
session_id: str,
|
||||||
|
tool_call_id: str,
|
||||||
|
result: str,
|
||||||
|
) -> None:
|
||||||
|
"""Update the pending tool message with final result.
|
||||||
|
|
||||||
|
This is called by background tasks when long-running operations complete.
|
||||||
|
"""
|
||||||
|
# Update the message in database
|
||||||
|
updated = await chat_db.update_tool_message_content(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
new_content=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
if updated:
|
||||||
|
# Invalidate Redis cache so next load gets fresh data
|
||||||
|
# Wrap in try/except to prevent cache failures from triggering error handling
|
||||||
|
# that would overwrite our successful DB update
|
||||||
|
try:
|
||||||
|
await invalidate_session_cache(session_id)
|
||||||
|
except Exception as e:
|
||||||
|
# Non-critical: cache will eventually be refreshed on next load
|
||||||
|
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||||
|
logger.info(
|
||||||
|
f"Updated pending operation for tool_call_id {tool_call_id} "
|
||||||
|
f"in session {session_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to update pending operation for tool_call_id {tool_call_id} "
|
||||||
|
f"in session {session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_llm_continuation(
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Generate an LLM response after a long-running tool completes.
|
||||||
|
|
||||||
|
This is called by background tasks to continue the conversation
|
||||||
|
after a tool result is saved. The response is saved to the database
|
||||||
|
so users see it when they refresh or poll.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load fresh session from DB (bypass cache to get the updated tool result)
|
||||||
|
await invalidate_session_cache(session_id)
|
||||||
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
if not session:
|
||||||
|
logger.error(f"Session {session_id} not found for LLM continuation")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build system prompt
|
||||||
|
system_prompt, _ = await _build_system_prompt(user_id)
|
||||||
|
|
||||||
|
# Build messages in OpenAI format
|
||||||
|
messages = session.to_openai_messages()
|
||||||
|
if system_prompt:
|
||||||
|
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||||
|
|
||||||
|
system_message = ChatCompletionSystemMessageParam(
|
||||||
|
role="system",
|
||||||
|
content=system_prompt,
|
||||||
|
)
|
||||||
|
messages = [system_message] + messages
|
||||||
|
|
||||||
|
# Build extra_body for tracing
|
||||||
|
extra_body: dict[str, Any] = {
|
||||||
|
"posthogProperties": {
|
||||||
|
"environment": settings.config.app_env.value,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if user_id:
|
||||||
|
extra_body["user"] = user_id[:128]
|
||||||
|
extra_body["posthogDistinctId"] = user_id
|
||||||
|
if session_id:
|
||||||
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
|
# Make non-streaming LLM call (no tools - just text response)
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
# No tools parameter = text-only response (no tool calls)
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=config.model,
|
||||||
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
|
extra_body=extra_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.choices and response.choices[0].message.content:
|
||||||
|
assistant_content = response.choices[0].message.content
|
||||||
|
|
||||||
|
# Reload session from DB to avoid race condition with user messages
|
||||||
|
# that may have been sent while we were generating the LLM response
|
||||||
|
fresh_session = await get_chat_session(session_id, user_id)
|
||||||
|
if not fresh_session:
|
||||||
|
logger.error(
|
||||||
|
f"Session {session_id} disappeared during LLM continuation"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save assistant message to database
|
||||||
|
assistant_message = ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=assistant_content,
|
||||||
|
)
|
||||||
|
fresh_session.messages.append(assistant_message)
|
||||||
|
|
||||||
|
# Save to database (not cache) to persist the response
|
||||||
|
await upsert_chat_session(fresh_session)
|
||||||
|
|
||||||
|
# Invalidate cache so next poll/refresh gets fresh data
|
||||||
|
await invalidate_session_cache(session_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generated LLM continuation for session {session_id}, "
|
||||||
|
f"response length: {len(assistant_content)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"LLM continuation returned empty response for {session_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||||
|
|||||||
@@ -0,0 +1,79 @@
|
|||||||
|
# CoPilot Tools - Future Ideas
|
||||||
|
|
||||||
|
## Multimodal Image Support for CoPilot
|
||||||
|
|
||||||
|
**Problem:** CoPilot uses a vision-capable model but can't "see" workspace images. When a block generates an image and returns `workspace://abc123`, CoPilot can't evaluate it (e.g., checking blog thumbnail quality).
|
||||||
|
|
||||||
|
**Backend Solution:**
|
||||||
|
When preparing messages for the LLM, detect `workspace://` image references and convert them to proper image content blocks:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Before sending to LLM, scan for workspace image references
|
||||||
|
# and inject them as image content parts
|
||||||
|
|
||||||
|
# Example message transformation:
|
||||||
|
# FROM: {"role": "assistant", "content": "Generated image: workspace://abc123"}
|
||||||
|
# TO: {"role": "assistant", "content": [
|
||||||
|
# {"type": "text", "text": "Generated image: workspace://abc123"},
|
||||||
|
# {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
|
||||||
|
# ]}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Where to implement:**
|
||||||
|
- In the chat stream handler before calling the LLM
|
||||||
|
- Or in a message preprocessing step
|
||||||
|
- Need to fetch image from workspace, convert to base64, add as image content
|
||||||
|
|
||||||
|
**Considerations:**
|
||||||
|
- Only do this for image MIME types (image/png, image/jpeg, etc.)
|
||||||
|
- May want a size limit (don't pass 10MB images)
|
||||||
|
- Track which images were "shown" to the AI for frontend indicator
|
||||||
|
- Cost implications - vision API calls are more expensive
|
||||||
|
|
||||||
|
**Frontend Solution:**
|
||||||
|
Show visual indicator on workspace files in chat:
|
||||||
|
- If AI saw the image: normal display
|
||||||
|
- If AI didn't see it: overlay icon saying "AI can't see this image"
|
||||||
|
|
||||||
|
Requires response metadata indicating which `workspace://` refs were passed to the model.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Output Post-Processing Layer for run_block
|
||||||
|
|
||||||
|
**Problem:** Many blocks produce large outputs that:
|
||||||
|
- Consume massive context (100KB base64 image = ~133KB tokens)
|
||||||
|
- Can't fit in conversation
|
||||||
|
- Break things and cause high LLM costs
|
||||||
|
|
||||||
|
**Proposed Solution:** Instead of modifying individual blocks or `store_media_file()`, implement a centralized output processor in `run_block.py` that handles outputs before they're returned to CoPilot.
|
||||||
|
|
||||||
|
**Benefits:**
|
||||||
|
1. **Centralized** - one place to handle all output processing
|
||||||
|
2. **Future-proof** - new blocks automatically get output processing
|
||||||
|
3. **Keeps blocks pure** - they don't need to know about context constraints
|
||||||
|
4. **Handles all large outputs** - not just images
|
||||||
|
|
||||||
|
**Processing Rules:**
|
||||||
|
- Detect base64 data URIs → save to workspace, return `workspace://` reference
|
||||||
|
- Truncate very long strings (>N chars) with truncation note
|
||||||
|
- Summarize large arrays/lists (e.g., "Array with 1000 items, first 5: [...]")
|
||||||
|
- Handle nested large outputs in dicts recursively
|
||||||
|
- Cap total output size
|
||||||
|
|
||||||
|
**Implementation Location:** `run_block.py` after block execution, before returning `BlockOutputResponse`
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```python
|
||||||
|
def _process_outputs_for_context(
|
||||||
|
outputs: dict[str, list[Any]],
|
||||||
|
workspace_manager: WorkspaceManager,
|
||||||
|
max_string_length: int = 10000,
|
||||||
|
max_array_preview: int = 5,
|
||||||
|
) -> dict[str, list[Any]]:
|
||||||
|
"""Process block outputs to prevent context bloat."""
|
||||||
|
processed = {}
|
||||||
|
for name, values in outputs.items():
|
||||||
|
processed[name] = [_process_value(v, workspace_manager) for v in values]
|
||||||
|
return processed
|
||||||
|
```
|
||||||
@@ -18,6 +18,12 @@ from .get_doc_page import GetDocPageTool
|
|||||||
from .run_agent import RunAgentTool
|
from .run_agent import RunAgentTool
|
||||||
from .run_block import RunBlockTool
|
from .run_block import RunBlockTool
|
||||||
from .search_docs import SearchDocsTool
|
from .search_docs import SearchDocsTool
|
||||||
|
from .workspace_files import (
|
||||||
|
DeleteWorkspaceFileTool,
|
||||||
|
ListWorkspaceFilesTool,
|
||||||
|
ReadWorkspaceFileTool,
|
||||||
|
WriteWorkspaceFileTool,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||||
@@ -37,6 +43,11 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
|||||||
"view_agent_output": AgentOutputTool(),
|
"view_agent_output": AgentOutputTool(),
|
||||||
"search_docs": SearchDocsTool(),
|
"search_docs": SearchDocsTool(),
|
||||||
"get_doc_page": GetDocPageTool(),
|
"get_doc_page": GetDocPageTool(),
|
||||||
|
# Workspace tools for CoPilot file operations
|
||||||
|
"list_workspace_files": ListWorkspaceFilesTool(),
|
||||||
|
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||||
|
"write_workspace_file": WriteWorkspaceFileTool(),
|
||||||
|
"delete_workspace_file": DeleteWorkspaceFileTool(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Export individual tool instances for backwards compatibility
|
# Export individual tool instances for backwards compatibility
|
||||||
@@ -49,6 +60,11 @@ tools: list[ChatCompletionToolParam] = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool(tool_name: str) -> BaseTool | None:
|
||||||
|
"""Get a tool instance by name."""
|
||||||
|
return TOOL_REGISTRY.get(tool_name)
|
||||||
|
|
||||||
|
|
||||||
async def execute_tool(
|
async def execute_tool(
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
parameters: dict[str, Any],
|
parameters: dict[str, Any],
|
||||||
@@ -57,7 +73,7 @@ async def execute_tool(
|
|||||||
tool_call_id: str,
|
tool_call_id: str,
|
||||||
) -> "StreamToolOutputAvailable":
|
) -> "StreamToolOutputAvailable":
|
||||||
"""Execute a tool by name."""
|
"""Execute a tool by name."""
|
||||||
tool = TOOL_REGISTRY.get(tool_name)
|
tool = get_tool(tool_name)
|
||||||
if not tool:
|
if not tool:
|
||||||
raise ValueError(f"Tool {tool_name} not found")
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
|
|
||||||
|
|||||||
@@ -2,27 +2,52 @@
|
|||||||
|
|
||||||
from .core import (
|
from .core import (
|
||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
|
AgentSummary,
|
||||||
|
DecompositionResult,
|
||||||
|
DecompositionStep,
|
||||||
|
LibraryAgentSummary,
|
||||||
|
MarketplaceAgentSummary,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
|
enrich_library_agents_from_steps,
|
||||||
|
extract_search_terms_from_steps,
|
||||||
|
extract_uuids_from_text,
|
||||||
generate_agent,
|
generate_agent,
|
||||||
generate_agent_patch,
|
generate_agent_patch,
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
|
get_all_relevant_agents_for_generation,
|
||||||
|
get_library_agent_by_graph_id,
|
||||||
|
get_library_agent_by_id,
|
||||||
|
get_library_agents_for_generation,
|
||||||
json_to_graph,
|
json_to_graph,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
|
search_marketplace_agents_for_generation,
|
||||||
)
|
)
|
||||||
|
from .errors import get_user_message_for_error
|
||||||
from .service import health_check as check_external_service_health
|
from .service import health_check as check_external_service_health
|
||||||
from .service import is_external_service_configured
|
from .service import is_external_service_configured
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Core functions
|
"AgentGeneratorNotConfiguredError",
|
||||||
|
"AgentSummary",
|
||||||
|
"DecompositionResult",
|
||||||
|
"DecompositionStep",
|
||||||
|
"LibraryAgentSummary",
|
||||||
|
"MarketplaceAgentSummary",
|
||||||
|
"check_external_service_health",
|
||||||
"decompose_goal",
|
"decompose_goal",
|
||||||
|
"enrich_library_agents_from_steps",
|
||||||
|
"extract_search_terms_from_steps",
|
||||||
|
"extract_uuids_from_text",
|
||||||
"generate_agent",
|
"generate_agent",
|
||||||
"generate_agent_patch",
|
"generate_agent_patch",
|
||||||
"save_agent_to_library",
|
|
||||||
"get_agent_as_json",
|
"get_agent_as_json",
|
||||||
"json_to_graph",
|
"get_all_relevant_agents_for_generation",
|
||||||
# Exceptions
|
"get_library_agent_by_graph_id",
|
||||||
"AgentGeneratorNotConfiguredError",
|
"get_library_agent_by_id",
|
||||||
# Service
|
"get_library_agents_for_generation",
|
||||||
|
"get_user_message_for_error",
|
||||||
"is_external_service_configured",
|
"is_external_service_configured",
|
||||||
"check_external_service_health",
|
"json_to_graph",
|
||||||
|
"save_agent_to_library",
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,11 +1,21 @@
|
|||||||
"""Core agent generation functions."""
|
"""Core agent generation functions."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.data.graph import Graph, Link, Node, create_graph
|
from backend.api.features.store import db as store_db
|
||||||
|
from backend.data.graph import (
|
||||||
|
Graph,
|
||||||
|
Link,
|
||||||
|
Node,
|
||||||
|
create_graph,
|
||||||
|
get_graph,
|
||||||
|
get_graph_all_versions,
|
||||||
|
)
|
||||||
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
decompose_goal_external,
|
decompose_goal_external,
|
||||||
@@ -17,6 +27,60 @@ from .service import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LibraryAgentSummary(TypedDict):
|
||||||
|
"""Summary of a library agent for sub-agent composition."""
|
||||||
|
|
||||||
|
graph_id: str
|
||||||
|
graph_version: int
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
output_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class MarketplaceAgentSummary(TypedDict):
|
||||||
|
"""Summary of a marketplace agent for sub-agent composition."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
sub_heading: str
|
||||||
|
creator: str
|
||||||
|
is_marketplace_agent: bool
|
||||||
|
|
||||||
|
|
||||||
|
class DecompositionStep(TypedDict, total=False):
|
||||||
|
"""A single step in decomposed instructions."""
|
||||||
|
|
||||||
|
description: str
|
||||||
|
action: str
|
||||||
|
block_name: str
|
||||||
|
tool: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class DecompositionResult(TypedDict, total=False):
|
||||||
|
"""Result from decompose_goal - can be instructions, questions, or error."""
|
||||||
|
|
||||||
|
type: str # "instructions", "clarifying_questions", "error", etc.
|
||||||
|
steps: list[DecompositionStep]
|
||||||
|
questions: list[dict[str, Any]]
|
||||||
|
error: str
|
||||||
|
error_type: str
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for agent summaries (can be either library or marketplace)
|
||||||
|
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def _to_dict_list(
|
||||||
|
agents: list[AgentSummary] | list[dict[str, Any]] | None,
|
||||||
|
) -> list[dict[str, Any]] | None:
|
||||||
|
"""Convert typed agent summaries to plain dicts for external service calls."""
|
||||||
|
if agents is None:
|
||||||
|
return None
|
||||||
|
return [dict(a) for a in agents]
|
||||||
|
|
||||||
|
|
||||||
class AgentGeneratorNotConfiguredError(Exception):
|
class AgentGeneratorNotConfiguredError(Exception):
|
||||||
"""Raised when the external Agent Generator service is not configured."""
|
"""Raised when the external Agent Generator service is not configured."""
|
||||||
|
|
||||||
@@ -36,15 +100,394 @@ def _check_service_configured() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
_UUID_PATTERN = re.compile(
|
||||||
|
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_uuids_from_text(text: str) -> list[str]:
|
||||||
|
"""Extract all UUID v4 strings from text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text that may contain UUIDs (e.g., user's goal description)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unique UUIDs found in the text (lowercase)
|
||||||
|
"""
|
||||||
|
matches = _UUID_PATTERN.findall(text)
|
||||||
|
return list({m.lower() for m in matches})
|
||||||
|
|
||||||
|
|
||||||
|
async def get_library_agent_by_id(
|
||||||
|
user_id: str, agent_id: str
|
||||||
|
) -> LibraryAgentSummary | None:
|
||||||
|
"""Fetch a specific library agent by its ID (library agent ID or graph_id).
|
||||||
|
|
||||||
|
This function tries multiple lookup strategies:
|
||||||
|
1. First tries to find by graph_id (AgentGraph primary key)
|
||||||
|
2. If not found, tries to find by library agent ID (LibraryAgent primary key)
|
||||||
|
|
||||||
|
This handles both cases:
|
||||||
|
- User provides graph_id (e.g., from AgentExecutorBlock)
|
||||||
|
- User provides library agent ID (e.g., from library URL)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LibraryAgentSummary if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||||
|
return LibraryAgentSummary(
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
graph_version=agent.graph_version,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description,
|
||||||
|
input_schema=agent.input_schema,
|
||||||
|
output_schema=agent.output_schema,
|
||||||
|
)
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||||
|
return LibraryAgentSummary(
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
graph_version=agent.graph_version,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description,
|
||||||
|
input_schema=agent.input_schema,
|
||||||
|
output_schema=agent.output_schema,
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Alias for backward compatibility
|
||||||
|
get_library_agent_by_graph_id = get_library_agent_by_id
|
||||||
|
|
||||||
|
|
||||||
|
async def get_library_agents_for_generation(
|
||||||
|
user_id: str,
|
||||||
|
search_query: str | None = None,
|
||||||
|
exclude_graph_id: str | None = None,
|
||||||
|
max_results: int = 15,
|
||||||
|
) -> list[LibraryAgentSummary]:
|
||||||
|
"""Fetch user's library agents formatted for Agent Generator.
|
||||||
|
|
||||||
|
Uses search-based fetching to return relevant agents instead of all agents.
|
||||||
|
This is more scalable for users with large libraries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
search_query: Optional search term to find relevant agents (user's goal/description)
|
||||||
|
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
|
||||||
|
max_results: Maximum number of agents to return (default 15)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LibraryAgentSummary with schemas for sub-agent composition
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Future enhancement: Add quality filtering based on execution success rate
|
||||||
|
or correctness_score from AgentGraphExecution stats. The current
|
||||||
|
LibraryAgentStatus.ERROR is too aggressive (1 failed run = ERROR).
|
||||||
|
Better approach: filter by success rate (e.g., >50% successful runs)
|
||||||
|
or require at least 1 successful execution.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await library_db.list_library_agents(
|
||||||
|
user_id=user_id,
|
||||||
|
search_term=search_query,
|
||||||
|
page=1,
|
||||||
|
page_size=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[LibraryAgentSummary] = []
|
||||||
|
for agent in response.agents:
|
||||||
|
if exclude_graph_id is not None and agent.graph_id == exclude_graph_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
LibraryAgentSummary(
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
graph_version=agent.graph_version,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description,
|
||||||
|
input_schema=agent.input_schema,
|
||||||
|
output_schema=agent.output_schema,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch library agents: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def search_marketplace_agents_for_generation(
|
||||||
|
search_query: str,
|
||||||
|
max_results: int = 10,
|
||||||
|
) -> list[MarketplaceAgentSummary]:
|
||||||
|
"""Search marketplace agents formatted for Agent Generator.
|
||||||
|
|
||||||
|
Note: This returns basic agent info. Full input/output schemas would require
|
||||||
|
additional graph fetches and is a potential future enhancement.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_query: Search term to find relevant public agents
|
||||||
|
max_results: Maximum number of agents to return (default 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of MarketplaceAgentSummary (without detailed schemas for now)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await store_db.get_store_agents(
|
||||||
|
search_query=search_query,
|
||||||
|
page=1,
|
||||||
|
page_size=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[MarketplaceAgentSummary] = []
|
||||||
|
for agent in response.agents:
|
||||||
|
results.append(
|
||||||
|
MarketplaceAgentSummary(
|
||||||
|
name=agent.agent_name,
|
||||||
|
description=agent.description,
|
||||||
|
sub_heading=agent.sub_heading,
|
||||||
|
creator=agent.creator,
|
||||||
|
is_marketplace_agent=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to search marketplace agents: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def get_all_relevant_agents_for_generation(
|
||||||
|
user_id: str,
|
||||||
|
search_query: str | None = None,
|
||||||
|
exclude_graph_id: str | None = None,
|
||||||
|
include_library: bool = True,
|
||||||
|
include_marketplace: bool = True,
|
||||||
|
max_library_results: int = 15,
|
||||||
|
max_marketplace_results: int = 10,
|
||||||
|
) -> list[AgentSummary]:
|
||||||
|
"""Fetch relevant agents from library and/or marketplace.
|
||||||
|
|
||||||
|
Searches both user's library and marketplace by default.
|
||||||
|
Explicitly mentioned UUIDs in the search query are always looked up.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
search_query: Search term to find relevant agents (user's goal/description)
|
||||||
|
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
|
||||||
|
include_library: Whether to search user's library (default True)
|
||||||
|
include_marketplace: Whether to also search marketplace (default True)
|
||||||
|
max_library_results: Max library agents to return (default 15)
|
||||||
|
max_marketplace_results: Max marketplace agents to return (default 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of AgentSummary, library agents first (with full schemas),
|
||||||
|
then marketplace agents (basic info only)
|
||||||
|
"""
|
||||||
|
agents: list[AgentSummary] = []
|
||||||
|
seen_graph_ids: set[str] = set()
|
||||||
|
|
||||||
|
if search_query:
|
||||||
|
mentioned_uuids = extract_uuids_from_text(search_query)
|
||||||
|
for graph_id in mentioned_uuids:
|
||||||
|
if graph_id == exclude_graph_id:
|
||||||
|
continue
|
||||||
|
agent = await get_library_agent_by_graph_id(user_id, graph_id)
|
||||||
|
if agent and agent["graph_id"] not in seen_graph_ids:
|
||||||
|
agents.append(agent)
|
||||||
|
seen_graph_ids.add(agent["graph_id"])
|
||||||
|
logger.debug(f"Found explicitly mentioned agent: {agent['name']}")
|
||||||
|
|
||||||
|
if include_library:
|
||||||
|
library_agents = await get_library_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=search_query,
|
||||||
|
exclude_graph_id=exclude_graph_id,
|
||||||
|
max_results=max_library_results,
|
||||||
|
)
|
||||||
|
for agent in library_agents:
|
||||||
|
if agent["graph_id"] not in seen_graph_ids:
|
||||||
|
agents.append(agent)
|
||||||
|
seen_graph_ids.add(agent["graph_id"])
|
||||||
|
|
||||||
|
if include_marketplace and search_query:
|
||||||
|
marketplace_agents = await search_marketplace_agents_for_generation(
|
||||||
|
search_query=search_query,
|
||||||
|
max_results=max_marketplace_results,
|
||||||
|
)
|
||||||
|
library_names = {a["name"].lower() for a in agents if a.get("name")}
|
||||||
|
for agent in marketplace_agents:
|
||||||
|
agent_name = agent.get("name")
|
||||||
|
if agent_name and agent_name.lower() not in library_names:
|
||||||
|
agents.append(agent)
|
||||||
|
|
||||||
|
return agents
|
||||||
|
|
||||||
|
|
||||||
|
def extract_search_terms_from_steps(
|
||||||
|
decomposition_result: DecompositionResult | dict[str, Any],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Extract search terms from decomposed instruction steps.
|
||||||
|
|
||||||
|
Analyzes the decomposition result to extract relevant keywords
|
||||||
|
for additional library agent searches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decomposition_result: Result from decompose_goal containing steps
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unique search terms extracted from steps
|
||||||
|
"""
|
||||||
|
search_terms: list[str] = []
|
||||||
|
|
||||||
|
if decomposition_result.get("type") != "instructions":
|
||||||
|
return search_terms
|
||||||
|
|
||||||
|
steps = decomposition_result.get("steps", [])
|
||||||
|
if not steps:
|
||||||
|
return search_terms
|
||||||
|
|
||||||
|
step_keys: list[str] = ["description", "action", "block_name", "tool", "name"]
|
||||||
|
|
||||||
|
for step in steps:
|
||||||
|
for key in step_keys:
|
||||||
|
value = step.get(key) # type: ignore[union-attr]
|
||||||
|
if isinstance(value, str) and len(value) > 3:
|
||||||
|
search_terms.append(value)
|
||||||
|
|
||||||
|
seen: set[str] = set()
|
||||||
|
unique_terms: list[str] = []
|
||||||
|
for term in search_terms:
|
||||||
|
term_lower = term.lower()
|
||||||
|
if term_lower not in seen:
|
||||||
|
seen.add(term_lower)
|
||||||
|
unique_terms.append(term)
|
||||||
|
|
||||||
|
return unique_terms
|
||||||
|
|
||||||
|
|
||||||
|
async def enrich_library_agents_from_steps(
|
||||||
|
user_id: str,
|
||||||
|
decomposition_result: DecompositionResult | dict[str, Any],
|
||||||
|
existing_agents: list[AgentSummary] | list[dict[str, Any]],
|
||||||
|
exclude_graph_id: str | None = None,
|
||||||
|
include_marketplace: bool = True,
|
||||||
|
max_additional_results: int = 10,
|
||||||
|
) -> list[AgentSummary] | list[dict[str, Any]]:
|
||||||
|
"""Enrich library agents list with additional searches based on decomposed steps.
|
||||||
|
|
||||||
|
This implements two-phase search: after decomposition, we search for additional
|
||||||
|
relevant agents based on the specific steps identified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
decomposition_result: Result from decompose_goal containing steps
|
||||||
|
existing_agents: Already fetched library agents from initial search
|
||||||
|
exclude_graph_id: Optional graph ID to exclude
|
||||||
|
include_marketplace: Whether to also search marketplace
|
||||||
|
max_additional_results: Max additional agents per search term (default 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined list of library agents (existing + newly discovered)
|
||||||
|
"""
|
||||||
|
search_terms = extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
if not search_terms:
|
||||||
|
return existing_agents
|
||||||
|
|
||||||
|
existing_ids: set[str] = set()
|
||||||
|
existing_names: set[str] = set()
|
||||||
|
|
||||||
|
for agent in existing_agents:
|
||||||
|
agent_name = agent.get("name", "")
|
||||||
|
if agent_name:
|
||||||
|
existing_names.add(agent_name.lower())
|
||||||
|
graph_id = agent.get("graph_id") # type: ignore[call-overload]
|
||||||
|
if graph_id:
|
||||||
|
existing_ids.add(graph_id)
|
||||||
|
|
||||||
|
all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents)
|
||||||
|
|
||||||
|
for term in search_terms[:3]:
|
||||||
|
try:
|
||||||
|
additional_agents = await get_all_relevant_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=term,
|
||||||
|
exclude_graph_id=exclude_graph_id,
|
||||||
|
include_marketplace=include_marketplace,
|
||||||
|
max_library_results=max_additional_results,
|
||||||
|
max_marketplace_results=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
for agent in additional_agents:
|
||||||
|
agent_name = agent.get("name", "")
|
||||||
|
if not agent_name:
|
||||||
|
continue
|
||||||
|
agent_name_lower = agent_name.lower()
|
||||||
|
|
||||||
|
if agent_name_lower in existing_names:
|
||||||
|
continue
|
||||||
|
|
||||||
|
graph_id = agent.get("graph_id") # type: ignore[call-overload]
|
||||||
|
if graph_id and graph_id in existing_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
all_agents.append(agent)
|
||||||
|
existing_names.add(agent_name_lower)
|
||||||
|
if graph_id:
|
||||||
|
existing_ids.add(graph_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to search for additional agents with term '{term}': {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Enriched library agents: {len(existing_agents)} initial + "
|
||||||
|
f"{len(all_agents) - len(existing_agents)} additional = {len(all_agents)} total"
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_agents
|
||||||
|
|
||||||
|
|
||||||
|
async def decompose_goal(
|
||||||
|
description: str,
|
||||||
|
context: str = "",
|
||||||
|
library_agents: list[AgentSummary] | None = None,
|
||||||
|
) -> DecompositionResult | None:
|
||||||
"""Break down a goal into steps or return clarifying questions.
|
"""Break down a goal into steps or return clarifying questions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
description: Natural language goal description
|
description: Natural language goal description
|
||||||
context: Additional context (e.g., answers to previous questions)
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with either:
|
DecompositionResult with either:
|
||||||
- {"type": "clarifying_questions", "questions": [...]}
|
- {"type": "clarifying_questions", "questions": [...]}
|
||||||
- {"type": "instructions", "steps": [...]}
|
- {"type": "instructions", "steps": [...]}
|
||||||
Or None on error
|
Or None on error
|
||||||
@@ -54,26 +497,41 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
|
|||||||
"""
|
"""
|
||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for decompose_goal")
|
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||||
return await decompose_goal_external(description, context)
|
# Convert typed dicts to plain dicts for external service
|
||||||
|
result = await decompose_goal_external(
|
||||||
|
description, context, _to_dict_list(library_agents)
|
||||||
|
)
|
||||||
|
# Cast the result to DecompositionResult (external service returns dict)
|
||||||
|
return result # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
async def generate_agent(
|
||||||
|
instructions: DecompositionResult | dict[str, Any],
|
||||||
|
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
"""Generate agent JSON from instructions.
|
"""Generate agent JSON from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict or None on error
|
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent")
|
logger.info("Calling external Agent Generator service for generate_agent")
|
||||||
result = await generate_agent_external(instructions)
|
# Convert typed dicts to plain dicts for external service
|
||||||
|
result = await generate_agent_external(
|
||||||
|
dict(instructions), _to_dict_list(library_agents)
|
||||||
|
)
|
||||||
if result:
|
if result:
|
||||||
# Ensure required fields
|
# Check if it's an error response - pass through as-is
|
||||||
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
|
return result
|
||||||
|
# Ensure required fields for successful agent generation
|
||||||
if "id" not in result:
|
if "id" not in result:
|
||||||
result["id"] = str(uuid.uuid4())
|
result["id"] = str(uuid.uuid4())
|
||||||
if "version" not in result:
|
if "version" not in result:
|
||||||
@@ -159,8 +617,6 @@ async def save_agent_to_library(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (created Graph, LibraryAgent)
|
Tuple of (created Graph, LibraryAgent)
|
||||||
"""
|
"""
|
||||||
from backend.data.graph import get_graph_all_versions
|
|
||||||
|
|
||||||
graph = json_to_graph(agent_json)
|
graph = json_to_graph(agent_json)
|
||||||
|
|
||||||
if is_update:
|
if is_update:
|
||||||
@@ -197,25 +653,31 @@ async def save_agent_to_library(
|
|||||||
|
|
||||||
|
|
||||||
async def get_agent_as_json(
|
async def get_agent_as_json(
|
||||||
graph_id: str, user_id: str | None
|
agent_id: str, user_id: str | None
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Fetch an agent and convert to JSON format for editing.
|
"""Fetch an agent and convert to JSON format for editing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: Graph ID or library agent ID
|
agent_id: Graph ID or library agent ID
|
||||||
user_id: User ID
|
user_id: User ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent as JSON dict or None if not found
|
Agent as JSON dict or None if not found
|
||||||
"""
|
"""
|
||||||
from backend.data.graph import get_graph
|
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||||
|
|
||||||
|
if not graph and user_id:
|
||||||
|
try:
|
||||||
|
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
graph = await get_graph(
|
||||||
|
library_agent.graph_id, version=None, user_id=user_id
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Try to get the graph (version=None gets the active version)
|
|
||||||
graph = await get_graph(graph_id, version=None, user_id=user_id)
|
|
||||||
if not graph:
|
if not graph:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Convert to JSON format
|
|
||||||
nodes = []
|
nodes = []
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
nodes.append(
|
nodes.append(
|
||||||
@@ -253,7 +715,9 @@ async def get_agent_as_json(
|
|||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch(
|
async def generate_agent_patch(
|
||||||
update_request: str, current_agent: dict[str, Any]
|
update_request: str,
|
||||||
|
current_agent: dict[str, Any],
|
||||||
|
library_agents: list[AgentSummary] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Update an existing agent using natural language.
|
"""Update an existing agent using natural language.
|
||||||
|
|
||||||
@@ -265,13 +729,18 @@ async def generate_agent_patch(
|
|||||||
Args:
|
Args:
|
||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict, or None on error
|
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||||
|
error dict {"type": "error", ...}, or None on unexpected error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||||
return await generate_agent_patch_external(update_request, current_agent)
|
# Convert typed dicts to plain dicts for external service
|
||||||
|
return await generate_agent_patch_external(
|
||||||
|
update_request, current_agent, _to_dict_list(library_agents)
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,104 @@
|
|||||||
|
"""Error handling utilities for agent generator."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_error_details(details: str) -> str:
|
||||||
|
"""Sanitize error details to remove sensitive information.
|
||||||
|
|
||||||
|
Strips common patterns that could expose internal system info:
|
||||||
|
- File paths (Unix and Windows)
|
||||||
|
- Database connection strings
|
||||||
|
- URLs with credentials
|
||||||
|
- Stack trace internals
|
||||||
|
|
||||||
|
Args:
|
||||||
|
details: Raw error details string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized error details safe for user display
|
||||||
|
"""
|
||||||
|
# Remove file paths (Unix-style)
|
||||||
|
sanitized = re.sub(
|
||||||
|
r"/[a-zA-Z0-9_./\-]+\.(py|js|ts|json|yaml|yml)", "[path]", details
|
||||||
|
)
|
||||||
|
# Remove file paths (Windows-style)
|
||||||
|
sanitized = re.sub(r"[A-Z]:\\[a-zA-Z0-9_\\.\\-]+", "[path]", sanitized)
|
||||||
|
# Remove database URLs
|
||||||
|
sanitized = re.sub(
|
||||||
|
r"(postgres|mysql|mongodb|redis)://[^\s]+", "[database_url]", sanitized
|
||||||
|
)
|
||||||
|
# Remove URLs with credentials
|
||||||
|
sanitized = re.sub(r"https?://[^:]+:[^@]+@[^\s]+", "[url]", sanitized)
|
||||||
|
# Remove line numbers from stack traces
|
||||||
|
sanitized = re.sub(r", line \d+", "", sanitized)
|
||||||
|
# Remove "File" references from stack traces
|
||||||
|
sanitized = re.sub(r'File "[^"]+",?', "", sanitized)
|
||||||
|
|
||||||
|
return sanitized.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_message_for_error(
|
||||||
|
error_type: str,
|
||||||
|
operation: str = "process the request",
|
||||||
|
llm_parse_message: str | None = None,
|
||||||
|
validation_message: str | None = None,
|
||||||
|
error_details: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Get a user-friendly error message based on error type.
|
||||||
|
|
||||||
|
This function maps internal error types to user-friendly messages,
|
||||||
|
providing a consistent experience across different agent operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_type: The error type from the external service
|
||||||
|
(e.g., "llm_parse_error", "timeout", "rate_limit")
|
||||||
|
operation: Description of what operation failed, used in the default
|
||||||
|
message (e.g., "analyze the goal", "generate the agent")
|
||||||
|
llm_parse_message: Custom message for llm_parse_error type
|
||||||
|
validation_message: Custom message for validation_error type
|
||||||
|
error_details: Optional additional details about the error
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User-friendly error message suitable for display to the user
|
||||||
|
"""
|
||||||
|
base_message = ""
|
||||||
|
|
||||||
|
if error_type == "llm_parse_error":
|
||||||
|
base_message = (
|
||||||
|
llm_parse_message
|
||||||
|
or "The AI had trouble processing this request. Please try again."
|
||||||
|
)
|
||||||
|
elif error_type == "validation_error":
|
||||||
|
base_message = (
|
||||||
|
validation_message
|
||||||
|
or "The generated agent failed validation. "
|
||||||
|
"This usually happens when the agent structure doesn't match "
|
||||||
|
"what the platform expects. Please try simplifying your goal "
|
||||||
|
"or breaking it into smaller parts."
|
||||||
|
)
|
||||||
|
elif error_type == "patch_error":
|
||||||
|
base_message = (
|
||||||
|
"Failed to apply the changes. The modification couldn't be "
|
||||||
|
"validated. Please try a different approach or simplify the change."
|
||||||
|
)
|
||||||
|
elif error_type in ("timeout", "llm_timeout"):
|
||||||
|
base_message = (
|
||||||
|
"The request took too long to process. This can happen with "
|
||||||
|
"complex agents. Please try again or simplify your goal."
|
||||||
|
)
|
||||||
|
elif error_type in ("rate_limit", "llm_rate_limit"):
|
||||||
|
base_message = "The service is currently busy. Please try again in a moment."
|
||||||
|
else:
|
||||||
|
base_message = f"Failed to {operation}. Please try again."
|
||||||
|
|
||||||
|
# Add error details if provided (sanitized and truncated)
|
||||||
|
if error_details:
|
||||||
|
# Sanitize to remove sensitive information
|
||||||
|
details = _sanitize_error_details(error_details)
|
||||||
|
# Truncate long error details
|
||||||
|
if len(details) > 200:
|
||||||
|
details = details[:200] + "..."
|
||||||
|
base_message += f"\n\nTechnical details: {details}"
|
||||||
|
|
||||||
|
return base_message
|
||||||
@@ -14,6 +14,70 @@ from backend.util.settings import Settings
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_error_response(
|
||||||
|
error_message: str,
|
||||||
|
error_type: str = "unknown",
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create a standardized error response dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_message: Human-readable error message
|
||||||
|
error_type: Machine-readable error type
|
||||||
|
details: Optional additional error details
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Error dict with type="error" and error details
|
||||||
|
"""
|
||||||
|
response: dict[str, Any] = {
|
||||||
|
"type": "error",
|
||||||
|
"error": error_message,
|
||||||
|
"error_type": error_type,
|
||||||
|
}
|
||||||
|
if details:
|
||||||
|
response["details"] = details
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
|
||||||
|
"""Classify an HTTP error into error_type and message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e: The HTTP status error
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (error_type, error_message)
|
||||||
|
"""
|
||||||
|
status = e.response.status_code
|
||||||
|
if status == 429:
|
||||||
|
return "rate_limit", f"Agent Generator rate limited: {e}"
|
||||||
|
elif status == 503:
|
||||||
|
return "service_unavailable", f"Agent Generator unavailable: {e}"
|
||||||
|
elif status == 504 or status == 408:
|
||||||
|
return "timeout", f"Agent Generator timed out: {e}"
|
||||||
|
else:
|
||||||
|
return "http_error", f"HTTP error calling Agent Generator: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
|
||||||
|
"""Classify a request error into error_type and message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e: The request error
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (error_type, error_message)
|
||||||
|
"""
|
||||||
|
error_str = str(e).lower()
|
||||||
|
if "timeout" in error_str or "timed out" in error_str:
|
||||||
|
return "timeout", f"Agent Generator request timed out: {e}"
|
||||||
|
elif "connect" in error_str:
|
||||||
|
return "connection_error", f"Could not connect to Agent Generator: {e}"
|
||||||
|
else:
|
||||||
|
return "request_error", f"Request error calling Agent Generator: {e}"
|
||||||
|
|
||||||
|
|
||||||
_client: httpx.AsyncClient | None = None
|
_client: httpx.AsyncClient | None = None
|
||||||
_settings: Settings | None = None
|
_settings: Settings | None = None
|
||||||
|
|
||||||
@@ -53,13 +117,16 @@ def _get_client() -> httpx.AsyncClient:
|
|||||||
|
|
||||||
|
|
||||||
async def decompose_goal_external(
|
async def decompose_goal_external(
|
||||||
description: str, context: str = ""
|
description: str,
|
||||||
|
context: str = "",
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to decompose a goal.
|
"""Call the external service to decompose a goal.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
description: Natural language goal description
|
description: Natural language goal description
|
||||||
context: Additional context (e.g., answers to previous questions)
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with either:
|
Dict with either:
|
||||||
@@ -67,7 +134,8 @@ async def decompose_goal_external(
|
|||||||
- {"type": "instructions", "steps": [...]}
|
- {"type": "instructions", "steps": [...]}
|
||||||
- {"type": "unachievable_goal", ...}
|
- {"type": "unachievable_goal", ...}
|
||||||
- {"type": "vague_goal", ...}
|
- {"type": "vague_goal", ...}
|
||||||
Or None on error
|
- {"type": "error", "error": "...", "error_type": "..."} on error
|
||||||
|
Or None on unexpected error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
@@ -76,6 +144,8 @@ async def decompose_goal_external(
|
|||||||
if context:
|
if context:
|
||||||
# The external service uses user_instruction for additional context
|
# The external service uses user_instruction for additional context
|
||||||
payload["user_instruction"] = context
|
payload["user_instruction"] = context
|
||||||
|
if library_agents:
|
||||||
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/decompose-description", json=payload)
|
response = await client.post("/api/decompose-description", json=payload)
|
||||||
@@ -83,8 +153,13 @@ async def decompose_goal_external(
|
|||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
if not data.get("success"):
|
if not data.get("success"):
|
||||||
logger.error(f"External service returned error: {data.get('error')}")
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
return None
|
error_type = data.get("error_type", "unknown")
|
||||||
|
logger.error(
|
||||||
|
f"Agent Generator decomposition failed: {error_msg} "
|
||||||
|
f"(type: {error_type})"
|
||||||
|
)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
# Map the response to the expected format
|
# Map the response to the expected format
|
||||||
response_type = data.get("type")
|
response_type = data.get("type")
|
||||||
@@ -106,88 +181,120 @@ async def decompose_goal_external(
|
|||||||
"type": "vague_goal",
|
"type": "vague_goal",
|
||||||
"suggested_goal": data.get("suggested_goal"),
|
"suggested_goal": data.get("suggested_goal"),
|
||||||
}
|
}
|
||||||
|
elif response_type == "error":
|
||||||
|
# Pass through error from the service
|
||||||
|
return _create_error_response(
|
||||||
|
data.get("error", "Unknown error"),
|
||||||
|
data.get("error_type", "unknown"),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Unknown response type from external service: {response_type}"
|
f"Unknown response type from external service: {response_type}"
|
||||||
)
|
)
|
||||||
return None
|
return _create_error_response(
|
||||||
|
f"Unknown response type from Agent Generator: {response_type}",
|
||||||
|
"invalid_response",
|
||||||
|
)
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
error_type, error_msg = _classify_http_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
logger.error(f"Request error calling external agent generator: {e}")
|
error_type, error_msg = _classify_request_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_external(
|
async def generate_agent_external(
|
||||||
instructions: dict[str, Any]
|
instructions: dict[str, Any],
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate an agent from instructions.
|
"""Call the external service to generate an agent from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict or None on error
|
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {"instructions": instructions}
|
||||||
|
if library_agents:
|
||||||
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post(
|
response = await client.post("/api/generate-agent", json=payload)
|
||||||
"/api/generate-agent", json={"instructions": instructions}
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
if not data.get("success"):
|
if not data.get("success"):
|
||||||
logger.error(f"External service returned error: {data.get('error')}")
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
return None
|
error_type = data.get("error_type", "unknown")
|
||||||
|
logger.error(
|
||||||
|
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
|
||||||
|
)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
return data.get("agent_json")
|
return data.get("agent_json")
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
error_type, error_msg = _classify_http_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
logger.error(f"Request error calling external agent generator: {e}")
|
error_type, error_msg = _classify_request_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch_external(
|
async def generate_agent_patch_external(
|
||||||
update_request: str, current_agent: dict[str, Any]
|
update_request: str,
|
||||||
|
current_agent: dict[str, Any],
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate a patch for an existing agent.
|
"""Call the external service to generate a patch for an existing agent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict, or None on error
|
Updated agent JSON, clarifying questions dict, or error dict on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"update_request": update_request,
|
||||||
|
"current_agent_json": current_agent,
|
||||||
|
}
|
||||||
|
if library_agents:
|
||||||
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post(
|
response = await client.post("/api/update-agent", json=payload)
|
||||||
"/api/update-agent",
|
|
||||||
json={
|
|
||||||
"update_request": update_request,
|
|
||||||
"current_agent_json": current_agent,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
if not data.get("success"):
|
if not data.get("success"):
|
||||||
logger.error(f"External service returned error: {data.get('error')}")
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
return None
|
error_type = data.get("error_type", "unknown")
|
||||||
|
logger.error(
|
||||||
|
f"Agent Generator patch generation failed: {error_msg} "
|
||||||
|
f"(type: {error_type})"
|
||||||
|
)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
# Check if it's clarifying questions
|
# Check if it's clarifying questions
|
||||||
if data.get("type") == "clarifying_questions":
|
if data.get("type") == "clarifying_questions":
|
||||||
@@ -196,18 +303,28 @@ async def generate_agent_patch_external(
|
|||||||
"questions": data.get("questions", []),
|
"questions": data.get("questions", []),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Check if it's an error passed through
|
||||||
|
if data.get("type") == "error":
|
||||||
|
return _create_error_response(
|
||||||
|
data.get("error", "Unknown error"),
|
||||||
|
data.get("error_type", "unknown"),
|
||||||
|
)
|
||||||
|
|
||||||
# Otherwise return the updated agent JSON
|
# Otherwise return the updated agent JSON
|
||||||
return data.get("agent_json")
|
return data.get("agent_json")
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
error_type, error_msg = _classify_http_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
logger.error(f"Request error calling external agent generator: {e}")
|
error_type, error_msg = _classify_request_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
@@ -19,6 +20,86 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
SearchSource = Literal["marketplace", "library"]
|
SearchSource = Literal["marketplace", "library"]
|
||||||
|
|
||||||
|
# UUID v4 pattern for direct agent ID lookup
|
||||||
|
_UUID_PATTERN = re.compile(
|
||||||
|
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(text: str) -> bool:
|
||||||
|
"""Check if text is a valid UUID v4."""
|
||||||
|
return bool(_UUID_PATTERN.match(text.strip()))
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
|
||||||
|
"""Fetch a library agent by ID (library agent ID or graph_id).
|
||||||
|
|
||||||
|
Tries multiple lookup strategies:
|
||||||
|
1. First by graph_id (AgentGraph primary key)
|
||||||
|
2. Then by library agent ID (LibraryAgent primary key)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentInfo if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||||
|
return AgentInfo(
|
||||||
|
id=agent.id,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description or "",
|
||||||
|
source="library",
|
||||||
|
in_library=True,
|
||||||
|
creator=agent.creator_name,
|
||||||
|
status=agent.status.value,
|
||||||
|
can_access_graph=agent.can_access_graph,
|
||||||
|
has_external_trigger=agent.has_external_trigger,
|
||||||
|
new_output=agent.new_output,
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
)
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not fetch library agent by graph_id {agent_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||||
|
return AgentInfo(
|
||||||
|
id=agent.id,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description or "",
|
||||||
|
source="library",
|
||||||
|
in_library=True,
|
||||||
|
creator=agent.creator_name,
|
||||||
|
status=agent.status.value,
|
||||||
|
can_access_graph=agent.can_access_graph,
|
||||||
|
has_external_trigger=agent.has_external_trigger,
|
||||||
|
new_output=agent.new_output,
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def search_agents(
|
async def search_agents(
|
||||||
query: str,
|
query: str,
|
||||||
@@ -70,28 +151,38 @@ async def search_agents(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
else: # library
|
else: # library
|
||||||
logger.info(f"Searching user library for: {query}")
|
# If query looks like a UUID, try direct lookup first
|
||||||
results = await library_db.list_library_agents(
|
if _is_uuid(query):
|
||||||
user_id=user_id, # type: ignore[arg-type]
|
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
|
||||||
search_term=query,
|
agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type]
|
||||||
page_size=10,
|
if agent:
|
||||||
)
|
agents.append(agent)
|
||||||
for agent in results.agents:
|
logger.info(f"Found agent by direct ID lookup: {agent.name}")
|
||||||
agents.append(
|
|
||||||
AgentInfo(
|
# If no results from UUID lookup, do text search
|
||||||
id=agent.id,
|
if not agents:
|
||||||
name=agent.name,
|
logger.info(f"Searching user library for: {query}")
|
||||||
description=agent.description or "",
|
results = await library_db.list_library_agents(
|
||||||
source="library",
|
user_id=user_id, # type: ignore[arg-type]
|
||||||
in_library=True,
|
search_term=query,
|
||||||
creator=agent.creator_name,
|
page_size=10,
|
||||||
status=agent.status.value,
|
|
||||||
can_access_graph=agent.can_access_graph,
|
|
||||||
has_external_trigger=agent.has_external_trigger,
|
|
||||||
new_output=agent.new_output,
|
|
||||||
graph_id=agent.graph_id,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for agent in results.agents:
|
||||||
|
agents.append(
|
||||||
|
AgentInfo(
|
||||||
|
id=agent.id,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description or "",
|
||||||
|
source="library",
|
||||||
|
in_library=True,
|
||||||
|
creator=agent.creator_name,
|
||||||
|
status=agent.status.value,
|
||||||
|
can_access_graph=agent.can_access_graph,
|
||||||
|
has_external_trigger=agent.has_external_trigger,
|
||||||
|
new_output=agent.new_output,
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
logger.info(f"Found {len(agents)} agents in {source}")
|
logger.info(f"Found {len(agents)} agents in {source}")
|
||||||
except NotFoundError:
|
except NotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -36,6 +36,16 @@ class BaseTool:
|
|||||||
"""Whether this tool requires authentication."""
|
"""Whether this tool requires authentication."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
"""Whether this tool is long-running and should execute in background.
|
||||||
|
|
||||||
|
Long-running tools (like agent generation) are executed via background
|
||||||
|
tasks to survive SSE disconnections. The result is persisted to chat
|
||||||
|
history and visible when the user refreshes.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
def as_openai_tool(self) -> ChatCompletionToolParam:
|
def as_openai_tool(self) -> ChatCompletionToolParam:
|
||||||
"""Convert to OpenAI tool format."""
|
"""Convert to OpenAI tool format."""
|
||||||
return ChatCompletionToolParam(
|
return ChatCompletionToolParam(
|
||||||
|
|||||||
@@ -8,7 +8,10 @@ from backend.api.features.chat.model import ChatSession
|
|||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
|
enrich_library_agents_from_steps,
|
||||||
generate_agent,
|
generate_agent,
|
||||||
|
get_all_relevant_agents_for_generation,
|
||||||
|
get_user_message_for_error,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
)
|
)
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
@@ -42,6 +45,10 @@ class CreateAgentTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
@@ -98,9 +105,27 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Fetch relevant library and marketplace agents for sub-agent composition
|
||||||
|
library_agents = None
|
||||||
|
if user_id:
|
||||||
|
try:
|
||||||
|
library_agents = await get_all_relevant_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=description, # Use goal as search term
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Log but don't fail - agent generation can work without sub-agents
|
||||||
|
logger.warning(f"Failed to fetch library agents: {e}")
|
||||||
|
|
||||||
# Step 1: Decompose goal into steps
|
# Step 1: Decompose goal into steps
|
||||||
try:
|
try:
|
||||||
decomposition_result = await decompose_goal(description, context)
|
decomposition_result = await decompose_goal(
|
||||||
|
description, context, library_agents
|
||||||
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -113,11 +138,29 @@ class CreateAgentTool(BaseTool):
|
|||||||
|
|
||||||
if decomposition_result is None:
|
if decomposition_result is None:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to analyze the goal. The agent generation service may be unavailable or timed out. Please try again.",
|
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
|
||||||
error="decomposition_failed",
|
error="decomposition_failed",
|
||||||
|
details={"description": description[:100]},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the result is an error from the external service
|
||||||
|
if decomposition_result.get("type") == "error":
|
||||||
|
error_msg = decomposition_result.get("error", "Unknown error")
|
||||||
|
error_type = decomposition_result.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="analyze the goal",
|
||||||
|
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"decomposition_failed:{error_type}",
|
||||||
details={
|
details={
|
||||||
"description": description[:100]
|
"description": description[:100],
|
||||||
}, # Include context for debugging
|
"service_error": error_msg,
|
||||||
|
"error_type": error_type,
|
||||||
|
},
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -167,9 +210,26 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Step 1.5: Enrich library agents with step-based search (two-phase search)
|
||||||
|
# After decomposition, search for additional relevant agents based on the steps
|
||||||
|
if user_id and library_agents is not None:
|
||||||
|
try:
|
||||||
|
library_agents = await enrich_library_agents_from_steps(
|
||||||
|
user_id=user_id,
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=library_agents,
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Log but don't fail - continue with existing agents
|
||||||
|
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||||
|
|
||||||
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
||||||
try:
|
try:
|
||||||
agent_json = await generate_agent(decomposition_result)
|
agent_json = await generate_agent(decomposition_result, library_agents)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -182,11 +242,35 @@ class CreateAgentTool(BaseTool):
|
|||||||
|
|
||||||
if agent_json is None:
|
if agent_json is None:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to generate the agent. The agent generation service may be unavailable or timed out. Please try again.",
|
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
|
||||||
error="generation_failed",
|
error="generation_failed",
|
||||||
|
details={"description": description[:100]},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the result is an error from the external service
|
||||||
|
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
|
||||||
|
error_msg = agent_json.get("error", "Unknown error")
|
||||||
|
error_type = agent_json.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="generate the agent",
|
||||||
|
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
|
||||||
|
validation_message=(
|
||||||
|
"I wasn't able to create a valid agent for this request. "
|
||||||
|
"The generated workflow had some structural issues. "
|
||||||
|
"Please try simplifying your goal or breaking it into smaller steps."
|
||||||
|
),
|
||||||
|
error_details=error_msg if error_type == "validation_error" else None,
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"generation_failed:{error_type}",
|
||||||
details={
|
details={
|
||||||
"description": description[:100]
|
"description": description[:100],
|
||||||
}, # Include context for debugging
|
"service_error": error_msg,
|
||||||
|
"error_type": error_type,
|
||||||
|
},
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -228,7 +312,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
agent_id=created_graph.id,
|
agent_id=created_graph.id,
|
||||||
agent_name=created_graph.name,
|
agent_name=created_graph.name,
|
||||||
library_agent_id=library_agent.id,
|
library_agent_id=library_agent.id,
|
||||||
library_agent_link=f"/library/{library_agent.id}",
|
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ from .agent_generator import (
|
|||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
generate_agent_patch,
|
generate_agent_patch,
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
|
get_all_relevant_agents_for_generation,
|
||||||
|
get_user_message_for_error,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
)
|
)
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
@@ -42,6 +44,10 @@ class EditAgentTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
@@ -122,6 +128,22 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
library_agents = None
|
||||||
|
if user_id:
|
||||||
|
try:
|
||||||
|
exclude_id = current_agent.get("id") or agent_id
|
||||||
|
library_agents = await get_all_relevant_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=changes,
|
||||||
|
exclude_graph_id=exclude_id,
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch library agents: {e}")
|
||||||
|
|
||||||
# Build the update request with context
|
# Build the update request with context
|
||||||
update_request = changes
|
update_request = changes
|
||||||
if context:
|
if context:
|
||||||
@@ -129,7 +151,9 @@ class EditAgentTool(BaseTool):
|
|||||||
|
|
||||||
# Step 2: Generate updated agent (external service handles fixing and validation)
|
# Step 2: Generate updated agent (external service handles fixing and validation)
|
||||||
try:
|
try:
|
||||||
result = await generate_agent_patch(update_request, current_agent)
|
result = await generate_agent_patch(
|
||||||
|
update_request, current_agent, library_agents
|
||||||
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -148,6 +172,28 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if the result is an error from the external service
|
||||||
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
|
error_msg = result.get("error", "Unknown error")
|
||||||
|
error_type = result.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="generate the changes",
|
||||||
|
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
|
||||||
|
validation_message="The generated changes failed validation. Please try rephrasing your request.",
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"update_generation_failed:{error_type}",
|
||||||
|
details={
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"changes": changes[:100],
|
||||||
|
"service_error": error_msg,
|
||||||
|
"error_type": error_type,
|
||||||
|
},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Check if LLM returned clarifying questions
|
# Check if LLM returned clarifying questions
|
||||||
if result.get("type") == "clarifying_questions":
|
if result.get("type") == "clarifying_questions":
|
||||||
questions = result.get("questions", [])
|
questions = result.get("questions", [])
|
||||||
@@ -209,7 +255,7 @@ class EditAgentTool(BaseTool):
|
|||||||
agent_id=created_graph.id,
|
agent_id=created_graph.id,
|
||||||
agent_name=created_graph.name,
|
agent_name=created_graph.name,
|
||||||
library_agent_id=library_agent.id,
|
library_agent_id=library_agent.id,
|
||||||
library_agent_link=f"/library/{library_agent.id}",
|
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -28,6 +28,16 @@ class ResponseType(str, Enum):
|
|||||||
BLOCK_OUTPUT = "block_output"
|
BLOCK_OUTPUT = "block_output"
|
||||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||||
DOC_PAGE = "doc_page"
|
DOC_PAGE = "doc_page"
|
||||||
|
# Workspace response types
|
||||||
|
WORKSPACE_FILE_LIST = "workspace_file_list"
|
||||||
|
WORKSPACE_FILE_CONTENT = "workspace_file_content"
|
||||||
|
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
|
||||||
|
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
|
||||||
|
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
|
||||||
|
# Long-running operation types
|
||||||
|
OPERATION_STARTED = "operation_started"
|
||||||
|
OPERATION_PENDING = "operation_pending"
|
||||||
|
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||||
|
|
||||||
|
|
||||||
# Base response model
|
# Base response model
|
||||||
@@ -334,3 +344,39 @@ class BlockOutputResponse(ToolResponseBase):
|
|||||||
block_name: str
|
block_name: str
|
||||||
outputs: dict[str, list[Any]]
|
outputs: dict[str, list[Any]]
|
||||||
success: bool = True
|
success: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
# Long-running operation models
|
||||||
|
class OperationStartedResponse(ToolResponseBase):
|
||||||
|
"""Response when a long-running operation has been started in the background.
|
||||||
|
|
||||||
|
This is returned immediately to the client while the operation continues
|
||||||
|
to execute. The user can close the tab and check back later.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||||
|
operation_id: str
|
||||||
|
tool_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class OperationPendingResponse(ToolResponseBase):
|
||||||
|
"""Response stored in chat history while a long-running operation is executing.
|
||||||
|
|
||||||
|
This is persisted to the database so users see a pending state when they
|
||||||
|
refresh before the operation completes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_PENDING
|
||||||
|
operation_id: str
|
||||||
|
tool_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class OperationInProgressResponse(ToolResponseBase):
|
||||||
|
"""Response when an operation is already in progress.
|
||||||
|
|
||||||
|
Returned for idempotency when the same tool_call_id is requested again
|
||||||
|
while the background task is still running.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||||
|
tool_call_id: str
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Tool for executing blocks directly."""
|
"""Tool for executing blocks directly."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -8,6 +9,7 @@ from backend.api.features.chat.model import ChatSession
|
|||||||
from backend.data.block import get_block
|
from backend.data.block import get_block
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
@@ -223,11 +225,48 @@ class RunBlockTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Fetch actual credentials and prepare kwargs for block execution
|
# Get or create user's workspace for CoPilot file operations
|
||||||
# Create execution context with defaults (blocks may require it)
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
|
||||||
|
# Generate synthetic IDs for CoPilot context
|
||||||
|
# Each chat session is treated as its own agent with one continuous run
|
||||||
|
# This means:
|
||||||
|
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
|
||||||
|
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
|
||||||
|
# - node_exec_id = unique per block execution
|
||||||
|
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
||||||
|
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
||||||
|
synthetic_node_id = f"copilot-node-{block_id}"
|
||||||
|
synthetic_node_exec_id = (
|
||||||
|
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create unified execution context with all required fields
|
||||||
|
execution_context = ExecutionContext(
|
||||||
|
# Execution identity
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=synthetic_graph_id,
|
||||||
|
graph_exec_id=synthetic_graph_exec_id,
|
||||||
|
graph_version=1, # Versions are 1-indexed
|
||||||
|
node_id=synthetic_node_id,
|
||||||
|
node_exec_id=synthetic_node_exec_id,
|
||||||
|
# Workspace with session scoping
|
||||||
|
workspace_id=workspace.id,
|
||||||
|
session_id=session.session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare kwargs for block execution
|
||||||
|
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||||
exec_kwargs: dict[str, Any] = {
|
exec_kwargs: dict[str, Any] = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"execution_context": ExecutionContext(),
|
"execution_context": execution_context,
|
||||||
|
# Legacy: individual kwargs for blocks not yet using execution_context
|
||||||
|
"workspace_id": workspace.id,
|
||||||
|
"graph_exec_id": synthetic_graph_exec_id,
|
||||||
|
"node_exec_id": synthetic_node_exec_id,
|
||||||
|
"node_id": synthetic_node_id,
|
||||||
|
"graph_version": 1, # Versions are 1-indexed
|
||||||
|
"graph_id": synthetic_graph_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
for field_name, cred_meta in matched_credentials.items():
|
for field_name, cred_meta in matched_credentials.items():
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from backend.api.features.library import model as library_model
|
|||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
@@ -266,13 +266,14 @@ async def match_user_credentials_to_graph(
|
|||||||
credential_requirements,
|
credential_requirements,
|
||||||
_node_fields,
|
_node_fields,
|
||||||
) in aggregated_creds.items():
|
) in aggregated_creds.items():
|
||||||
# Find first matching credential by provider and type
|
# Find first matching credential by provider, type, and scopes
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
(
|
(
|
||||||
cred
|
cred
|
||||||
for cred in available_creds
|
for cred in available_creds
|
||||||
if cred.provider in credential_requirements.provider
|
if cred.provider in credential_requirements.provider
|
||||||
and cred.type in credential_requirements.supported_types
|
and cred.type in credential_requirements.supported_types
|
||||||
|
and _credential_has_required_scopes(cred, credential_requirements)
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -296,10 +297,17 @@ async def match_user_credentials_to_graph(
|
|||||||
f"{credential_field_name} (validation failed: {e})"
|
f"{credential_field_name} (validation failed: {e})"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Build a helpful error message including scope requirements
|
||||||
|
error_parts = [
|
||||||
|
f"provider in {list(credential_requirements.provider)}",
|
||||||
|
f"type in {list(credential_requirements.supported_types)}",
|
||||||
|
]
|
||||||
|
if credential_requirements.required_scopes:
|
||||||
|
error_parts.append(
|
||||||
|
f"scopes including {list(credential_requirements.required_scopes)}"
|
||||||
|
)
|
||||||
missing_creds.append(
|
missing_creds.append(
|
||||||
f"{credential_field_name} "
|
f"{credential_field_name} (requires {', '.join(error_parts)})"
|
||||||
f"(requires provider in {list(credential_requirements.provider)}, "
|
|
||||||
f"type in {list(credential_requirements.supported_types)})"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -309,6 +317,28 @@ async def match_user_credentials_to_graph(
|
|||||||
return graph_credentials_inputs, missing_creds
|
return graph_credentials_inputs, missing_creds
|
||||||
|
|
||||||
|
|
||||||
|
def _credential_has_required_scopes(
|
||||||
|
credential: Credentials,
|
||||||
|
requirements: CredentialsFieldInfo,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a credential has all the scopes required by the block.
|
||||||
|
|
||||||
|
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||||
|
of the required scopes. For other credential types, returns True (no scope check).
|
||||||
|
"""
|
||||||
|
# Only OAuth2 credentials have scopes to check
|
||||||
|
if credential.type != "oauth2":
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If no scopes are required, any credential matches
|
||||||
|
if not requirements.required_scopes:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check that credential scopes are a superset of required scopes
|
||||||
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
async def check_user_has_required_credentials(
|
async def check_user_has_required_credentials(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
required_credentials: list[CredentialsMetaInput],
|
required_credentials: list[CredentialsMetaInput],
|
||||||
|
|||||||
@@ -0,0 +1,620 @@
|
|||||||
|
"""CoPilot tools for workspace file operations."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.data.workspace import get_or_create_workspace
|
||||||
|
from backend.util.settings import Config
|
||||||
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
from backend.util.workspace import WorkspaceManager
|
||||||
|
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileInfoData(BaseModel):
|
||||||
|
"""Data model for workspace file information (not a response itself)."""
|
||||||
|
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
size_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileListResponse(ToolResponseBase):
|
||||||
|
"""Response containing list of workspace files."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
|
||||||
|
files: list[WorkspaceFileInfoData]
|
||||||
|
total_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileContentResponse(ToolResponseBase):
|
||||||
|
"""Response containing workspace file content (legacy, for small text files)."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
content_base64: str
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
||||||
|
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
size_bytes: int
|
||||||
|
download_url: str
|
||||||
|
preview: str | None = None # First 500 chars for text files
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceWriteResponse(ToolResponseBase):
|
||||||
|
"""Response after writing a file to workspace."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
size_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceDeleteResponse(ToolResponseBase):
|
||||||
|
"""Response after deleting a file from workspace."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
|
||||||
|
file_id: str
|
||||||
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
|
class ListWorkspaceFilesTool(BaseTool):
|
||||||
|
"""Tool for listing files in user's workspace."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "list_workspace_files"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"List files in the user's workspace. "
|
||||||
|
"Returns file names, paths, sizes, and metadata. "
|
||||||
|
"Optionally filter by path prefix."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path_prefix": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional path prefix to filter files "
|
||||||
|
"(e.g., '/documents/' to list only files in documents folder). "
|
||||||
|
"By default, only files from the current session are listed."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of files to return (default 50, max 100)",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 100,
|
||||||
|
},
|
||||||
|
"include_all_sessions": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"If true, list files from all sessions. "
|
||||||
|
"Default is false (only current session's files)."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||||
|
limit = min(kwargs.get("limit", 50), 100)
|
||||||
|
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
files = await manager.list_files(
|
||||||
|
path=path_prefix,
|
||||||
|
limit=limit,
|
||||||
|
include_all_sessions=include_all_sessions,
|
||||||
|
)
|
||||||
|
total = await manager.get_file_count(
|
||||||
|
path=path_prefix,
|
||||||
|
include_all_sessions=include_all_sessions,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_infos = [
|
||||||
|
WorkspaceFileInfoData(
|
||||||
|
file_id=f.id,
|
||||||
|
name=f.name,
|
||||||
|
path=f.path,
|
||||||
|
mime_type=f.mimeType,
|
||||||
|
size_bytes=f.sizeBytes,
|
||||||
|
)
|
||||||
|
for f in files
|
||||||
|
]
|
||||||
|
|
||||||
|
scope_msg = "all sessions" if include_all_sessions else "current session"
|
||||||
|
return WorkspaceFileListResponse(
|
||||||
|
files=file_infos,
|
||||||
|
total_count=total,
|
||||||
|
message=f"Found {len(files)} files in workspace ({scope_msg})",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing workspace files: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to list workspace files: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadWorkspaceFileTool(BaseTool):
|
||||||
|
"""Tool for reading file content from workspace."""
|
||||||
|
|
||||||
|
# Size threshold for returning full content vs metadata+URL
|
||||||
|
# Files larger than this return metadata with download URL to prevent context bloat
|
||||||
|
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||||
|
# Preview size for text files
|
||||||
|
PREVIEW_SIZE = 500
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "read_workspace_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Read a file from the user's workspace. "
|
||||||
|
"Specify either file_id or path to identify the file. "
|
||||||
|
"For small text files, returns content directly. "
|
||||||
|
"For large or binary files, returns metadata and a download URL. "
|
||||||
|
"Paths are scoped to the current session by default. "
|
||||||
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file's unique ID (from list_workspace_files)",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||||
|
"Scoped to current session by default."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"force_download_url": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"If true, always return metadata+URL instead of inline content. "
|
||||||
|
"Default is false (auto-selects based on file size/type)."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [], # At least one must be provided
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _is_text_mime_type(self, mime_type: str) -> bool:
|
||||||
|
"""Check if the MIME type is a text-based type."""
|
||||||
|
text_types = [
|
||||||
|
"text/",
|
||||||
|
"application/json",
|
||||||
|
"application/xml",
|
||||||
|
"application/javascript",
|
||||||
|
"application/x-python",
|
||||||
|
"application/x-sh",
|
||||||
|
]
|
||||||
|
return any(mime_type.startswith(t) for t in text_types)
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id: Optional[str] = kwargs.get("file_id")
|
||||||
|
path: Optional[str] = kwargs.get("path")
|
||||||
|
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||||
|
|
||||||
|
if not file_id and not path:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide either file_id or path",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
# Get file info
|
||||||
|
if file_id:
|
||||||
|
file_info = await manager.get_file_info(file_id)
|
||||||
|
if file_info is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found: {file_id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
target_file_id = file_id
|
||||||
|
else:
|
||||||
|
# path is guaranteed to be non-None here due to the check above
|
||||||
|
assert path is not None
|
||||||
|
file_info = await manager.get_file_info_by_path(path)
|
||||||
|
if file_info is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found at path: {path}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
target_file_id = file_info.id
|
||||||
|
|
||||||
|
# Decide whether to return inline content or metadata+URL
|
||||||
|
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
|
||||||
|
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
||||||
|
|
||||||
|
# Return inline content for small text files (unless force_download_url)
|
||||||
|
if is_small_file and is_text_file and not force_download_url:
|
||||||
|
content = await manager.read_file_by_id(target_file_id)
|
||||||
|
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||||
|
|
||||||
|
return WorkspaceFileContentResponse(
|
||||||
|
file_id=file_info.id,
|
||||||
|
name=file_info.name,
|
||||||
|
path=file_info.path,
|
||||||
|
mime_type=file_info.mimeType,
|
||||||
|
content_base64=content_b64,
|
||||||
|
message=f"Successfully read file: {file_info.name}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return metadata + workspace:// reference for large or binary files
|
||||||
|
# This prevents context bloat (100KB file = ~133KB as base64)
|
||||||
|
# Use workspace:// format so frontend urlTransform can add proxy prefix
|
||||||
|
download_url = f"workspace://{target_file_id}"
|
||||||
|
|
||||||
|
# Generate preview for text files
|
||||||
|
preview: str | None = None
|
||||||
|
if is_text_file:
|
||||||
|
try:
|
||||||
|
content = await manager.read_file_by_id(target_file_id)
|
||||||
|
preview_text = content[: self.PREVIEW_SIZE].decode(
|
||||||
|
"utf-8", errors="replace"
|
||||||
|
)
|
||||||
|
if len(content) > self.PREVIEW_SIZE:
|
||||||
|
preview_text += "..."
|
||||||
|
preview = preview_text
|
||||||
|
except Exception:
|
||||||
|
pass # Preview is optional
|
||||||
|
|
||||||
|
return WorkspaceFileMetadataResponse(
|
||||||
|
file_id=file_info.id,
|
||||||
|
name=file_info.name,
|
||||||
|
path=file_info.path,
|
||||||
|
mime_type=file_info.mimeType,
|
||||||
|
size_bytes=file_info.sizeBytes,
|
||||||
|
download_url=download_url,
|
||||||
|
preview=preview,
|
||||||
|
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading workspace file: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to read workspace file: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteWorkspaceFileTool(BaseTool):
|
||||||
|
"""Tool for writing files to workspace."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "write_workspace_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Write or create a file in the user's workspace. "
|
||||||
|
"Provide the content as a base64-encoded string. "
|
||||||
|
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||||
|
"Files are saved to the current session's folder by default. "
|
||||||
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"filename": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Name for the file (e.g., 'report.pdf')",
|
||||||
|
},
|
||||||
|
"content_base64": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Base64-encoded file content",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional virtual path where to save the file "
|
||||||
|
"(e.g., '/documents/report.pdf'). "
|
||||||
|
"Defaults to '/{filename}'. Scoped to current session."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"mime_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional MIME type of the file. "
|
||||||
|
"Auto-detected from filename if not provided."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"overwrite": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["filename", "content_base64"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
filename: str = kwargs.get("filename", "")
|
||||||
|
content_b64: str = kwargs.get("content_base64", "")
|
||||||
|
path: Optional[str] = kwargs.get("path")
|
||||||
|
mime_type: Optional[str] = kwargs.get("mime_type")
|
||||||
|
overwrite: bool = kwargs.get("overwrite", False)
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide a filename",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not content_b64:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide content_base64",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode content
|
||||||
|
try:
|
||||||
|
content = base64.b64decode(content_b64)
|
||||||
|
except Exception:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Invalid base64-encoded content",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check size
|
||||||
|
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||||
|
if len(content) > max_file_size:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Virus scan
|
||||||
|
await scan_content_safe(content, filename=filename)
|
||||||
|
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
file_record = await manager.write_file(
|
||||||
|
content=content,
|
||||||
|
filename=filename,
|
||||||
|
path=path,
|
||||||
|
mime_type=mime_type,
|
||||||
|
overwrite=overwrite,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkspaceWriteResponse(
|
||||||
|
file_id=file_record.id,
|
||||||
|
name=file_record.name,
|
||||||
|
path=file_record.path,
|
||||||
|
size_bytes=file_record.sizeBytes,
|
||||||
|
message=f"Successfully wrote file: {file_record.name}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error writing workspace file: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to write workspace file: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteWorkspaceFileTool(BaseTool):
|
||||||
|
"""Tool for deleting files from workspace."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "delete_workspace_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Delete a file from the user's workspace. "
|
||||||
|
"Specify either file_id or path to identify the file. "
|
||||||
|
"Paths are scoped to the current session by default. "
|
||||||
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file's unique ID (from list_workspace_files)",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||||
|
"Scoped to current session by default."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [], # At least one must be provided
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id: Optional[str] = kwargs.get("file_id")
|
||||||
|
path: Optional[str] = kwargs.get("path")
|
||||||
|
|
||||||
|
if not file_id and not path:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide either file_id or path",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
# Determine the file_id to delete
|
||||||
|
target_file_id: str
|
||||||
|
if file_id:
|
||||||
|
target_file_id = file_id
|
||||||
|
else:
|
||||||
|
# path is guaranteed to be non-None here due to the check above
|
||||||
|
assert path is not None
|
||||||
|
file_info = await manager.get_file_info_by_path(path)
|
||||||
|
if file_info is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found at path: {path}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
target_file_id = file_info.id
|
||||||
|
|
||||||
|
success = await manager.delete_file(target_file_id)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found: {target_file_id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkspaceDeleteResponse(
|
||||||
|
file_id=target_file_id,
|
||||||
|
success=True,
|
||||||
|
message="File deleted successfully",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to delete workspace file: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
@@ -21,7 +21,7 @@ from backend.data.model import CredentialsMetaInput
|
|||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
@@ -64,11 +64,11 @@ async def list_library_agents(
|
|||||||
|
|
||||||
if page < 1 or page_size < 1:
|
if page < 1 or page_size < 1:
|
||||||
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
||||||
raise DatabaseError("Invalid pagination input")
|
raise InvalidInputError("Invalid pagination input")
|
||||||
|
|
||||||
if search_term and len(search_term.strip()) > 100:
|
if search_term and len(search_term.strip()) > 100:
|
||||||
logger.warning(f"Search term too long: {repr(search_term)}")
|
logger.warning(f"Search term too long: {repr(search_term)}")
|
||||||
raise DatabaseError("Search term is too long")
|
raise InvalidInputError("Search term is too long")
|
||||||
|
|
||||||
where_clause: prisma.types.LibraryAgentWhereInput = {
|
where_clause: prisma.types.LibraryAgentWhereInput = {
|
||||||
"userId": user_id,
|
"userId": user_id,
|
||||||
@@ -77,21 +77,32 @@ async def list_library_agents(
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Build search filter if applicable
|
# Build search filter if applicable
|
||||||
|
# Split into words and match ANY word in name or description
|
||||||
if search_term:
|
if search_term:
|
||||||
where_clause["OR"] = [
|
words = [w.strip() for w in search_term.split() if len(w.strip()) >= 3]
|
||||||
{
|
if words:
|
||||||
"AgentGraph": {
|
or_conditions: list[prisma.types.LibraryAgentWhereInput] = []
|
||||||
"is": {"name": {"contains": search_term, "mode": "insensitive"}}
|
for word in words:
|
||||||
}
|
or_conditions.append(
|
||||||
},
|
{
|
||||||
{
|
"AgentGraph": {
|
||||||
"AgentGraph": {
|
"is": {"name": {"contains": word, "mode": "insensitive"}}
|
||||||
"is": {
|
}
|
||||||
"description": {"contains": search_term, "mode": "insensitive"}
|
|
||||||
}
|
}
|
||||||
}
|
)
|
||||||
},
|
or_conditions.append(
|
||||||
]
|
{
|
||||||
|
"AgentGraph": {
|
||||||
|
"is": {
|
||||||
|
"description": {
|
||||||
|
"contains": word,
|
||||||
|
"mode": "insensitive",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
where_clause["OR"] = or_conditions
|
||||||
|
|
||||||
# Determine sorting
|
# Determine sorting
|
||||||
order_by: prisma.types.LibraryAgentOrderByInput | None = None
|
order_by: prisma.types.LibraryAgentOrderByInput | None = None
|
||||||
@@ -175,7 +186,7 @@ async def list_favorite_library_agents(
|
|||||||
|
|
||||||
if page < 1 or page_size < 1:
|
if page < 1 or page_size < 1:
|
||||||
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
||||||
raise DatabaseError("Invalid pagination input")
|
raise InvalidInputError("Invalid pagination input")
|
||||||
|
|
||||||
where_clause: prisma.types.LibraryAgentWhereInput = {
|
where_clause: prisma.types.LibraryAgentWhereInput = {
|
||||||
"userId": user_id,
|
"userId": user_id,
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import autogpt_libs.auth as autogpt_auth_lib
|
import autogpt_libs.auth as autogpt_auth_lib
|
||||||
@@ -6,15 +5,11 @@ from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
|||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
from prisma.enums import OnboardingStep
|
from prisma.enums import OnboardingStep
|
||||||
|
|
||||||
import backend.api.features.store.exceptions as store_exceptions
|
|
||||||
from backend.data.onboarding import complete_onboarding_step
|
from backend.data.onboarding import complete_onboarding_step
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
|
||||||
|
|
||||||
from .. import db as library_db
|
from .. import db as library_db
|
||||||
from .. import model as library_model
|
from .. import model as library_model
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/agents",
|
prefix="/agents",
|
||||||
tags=["library", "private"],
|
tags=["library", "private"],
|
||||||
@@ -26,10 +21,6 @@ router = APIRouter(
|
|||||||
"",
|
"",
|
||||||
summary="List Library Agents",
|
summary="List Library Agents",
|
||||||
response_model=library_model.LibraryAgentResponse,
|
response_model=library_model.LibraryAgentResponse,
|
||||||
responses={
|
|
||||||
200: {"description": "List of library agents"},
|
|
||||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def list_library_agents(
|
async def list_library_agents(
|
||||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||||
@@ -53,43 +44,19 @@ async def list_library_agents(
|
|||||||
) -> library_model.LibraryAgentResponse:
|
) -> library_model.LibraryAgentResponse:
|
||||||
"""
|
"""
|
||||||
Get all agents in the user's library (both created and saved).
|
Get all agents in the user's library (both created and saved).
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
search_term: Optional search term to filter agents by name/description.
|
|
||||||
filter_by: List of filters to apply (favorites, created by user).
|
|
||||||
sort_by: List of sorting criteria (created date, updated date).
|
|
||||||
page: Page number to retrieve.
|
|
||||||
page_size: Number of agents per page.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A LibraryAgentResponse containing agents and pagination metadata.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If a server/database error occurs.
|
|
||||||
"""
|
"""
|
||||||
try:
|
return await library_db.list_library_agents(
|
||||||
return await library_db.list_library_agents(
|
user_id=user_id,
|
||||||
user_id=user_id,
|
search_term=search_term,
|
||||||
search_term=search_term,
|
sort_by=sort_by,
|
||||||
sort_by=sort_by,
|
page=page,
|
||||||
page=page,
|
page_size=page_size,
|
||||||
page_size=page_size,
|
)
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Could not list library agents for user #{user_id}: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=str(e),
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/favorites",
|
"/favorites",
|
||||||
summary="List Favorite Library Agents",
|
summary="List Favorite Library Agents",
|
||||||
responses={
|
|
||||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def list_favorite_library_agents(
|
async def list_favorite_library_agents(
|
||||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||||
@@ -106,30 +73,12 @@ async def list_favorite_library_agents(
|
|||||||
) -> library_model.LibraryAgentResponse:
|
) -> library_model.LibraryAgentResponse:
|
||||||
"""
|
"""
|
||||||
Get all favorite agents in the user's library.
|
Get all favorite agents in the user's library.
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
page: Page number to retrieve.
|
|
||||||
page_size: Number of agents per page.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A LibraryAgentResponse containing favorite agents and pagination metadata.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If a server/database error occurs.
|
|
||||||
"""
|
"""
|
||||||
try:
|
return await library_db.list_favorite_library_agents(
|
||||||
return await library_db.list_favorite_library_agents(
|
user_id=user_id,
|
||||||
user_id=user_id,
|
page=page,
|
||||||
page=page,
|
page_size=page_size,
|
||||||
page_size=page_size,
|
)
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Could not list favorite library agents for user #{user_id}: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=str(e),
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{library_agent_id}", summary="Get Library Agent")
|
@router.get("/{library_agent_id}", summary="Get Library Agent")
|
||||||
@@ -162,10 +111,6 @@ async def get_library_agent_by_graph_id(
|
|||||||
summary="Get Agent By Store ID",
|
summary="Get Agent By Store ID",
|
||||||
tags=["store", "library"],
|
tags=["store", "library"],
|
||||||
response_model=library_model.LibraryAgent | None,
|
response_model=library_model.LibraryAgent | None,
|
||||||
responses={
|
|
||||||
200: {"description": "Library agent found"},
|
|
||||||
404: {"description": "Agent not found"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def get_library_agent_by_store_listing_version_id(
|
async def get_library_agent_by_store_listing_version_id(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
@@ -174,32 +119,15 @@ async def get_library_agent_by_store_listing_version_id(
|
|||||||
"""
|
"""
|
||||||
Get Library Agent from Store Listing Version ID.
|
Get Library Agent from Store Listing Version ID.
|
||||||
"""
|
"""
|
||||||
try:
|
return await library_db.get_library_agent_by_store_version_id(
|
||||||
return await library_db.get_library_agent_by_store_version_id(
|
store_listing_version_id, user_id
|
||||||
store_listing_version_id, user_id
|
)
|
||||||
)
|
|
||||||
except NotFoundError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=str(e),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Could not fetch library agent from store version ID: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=str(e),
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"",
|
"",
|
||||||
summary="Add Marketplace Agent",
|
summary="Add Marketplace Agent",
|
||||||
status_code=status.HTTP_201_CREATED,
|
status_code=status.HTTP_201_CREATED,
|
||||||
responses={
|
|
||||||
201: {"description": "Agent added successfully"},
|
|
||||||
404: {"description": "Store listing version not found"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def add_marketplace_agent_to_library(
|
async def add_marketplace_agent_to_library(
|
||||||
store_listing_version_id: str = Body(embed=True),
|
store_listing_version_id: str = Body(embed=True),
|
||||||
@@ -210,59 +138,19 @@ async def add_marketplace_agent_to_library(
|
|||||||
) -> library_model.LibraryAgent:
|
) -> library_model.LibraryAgent:
|
||||||
"""
|
"""
|
||||||
Add an agent from the marketplace to the user's library.
|
Add an agent from the marketplace to the user's library.
|
||||||
|
|
||||||
Args:
|
|
||||||
store_listing_version_id: ID of the store listing version to add.
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
library_model.LibraryAgent: Agent added to the library
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException(404): If the listing version is not found.
|
|
||||||
HTTPException(500): If a server/database error occurs.
|
|
||||||
"""
|
"""
|
||||||
try:
|
agent = await library_db.add_store_agent_to_library(
|
||||||
agent = await library_db.add_store_agent_to_library(
|
store_listing_version_id=store_listing_version_id,
|
||||||
store_listing_version_id=store_listing_version_id,
|
user_id=user_id,
|
||||||
user_id=user_id,
|
)
|
||||||
)
|
if source != "onboarding":
|
||||||
if source != "onboarding":
|
await complete_onboarding_step(user_id, OnboardingStep.MARKETPLACE_ADD_AGENT)
|
||||||
await complete_onboarding_step(
|
return agent
|
||||||
user_id, OnboardingStep.MARKETPLACE_ADD_AGENT
|
|
||||||
)
|
|
||||||
return agent
|
|
||||||
|
|
||||||
except store_exceptions.AgentNotFoundError as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Could not find store listing version {store_listing_version_id} "
|
|
||||||
"to add to library"
|
|
||||||
)
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
|
||||||
except DatabaseError as e:
|
|
||||||
logger.error(f"Database error while adding agent to library: {e}", e)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail={"message": str(e), "hint": "Inspect DB logs for details."},
|
|
||||||
) from e
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error while adding agent to library: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail={
|
|
||||||
"message": str(e),
|
|
||||||
"hint": "Check server logs for more information.",
|
|
||||||
},
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
@router.patch(
|
||||||
"/{library_agent_id}",
|
"/{library_agent_id}",
|
||||||
summary="Update Library Agent",
|
summary="Update Library Agent",
|
||||||
responses={
|
|
||||||
200: {"description": "Agent updated successfully"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def update_library_agent(
|
async def update_library_agent(
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
@@ -271,52 +159,21 @@ async def update_library_agent(
|
|||||||
) -> library_model.LibraryAgent:
|
) -> library_model.LibraryAgent:
|
||||||
"""
|
"""
|
||||||
Update the library agent with the given fields.
|
Update the library agent with the given fields.
|
||||||
|
|
||||||
Args:
|
|
||||||
library_agent_id: ID of the library agent to update.
|
|
||||||
payload: Fields to update (auto_update_version, is_favorite, etc.).
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException(500): If a server/database error occurs.
|
|
||||||
"""
|
"""
|
||||||
try:
|
return await library_db.update_library_agent(
|
||||||
return await library_db.update_library_agent(
|
library_agent_id=library_agent_id,
|
||||||
library_agent_id=library_agent_id,
|
user_id=user_id,
|
||||||
user_id=user_id,
|
auto_update_version=payload.auto_update_version,
|
||||||
auto_update_version=payload.auto_update_version,
|
graph_version=payload.graph_version,
|
||||||
graph_version=payload.graph_version,
|
is_favorite=payload.is_favorite,
|
||||||
is_favorite=payload.is_favorite,
|
is_archived=payload.is_archived,
|
||||||
is_archived=payload.is_archived,
|
settings=payload.settings,
|
||||||
settings=payload.settings,
|
)
|
||||||
)
|
|
||||||
except NotFoundError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=str(e),
|
|
||||||
) from e
|
|
||||||
except DatabaseError as e:
|
|
||||||
logger.error(f"Database error while updating library agent: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail={"message": str(e), "hint": "Verify DB connection."},
|
|
||||||
) from e
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error while updating library agent: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail={"message": str(e), "hint": "Check server logs."},
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"/{library_agent_id}",
|
"/{library_agent_id}",
|
||||||
summary="Delete Library Agent",
|
summary="Delete Library Agent",
|
||||||
responses={
|
|
||||||
204: {"description": "Agent deleted successfully"},
|
|
||||||
404: {"description": "Agent not found"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def delete_library_agent(
|
async def delete_library_agent(
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
@@ -324,28 +181,11 @@ async def delete_library_agent(
|
|||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Soft-delete the specified library agent.
|
Soft-delete the specified library agent.
|
||||||
|
|
||||||
Args:
|
|
||||||
library_agent_id: ID of the library agent to delete.
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
204 No Content if successful.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException(404): If the agent does not exist.
|
|
||||||
HTTPException(500): If a server/database error occurs.
|
|
||||||
"""
|
"""
|
||||||
try:
|
await library_db.delete_library_agent(
|
||||||
await library_db.delete_library_agent(
|
library_agent_id=library_agent_id, user_id=user_id
|
||||||
library_agent_id=library_agent_id, user_id=user_id
|
)
|
||||||
)
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
|
||||||
except NotFoundError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=str(e),
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")
|
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")
|
||||||
|
|||||||
@@ -118,21 +118,6 @@ async def test_get_library_agents_success(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
|
||||||
mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents")
|
|
||||||
mock_db_call.side_effect = Exception("Test error")
|
|
||||||
|
|
||||||
response = client.get("/agents?search_term=test")
|
|
||||||
assert response.status_code == 500
|
|
||||||
mock_db_call.assert_called_once_with(
|
|
||||||
user_id=test_user_id,
|
|
||||||
search_term="test",
|
|
||||||
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
|
|
||||||
page=1,
|
|
||||||
page_size=15,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_favorite_library_agents_success(
|
async def test_get_favorite_library_agents_success(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
@@ -190,23 +175,6 @@ async def test_get_favorite_library_agents_success(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_favorite_library_agents_error(
|
|
||||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
|
||||||
):
|
|
||||||
mock_db_call = mocker.patch(
|
|
||||||
"backend.api.features.library.db.list_favorite_library_agents"
|
|
||||||
)
|
|
||||||
mock_db_call.side_effect = Exception("Test error")
|
|
||||||
|
|
||||||
response = client.get("/agents/favorites")
|
|
||||||
assert response.status_code == 500
|
|
||||||
mock_db_call.assert_called_once_with(
|
|
||||||
user_id=test_user_id,
|
|
||||||
page=1,
|
|
||||||
page_size=15,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_agent_to_library_success(
|
def test_add_agent_to_library_success(
|
||||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||||
):
|
):
|
||||||
@@ -258,19 +226,3 @@ def test_add_agent_to_library_success(
|
|||||||
store_listing_version_id="test-version-id", user_id=test_user_id
|
store_listing_version_id="test-version-id", user_id=test_user_id
|
||||||
)
|
)
|
||||||
mock_complete_onboarding.assert_awaited_once()
|
mock_complete_onboarding.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
|
||||||
mock_db_call = mocker.patch(
|
|
||||||
"backend.api.features.library.db.add_store_agent_to_library"
|
|
||||||
)
|
|
||||||
mock_db_call.side_effect = Exception("Test error")
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/agents", json={"store_listing_version_id": "test-version-id"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 500
|
|
||||||
assert "detail" in response.json() # Verify error response structure
|
|
||||||
mock_db_call.assert_called_once_with(
|
|
||||||
store_listing_version_id="test-version-id", user_id=test_user_id
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -454,6 +454,7 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
|||||||
total_processed = 0
|
total_processed = 0
|
||||||
total_success = 0
|
total_success = 0
|
||||||
total_failed = 0
|
total_failed = 0
|
||||||
|
all_errors: dict[str, int] = {} # Aggregate errors across all content types
|
||||||
|
|
||||||
# Process content types in explicit order
|
# Process content types in explicit order
|
||||||
processing_order = [
|
processing_order = [
|
||||||
@@ -499,23 +500,12 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
|||||||
success = sum(1 for result in results if result is True)
|
success = sum(1 for result in results if result is True)
|
||||||
failed = len(results) - success
|
failed = len(results) - success
|
||||||
|
|
||||||
# Aggregate unique errors to avoid Sentry spam
|
# Aggregate errors across all content types
|
||||||
if failed > 0:
|
if failed > 0:
|
||||||
# Group errors by type and message
|
|
||||||
error_summary: dict[str, int] = {}
|
|
||||||
for result in results:
|
for result in results:
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
error_key = f"{type(result).__name__}: {str(result)}"
|
error_key = f"{type(result).__name__}: {str(result)}"
|
||||||
error_summary[error_key] = error_summary.get(error_key, 0) + 1
|
all_errors[error_key] = all_errors.get(error_key, 0) + 1
|
||||||
|
|
||||||
# Log aggregated error summary
|
|
||||||
error_details = ", ".join(
|
|
||||||
f"{error} ({count}x)" for error, count in error_summary.items()
|
|
||||||
)
|
|
||||||
logger.error(
|
|
||||||
f"{content_type.value}: {failed}/{len(results)} embeddings failed. "
|
|
||||||
f"Errors: {error_details}"
|
|
||||||
)
|
|
||||||
|
|
||||||
results_by_type[content_type.value] = {
|
results_by_type[content_type.value] = {
|
||||||
"processed": len(missing_items),
|
"processed": len(missing_items),
|
||||||
@@ -542,6 +532,13 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
|||||||
"error": str(e),
|
"error": str(e),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Log aggregated errors once at the end
|
||||||
|
if all_errors:
|
||||||
|
error_details = ", ".join(
|
||||||
|
f"{error} ({count}x)" for error, count in all_errors.items()
|
||||||
|
)
|
||||||
|
logger.error(f"Embedding backfill errors: {error_details}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"by_type": results_by_type,
|
"by_type": results_by_type,
|
||||||
"totals": {
|
"totals": {
|
||||||
|
|||||||
@@ -261,14 +261,36 @@ async def get_onboarding_agents(
|
|||||||
return await get_recommended_agents(user_id)
|
return await get_recommended_agents(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingStatusResponse(pydantic.BaseModel):
|
||||||
|
"""Response for onboarding status check."""
|
||||||
|
|
||||||
|
is_onboarding_enabled: bool
|
||||||
|
is_chat_enabled: bool
|
||||||
|
|
||||||
|
|
||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
"/onboarding/enabled",
|
"/onboarding/enabled",
|
||||||
summary="Is onboarding enabled",
|
summary="Is onboarding enabled",
|
||||||
tags=["onboarding", "public"],
|
tags=["onboarding", "public"],
|
||||||
dependencies=[Security(requires_user)],
|
response_model=OnboardingStatusResponse,
|
||||||
)
|
)
|
||||||
async def is_onboarding_enabled() -> bool:
|
async def is_onboarding_enabled(
|
||||||
return await onboarding_enabled()
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
) -> OnboardingStatusResponse:
|
||||||
|
# Check if chat is enabled for user
|
||||||
|
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
|
||||||
|
|
||||||
|
# If chat is enabled, skip legacy onboarding
|
||||||
|
if is_chat_enabled:
|
||||||
|
return OnboardingStatusResponse(
|
||||||
|
is_onboarding_enabled=False,
|
||||||
|
is_chat_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OnboardingStatusResponse(
|
||||||
|
is_onboarding_enabled=await onboarding_enabled(),
|
||||||
|
is_chat_enabled=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@v1_router.post(
|
@v1_router.post(
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
# Workspace API feature module
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
"""
|
||||||
|
Workspace API routes for managing user file storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Annotated
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from backend.data.workspace import get_workspace, get_workspace_file
|
||||||
|
from backend.util.workspace_storage import get_workspace_storage
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_filename_for_header(filename: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize filename for Content-Disposition header to prevent header injection.
|
||||||
|
|
||||||
|
Removes/replaces characters that could break the header or inject new headers.
|
||||||
|
Uses RFC5987 encoding for non-ASCII characters.
|
||||||
|
"""
|
||||||
|
# Remove CR, LF, and null bytes (header injection prevention)
|
||||||
|
sanitized = re.sub(r"[\r\n\x00]", "", filename)
|
||||||
|
# Escape quotes
|
||||||
|
sanitized = sanitized.replace('"', '\\"')
|
||||||
|
# For non-ASCII, use RFC5987 filename* parameter
|
||||||
|
# Check if filename has non-ASCII characters
|
||||||
|
try:
|
||||||
|
sanitized.encode("ascii")
|
||||||
|
return f'attachment; filename="{sanitized}"'
|
||||||
|
except UnicodeEncodeError:
|
||||||
|
# Use RFC5987 encoding for UTF-8 filenames
|
||||||
|
encoded = quote(sanitized, safe="")
|
||||||
|
return f"attachment; filename*=UTF-8''{encoded}"
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = fastapi.APIRouter(
|
||||||
|
dependencies=[fastapi.Security(requires_user)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_streaming_response(content: bytes, file) -> Response:
|
||||||
|
"""Create a streaming response for file content."""
|
||||||
|
return Response(
|
||||||
|
content=content,
|
||||||
|
media_type=file.mimeType,
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||||
|
"Content-Length": str(len(content)),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_file_download_response(file) -> Response:
|
||||||
|
"""
|
||||||
|
Create a download response for a workspace file.
|
||||||
|
|
||||||
|
Handles both local storage (direct streaming) and GCS (signed URL redirect
|
||||||
|
with fallback to streaming).
|
||||||
|
"""
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
|
||||||
|
# For local storage, stream the file directly
|
||||||
|
if file.storagePath.startswith("local://"):
|
||||||
|
content = await storage.retrieve(file.storagePath)
|
||||||
|
return _create_streaming_response(content, file)
|
||||||
|
|
||||||
|
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||||
|
try:
|
||||||
|
url = await storage.get_download_url(file.storagePath, expires_in=300)
|
||||||
|
# If we got back an API path (fallback), stream directly instead
|
||||||
|
if url.startswith("/api/"):
|
||||||
|
content = await storage.retrieve(file.storagePath)
|
||||||
|
return _create_streaming_response(content, file)
|
||||||
|
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||||
|
except Exception as e:
|
||||||
|
# Log the signed URL failure with context
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get signed URL for file {file.id} "
|
||||||
|
f"(storagePath={file.storagePath}): {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# Fall back to streaming directly from GCS
|
||||||
|
try:
|
||||||
|
content = await storage.retrieve(file.storagePath)
|
||||||
|
return _create_streaming_response(content, file)
|
||||||
|
except Exception as fallback_error:
|
||||||
|
logger.error(
|
||||||
|
f"Fallback streaming also failed for file {file.id} "
|
||||||
|
f"(storagePath={file.storagePath}): {fallback_error}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/files/{file_id}/download",
|
||||||
|
summary="Download file by ID",
|
||||||
|
)
|
||||||
|
async def download_file(
|
||||||
|
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||||
|
file_id: str,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Download a file by its ID.
|
||||||
|
|
||||||
|
Returns the file content directly or redirects to a signed URL for GCS.
|
||||||
|
"""
|
||||||
|
workspace = await get_workspace(user_id)
|
||||||
|
if workspace is None:
|
||||||
|
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
|
||||||
|
|
||||||
|
file = await get_workspace_file(file_id, workspace.id)
|
||||||
|
if file is None:
|
||||||
|
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
|
return await _create_file_download_response(file)
|
||||||
@@ -32,6 +32,7 @@ import backend.api.features.postmark.postmark
|
|||||||
import backend.api.features.store.model
|
import backend.api.features.store.model
|
||||||
import backend.api.features.store.routes
|
import backend.api.features.store.routes
|
||||||
import backend.api.features.v1
|
import backend.api.features.v1
|
||||||
|
import backend.api.features.workspace.routes as workspace_routes
|
||||||
import backend.data.block
|
import backend.data.block
|
||||||
import backend.data.db
|
import backend.data.db
|
||||||
import backend.data.graph
|
import backend.data.graph
|
||||||
@@ -52,6 +53,7 @@ from backend.util.exceptions import (
|
|||||||
)
|
)
|
||||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||||
from backend.util.service import UnhealthyServiceError
|
from backend.util.service import UnhealthyServiceError
|
||||||
|
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||||
|
|
||||||
from .external.fastapi_app import external_api
|
from .external.fastapi_app import external_api
|
||||||
from .features.analytics import router as analytics_router
|
from .features.analytics import router as analytics_router
|
||||||
@@ -124,6 +126,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error shutting down cloud storage handler: {e}")
|
logger.warning(f"Error shutting down cloud storage handler: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await shutdown_workspace_storage()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error shutting down workspace storage: {e}")
|
||||||
|
|
||||||
await backend.data.db.disconnect()
|
await backend.data.db.disconnect()
|
||||||
|
|
||||||
|
|
||||||
@@ -315,6 +322,11 @@ app.include_router(
|
|||||||
tags=["v2", "chat"],
|
tags=["v2", "chat"],
|
||||||
prefix="/api/chat",
|
prefix="/api/chat",
|
||||||
)
|
)
|
||||||
|
app.include_router(
|
||||||
|
workspace_routes.router,
|
||||||
|
tags=["workspace"],
|
||||||
|
prefix="/api/workspace",
|
||||||
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.oauth.router,
|
backend.api.features.oauth.router,
|
||||||
tags=["oauth"],
|
tags=["oauth"],
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -117,11 +118,13 @@ class AIImageCustomizerBlock(Block):
|
|||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_output=[
|
test_output=[
|
||||||
("image_url", "https://replicate.delivery/generated-image.jpg"),
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
|
# Use data URI to avoid HTTP requests during tests
|
||||||
"run_model": lambda *args, **kwargs: MediaFileType(
|
"run_model": lambda *args, **kwargs: MediaFileType(
|
||||||
"https://replicate.delivery/generated-image.jpg"
|
"data:image/jpeg;base64,/9j/4AAQSkZJRgABAgAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q=="
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -132,8 +135,7 @@ class AIImageCustomizerBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
@@ -141,10 +143,9 @@ class AIImageCustomizerBlock(Block):
|
|||||||
processed_images = await asyncio.gather(
|
processed_images = await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
store_media_file(
|
store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=img,
|
file=img,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True,
|
return_format="for_external_api", # Get content for Replicate API
|
||||||
)
|
)
|
||||||
for img in input_data.images
|
for img in input_data.images
|
||||||
)
|
)
|
||||||
@@ -158,7 +159,14 @@ class AIImageCustomizerBlock(Block):
|
|||||||
aspect_ratio=input_data.aspect_ratio.value,
|
aspect_ratio=input_data.aspect_ratio.value,
|
||||||
output_format=input_data.output_format.value,
|
output_format=input_data.output_format.value,
|
||||||
)
|
)
|
||||||
yield "image_url", result
|
|
||||||
|
# Store the generated image to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=result,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", stored_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "error", str(e)
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from replicate.client import Client as ReplicateClient
|
|||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -13,6 +14,8 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.file import store_media_file
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
|
|
||||||
class ImageSize(str, Enum):
|
class ImageSize(str, Enum):
|
||||||
@@ -165,11 +168,13 @@ class AIImageGeneratorBlock(Block):
|
|||||||
test_output=[
|
test_output=[
|
||||||
(
|
(
|
||||||
"image_url",
|
"image_url",
|
||||||
"https://replicate.delivery/generated-image.webp",
|
# Test output is a data URI since we now store images
|
||||||
|
lambda x: x.startswith("data:image/"),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
"_run_client": lambda *args, **kwargs: "https://replicate.delivery/generated-image.webp"
|
# Return a data URI directly so store_media_file doesn't need to download
|
||||||
|
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -318,11 +323,24 @@ class AIImageGeneratorBlock(Block):
|
|||||||
style_text = style_map.get(style, "")
|
style_text = style_map.get(style, "")
|
||||||
return f"{style_text} of" if style_text else ""
|
return f"{style_text} of" if style_text else ""
|
||||||
|
|
||||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
url = await self.generate_image(input_data, credentials)
|
url = await self.generate_image(input_data, credentials)
|
||||||
if url:
|
if url:
|
||||||
yield "image_url", url
|
# Store the generated image to the user's workspace/execution folder
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", stored_url
|
||||||
else:
|
else:
|
||||||
yield "error", "Image generation returned an empty result."
|
yield "error", "Image generation returned an empty result."
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -21,7 +22,9 @@ from backend.data.model import (
|
|||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.exceptions import BlockExecutionError
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
@@ -271,7 +274,10 @@ class AIShortformVideoCreatorBlock(Block):
|
|||||||
"voice": Voice.LILY,
|
"voice": Voice.LILY,
|
||||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||||
},
|
},
|
||||||
test_output=("video_url", "https://example.com/video.mp4"),
|
test_output=(
|
||||||
|
"video_url",
|
||||||
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
|
),
|
||||||
test_mock={
|
test_mock={
|
||||||
"create_webhook": lambda *args, **kwargs: (
|
"create_webhook": lambda *args, **kwargs: (
|
||||||
"test_uuid",
|
"test_uuid",
|
||||||
@@ -280,15 +286,21 @@ class AIShortformVideoCreatorBlock(Block):
|
|||||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||||
"check_video_status": lambda *args, **kwargs: {
|
"check_video_status": lambda *args, **kwargs: {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"videoUrl": "https://example.com/video.mp4",
|
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
|
# Use data URI to avoid HTTP requests during tests
|
||||||
|
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Create a new Webhook.site URL
|
# Create a new Webhook.site URL
|
||||||
webhook_token, webhook_url = await self.create_webhook()
|
webhook_token, webhook_url = await self.create_webhook()
|
||||||
@@ -340,7 +352,13 @@ class AIShortformVideoCreatorBlock(Block):
|
|||||||
)
|
)
|
||||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||||
logger.debug(f"Video ready: {video_url}")
|
logger.debug(f"Video ready: {video_url}")
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
|
|
||||||
|
|
||||||
class AIAdMakerVideoCreatorBlock(Block):
|
class AIAdMakerVideoCreatorBlock(Block):
|
||||||
@@ -447,7 +465,10 @@ class AIAdMakerVideoCreatorBlock(Block):
|
|||||||
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
test_output=("video_url", "https://example.com/ad.mp4"),
|
test_output=(
|
||||||
|
"video_url",
|
||||||
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
|
),
|
||||||
test_mock={
|
test_mock={
|
||||||
"create_webhook": lambda *args, **kwargs: (
|
"create_webhook": lambda *args, **kwargs: (
|
||||||
"test_uuid",
|
"test_uuid",
|
||||||
@@ -456,14 +477,21 @@ class AIAdMakerVideoCreatorBlock(Block):
|
|||||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||||
"check_video_status": lambda *args, **kwargs: {
|
"check_video_status": lambda *args, **kwargs: {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"videoUrl": "https://example.com/ad.mp4",
|
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
|
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
webhook_token, webhook_url = await self.create_webhook()
|
webhook_token, webhook_url = await self.create_webhook()
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -531,7 +559,13 @@ class AIAdMakerVideoCreatorBlock(Block):
|
|||||||
raise RuntimeError("Failed to create video: No project ID returned")
|
raise RuntimeError("Failed to create video: No project ID returned")
|
||||||
|
|
||||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
|
|
||||||
|
|
||||||
class AIScreenshotToVideoAdBlock(Block):
|
class AIScreenshotToVideoAdBlock(Block):
|
||||||
@@ -626,7 +660,10 @@ class AIScreenshotToVideoAdBlock(Block):
|
|||||||
"script": "Amazing numbers!",
|
"script": "Amazing numbers!",
|
||||||
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
||||||
},
|
},
|
||||||
test_output=("video_url", "https://example.com/screenshot.mp4"),
|
test_output=(
|
||||||
|
"video_url",
|
||||||
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
|
),
|
||||||
test_mock={
|
test_mock={
|
||||||
"create_webhook": lambda *args, **kwargs: (
|
"create_webhook": lambda *args, **kwargs: (
|
||||||
"test_uuid",
|
"test_uuid",
|
||||||
@@ -635,14 +672,21 @@ class AIScreenshotToVideoAdBlock(Block):
|
|||||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||||
"check_video_status": lambda *args, **kwargs: {
|
"check_video_status": lambda *args, **kwargs: {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"videoUrl": "https://example.com/screenshot.mp4",
|
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
|
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
webhook_token, webhook_url = await self.create_webhook()
|
webhook_token, webhook_url = await self.create_webhook()
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -710,4 +754,10 @@ class AIScreenshotToVideoAdBlock(Block):
|
|||||||
raise RuntimeError("Failed to create video: No project ID returned")
|
raise RuntimeError("Failed to create video: No project ID returned")
|
||||||
|
|
||||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
@@ -17,6 +18,8 @@ from backend.sdk import (
|
|||||||
Requests,
|
Requests,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
|
from backend.util.file import store_media_file
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
from ._config import bannerbear
|
from ._config import bannerbear
|
||||||
|
|
||||||
@@ -135,15 +138,17 @@ class BannerbearTextOverlayBlock(Block):
|
|||||||
},
|
},
|
||||||
test_output=[
|
test_output=[
|
||||||
("success", True),
|
("success", True),
|
||||||
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
("uid", "test-uid-123"),
|
("uid", "test-uid-123"),
|
||||||
("status", "completed"),
|
("status", "completed"),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
|
# Use data URI to avoid HTTP requests during tests
|
||||||
"_make_api_request": lambda *args, **kwargs: {
|
"_make_api_request": lambda *args, **kwargs: {
|
||||||
"uid": "test-uid-123",
|
"uid": "test-uid-123",
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
|
"image_url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAABAAEBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+v//Z",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -177,7 +182,12 @@ class BannerbearTextOverlayBlock(Block):
|
|||||||
raise Exception(error_msg)
|
raise Exception(error_msg)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Build the modifications array
|
# Build the modifications array
|
||||||
modifications = []
|
modifications = []
|
||||||
@@ -234,6 +244,18 @@ class BannerbearTextOverlayBlock(Block):
|
|||||||
|
|
||||||
# Synchronous request - image should be ready
|
# Synchronous request - image should be ready
|
||||||
yield "success", True
|
yield "success", True
|
||||||
yield "image_url", data.get("image_url", "")
|
|
||||||
|
# Store the generated image to workspace for persistence
|
||||||
|
image_url = data.get("image_url", "")
|
||||||
|
if image_url:
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(image_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", stored_url
|
||||||
|
else:
|
||||||
|
yield "image_url", ""
|
||||||
|
|
||||||
yield "uid", data.get("uid", "")
|
yield "uid", data.get("uid", "")
|
||||||
yield "status", data.get("status", "completed")
|
yield "status", data.get("status", "completed")
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.type import MediaFileType, convert
|
from backend.util.type import MediaFileType, convert
|
||||||
@@ -17,10 +18,10 @@ from backend.util.type import MediaFileType, convert
|
|||||||
class FileStoreBlock(Block):
|
class FileStoreBlock(Block):
|
||||||
class Input(BlockSchemaInput):
|
class Input(BlockSchemaInput):
|
||||||
file_in: MediaFileType = SchemaField(
|
file_in: MediaFileType = SchemaField(
|
||||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
description="The file to download and store. Can be a URL (https://...), data URI, or local path."
|
||||||
)
|
)
|
||||||
base_64: bool = SchemaField(
|
base_64: bool = SchemaField(
|
||||||
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
|
description="Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks).",
|
||||||
default=False,
|
default=False,
|
||||||
advanced=True,
|
advanced=True,
|
||||||
title="Produce Base64 Output",
|
title="Produce Base64 Output",
|
||||||
@@ -28,13 +29,18 @@ class FileStoreBlock(Block):
|
|||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
file_out: MediaFileType = SchemaField(
|
file_out: MediaFileType = SchemaField(
|
||||||
description="The relative path to the stored file in the temporary directory."
|
description="Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks."
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
|
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
|
||||||
description="Stores the input file in the temporary directory.",
|
description=(
|
||||||
|
"Downloads and stores a file from a URL, data URI, or local path. "
|
||||||
|
"Use this to fetch images, documents, or other files for processing. "
|
||||||
|
"In CoPilot: saves to workspace (use list_workspace_files to see it). "
|
||||||
|
"In graphs: outputs a data URI to pass to other blocks."
|
||||||
|
),
|
||||||
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
|
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
|
||||||
input_schema=FileStoreBlock.Input,
|
input_schema=FileStoreBlock.Input,
|
||||||
output_schema=FileStoreBlock.Output,
|
output_schema=FileStoreBlock.Output,
|
||||||
@@ -45,15 +51,18 @@ class FileStoreBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
# Determine return format based on user preference
|
||||||
|
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
|
||||||
|
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
|
||||||
|
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
|
||||||
|
|
||||||
yield "file_out", await store_media_file(
|
yield "file_out", await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.file_in,
|
file=input_data.file_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=input_data.base_64,
|
return_format=return_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import APIKeyCredentials, SchemaField
|
from backend.data.model import APIKeyCredentials, SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
@@ -666,8 +667,7 @@ class SendDiscordFileBlock(Block):
|
|||||||
file: MediaFileType,
|
file: MediaFileType,
|
||||||
filename: str,
|
filename: str,
|
||||||
message_content: str,
|
message_content: str,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
intents.guilds = True
|
intents.guilds = True
|
||||||
@@ -731,10 +731,9 @@ class SendDiscordFileBlock(Block):
|
|||||||
# Local file path - read from stored media file
|
# Local file path - read from stored media file
|
||||||
# This would be a path from a previous block's output
|
# This would be a path from a previous block's output
|
||||||
stored_file = await store_media_file(
|
stored_file = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=file,
|
file=file,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True, # Get as data URI
|
return_format="for_external_api", # Get content to send to Discord
|
||||||
)
|
)
|
||||||
# Now process as data URI
|
# Now process as data URI
|
||||||
header, encoded = stored_file.split(",", 1)
|
header, encoded = stored_file.split(",", 1)
|
||||||
@@ -781,8 +780,7 @@ class SendDiscordFileBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
@@ -793,8 +791,7 @@ class SendDiscordFileBlock(Block):
|
|||||||
file=input_data.file,
|
file=input_data.file,
|
||||||
filename=input_data.filename,
|
filename=input_data.filename,
|
||||||
message_content=input_data.message_content,
|
message_content=input_data.message_content,
|
||||||
graph_exec_id=graph_exec_id,
|
execution_context=execution_context,
|
||||||
user_id=user_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "status", result.get("status", "Unknown error")
|
yield "status", result.get("status", "Unknown error")
|
||||||
|
|||||||
@@ -1,71 +0,0 @@
|
|||||||
"""Text encoding block for converting special characters to escape sequences."""
|
|
||||||
|
|
||||||
import codecs
|
|
||||||
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
|
|
||||||
|
|
||||||
class TextEncoderBlock(Block):
|
|
||||||
"""
|
|
||||||
Encodes a string by converting special characters into escape sequences.
|
|
||||||
|
|
||||||
This block is the inverse of TextDecoderBlock. It takes text containing
|
|
||||||
special characters (like newlines, tabs, etc.) and converts them into
|
|
||||||
their escape sequence representations (e.g., newline becomes \\n).
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
"""Input schema for TextEncoderBlock."""
|
|
||||||
|
|
||||||
text: str = SchemaField(
|
|
||||||
description="A string containing special characters to be encoded",
|
|
||||||
placeholder="Your text with newlines and quotes to encode",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
"""Output schema for TextEncoderBlock."""
|
|
||||||
|
|
||||||
encoded_text: str = SchemaField(
|
|
||||||
description="The encoded text with special characters converted to escape sequences"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="5185f32e-4b65-4ecf-8fbb-873f003f09d6",
|
|
||||||
description="Encodes a string by converting special characters into escape sequences",
|
|
||||||
categories={BlockCategory.TEXT},
|
|
||||||
input_schema=TextEncoderBlock.Input,
|
|
||||||
output_schema=TextEncoderBlock.Output,
|
|
||||||
test_input={
|
|
||||||
"text": """Hello
|
|
||||||
World!
|
|
||||||
This is a "quoted" string."""
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"encoded_text",
|
|
||||||
"""Hello\\nWorld!\\nThis is a "quoted" string.""",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
|
||||||
"""
|
|
||||||
Encode the input text by converting special characters to escape sequences.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_data: The input containing the text to encode.
|
|
||||||
**kwargs: Additional keyword arguments (unused).
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
The encoded text with escape sequences.
|
|
||||||
"""
|
|
||||||
encoded_text = codecs.encode(input_data.text, "unicode_escape").decode("utf-8")
|
|
||||||
yield "encoded_text", encoded_text
|
|
||||||
@@ -17,8 +17,11 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import ClientResponseError, Requests
|
from backend.util.request import ClientResponseError, Requests
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -64,9 +67,13 @@ class AIVideoGeneratorBlock(Block):
|
|||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
test_output=[("video_url", "https://fal.media/files/example/video.mp4")],
|
test_output=[
|
||||||
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("video_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
"generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4"
|
# Use data URI to avoid HTTP requests during tests
|
||||||
|
"generate_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -208,11 +215,22 @@ class AIVideoGeneratorBlock(Block):
|
|||||||
raise RuntimeError(f"API request failed: {str(e)}")
|
raise RuntimeError(f"API request failed: {str(e)}")
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: FalCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: FalCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
video_url = await self.generate_video(input_data, credentials)
|
video_url = await self.generate_video(input_data, credentials)
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
yield "error", error_message
|
yield "error", error_message
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -121,10 +122,12 @@ class AIImageEditorBlock(Block):
|
|||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_output=[
|
test_output=[
|
||||||
("output_image", "https://replicate.com/output/edited-image.png"),
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("output_image", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
"run_model": lambda *args, **kwargs: "https://replicate.com/output/edited-image.png",
|
# Use data URI to avoid HTTP requests during tests
|
||||||
|
"run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
@@ -134,8 +137,7 @@ class AIImageEditorBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
result = await self.run_model(
|
result = await self.run_model(
|
||||||
@@ -144,20 +146,25 @@ class AIImageEditorBlock(Block):
|
|||||||
prompt=input_data.prompt,
|
prompt=input_data.prompt,
|
||||||
input_image_b64=(
|
input_image_b64=(
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.input_image,
|
file=input_data.input_image,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True,
|
return_format="for_external_api", # Get content for Replicate API
|
||||||
)
|
)
|
||||||
if input_data.input_image
|
if input_data.input_image
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
aspect_ratio=input_data.aspect_ratio.value,
|
aspect_ratio=input_data.aspect_ratio.value,
|
||||||
seed=input_data.seed,
|
seed=input_data.seed,
|
||||||
user_id=user_id,
|
user_id=execution_context.user_id or "",
|
||||||
graph_exec_id=graph_exec_id,
|
graph_exec_id=execution_context.graph_exec_id or "",
|
||||||
)
|
)
|
||||||
yield "output_image", result
|
# Store the generated image to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=result,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "output_image", stored_url
|
||||||
|
|
||||||
async def run_model(
|
async def run_model(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
@@ -95,8 +96,7 @@ def _make_mime_text(
|
|||||||
|
|
||||||
async def create_mime_message(
|
async def create_mime_message(
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
||||||
|
|
||||||
@@ -117,12 +117,12 @@ async def create_mime_message(
|
|||||||
if input_data.attachments:
|
if input_data.attachments:
|
||||||
for attach in input_data.attachments:
|
for attach in input_data.attachments:
|
||||||
local_path = await store_media_file(
|
local_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=attach,
|
file=attach,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
part = MIMEBase("application", "octet-stream")
|
part = MIMEBase("application", "octet-stream")
|
||||||
with open(abs_path, "rb") as f:
|
with open(abs_path, "rb") as f:
|
||||||
part.set_payload(f.read())
|
part.set_payload(f.read())
|
||||||
@@ -582,27 +582,25 @@ class GmailSendBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
result = await self._send_email(
|
result = await self._send_email(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "result", result
|
yield "result", result
|
||||||
|
|
||||||
async def _send_email(
|
async def _send_email(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not input_data.to or not input_data.subject or not input_data.body:
|
if not input_data.to or not input_data.subject or not input_data.body:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"At least one recipient, subject, and body are required for sending an email"
|
"At least one recipient, subject, and body are required for sending an email"
|
||||||
)
|
)
|
||||||
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
|
raw_message = await create_mime_message(input_data, execution_context)
|
||||||
sent_message = await asyncio.to_thread(
|
sent_message = await asyncio.to_thread(
|
||||||
lambda: service.users()
|
lambda: service.users()
|
||||||
.messages()
|
.messages()
|
||||||
@@ -692,30 +690,28 @@ class GmailCreateDraftBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
result = await self._create_draft(
|
result = await self._create_draft(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "result", GmailDraftResult(
|
yield "result", GmailDraftResult(
|
||||||
id=result["id"], message_id=result["message"]["id"], status="draft_created"
|
id=result["id"], message_id=result["message"]["id"], status="draft_created"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _create_draft(
|
async def _create_draft(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not input_data.to or not input_data.subject:
|
if not input_data.to or not input_data.subject:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"At least one recipient and subject are required for creating a draft"
|
"At least one recipient and subject are required for creating a draft"
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
|
raw_message = await create_mime_message(input_data, execution_context)
|
||||||
draft = await asyncio.to_thread(
|
draft = await asyncio.to_thread(
|
||||||
lambda: service.users()
|
lambda: service.users()
|
||||||
.drafts()
|
.drafts()
|
||||||
@@ -1100,7 +1096,7 @@ class GmailGetThreadBlock(GmailBase):
|
|||||||
|
|
||||||
|
|
||||||
async def _build_reply_message(
|
async def _build_reply_message(
|
||||||
service, input_data, graph_exec_id: str, user_id: str
|
service, input_data, execution_context: ExecutionContext
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Builds a reply MIME message for Gmail threads.
|
Builds a reply MIME message for Gmail threads.
|
||||||
@@ -1190,12 +1186,12 @@ async def _build_reply_message(
|
|||||||
# Handle attachments
|
# Handle attachments
|
||||||
for attach in input_data.attachments:
|
for attach in input_data.attachments:
|
||||||
local_path = await store_media_file(
|
local_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=attach,
|
file=attach,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
part = MIMEBase("application", "octet-stream")
|
part = MIMEBase("application", "octet-stream")
|
||||||
with open(abs_path, "rb") as f:
|
with open(abs_path, "rb") as f:
|
||||||
part.set_payload(f.read())
|
part.set_payload(f.read())
|
||||||
@@ -1311,16 +1307,14 @@ class GmailReplyBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
message = await self._reply(
|
message = await self._reply(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "messageId", message["id"]
|
yield "messageId", message["id"]
|
||||||
yield "threadId", message.get("threadId", input_data.threadId)
|
yield "threadId", message.get("threadId", input_data.threadId)
|
||||||
@@ -1343,11 +1337,11 @@ class GmailReplyBlock(GmailBase):
|
|||||||
yield "email", email
|
yield "email", email
|
||||||
|
|
||||||
async def _reply(
|
async def _reply(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# Build the reply message using the shared helper
|
# Build the reply message using the shared helper
|
||||||
raw, thread_id = await _build_reply_message(
|
raw, thread_id = await _build_reply_message(
|
||||||
service, input_data, graph_exec_id, user_id
|
service, input_data, execution_context
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send the message
|
# Send the message
|
||||||
@@ -1441,16 +1435,14 @@ class GmailDraftReplyBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
draft = await self._create_draft_reply(
|
draft = await self._create_draft_reply(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "draftId", draft["id"]
|
yield "draftId", draft["id"]
|
||||||
yield "messageId", draft["message"]["id"]
|
yield "messageId", draft["message"]["id"]
|
||||||
@@ -1458,11 +1450,11 @@ class GmailDraftReplyBlock(GmailBase):
|
|||||||
yield "status", "draft_created"
|
yield "status", "draft_created"
|
||||||
|
|
||||||
async def _create_draft_reply(
|
async def _create_draft_reply(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# Build the reply message using the shared helper
|
# Build the reply message using the shared helper
|
||||||
raw, thread_id = await _build_reply_message(
|
raw, thread_id = await _build_reply_message(
|
||||||
service, input_data, graph_exec_id, user_id
|
service, input_data, execution_context
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create draft with proper thread association
|
# Create draft with proper thread association
|
||||||
@@ -1629,23 +1621,21 @@ class GmailForwardBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
result = await self._forward_message(
|
result = await self._forward_message(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "messageId", result["id"]
|
yield "messageId", result["id"]
|
||||||
yield "threadId", result.get("threadId", "")
|
yield "threadId", result.get("threadId", "")
|
||||||
yield "status", "forwarded"
|
yield "status", "forwarded"
|
||||||
|
|
||||||
async def _forward_message(
|
async def _forward_message(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not input_data.to:
|
if not input_data.to:
|
||||||
raise ValueError("At least one recipient is required for forwarding")
|
raise ValueError("At least one recipient is required for forwarding")
|
||||||
@@ -1727,12 +1717,12 @@ To: {original_to}
|
|||||||
# Add any additional attachments
|
# Add any additional attachments
|
||||||
for attach in input_data.additionalAttachments:
|
for attach in input_data.additionalAttachments:
|
||||||
local_path = await store_media_file(
|
local_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=attach,
|
file=attach,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
part = MIMEBase("application", "octet-stream")
|
part = MIMEBase("application", "octet-stream")
|
||||||
with open(abs_path, "rb") as f:
|
with open(abs_path, "rb") as f:
|
||||||
part.set_payload(f.read())
|
part.set_payload(f.read())
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
@@ -116,10 +117,9 @@ class SendWebRequestBlock(Block):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _prepare_files(
|
async def _prepare_files(
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
files_name: str,
|
files_name: str,
|
||||||
files: list[MediaFileType],
|
files: list[MediaFileType],
|
||||||
user_id: str,
|
|
||||||
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
|
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
|
||||||
"""
|
"""
|
||||||
Prepare files for the request by storing them and reading their content.
|
Prepare files for the request by storing them and reading their content.
|
||||||
@@ -127,11 +127,16 @@ class SendWebRequestBlock(Block):
|
|||||||
(files_name, (filename, BytesIO, mime_type))
|
(files_name, (filename, BytesIO, mime_type))
|
||||||
"""
|
"""
|
||||||
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
if graph_exec_id is None:
|
||||||
|
raise ValueError("graph_exec_id is required for file operations")
|
||||||
|
|
||||||
for media in files:
|
for media in files:
|
||||||
# Normalise to a list so we can repeat the same key
|
# Normalise to a list so we can repeat the same key
|
||||||
rel_path = await store_media_file(
|
rel_path = await store_media_file(
|
||||||
graph_exec_id, media, user_id, return_content=False
|
file=media,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, rel_path)
|
abs_path = get_exec_file_path(graph_exec_id, rel_path)
|
||||||
async with aiofiles.open(abs_path, "rb") as f:
|
async with aiofiles.open(abs_path, "rb") as f:
|
||||||
@@ -143,7 +148,7 @@ class SendWebRequestBlock(Block):
|
|||||||
return files_payload
|
return files_payload
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs
|
self, input_data: Input, *, execution_context: ExecutionContext, **kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# ─── Parse/normalise body ────────────────────────────────────
|
# ─── Parse/normalise body ────────────────────────────────────
|
||||||
body = input_data.body
|
body = input_data.body
|
||||||
@@ -174,7 +179,7 @@ class SendWebRequestBlock(Block):
|
|||||||
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
||||||
if use_files:
|
if use_files:
|
||||||
files_payload = await self._prepare_files(
|
files_payload = await self._prepare_files(
|
||||||
graph_exec_id, input_data.files_name, input_data.files, user_id
|
execution_context, input_data.files_name, input_data.files
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enforce body format rules
|
# Enforce body format rules
|
||||||
@@ -238,9 +243,8 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
credentials: HostScopedCredentials,
|
credentials: HostScopedCredentials,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
||||||
@@ -271,6 +275,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
|||||||
|
|
||||||
# Use parent class run method
|
# Use parent class run method
|
||||||
async for output_name, output_data in super().run(
|
async for output_name, output_data in super().run(
|
||||||
base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs
|
base_input, execution_context=execution_context, **kwargs
|
||||||
):
|
):
|
||||||
yield output_name, output_data
|
yield output_name, output_data
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.mock import MockObject
|
from backend.util.mock import MockObject
|
||||||
@@ -462,18 +463,21 @@ class AgentFileInputBlock(AgentInputBlock):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
if not input_data.value:
|
if not input_data.value:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Determine return format based on user preference
|
||||||
|
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
|
||||||
|
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
|
||||||
|
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
|
||||||
|
|
||||||
yield "result", await store_media_file(
|
yield "result", await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.value,
|
file=input_data.value,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=input_data.base_64,
|
return_format=return_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -115,7 +115,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
|
||||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||||
# AI/ML API models
|
# AI/ML API models
|
||||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||||
@@ -280,9 +279,6 @@ MODEL_METADATA = {
|
|||||||
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
||||||
), # claude-haiku-4-5-20251001
|
), # claude-haiku-4-5-20251001
|
||||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
|
||||||
"anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2
|
|
||||||
), # claude-3-7-sonnet-20250219
|
|
||||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||||
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
||||||
), # claude-3-haiku-20240307
|
), # claude-3-haiku-20240307
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Literal, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
from moviepy.video.fx.Loop import Loop
|
from moviepy.video.fx.Loop import Loop
|
||||||
@@ -13,6 +13,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
@@ -46,18 +47,19 @@ class MediaDurationBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# 1) Store the input media locally
|
# 1) Store the input media locally
|
||||||
local_media_path = await store_media_file(
|
local_media_path = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.media_in,
|
file=input_data.media_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=False,
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
media_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_media_path
|
||||||
)
|
)
|
||||||
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
|
|
||||||
|
|
||||||
# 2) Load the clip
|
# 2) Load the clip
|
||||||
if input_data.is_video:
|
if input_data.is_video:
|
||||||
@@ -88,10 +90,6 @@ class LoopVideoBlock(Block):
|
|||||||
default=None,
|
default=None,
|
||||||
ge=1,
|
ge=1,
|
||||||
)
|
)
|
||||||
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
|
||||||
description="How to return the output video. Either a relative path or base64 data URI.",
|
|
||||||
default="file_path",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
video_out: str = SchemaField(
|
video_out: str = SchemaField(
|
||||||
@@ -111,17 +109,19 @@ class LoopVideoBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
node_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
graph_exec_id: str,
|
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
# 1) Store the input video locally
|
# 1) Store the input video locally
|
||||||
local_video_path = await store_media_file(
|
local_video_path = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.video_in,
|
file=input_data.video_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=False,
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||||
|
|
||||||
@@ -149,12 +149,11 @@ class LoopVideoBlock(Block):
|
|||||||
looped_clip = looped_clip.with_audio(clip.audio)
|
looped_clip = looped_clip.with_audio(clip.audio)
|
||||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||||
|
|
||||||
# Return as data URI
|
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
video_out = await store_media_file(
|
video_out = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=output_filename,
|
file=output_filename,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=input_data.output_return_type == "data_uri",
|
return_format="for_block_output",
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "video_out", video_out
|
yield "video_out", video_out
|
||||||
@@ -177,10 +176,6 @@ class AddAudioToVideoBlock(Block):
|
|||||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||||
default=1.0,
|
default=1.0,
|
||||||
)
|
)
|
||||||
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
|
||||||
description="Return the final output as a relative path or base64 data URI.",
|
|
||||||
default="file_path",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
video_out: MediaFileType = SchemaField(
|
video_out: MediaFileType = SchemaField(
|
||||||
@@ -200,23 +195,24 @@ class AddAudioToVideoBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
node_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
graph_exec_id: str,
|
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
# 1) Store the inputs locally
|
# 1) Store the inputs locally
|
||||||
local_video_path = await store_media_file(
|
local_video_path = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.video_in,
|
file=input_data.video_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=False,
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
local_audio_path = await store_media_file(
|
local_audio_path = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.audio_in,
|
file=input_data.audio_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=False,
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
||||||
@@ -240,12 +236,11 @@ class AddAudioToVideoBlock(Block):
|
|||||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||||
|
|
||||||
# 5) Return either path or data URI
|
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
video_out = await store_media_file(
|
video_out = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=output_filename,
|
file=output_filename,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=input_data.output_return_type == "data_uri",
|
return_format="for_block_output",
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "video_out", video_out
|
yield "video_out", video_out
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -112,8 +113,7 @@ class ScreenshotWebPageBlock(Block):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def take_screenshot(
|
async def take_screenshot(
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
url: str,
|
url: str,
|
||||||
viewport_width: int,
|
viewport_width: int,
|
||||||
viewport_height: int,
|
viewport_height: int,
|
||||||
@@ -155,12 +155,11 @@ class ScreenshotWebPageBlock(Block):
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"image": await store_media_file(
|
"image": await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=MediaFileType(
|
file=MediaFileType(
|
||||||
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
|
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
|
||||||
),
|
),
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True,
|
return_format="for_block_output",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,15 +168,13 @@ class ScreenshotWebPageBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
screenshot_data = await self.take_screenshot(
|
screenshot_data = await self.take_screenshot(
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
graph_exec_id=graph_exec_id,
|
execution_context=execution_context,
|
||||||
user_id=user_id,
|
|
||||||
url=input_data.url,
|
url=input_data.url,
|
||||||
viewport_width=input_data.viewport_width,
|
viewport_width=input_data.viewport_width,
|
||||||
viewport_height=input_data.viewport_height,
|
viewport_height=input_data.viewport_height,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import ContributorDetails, SchemaField
|
from backend.data.model import ContributorDetails, SchemaField
|
||||||
from backend.util.file import get_exec_file_path, store_media_file
|
from backend.util.file import get_exec_file_path, store_media_file
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
@@ -98,7 +99,7 @@ class ReadSpreadsheetBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
import csv
|
import csv
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
@@ -106,14 +107,16 @@ class ReadSpreadsheetBlock(Block):
|
|||||||
# Determine data source - prefer file_input if provided, otherwise use contents
|
# Determine data source - prefer file_input if provided, otherwise use contents
|
||||||
if input_data.file_input:
|
if input_data.file_input:
|
||||||
stored_file_path = await store_media_file(
|
stored_file_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.file_input,
|
file=input_data.file_input,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get full file path
|
# Get full file path
|
||||||
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
file_path = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, stored_file_path
|
||||||
|
)
|
||||||
if not Path(file_path).exists():
|
if not Path(file_path).exists():
|
||||||
raise ValueError(f"File does not exist: {file_path}")
|
raise ValueError(f"File does not exist: {file_path}")
|
||||||
|
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum):
|
|||||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||||
|
|
||||||
# Anthropic
|
# Anthropic
|
||||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_name(self) -> str:
|
def provider_name(self) -> str:
|
||||||
@@ -137,7 +137,7 @@ class StagehandObserveBlock(Block):
|
|||||||
model: StagehandRecommendedLlmModel = SchemaField(
|
model: StagehandRecommendedLlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
description="LLM to use for Stagehand (provider is inferred)",
|
description="LLM to use for Stagehand (provider is inferred)",
|
||||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
model_credentials: AICredentials = AICredentialsField()
|
model_credentials: AICredentials = AICredentialsField()
|
||||||
@@ -230,7 +230,7 @@ class StagehandActBlock(Block):
|
|||||||
model: StagehandRecommendedLlmModel = SchemaField(
|
model: StagehandRecommendedLlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
description="LLM to use for Stagehand (provider is inferred)",
|
description="LLM to use for Stagehand (provider is inferred)",
|
||||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
model_credentials: AICredentials = AICredentialsField()
|
model_credentials: AICredentials = AICredentialsField()
|
||||||
@@ -330,7 +330,7 @@ class StagehandExtractBlock(Block):
|
|||||||
model: StagehandRecommendedLlmModel = SchemaField(
|
model: StagehandRecommendedLlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
description="LLM to use for Stagehand (provider is inferred)",
|
description="LLM to use for Stagehand (provider is inferred)",
|
||||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
model_credentials: AICredentials = AICredentialsField()
|
model_credentials: AICredentials = AICredentialsField()
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -17,7 +18,9 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
@@ -102,7 +105,7 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
test_output=[
|
test_output=[
|
||||||
(
|
(
|
||||||
"video_url",
|
"video_url",
|
||||||
"https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
@@ -110,9 +113,10 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
|
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
|
||||||
"status": "created",
|
"status": "created",
|
||||||
},
|
},
|
||||||
|
# Use data URI to avoid HTTP requests during tests
|
||||||
"get_clip_status": lambda *args, **kwargs: {
|
"get_clip_status": lambda *args, **kwargs: {
|
||||||
"status": "done",
|
"status": "done",
|
||||||
"result_url": "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
|
"result_url": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -138,7 +142,12 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Create the clip
|
# Create the clip
|
||||||
payload = {
|
payload = {
|
||||||
@@ -165,7 +174,14 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
for _ in range(input_data.max_polling_attempts):
|
for _ in range(input_data.max_polling_attempts):
|
||||||
status_response = await self.get_clip_status(credentials.api_key, clip_id)
|
status_response = await self.get_clip_status(credentials.api_key, clip_id)
|
||||||
if status_response["status"] == "done":
|
if status_response["status"] == "done":
|
||||||
yield "video_url", status_response["result_url"]
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
video_url = status_response["result_url"]
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
return
|
return
|
||||||
elif status_response["status"] == "error":
|
elif status_response["status"] == "error":
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from backend.blocks.iteration import StepThroughItemsBlock
|
|||||||
from backend.blocks.llm import AITextSummarizerBlock
|
from backend.blocks.llm import AITextSummarizerBlock
|
||||||
from backend.blocks.text import ExtractTextInformationBlock
|
from backend.blocks.text import ExtractTextInformationBlock
|
||||||
from backend.blocks.xml_parser import XMLParserBlock
|
from backend.blocks.xml_parser import XMLParserBlock
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
@@ -233,9 +234,12 @@ class TestStoreMediaFileSecurity:
|
|||||||
|
|
||||||
with pytest.raises(ValueError, match="File too large"):
|
with pytest.raises(ValueError, match="File too large"):
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id="test",
|
|
||||||
file=MediaFileType(large_data_uri),
|
file=MediaFileType(large_data_uri),
|
||||||
user_id="test_user",
|
execution_context=ExecutionContext(
|
||||||
|
user_id="test_user",
|
||||||
|
graph_exec_id="test",
|
||||||
|
),
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("backend.util.file.Path")
|
@patch("backend.util.file.Path")
|
||||||
@@ -270,9 +274,12 @@ class TestStoreMediaFileSecurity:
|
|||||||
# Should raise an error when directory size exceeds limit
|
# Should raise an error when directory size exceeds limit
|
||||||
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
|
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id="test",
|
|
||||||
file=MediaFileType(
|
file=MediaFileType(
|
||||||
"data:text/plain;base64,dGVzdA=="
|
"data:text/plain;base64,dGVzdA=="
|
||||||
), # Small test file
|
), # Small test file
|
||||||
user_id="test_user",
|
execution_context=ExecutionContext(
|
||||||
|
user_id="test_user",
|
||||||
|
graph_exec_id="test",
|
||||||
|
),
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,10 +11,22 @@ from backend.blocks.http import (
|
|||||||
HttpMethod,
|
HttpMethod,
|
||||||
SendAuthenticatedWebRequestBlock,
|
SendAuthenticatedWebRequestBlock,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import HostScopedCredentials
|
from backend.data.model import HostScopedCredentials
|
||||||
from backend.util.request import Response
|
from backend.util.request import Response
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_context(
|
||||||
|
graph_exec_id: str = "test-exec-id",
|
||||||
|
user_id: str = "test-user-id",
|
||||||
|
) -> ExecutionContext:
|
||||||
|
"""Helper to create test ExecutionContext."""
|
||||||
|
return ExecutionContext(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestHttpBlockWithHostScopedCredentials:
|
class TestHttpBlockWithHostScopedCredentials:
|
||||||
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
||||||
|
|
||||||
@@ -105,8 +117,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=exact_match_credentials,
|
credentials=exact_match_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -161,8 +172,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=wildcard_credentials,
|
credentials=wildcard_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -208,8 +218,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=non_matching_credentials,
|
credentials=non_matching_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -258,8 +267,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=exact_match_credentials,
|
credentials=exact_match_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -318,8 +326,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=auto_discovered_creds, # Execution manager found these
|
credentials=auto_discovered_creds, # Execution manager found these
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -382,8 +389,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=multi_header_creds,
|
credentials=multi_header_creds,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -471,8 +477,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=test_creds,
|
credentials=test_creds,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util import json, text
|
from backend.util import json, text
|
||||||
from backend.util.file import get_exec_file_path, store_media_file
|
from backend.util.file import get_exec_file_path, store_media_file
|
||||||
@@ -444,18 +445,21 @@ class FileReadBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Store the media file properly (handles URLs, data URIs, etc.)
|
# Store the media file properly (handles URLs, data URIs, etc.)
|
||||||
stored_file_path = await store_media_file(
|
stored_file_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.file_input,
|
file=input_data.file_input,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get full file path
|
# Get full file path (graph_exec_id validated by store_media_file above)
|
||||||
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
if not execution_context.graph_exec_id:
|
||||||
|
raise ValueError("execution_context.graph_exec_id is required")
|
||||||
|
file_path = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, stored_file_path
|
||||||
|
)
|
||||||
|
|
||||||
if not Path(file_path).exists():
|
if not Path(file_path).exists():
|
||||||
raise ValueError(f"File does not exist: {file_path}")
|
raise ValueError(f"File does not exist: {file_path}")
|
||||||
|
|||||||
@@ -81,7 +81,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
|||||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
|
||||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||||
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
||||||
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
||||||
|
|||||||
@@ -83,12 +83,29 @@ class ExecutionContext(BaseModel):
|
|||||||
|
|
||||||
model_config = {"extra": "ignore"}
|
model_config = {"extra": "ignore"}
|
||||||
|
|
||||||
|
# Execution identity
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
graph_id: Optional[str] = None
|
||||||
|
graph_exec_id: Optional[str] = None
|
||||||
|
graph_version: Optional[int] = None
|
||||||
|
node_id: Optional[str] = None
|
||||||
|
node_exec_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Safety settings
|
||||||
human_in_the_loop_safe_mode: bool = True
|
human_in_the_loop_safe_mode: bool = True
|
||||||
sensitive_action_safe_mode: bool = False
|
sensitive_action_safe_mode: bool = False
|
||||||
|
|
||||||
|
# User settings
|
||||||
user_timezone: str = "UTC"
|
user_timezone: str = "UTC"
|
||||||
|
|
||||||
|
# Execution hierarchy
|
||||||
root_execution_id: Optional[str] = None
|
root_execution_id: Optional[str] = None
|
||||||
parent_execution_id: Optional[str] = None
|
parent_execution_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Workspace
|
||||||
|
workspace_id: Optional[str] = None
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# -------------------------- Models -------------------------- #
|
# -------------------------- Models -------------------------- #
|
||||||
|
|
||||||
|
|||||||
@@ -666,10 +666,16 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
if not (self.discriminator and self.discriminator_mapping):
|
if not (self.discriminator and self.discriminator_mapping):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider = self.discriminator_mapping[discriminator_value]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{discriminator_value}' is not supported. "
|
||||||
|
"It may have been deprecated. Please update your agent configuration."
|
||||||
|
)
|
||||||
|
|
||||||
return CredentialsFieldInfo(
|
return CredentialsFieldInfo(
|
||||||
credentials_provider=frozenset(
|
credentials_provider=frozenset([provider]),
|
||||||
[self.discriminator_mapping[discriminator_value]]
|
|
||||||
),
|
|
||||||
credentials_types=self.supported_types,
|
credentials_types=self.supported_types,
|
||||||
credentials_scopes=self.required_scopes,
|
credentials_scopes=self.required_scopes,
|
||||||
discriminator=self.discriminator,
|
discriminator=self.discriminator,
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ FrontendOnboardingStep = Literal[
|
|||||||
OnboardingStep.AGENT_NEW_RUN,
|
OnboardingStep.AGENT_NEW_RUN,
|
||||||
OnboardingStep.AGENT_INPUT,
|
OnboardingStep.AGENT_INPUT,
|
||||||
OnboardingStep.CONGRATS,
|
OnboardingStep.CONGRATS,
|
||||||
|
OnboardingStep.VISIT_COPILOT,
|
||||||
OnboardingStep.MARKETPLACE_VISIT,
|
OnboardingStep.MARKETPLACE_VISIT,
|
||||||
OnboardingStep.BUILDER_OPEN,
|
OnboardingStep.BUILDER_OPEN,
|
||||||
]
|
]
|
||||||
@@ -122,6 +123,9 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
|||||||
async def _reward_user(user_id: str, onboarding: UserOnboarding, step: OnboardingStep):
|
async def _reward_user(user_id: str, onboarding: UserOnboarding, step: OnboardingStep):
|
||||||
reward = 0
|
reward = 0
|
||||||
match step:
|
match step:
|
||||||
|
# Welcome bonus for visiting copilot ($5 = 500 credits)
|
||||||
|
case OnboardingStep.VISIT_COPILOT:
|
||||||
|
reward = 500
|
||||||
# Reward user when they clicked New Run during onboarding
|
# Reward user when they clicked New Run during onboarding
|
||||||
# This is because they need credits before scheduling a run (next step)
|
# This is because they need credits before scheduling a run (next step)
|
||||||
# This is seen as a reward for the GET_RESULTS step in the wallet
|
# This is seen as a reward for the GET_RESULTS step in the wallet
|
||||||
|
|||||||
276
autogpt_platform/backend/backend/data/workspace.py
Normal file
276
autogpt_platform/backend/backend/data/workspace.py
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
"""
|
||||||
|
Database CRUD operations for User Workspace.
|
||||||
|
|
||||||
|
This module provides functions for managing user workspaces and workspace files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from prisma.models import UserWorkspace, UserWorkspaceFile
|
||||||
|
from prisma.types import UserWorkspaceFileWhereInput
|
||||||
|
|
||||||
|
from backend.util.json import SafeJson
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
|
||||||
|
"""
|
||||||
|
Get user's workspace, creating one if it doesn't exist.
|
||||||
|
|
||||||
|
Uses upsert to handle race conditions when multiple concurrent requests
|
||||||
|
attempt to create a workspace for the same user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspace instance
|
||||||
|
"""
|
||||||
|
workspace = await UserWorkspace.prisma().upsert(
|
||||||
|
where={"userId": user_id},
|
||||||
|
data={
|
||||||
|
"create": {"userId": user_id},
|
||||||
|
"update": {}, # No updates needed if exists
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return workspace
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
|
||||||
|
"""
|
||||||
|
Get user's workspace if it exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspace instance or None
|
||||||
|
"""
|
||||||
|
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
||||||
|
|
||||||
|
|
||||||
|
async def create_workspace_file(
|
||||||
|
workspace_id: str,
|
||||||
|
file_id: str,
|
||||||
|
name: str,
|
||||||
|
path: str,
|
||||||
|
storage_path: str,
|
||||||
|
mime_type: str,
|
||||||
|
size_bytes: int,
|
||||||
|
checksum: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
) -> UserWorkspaceFile:
|
||||||
|
"""
|
||||||
|
Create a new workspace file record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
file_id: The file ID (same as used in storage path for consistency)
|
||||||
|
name: User-visible filename
|
||||||
|
path: Virtual path (e.g., "/documents/report.pdf")
|
||||||
|
storage_path: Actual storage path (GCS or local)
|
||||||
|
mime_type: MIME type of the file
|
||||||
|
size_bytes: File size in bytes
|
||||||
|
checksum: Optional SHA256 checksum
|
||||||
|
metadata: Optional additional metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created UserWorkspaceFile instance
|
||||||
|
"""
|
||||||
|
# Normalize path to start with /
|
||||||
|
if not path.startswith("/"):
|
||||||
|
path = f"/{path}"
|
||||||
|
|
||||||
|
file = await UserWorkspaceFile.prisma().create(
|
||||||
|
data={
|
||||||
|
"id": file_id,
|
||||||
|
"workspaceId": workspace_id,
|
||||||
|
"name": name,
|
||||||
|
"path": path,
|
||||||
|
"storagePath": storage_path,
|
||||||
|
"mimeType": mime_type,
|
||||||
|
"sizeBytes": size_bytes,
|
||||||
|
"checksum": checksum,
|
||||||
|
"metadata": SafeJson(metadata or {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created workspace file {file.id} at path {path} "
|
||||||
|
f"in workspace {workspace_id}"
|
||||||
|
)
|
||||||
|
return file
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_file(
|
||||||
|
file_id: str,
|
||||||
|
workspace_id: Optional[str] = None,
|
||||||
|
) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Get a workspace file by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file ID
|
||||||
|
workspace_id: Optional workspace ID for validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspaceFile instance or None
|
||||||
|
"""
|
||||||
|
where_clause: dict = {"id": file_id, "isDeleted": False}
|
||||||
|
if workspace_id:
|
||||||
|
where_clause["workspaceId"] = workspace_id
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_file_by_path(
|
||||||
|
workspace_id: str,
|
||||||
|
path: str,
|
||||||
|
) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Get a workspace file by its virtual path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
path: Virtual path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspaceFile instance or None
|
||||||
|
"""
|
||||||
|
# Normalize path
|
||||||
|
if not path.startswith("/"):
|
||||||
|
path = f"/{path}"
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().find_first(
|
||||||
|
where={
|
||||||
|
"workspaceId": workspace_id,
|
||||||
|
"path": path,
|
||||||
|
"isDeleted": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_workspace_files(
|
||||||
|
workspace_id: str,
|
||||||
|
path_prefix: Optional[str] = None,
|
||||||
|
include_deleted: bool = False,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
List files in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
path_prefix: Optional path prefix to filter (e.g., "/documents/")
|
||||||
|
include_deleted: Whether to include soft-deleted files
|
||||||
|
limit: Maximum number of files to return
|
||||||
|
offset: Number of files to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of UserWorkspaceFile instances
|
||||||
|
"""
|
||||||
|
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
|
||||||
|
|
||||||
|
if not include_deleted:
|
||||||
|
where_clause["isDeleted"] = False
|
||||||
|
|
||||||
|
if path_prefix:
|
||||||
|
# Normalize prefix
|
||||||
|
if not path_prefix.startswith("/"):
|
||||||
|
path_prefix = f"/{path_prefix}"
|
||||||
|
where_clause["path"] = {"startswith": path_prefix}
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().find_many(
|
||||||
|
where=where_clause,
|
||||||
|
order={"createdAt": "desc"},
|
||||||
|
take=limit,
|
||||||
|
skip=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def count_workspace_files(
|
||||||
|
workspace_id: str,
|
||||||
|
path_prefix: Optional[str] = None,
|
||||||
|
include_deleted: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Count files in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
path_prefix: Optional path prefix to filter (e.g., "/sessions/abc123/")
|
||||||
|
include_deleted: Whether to include soft-deleted files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of files
|
||||||
|
"""
|
||||||
|
where_clause: dict = {"workspaceId": workspace_id}
|
||||||
|
if not include_deleted:
|
||||||
|
where_clause["isDeleted"] = False
|
||||||
|
|
||||||
|
if path_prefix:
|
||||||
|
# Normalize prefix
|
||||||
|
if not path_prefix.startswith("/"):
|
||||||
|
path_prefix = f"/{path_prefix}"
|
||||||
|
where_clause["path"] = {"startswith": path_prefix}
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().count(where=where_clause)
|
||||||
|
|
||||||
|
|
||||||
|
async def soft_delete_workspace_file(
|
||||||
|
file_id: str,
|
||||||
|
workspace_id: Optional[str] = None,
|
||||||
|
) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Soft-delete a workspace file.
|
||||||
|
|
||||||
|
The path is modified to include a deletion timestamp to free up the original
|
||||||
|
path for new files while preserving the record for potential recovery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file ID
|
||||||
|
workspace_id: Optional workspace ID for validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated UserWorkspaceFile instance or None if not found
|
||||||
|
"""
|
||||||
|
# First verify the file exists and belongs to workspace
|
||||||
|
file = await get_workspace_file(file_id, workspace_id)
|
||||||
|
if file is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
deleted_at = datetime.now(timezone.utc)
|
||||||
|
# Modify path to free up the unique constraint for new files at original path
|
||||||
|
# Format: {original_path}__deleted__{timestamp}
|
||||||
|
deleted_path = f"{file.path}__deleted__{int(deleted_at.timestamp())}"
|
||||||
|
|
||||||
|
updated = await UserWorkspaceFile.prisma().update(
|
||||||
|
where={"id": file_id},
|
||||||
|
data={
|
||||||
|
"isDeleted": True,
|
||||||
|
"deletedAt": deleted_at,
|
||||||
|
"path": deleted_path,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Soft-deleted workspace file {file_id}")
|
||||||
|
return updated
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_total_size(workspace_id: str) -> int:
|
||||||
|
"""
|
||||||
|
Get the total size of all files in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total size in bytes
|
||||||
|
"""
|
||||||
|
files = await list_workspace_files(workspace_id)
|
||||||
|
return sum(file.sizeBytes for file in files)
|
||||||
@@ -236,7 +236,14 @@ async def execute_node(
|
|||||||
input_size = len(input_data_str)
|
input_size = len(input_data_str)
|
||||||
log_metadata.debug("Executed node with input", input=input_data_str)
|
log_metadata.debug("Executed node with input", input=input_data_str)
|
||||||
|
|
||||||
|
# Create node-specific execution context to avoid race conditions
|
||||||
|
# (multiple nodes can execute concurrently and would otherwise mutate shared state)
|
||||||
|
execution_context = execution_context.model_copy(
|
||||||
|
update={"node_id": node_id, "node_exec_id": node_exec_id}
|
||||||
|
)
|
||||||
|
|
||||||
# Inject extra execution arguments for the blocks via kwargs
|
# Inject extra execution arguments for the blocks via kwargs
|
||||||
|
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||||
extra_exec_kwargs: dict = {
|
extra_exec_kwargs: dict = {
|
||||||
"graph_id": graph_id,
|
"graph_id": graph_id,
|
||||||
"graph_version": graph_version,
|
"graph_version": graph_version,
|
||||||
|
|||||||
@@ -892,11 +892,19 @@ async def add_graph_execution(
|
|||||||
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
||||||
|
|
||||||
execution_context = ExecutionContext(
|
execution_context = ExecutionContext(
|
||||||
|
# Execution identity
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_exec_id=graph_exec.id,
|
||||||
|
graph_version=graph_exec.graph_version,
|
||||||
|
# Safety settings
|
||||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||||
|
# User settings
|
||||||
user_timezone=(
|
user_timezone=(
|
||||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||||
),
|
),
|
||||||
|
# Execution hierarchy
|
||||||
root_execution_id=graph_exec.id,
|
root_execution_id=graph_exec.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -348,6 +348,7 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
|||||||
mock_graph_exec.id = "execution-id-123"
|
mock_graph_exec.id = "execution-id-123"
|
||||||
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
||||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||||
|
mock_graph_exec.graph_version = graph_version
|
||||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||||
|
|
||||||
# Mock the queue and event bus
|
# Mock the queue and event bus
|
||||||
@@ -434,6 +435,9 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
|||||||
# Create a second mock execution for the sanity check
|
# Create a second mock execution for the sanity check
|
||||||
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||||
mock_graph_exec_2.id = "execution-id-456"
|
mock_graph_exec_2.id = "execution-id-456"
|
||||||
|
mock_graph_exec_2.node_executions = []
|
||||||
|
mock_graph_exec_2.status = ExecutionStatus.QUEUED
|
||||||
|
mock_graph_exec_2.graph_version = graph_version
|
||||||
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
|
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||||
|
|
||||||
# Reset mocks and set up for second call
|
# Reset mocks and set up for second call
|
||||||
@@ -614,6 +618,7 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
|||||||
mock_graph_exec.id = "execution-id-123"
|
mock_graph_exec.id = "execution-id-123"
|
||||||
mock_graph_exec.node_executions = []
|
mock_graph_exec.node_executions = []
|
||||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||||
|
mock_graph_exec.graph_version = graph_version
|
||||||
|
|
||||||
# Track what's passed to to_graph_execution_entry
|
# Track what's passed to to_graph_execution_entry
|
||||||
captured_kwargs = {}
|
captured_kwargs = {}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import aiohttp
|
|||||||
from gcloud.aio import storage as async_gcs_storage
|
from gcloud.aio import storage as async_gcs_storage
|
||||||
from google.cloud import storage as gcs_storage
|
from google.cloud import storage as gcs_storage
|
||||||
|
|
||||||
|
from backend.util.gcs_utils import download_with_fresh_session, generate_signed_url
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -251,7 +252,7 @@ class CloudStorageHandler:
|
|||||||
f"in_task: {current_task is not None}"
|
f"in_task: {current_task is not None}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse bucket and blob name from path
|
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
|
||||||
parts = path.split("/", 1)
|
parts = path.split("/", 1)
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
raise ValueError(f"Invalid GCS path: {path}")
|
raise ValueError(f"Invalid GCS path: {path}")
|
||||||
@@ -261,50 +262,19 @@ class CloudStorageHandler:
|
|||||||
# Authorization check
|
# Authorization check
|
||||||
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
||||||
|
|
||||||
# Use a fresh client for each download to avoid session issues
|
logger.info(
|
||||||
# This is less efficient but more reliable with the executor's event loop
|
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
|
||||||
logger.info("[CloudStorage] Creating fresh GCS client for download")
|
|
||||||
|
|
||||||
# Create a new session specifically for this download
|
|
||||||
session = aiohttp.ClientSession(
|
|
||||||
connector=aiohttp.TCPConnector(limit=10, force_close=True)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async_client = None
|
|
||||||
try:
|
try:
|
||||||
# Create a new GCS client with the fresh session
|
content = await download_with_fresh_session(bucket_name, blob_name)
|
||||||
async_client = async_gcs_storage.Storage(session=session)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Download content using the fresh client
|
|
||||||
content = await async_client.download(bucket_name, blob_name)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CloudStorage] GCS download successful - size: {len(content)} bytes"
|
f"[CloudStorage] GCS download successful - size: {len(content)} bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clean up
|
|
||||||
await async_client.close()
|
|
||||||
await session.close()
|
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Always try to clean up
|
|
||||||
if async_client is not None:
|
|
||||||
try:
|
|
||||||
await async_client.close()
|
|
||||||
except Exception as cleanup_error:
|
|
||||||
logger.warning(
|
|
||||||
f"[CloudStorage] Error closing GCS client: {cleanup_error}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await session.close()
|
|
||||||
except Exception as cleanup_error:
|
|
||||||
logger.warning(f"[CloudStorage] Error closing session: {cleanup_error}")
|
|
||||||
|
|
||||||
# Log the specific error for debugging
|
# Log the specific error for debugging
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[CloudStorage] GCS download failed - error: {str(e)}, "
|
f"[CloudStorage] GCS download failed - error: {str(e)}, "
|
||||||
@@ -319,10 +289,6 @@ class CloudStorageHandler:
|
|||||||
f"current_task: {current_task}, "
|
f"current_task: {current_task}, "
|
||||||
f"bucket: {bucket_name}, blob: redacted for privacy"
|
f"bucket: {bucket_name}, blob: redacted for privacy"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert gcloud-aio exceptions to standard ones
|
|
||||||
if "404" in str(e) or "Not Found" in str(e):
|
|
||||||
raise FileNotFoundError(f"File not found: gcs://{path}")
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _validate_file_access(
|
def _validate_file_access(
|
||||||
@@ -445,8 +411,7 @@ class CloudStorageHandler:
|
|||||||
graph_exec_id: str | None = None,
|
graph_exec_id: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate signed URL for GCS with authorization."""
|
"""Generate signed URL for GCS with authorization."""
|
||||||
|
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
|
||||||
# Parse bucket and blob name from path
|
|
||||||
parts = path.split("/", 1)
|
parts = path.split("/", 1)
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
raise ValueError(f"Invalid GCS path: {path}")
|
raise ValueError(f"Invalid GCS path: {path}")
|
||||||
@@ -456,21 +421,11 @@ class CloudStorageHandler:
|
|||||||
# Authorization check
|
# Authorization check
|
||||||
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
||||||
|
|
||||||
# Use sync client for signed URLs since gcloud-aio doesn't support them
|
|
||||||
sync_client = self._get_sync_gcs_client()
|
sync_client = self._get_sync_gcs_client()
|
||||||
bucket = sync_client.bucket(bucket_name)
|
return await generate_signed_url(
|
||||||
blob = bucket.blob(blob_name)
|
sync_client, bucket_name, blob_name, expiration_hours * 3600
|
||||||
|
|
||||||
# Generate signed URL asynchronously using sync client
|
|
||||||
url = await asyncio.to_thread(
|
|
||||||
blob.generate_signed_url,
|
|
||||||
version="v4",
|
|
||||||
expiration=datetime.now(timezone.utc) + timedelta(hours=expiration_hours),
|
|
||||||
method="GET",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return url
|
|
||||||
|
|
||||||
async def delete_expired_files(self, provider: str = "gcs") -> int:
|
async def delete_expired_files(self, provider: str = "gcs") -> int:
|
||||||
"""
|
"""
|
||||||
Delete files that have passed their expiration time.
|
Delete files that have passed their expiration time.
|
||||||
|
|||||||
@@ -135,6 +135,12 @@ class GraphValidationError(ValueError):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidInputError(ValueError):
|
||||||
|
"""Raised when user input validation fails (e.g., search term too long)"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DatabaseError(Exception):
|
class DatabaseError(Exception):
|
||||||
"""Raised when there is an error interacting with the database"""
|
"""Raised when there is an error interacting with the database"""
|
||||||
|
|
||||||
|
|||||||
@@ -5,13 +5,26 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Literal
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
from backend.util.settings import Config
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
|
||||||
|
# Return format options for store_media_file
|
||||||
|
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||||
|
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
||||||
|
# - "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
|
||||||
|
MediaReturnFormat = Literal[
|
||||||
|
"for_local_processing", "for_external_api", "for_block_output"
|
||||||
|
]
|
||||||
|
|
||||||
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
|
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
|
||||||
|
|
||||||
# Maximum filename length (conservative limit for most filesystems)
|
# Maximum filename length (conservative limit for most filesystems)
|
||||||
@@ -67,42 +80,56 @@ def clean_exec_files(graph_exec_id: str, file: str = "") -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def store_media_file(
|
async def store_media_file(
|
||||||
graph_exec_id: str,
|
|
||||||
file: MediaFileType,
|
file: MediaFileType,
|
||||||
user_id: str,
|
execution_context: "ExecutionContext",
|
||||||
return_content: bool = False,
|
*,
|
||||||
|
return_format: MediaReturnFormat,
|
||||||
) -> MediaFileType:
|
) -> MediaFileType:
|
||||||
"""
|
"""
|
||||||
Safely handle 'file' (a data URI, a URL, or a local path relative to {temp}/exec_file/{exec_id}),
|
Safely handle 'file' (a data URI, a URL, a workspace:// reference, or a local path
|
||||||
placing or verifying it under:
|
relative to {temp}/exec_file/{exec_id}), placing or verifying it under:
|
||||||
{tempdir}/exec_file/{exec_id}/...
|
{tempdir}/exec_file/{exec_id}/...
|
||||||
|
|
||||||
If 'return_content=True', return a data URI (data:<mime>;base64,<content>).
|
For each MediaFileType input:
|
||||||
Otherwise, returns the file media path relative to the exec_id folder.
|
- Data URI: decode and store locally
|
||||||
|
- URL: download and store locally
|
||||||
|
- workspace:// reference: read from workspace, store locally
|
||||||
|
- Local path: verify it exists in exec_file directory
|
||||||
|
|
||||||
For each MediaFileType type:
|
Return format options:
|
||||||
- Data URI:
|
- "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||||
-> decode and store in a new random file in that folder
|
- "for_external_api": Returns data URI (base64) - use when sending to external APIs
|
||||||
- URL:
|
- "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
|
||||||
-> download and store in that folder
|
|
||||||
- Local path:
|
|
||||||
-> interpret as relative to that folder; verify it exists
|
|
||||||
(no copying, as it's presumably already there).
|
|
||||||
We realpath-check so no symlink or '..' can escape the folder.
|
|
||||||
|
|
||||||
|
:param file: Data URI, URL, workspace://, or local (relative) path.
|
||||||
:param graph_exec_id: The unique ID of the graph execution.
|
:param execution_context: ExecutionContext with user_id, graph_exec_id, workspace_id.
|
||||||
:param file: Data URI, URL, or local (relative) path.
|
:param return_format: What to return: "for_local_processing", "for_external_api", or "for_block_output".
|
||||||
:param return_content: If True, return a data URI of the file content.
|
:return: The requested result based on return_format.
|
||||||
If False, return the *relative* path inside the exec_id folder.
|
|
||||||
:return: The requested result: data URI or relative path of the media.
|
|
||||||
"""
|
"""
|
||||||
|
# Extract values from execution_context
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
user_id = execution_context.user_id
|
||||||
|
|
||||||
|
if not graph_exec_id:
|
||||||
|
raise ValueError("execution_context.graph_exec_id is required")
|
||||||
|
if not user_id:
|
||||||
|
raise ValueError("execution_context.user_id is required")
|
||||||
|
|
||||||
|
# Create workspace_manager if we have workspace_id (with session scoping)
|
||||||
|
# Import here to avoid circular import (file.py → workspace.py → data → blocks → file.py)
|
||||||
|
from backend.util.workspace import WorkspaceManager
|
||||||
|
|
||||||
|
workspace_manager: WorkspaceManager | None = None
|
||||||
|
if execution_context.workspace_id:
|
||||||
|
workspace_manager = WorkspaceManager(
|
||||||
|
user_id, execution_context.workspace_id, execution_context.session_id
|
||||||
|
)
|
||||||
# Build base path
|
# Build base path
|
||||||
base_path = Path(get_exec_file_path(graph_exec_id, ""))
|
base_path = Path(get_exec_file_path(graph_exec_id, ""))
|
||||||
base_path.mkdir(parents=True, exist_ok=True)
|
base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Security fix: Add disk space limits to prevent DoS
|
# Security fix: Add disk space limits to prevent DoS
|
||||||
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB per file
|
MAX_FILE_SIZE_BYTES = Config().max_file_size_mb * 1024 * 1024
|
||||||
MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory
|
MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory
|
||||||
|
|
||||||
# Check total disk usage in base_path
|
# Check total disk usage in base_path
|
||||||
@@ -142,9 +169,57 @@ async def store_media_file(
|
|||||||
"""
|
"""
|
||||||
return str(absolute_path.relative_to(base))
|
return str(absolute_path.relative_to(base))
|
||||||
|
|
||||||
# Check if this is a cloud storage path
|
# Get cloud storage handler for checking cloud paths
|
||||||
cloud_storage = await get_cloud_storage_handler()
|
cloud_storage = await get_cloud_storage_handler()
|
||||||
if cloud_storage.is_cloud_path(file):
|
|
||||||
|
# Track if the input came from workspace (don't re-save it)
|
||||||
|
is_from_workspace = file.startswith("workspace://")
|
||||||
|
|
||||||
|
# Check if this is a workspace file reference
|
||||||
|
if is_from_workspace:
|
||||||
|
if workspace_manager is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Workspace file reference requires workspace context. "
|
||||||
|
"This file type is only available in CoPilot sessions."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse workspace reference
|
||||||
|
# workspace://abc123 - by file ID
|
||||||
|
# workspace:///path/to/file.txt - by virtual path
|
||||||
|
file_ref = file[12:] # Remove "workspace://"
|
||||||
|
|
||||||
|
if file_ref.startswith("/"):
|
||||||
|
# Path reference
|
||||||
|
workspace_content = await workspace_manager.read_file(file_ref)
|
||||||
|
file_info = await workspace_manager.get_file_info_by_path(file_ref)
|
||||||
|
filename = sanitize_filename(
|
||||||
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# ID reference
|
||||||
|
workspace_content = await workspace_manager.read_file_by_id(file_ref)
|
||||||
|
file_info = await workspace_manager.get_file_info(file_ref)
|
||||||
|
filename = sanitize_filename(
|
||||||
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||||
|
except OSError as e:
|
||||||
|
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||||
|
|
||||||
|
# Check file size limit
|
||||||
|
if len(workspace_content) > MAX_FILE_SIZE_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large: {len(workspace_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Virus scan the workspace content before writing locally
|
||||||
|
await scan_content_safe(workspace_content, filename=filename)
|
||||||
|
target_path.write_bytes(workspace_content)
|
||||||
|
|
||||||
|
# Check if this is a cloud storage path
|
||||||
|
elif cloud_storage.is_cloud_path(file):
|
||||||
# Download from cloud storage and store locally
|
# Download from cloud storage and store locally
|
||||||
cloud_content = await cloud_storage.retrieve_file(
|
cloud_content = await cloud_storage.retrieve_file(
|
||||||
file, user_id=user_id, graph_exec_id=graph_exec_id
|
file, user_id=user_id, graph_exec_id=graph_exec_id
|
||||||
@@ -159,9 +234,9 @@ async def store_media_file(
|
|||||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||||
|
|
||||||
# Check file size limit
|
# Check file size limit
|
||||||
if len(cloud_content) > MAX_FILE_SIZE:
|
if len(cloud_content) > MAX_FILE_SIZE_BYTES:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE} bytes"
|
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Virus scan the cloud content before writing locally
|
# Virus scan the cloud content before writing locally
|
||||||
@@ -189,9 +264,9 @@ async def store_media_file(
|
|||||||
content = base64.b64decode(b64_content)
|
content = base64.b64decode(b64_content)
|
||||||
|
|
||||||
# Check file size limit
|
# Check file size limit
|
||||||
if len(content) > MAX_FILE_SIZE:
|
if len(content) > MAX_FILE_SIZE_BYTES:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE} bytes"
|
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Virus scan the base64 content before writing
|
# Virus scan the base64 content before writing
|
||||||
@@ -199,23 +274,31 @@ async def store_media_file(
|
|||||||
target_path.write_bytes(content)
|
target_path.write_bytes(content)
|
||||||
|
|
||||||
elif file.startswith(("http://", "https://")):
|
elif file.startswith(("http://", "https://")):
|
||||||
# URL
|
# URL - download first to get Content-Type header
|
||||||
|
resp = await Requests().get(file)
|
||||||
|
|
||||||
|
# Check file size limit
|
||||||
|
if len(resp.content) > MAX_FILE_SIZE_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract filename from URL path
|
||||||
parsed_url = urlparse(file)
|
parsed_url = urlparse(file)
|
||||||
filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}")
|
filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}")
|
||||||
|
|
||||||
|
# If filename lacks extension, add one from Content-Type header
|
||||||
|
if "." not in filename:
|
||||||
|
content_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
|
||||||
|
if content_type:
|
||||||
|
ext = _extension_from_mime(content_type)
|
||||||
|
filename = f"{filename}{ext}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||||
|
|
||||||
# Download and save
|
|
||||||
resp = await Requests().get(file)
|
|
||||||
|
|
||||||
# Check file size limit
|
|
||||||
if len(resp.content) > MAX_FILE_SIZE:
|
|
||||||
raise ValueError(
|
|
||||||
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE} bytes"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Virus scan the downloaded content before writing
|
# Virus scan the downloaded content before writing
|
||||||
await scan_content_safe(resp.content, filename=filename)
|
await scan_content_safe(resp.content, filename=filename)
|
||||||
target_path.write_bytes(resp.content)
|
target_path.write_bytes(resp.content)
|
||||||
@@ -230,12 +313,44 @@ async def store_media_file(
|
|||||||
if not target_path.is_file():
|
if not target_path.is_file():
|
||||||
raise ValueError(f"Local file does not exist: {target_path}")
|
raise ValueError(f"Local file does not exist: {target_path}")
|
||||||
|
|
||||||
# Return result
|
# Return based on requested format
|
||||||
if return_content:
|
if return_format == "for_local_processing":
|
||||||
return MediaFileType(_file_to_data_uri(target_path))
|
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
||||||
else:
|
# Returns: relative path in exec_file directory (e.g., "image.png")
|
||||||
return MediaFileType(_strip_base_prefix(target_path, base_path))
|
return MediaFileType(_strip_base_prefix(target_path, base_path))
|
||||||
|
|
||||||
|
elif return_format == "for_external_api":
|
||||||
|
# Use when sending content to external APIs that need base64
|
||||||
|
# Returns: data URI (e.g., "data:image/png;base64,iVBORw0...")
|
||||||
|
return MediaFileType(_file_to_data_uri(target_path))
|
||||||
|
|
||||||
|
elif return_format == "for_block_output":
|
||||||
|
# Use when returning output from a block to user/next block
|
||||||
|
# Returns: workspace:// ref (CoPilot) or data URI (graph execution)
|
||||||
|
if workspace_manager is None:
|
||||||
|
# No workspace available (graph execution without CoPilot)
|
||||||
|
# Fallback to data URI so the content can still be used/displayed
|
||||||
|
return MediaFileType(_file_to_data_uri(target_path))
|
||||||
|
|
||||||
|
# Don't re-save if input was already from workspace
|
||||||
|
if is_from_workspace:
|
||||||
|
# Return original workspace reference
|
||||||
|
return MediaFileType(file)
|
||||||
|
|
||||||
|
# Save new content to workspace
|
||||||
|
content = target_path.read_bytes()
|
||||||
|
filename = target_path.name
|
||||||
|
|
||||||
|
file_record = await workspace_manager.write_file(
|
||||||
|
content=content,
|
||||||
|
filename=filename,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
return MediaFileType(f"workspace://{file_record.id}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid return_format: {return_format}")
|
||||||
|
|
||||||
|
|
||||||
def get_dir_size(path: Path) -> int:
|
def get_dir_size(path: Path) -> int:
|
||||||
"""Get total size of directory."""
|
"""Get total size of directory."""
|
||||||
|
|||||||
@@ -7,10 +7,22 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_context(
|
||||||
|
graph_exec_id: str = "test-exec-123",
|
||||||
|
user_id: str = "test-user-123",
|
||||||
|
) -> ExecutionContext:
|
||||||
|
"""Helper to create test ExecutionContext."""
|
||||||
|
return ExecutionContext(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestFileCloudIntegration:
|
class TestFileCloudIntegration:
|
||||||
"""Test cases for cloud storage integration in file utilities."""
|
"""Test cases for cloud storage integration in file utilities."""
|
||||||
|
|
||||||
@@ -70,10 +82,9 @@ class TestFileCloudIntegration:
|
|||||||
mock_path_class.side_effect = path_constructor
|
mock_path_class.side_effect = path_constructor
|
||||||
|
|
||||||
result = await store_media_file(
|
result = await store_media_file(
|
||||||
graph_exec_id,
|
file=MediaFileType(cloud_path),
|
||||||
MediaFileType(cloud_path),
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
"test-user-123",
|
return_format="for_local_processing",
|
||||||
return_content=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify cloud storage operations
|
# Verify cloud storage operations
|
||||||
@@ -144,10 +155,9 @@ class TestFileCloudIntegration:
|
|||||||
mock_path_obj.name = "image.png"
|
mock_path_obj.name = "image.png"
|
||||||
with patch("backend.util.file.Path", return_value=mock_path_obj):
|
with patch("backend.util.file.Path", return_value=mock_path_obj):
|
||||||
result = await store_media_file(
|
result = await store_media_file(
|
||||||
graph_exec_id,
|
file=MediaFileType(cloud_path),
|
||||||
MediaFileType(cloud_path),
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
"test-user-123",
|
return_format="for_external_api",
|
||||||
return_content=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify result is a data URI
|
# Verify result is a data URI
|
||||||
@@ -198,10 +208,9 @@ class TestFileCloudIntegration:
|
|||||||
mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt")
|
mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt")
|
||||||
|
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id,
|
file=MediaFileType(data_uri),
|
||||||
MediaFileType(data_uri),
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
"test-user-123",
|
return_format="for_local_processing",
|
||||||
return_content=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify cloud handler was checked but not used for retrieval
|
# Verify cloud handler was checked but not used for retrieval
|
||||||
@@ -234,5 +243,7 @@ class TestFileCloudIntegration:
|
|||||||
FileNotFoundError, match="File not found in cloud storage"
|
FileNotFoundError, match="File not found in cloud storage"
|
||||||
):
|
):
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id, MediaFileType(cloud_path), "test-user-123"
|
file=MediaFileType(cloud_path),
|
||||||
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|||||||
108
autogpt_platform/backend/backend/util/gcs_utils.py
Normal file
108
autogpt_platform/backend/backend/util/gcs_utils.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""
|
||||||
|
Shared GCS utilities for workspace and cloud storage backends.
|
||||||
|
|
||||||
|
This module provides common functionality for working with Google Cloud Storage,
|
||||||
|
including path parsing, client management, and signed URL generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from gcloud.aio import storage as async_gcs_storage
|
||||||
|
from google.cloud import storage as gcs_storage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_gcs_path(path: str) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Parse a GCS path in the format 'gcs://bucket/blob' to (bucket, blob).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: GCS path string (e.g., "gcs://my-bucket/path/to/file")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (bucket_name, blob_name)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the path format is invalid
|
||||||
|
"""
|
||||||
|
if not path.startswith("gcs://"):
|
||||||
|
raise ValueError(f"Invalid GCS path: {path}")
|
||||||
|
|
||||||
|
path_without_prefix = path[6:] # Remove "gcs://"
|
||||||
|
parts = path_without_prefix.split("/", 1)
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise ValueError(f"Invalid GCS path format: {path}")
|
||||||
|
|
||||||
|
return parts[0], parts[1]
|
||||||
|
|
||||||
|
|
||||||
|
async def download_with_fresh_session(bucket: str, blob: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Download file content using a fresh session.
|
||||||
|
|
||||||
|
This approach avoids event loop issues that can occur when reusing
|
||||||
|
sessions across different async contexts (e.g., in executors).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bucket: GCS bucket name
|
||||||
|
blob: Blob path within the bucket
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content as bytes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the file doesn't exist
|
||||||
|
"""
|
||||||
|
session = aiohttp.ClientSession(
|
||||||
|
connector=aiohttp.TCPConnector(limit=10, force_close=True)
|
||||||
|
)
|
||||||
|
client: async_gcs_storage.Storage | None = None
|
||||||
|
try:
|
||||||
|
client = async_gcs_storage.Storage(session=session)
|
||||||
|
content = await client.download(bucket, blob)
|
||||||
|
return content
|
||||||
|
except Exception as e:
|
||||||
|
if "404" in str(e) or "Not Found" in str(e):
|
||||||
|
raise FileNotFoundError(f"File not found: gcs://{bucket}/{blob}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if client:
|
||||||
|
try:
|
||||||
|
await client.close()
|
||||||
|
except Exception:
|
||||||
|
pass # Best-effort cleanup
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_signed_url(
|
||||||
|
sync_client: gcs_storage.Client,
|
||||||
|
bucket_name: str,
|
||||||
|
blob_name: str,
|
||||||
|
expires_in: int,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a signed URL for temporary access to a GCS file.
|
||||||
|
|
||||||
|
Uses asyncio.to_thread() to run the sync operation without blocking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sync_client: Sync GCS client with service account credentials
|
||||||
|
bucket_name: GCS bucket name
|
||||||
|
blob_name: Blob path within the bucket
|
||||||
|
expires_in: URL expiration time in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Signed URL string
|
||||||
|
"""
|
||||||
|
bucket = sync_client.bucket(bucket_name)
|
||||||
|
blob = bucket.blob(blob_name)
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
blob.generate_signed_url,
|
||||||
|
version="v4",
|
||||||
|
expiration=datetime.now(timezone.utc) + timedelta(seconds=expires_in),
|
||||||
|
method="GET",
|
||||||
|
)
|
||||||
@@ -263,6 +263,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
description="The name of the Google Cloud Storage bucket for media files",
|
description="The name of the Google Cloud Storage bucket for media files",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
workspace_storage_dir: str = Field(
|
||||||
|
default="",
|
||||||
|
description="Local directory for workspace file storage when GCS is not configured. "
|
||||||
|
"If empty, defaults to {app_data}/workspaces. Used for self-hosted deployments.",
|
||||||
|
)
|
||||||
|
|
||||||
reddit_user_agent: str = Field(
|
reddit_user_agent: str = Field(
|
||||||
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
|
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
|
||||||
description="The user agent for the Reddit API",
|
description="The user agent for the Reddit API",
|
||||||
@@ -359,8 +365,8 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
description="The port for the Agent Generator service",
|
description="The port for the Agent Generator service",
|
||||||
)
|
)
|
||||||
agentgenerator_timeout: int = Field(
|
agentgenerator_timeout: int = Field(
|
||||||
default=120,
|
default=600,
|
||||||
description="The timeout in seconds for Agent Generator service requests",
|
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
|
||||||
)
|
)
|
||||||
|
|
||||||
enable_example_blocks: bool = Field(
|
enable_example_blocks: bool = Field(
|
||||||
@@ -389,6 +395,13 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
description="Maximum file size in MB for file uploads (1-1024 MB)",
|
description="Maximum file size in MB for file uploads (1-1024 MB)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
max_file_size_mb: int = Field(
|
||||||
|
default=100,
|
||||||
|
ge=1,
|
||||||
|
le=1024,
|
||||||
|
description="Maximum file size in MB for workspace files (1-1024 MB)",
|
||||||
|
)
|
||||||
|
|
||||||
# AutoMod configuration
|
# AutoMod configuration
|
||||||
automod_enabled: bool = Field(
|
automod_enabled: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
|
|||||||
@@ -140,14 +140,29 @@ async def execute_block_test(block: Block):
|
|||||||
setattr(block, mock_name, mock_obj)
|
setattr(block, mock_name, mock_obj)
|
||||||
|
|
||||||
# Populate credentials argument(s)
|
# Populate credentials argument(s)
|
||||||
|
# Generate IDs for execution context
|
||||||
|
graph_id = str(uuid.uuid4())
|
||||||
|
node_id = str(uuid.uuid4())
|
||||||
|
graph_exec_id = str(uuid.uuid4())
|
||||||
|
node_exec_id = str(uuid.uuid4())
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
graph_version = 1 # Default version for tests
|
||||||
|
|
||||||
extra_exec_kwargs: dict = {
|
extra_exec_kwargs: dict = {
|
||||||
"graph_id": str(uuid.uuid4()),
|
"graph_id": graph_id,
|
||||||
"node_id": str(uuid.uuid4()),
|
"node_id": node_id,
|
||||||
"graph_exec_id": str(uuid.uuid4()),
|
"graph_exec_id": graph_exec_id,
|
||||||
"node_exec_id": str(uuid.uuid4()),
|
"node_exec_id": node_exec_id,
|
||||||
"user_id": str(uuid.uuid4()),
|
"user_id": user_id,
|
||||||
"graph_version": 1, # Default version for tests
|
"graph_version": graph_version,
|
||||||
"execution_context": ExecutionContext(),
|
"execution_context": ExecutionContext(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
graph_version=graph_version,
|
||||||
|
node_id=node_id,
|
||||||
|
node_exec_id=node_exec_id,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
input_model = cast(type[BlockSchema], block.input_schema)
|
input_model = cast(type[BlockSchema], block.input_schema)
|
||||||
|
|
||||||
|
|||||||
419
autogpt_platform/backend/backend/util/workspace.py
Normal file
419
autogpt_platform/backend/backend/util/workspace.py
Normal file
@@ -0,0 +1,419 @@
|
|||||||
|
"""
|
||||||
|
WorkspaceManager for managing user workspace file operations.
|
||||||
|
|
||||||
|
This module provides a high-level interface for workspace file operations,
|
||||||
|
combining the storage backend and database layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from prisma.errors import UniqueViolationError
|
||||||
|
from prisma.models import UserWorkspaceFile
|
||||||
|
|
||||||
|
from backend.data.workspace import (
|
||||||
|
count_workspace_files,
|
||||||
|
create_workspace_file,
|
||||||
|
get_workspace_file,
|
||||||
|
get_workspace_file_by_path,
|
||||||
|
list_workspace_files,
|
||||||
|
soft_delete_workspace_file,
|
||||||
|
)
|
||||||
|
from backend.util.settings import Config
|
||||||
|
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceManager:
|
||||||
|
"""
|
||||||
|
Manages workspace file operations.
|
||||||
|
|
||||||
|
Combines storage backend operations with database record management.
|
||||||
|
Supports session-scoped file segmentation where files are stored in
|
||||||
|
session-specific virtual paths: /sessions/{session_id}/{filename}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, user_id: str, workspace_id: str, session_id: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize WorkspaceManager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
session_id: Optional session ID for session-scoped file access
|
||||||
|
"""
|
||||||
|
self.user_id = user_id
|
||||||
|
self.workspace_id = workspace_id
|
||||||
|
self.session_id = session_id
|
||||||
|
# Session path prefix for file isolation
|
||||||
|
self.session_path = f"/sessions/{session_id}" if session_id else ""
|
||||||
|
|
||||||
|
def _resolve_path(self, path: str) -> str:
|
||||||
|
"""
|
||||||
|
Resolve a path, defaulting to session folder if session_id is set.
|
||||||
|
|
||||||
|
Cross-session access is allowed by explicitly using /sessions/other-session-id/...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Virtual path (e.g., "/file.txt" or "/sessions/abc123/file.txt")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved path with session prefix if applicable
|
||||||
|
"""
|
||||||
|
# If path explicitly references a session folder, use it as-is
|
||||||
|
if path.startswith("/sessions/"):
|
||||||
|
return path
|
||||||
|
|
||||||
|
# If we have a session context, prepend session path
|
||||||
|
if self.session_path:
|
||||||
|
# Normalize the path
|
||||||
|
if not path.startswith("/"):
|
||||||
|
path = f"/{path}"
|
||||||
|
return f"{self.session_path}{path}"
|
||||||
|
|
||||||
|
# No session context, use path as-is
|
||||||
|
return path if path.startswith("/") else f"/{path}"
|
||||||
|
|
||||||
|
def _get_effective_path(
|
||||||
|
self, path: Optional[str], include_all_sessions: bool
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get effective path for list/count operations based on session context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Optional path prefix to filter
|
||||||
|
include_all_sessions: If True, don't apply session scoping
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Effective path prefix for database query
|
||||||
|
"""
|
||||||
|
if include_all_sessions:
|
||||||
|
# Normalize path to ensure leading slash (stored paths are normalized)
|
||||||
|
if path is not None and not path.startswith("/"):
|
||||||
|
return f"/{path}"
|
||||||
|
return path
|
||||||
|
elif path is not None:
|
||||||
|
# Resolve the provided path with session scoping
|
||||||
|
return self._resolve_path(path)
|
||||||
|
elif self.session_path:
|
||||||
|
# Default to session folder with trailing slash to prevent prefix collisions
|
||||||
|
# e.g., "/sessions/abc" should not match "/sessions/abc123"
|
||||||
|
return self.session_path.rstrip("/") + "/"
|
||||||
|
else:
|
||||||
|
# No session context, use path as-is
|
||||||
|
return path
|
||||||
|
|
||||||
|
async def read_file(self, path: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Read file from workspace by virtual path.
|
||||||
|
|
||||||
|
When session_id is set, paths are resolved relative to the session folder
|
||||||
|
unless they explicitly reference /sessions/...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Virtual path (e.g., "/documents/report.pdf")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content as bytes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If file doesn't exist
|
||||||
|
"""
|
||||||
|
resolved_path = self._resolve_path(path)
|
||||||
|
file = await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||||
|
if file is None:
|
||||||
|
raise FileNotFoundError(f"File not found at path: {resolved_path}")
|
||||||
|
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
return await storage.retrieve(file.storagePath)
|
||||||
|
|
||||||
|
async def read_file_by_id(self, file_id: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Read file from workspace by file ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content as bytes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If file doesn't exist
|
||||||
|
"""
|
||||||
|
file = await get_workspace_file(file_id, self.workspace_id)
|
||||||
|
if file is None:
|
||||||
|
raise FileNotFoundError(f"File not found: {file_id}")
|
||||||
|
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
return await storage.retrieve(file.storagePath)
|
||||||
|
|
||||||
|
async def write_file(
|
||||||
|
self,
|
||||||
|
content: bytes,
|
||||||
|
filename: str,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
mime_type: Optional[str] = None,
|
||||||
|
overwrite: bool = False,
|
||||||
|
) -> UserWorkspaceFile:
|
||||||
|
"""
|
||||||
|
Write file to workspace.
|
||||||
|
|
||||||
|
When session_id is set, files are written to /sessions/{session_id}/...
|
||||||
|
by default. Use explicit /sessions/... paths for cross-session access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: File content as bytes
|
||||||
|
filename: Filename for the file
|
||||||
|
path: Virtual path (defaults to "/{filename}", session-scoped if session_id set)
|
||||||
|
mime_type: MIME type (auto-detected if not provided)
|
||||||
|
overwrite: Whether to overwrite existing file at path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created UserWorkspaceFile instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If file exceeds size limit or path already exists
|
||||||
|
"""
|
||||||
|
# Enforce file size limit
|
||||||
|
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||||
|
if len(content) > max_file_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large: {len(content)} bytes exceeds "
|
||||||
|
f"{Config().max_file_size_mb}MB limit"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine path with session scoping
|
||||||
|
if path is None:
|
||||||
|
path = f"/{filename}"
|
||||||
|
elif not path.startswith("/"):
|
||||||
|
path = f"/{path}"
|
||||||
|
|
||||||
|
# Resolve path with session prefix
|
||||||
|
path = self._resolve_path(path)
|
||||||
|
|
||||||
|
# Check if file exists at path (only error for non-overwrite case)
|
||||||
|
# For overwrite=True, we let the write proceed and handle via UniqueViolationError
|
||||||
|
# This ensures the new file is written to storage BEFORE the old one is deleted,
|
||||||
|
# preventing data loss if the new write fails
|
||||||
|
if not overwrite:
|
||||||
|
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
||||||
|
if existing is not None:
|
||||||
|
raise ValueError(f"File already exists at path: {path}")
|
||||||
|
|
||||||
|
# Auto-detect MIME type if not provided
|
||||||
|
if mime_type is None:
|
||||||
|
mime_type, _ = mimetypes.guess_type(filename)
|
||||||
|
mime_type = mime_type or "application/octet-stream"
|
||||||
|
|
||||||
|
# Compute checksum
|
||||||
|
checksum = compute_file_checksum(content)
|
||||||
|
|
||||||
|
# Generate unique file ID for storage
|
||||||
|
file_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Store file in storage backend
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
storage_path = await storage.store(
|
||||||
|
workspace_id=self.workspace_id,
|
||||||
|
file_id=file_id,
|
||||||
|
filename=filename,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create database record - handle race condition where another request
|
||||||
|
# created a file at the same path between our check and create
|
||||||
|
try:
|
||||||
|
file = await create_workspace_file(
|
||||||
|
workspace_id=self.workspace_id,
|
||||||
|
file_id=file_id,
|
||||||
|
name=filename,
|
||||||
|
path=path,
|
||||||
|
storage_path=storage_path,
|
||||||
|
mime_type=mime_type,
|
||||||
|
size_bytes=len(content),
|
||||||
|
checksum=checksum,
|
||||||
|
)
|
||||||
|
except UniqueViolationError:
|
||||||
|
# Race condition: another request created a file at this path
|
||||||
|
if overwrite:
|
||||||
|
# Re-fetch and delete the conflicting file, then retry
|
||||||
|
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
||||||
|
if existing:
|
||||||
|
await self.delete_file(existing.id)
|
||||||
|
# Retry the create - if this also fails, clean up storage file
|
||||||
|
try:
|
||||||
|
file = await create_workspace_file(
|
||||||
|
workspace_id=self.workspace_id,
|
||||||
|
file_id=file_id,
|
||||||
|
name=filename,
|
||||||
|
path=path,
|
||||||
|
storage_path=storage_path,
|
||||||
|
mime_type=mime_type,
|
||||||
|
size_bytes=len(content),
|
||||||
|
checksum=checksum,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Clean up orphaned storage file on retry failure
|
||||||
|
try:
|
||||||
|
await storage.delete(storage_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# Clean up the orphaned storage file before raising
|
||||||
|
try:
|
||||||
|
await storage.delete(storage_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||||
|
raise ValueError(f"File already exists at path: {path}")
|
||||||
|
except Exception:
|
||||||
|
# Any other database error (connection, validation, etc.) - clean up storage
|
||||||
|
try:
|
||||||
|
await storage.delete(storage_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Wrote file {file.id} ({filename}) to workspace {self.workspace_id} "
|
||||||
|
f"at path {path}, size={len(content)} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
return file
|
||||||
|
|
||||||
|
async def list_files(
|
||||||
|
self,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
include_all_sessions: bool = False,
|
||||||
|
) -> list[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
List files in workspace.
|
||||||
|
|
||||||
|
When session_id is set and include_all_sessions is False (default),
|
||||||
|
only files in the current session's folder are listed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Optional path prefix to filter (e.g., "/documents/")
|
||||||
|
limit: Maximum number of files to return
|
||||||
|
offset: Number of files to skip
|
||||||
|
include_all_sessions: If True, list files from all sessions.
|
||||||
|
If False (default), only list current session's files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of UserWorkspaceFile instances
|
||||||
|
"""
|
||||||
|
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||||
|
|
||||||
|
return await list_workspace_files(
|
||||||
|
workspace_id=self.workspace_id,
|
||||||
|
path_prefix=effective_path,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_file(self, file_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a file (soft-delete).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False if not found
|
||||||
|
"""
|
||||||
|
file = await get_workspace_file(file_id, self.workspace_id)
|
||||||
|
if file is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Delete from storage
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
try:
|
||||||
|
await storage.delete(file.storagePath)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete file from storage: {e}")
|
||||||
|
# Continue with database soft-delete even if storage delete fails
|
||||||
|
|
||||||
|
# Soft-delete database record
|
||||||
|
result = await soft_delete_workspace_file(file_id, self.workspace_id)
|
||||||
|
return result is not None
|
||||||
|
|
||||||
|
async def get_download_url(self, file_id: str, expires_in: int = 3600) -> str:
|
||||||
|
"""
|
||||||
|
Get download URL for a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file's ID
|
||||||
|
expires_in: URL expiration in seconds (default 1 hour)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Download URL (signed URL for GCS, API endpoint for local)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If file doesn't exist
|
||||||
|
"""
|
||||||
|
file = await get_workspace_file(file_id, self.workspace_id)
|
||||||
|
if file is None:
|
||||||
|
raise FileNotFoundError(f"File not found: {file_id}")
|
||||||
|
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
return await storage.get_download_url(file.storagePath, expires_in)
|
||||||
|
|
||||||
|
async def get_file_info(self, file_id: str) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Get file metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspaceFile instance or None
|
||||||
|
"""
|
||||||
|
return await get_workspace_file(file_id, self.workspace_id)
|
||||||
|
|
||||||
|
async def get_file_info_by_path(self, path: str) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Get file metadata by path.
|
||||||
|
|
||||||
|
When session_id is set, paths are resolved relative to the session folder
|
||||||
|
unless they explicitly reference /sessions/...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Virtual path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspaceFile instance or None
|
||||||
|
"""
|
||||||
|
resolved_path = self._resolve_path(path)
|
||||||
|
return await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||||
|
|
||||||
|
async def get_file_count(
|
||||||
|
self,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
include_all_sessions: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Get number of files in workspace.
|
||||||
|
|
||||||
|
When session_id is set and include_all_sessions is False (default),
|
||||||
|
only counts files in the current session's folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Optional path prefix to filter (e.g., "/documents/")
|
||||||
|
include_all_sessions: If True, count all files in workspace.
|
||||||
|
If False (default), only count current session's files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of files
|
||||||
|
"""
|
||||||
|
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||||
|
|
||||||
|
return await count_workspace_files(
|
||||||
|
self.workspace_id, path_prefix=effective_path
|
||||||
|
)
|
||||||
398
autogpt_platform/backend/backend/util/workspace_storage.py
Normal file
398
autogpt_platform/backend/backend/util/workspace_storage.py
Normal file
@@ -0,0 +1,398 @@
|
|||||||
|
"""
|
||||||
|
Workspace storage backend abstraction for supporting both cloud and local deployments.
|
||||||
|
|
||||||
|
This module provides a unified interface for storing workspace files, with implementations
|
||||||
|
for Google Cloud Storage (cloud deployments) and local filesystem (self-hosted deployments).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import aiohttp
|
||||||
|
from gcloud.aio import storage as async_gcs_storage
|
||||||
|
from google.cloud import storage as gcs_storage
|
||||||
|
|
||||||
|
from backend.util.data import get_data_path
|
||||||
|
from backend.util.gcs_utils import (
|
||||||
|
download_with_fresh_session,
|
||||||
|
generate_signed_url,
|
||||||
|
parse_gcs_path,
|
||||||
|
)
|
||||||
|
from backend.util.settings import Config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceStorageBackend(ABC):
|
||||||
|
"""Abstract interface for workspace file storage."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def store(
|
||||||
|
self,
|
||||||
|
workspace_id: str,
|
||||||
|
file_id: str,
|
||||||
|
filename: str,
|
||||||
|
content: bytes,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Store file content, return storage path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
file_id: Unique file ID for storage
|
||||||
|
filename: Original filename
|
||||||
|
content: File content as bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path string (cloud path or local path)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def retrieve(self, storage_path: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Retrieve file content from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path: The storage path returned from store()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content as bytes
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete(self, storage_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete file from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path: The storage path to delete
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
||||||
|
"""
|
||||||
|
Get URL for downloading the file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path: The storage path
|
||||||
|
expires_in: URL expiration time in seconds (default 1 hour)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Download URL (signed URL for GCS, direct API path for local)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GCSWorkspaceStorage(WorkspaceStorageBackend):
|
||||||
|
"""Google Cloud Storage implementation for workspace storage."""
|
||||||
|
|
||||||
|
def __init__(self, bucket_name: str):
|
||||||
|
self.bucket_name = bucket_name
|
||||||
|
self._async_client: Optional[async_gcs_storage.Storage] = None
|
||||||
|
self._sync_client: Optional[gcs_storage.Client] = None
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
|
||||||
|
async def _get_async_client(self) -> async_gcs_storage.Storage:
|
||||||
|
"""Get or create async GCS client."""
|
||||||
|
if self._async_client is None:
|
||||||
|
self._session = aiohttp.ClientSession(
|
||||||
|
connector=aiohttp.TCPConnector(limit=100, force_close=False)
|
||||||
|
)
|
||||||
|
self._async_client = async_gcs_storage.Storage(session=self._session)
|
||||||
|
return self._async_client
|
||||||
|
|
||||||
|
def _get_sync_client(self) -> gcs_storage.Client:
|
||||||
|
"""Get or create sync GCS client (for signed URLs)."""
|
||||||
|
if self._sync_client is None:
|
||||||
|
self._sync_client = gcs_storage.Client()
|
||||||
|
return self._sync_client
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close all client connections."""
|
||||||
|
if self._async_client is not None:
|
||||||
|
try:
|
||||||
|
await self._async_client.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing GCS client: {e}")
|
||||||
|
self._async_client = None
|
||||||
|
|
||||||
|
if self._session is not None:
|
||||||
|
try:
|
||||||
|
await self._session.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing session: {e}")
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
def _build_blob_name(self, workspace_id: str, file_id: str, filename: str) -> str:
|
||||||
|
"""Build the blob path for workspace files."""
|
||||||
|
return f"workspaces/{workspace_id}/{file_id}/{filename}"
|
||||||
|
|
||||||
|
async def store(
|
||||||
|
self,
|
||||||
|
workspace_id: str,
|
||||||
|
file_id: str,
|
||||||
|
filename: str,
|
||||||
|
content: bytes,
|
||||||
|
) -> str:
|
||||||
|
"""Store file in GCS."""
|
||||||
|
client = await self._get_async_client()
|
||||||
|
blob_name = self._build_blob_name(workspace_id, file_id, filename)
|
||||||
|
|
||||||
|
# Upload with metadata
|
||||||
|
upload_time = datetime.now(timezone.utc)
|
||||||
|
await client.upload(
|
||||||
|
self.bucket_name,
|
||||||
|
blob_name,
|
||||||
|
content,
|
||||||
|
metadata={
|
||||||
|
"uploaded_at": upload_time.isoformat(),
|
||||||
|
"workspace_id": workspace_id,
|
||||||
|
"file_id": file_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"gcs://{self.bucket_name}/{blob_name}"
|
||||||
|
|
||||||
|
async def retrieve(self, storage_path: str) -> bytes:
|
||||||
|
"""Retrieve file from GCS."""
|
||||||
|
bucket_name, blob_name = parse_gcs_path(storage_path)
|
||||||
|
return await download_with_fresh_session(bucket_name, blob_name)
|
||||||
|
|
||||||
|
async def delete(self, storage_path: str) -> None:
|
||||||
|
"""Delete file from GCS."""
|
||||||
|
bucket_name, blob_name = parse_gcs_path(storage_path)
|
||||||
|
client = await self._get_async_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.delete(bucket_name, blob_name)
|
||||||
|
except Exception as e:
|
||||||
|
if "404" not in str(e) and "Not Found" not in str(e):
|
||||||
|
raise
|
||||||
|
# File already deleted, that's fine
|
||||||
|
|
||||||
|
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
||||||
|
"""
|
||||||
|
Generate download URL for GCS file.
|
||||||
|
|
||||||
|
Attempts to generate a signed URL if running with service account credentials.
|
||||||
|
Falls back to an API proxy endpoint if signed URL generation fails
|
||||||
|
(e.g., when running locally with user OAuth credentials).
|
||||||
|
"""
|
||||||
|
bucket_name, blob_name = parse_gcs_path(storage_path)
|
||||||
|
|
||||||
|
# Extract file_id from blob_name for fallback: workspaces/{workspace_id}/{file_id}/{filename}
|
||||||
|
blob_parts = blob_name.split("/")
|
||||||
|
file_id = blob_parts[2] if len(blob_parts) >= 3 else None
|
||||||
|
|
||||||
|
# Try to generate signed URL (requires service account credentials)
|
||||||
|
try:
|
||||||
|
sync_client = self._get_sync_client()
|
||||||
|
return await generate_signed_url(
|
||||||
|
sync_client, bucket_name, blob_name, expires_in
|
||||||
|
)
|
||||||
|
except AttributeError as e:
|
||||||
|
# Signed URL generation requires service account with private key.
|
||||||
|
# When running with user OAuth credentials, fall back to API proxy.
|
||||||
|
if "private key" in str(e) and file_id:
|
||||||
|
logger.debug(
|
||||||
|
"Cannot generate signed URL (no service account credentials), "
|
||||||
|
"falling back to API proxy endpoint"
|
||||||
|
)
|
||||||
|
return f"/api/workspace/files/{file_id}/download"
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class LocalWorkspaceStorage(WorkspaceStorageBackend):
|
||||||
|
"""Local filesystem implementation for workspace storage (self-hosted deployments)."""
|
||||||
|
|
||||||
|
def __init__(self, base_dir: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize local storage backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_dir: Base directory for workspace storage.
|
||||||
|
If None, defaults to {app_data}/workspaces
|
||||||
|
"""
|
||||||
|
if base_dir:
|
||||||
|
self.base_dir = Path(base_dir)
|
||||||
|
else:
|
||||||
|
self.base_dir = Path(get_data_path()) / "workspaces"
|
||||||
|
|
||||||
|
# Ensure base directory exists
|
||||||
|
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def _build_file_path(self, workspace_id: str, file_id: str, filename: str) -> Path:
|
||||||
|
"""Build the local file path with path traversal protection."""
|
||||||
|
# Import here to avoid circular import
|
||||||
|
# (file.py imports workspace.py which imports workspace_storage.py)
|
||||||
|
from backend.util.file import sanitize_filename
|
||||||
|
|
||||||
|
# Sanitize filename to prevent path traversal (removes / and \ among others)
|
||||||
|
safe_filename = sanitize_filename(filename)
|
||||||
|
file_path = (self.base_dir / workspace_id / file_id / safe_filename).resolve()
|
||||||
|
|
||||||
|
# Verify the resolved path is still under base_dir
|
||||||
|
if not file_path.is_relative_to(self.base_dir.resolve()):
|
||||||
|
raise ValueError("Invalid filename: path traversal detected")
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
def _parse_storage_path(self, storage_path: str) -> Path:
|
||||||
|
"""Parse local storage path to filesystem path."""
|
||||||
|
if storage_path.startswith("local://"):
|
||||||
|
relative_path = storage_path[8:] # Remove "local://"
|
||||||
|
else:
|
||||||
|
relative_path = storage_path
|
||||||
|
|
||||||
|
full_path = (self.base_dir / relative_path).resolve()
|
||||||
|
|
||||||
|
# Security check: ensure path is under base_dir
|
||||||
|
# Use is_relative_to() for robust path containment check
|
||||||
|
# (handles case-insensitive filesystems and edge cases)
|
||||||
|
if not full_path.is_relative_to(self.base_dir.resolve()):
|
||||||
|
raise ValueError("Invalid storage path: path traversal detected")
|
||||||
|
|
||||||
|
return full_path
|
||||||
|
|
||||||
|
async def store(
|
||||||
|
self,
|
||||||
|
workspace_id: str,
|
||||||
|
file_id: str,
|
||||||
|
filename: str,
|
||||||
|
content: bytes,
|
||||||
|
) -> str:
|
||||||
|
"""Store file locally."""
|
||||||
|
file_path = self._build_file_path(workspace_id, file_id, filename)
|
||||||
|
|
||||||
|
# Create parent directories
|
||||||
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Write file asynchronously
|
||||||
|
async with aiofiles.open(file_path, "wb") as f:
|
||||||
|
await f.write(content)
|
||||||
|
|
||||||
|
# Return relative path as storage path
|
||||||
|
relative_path = file_path.relative_to(self.base_dir)
|
||||||
|
return f"local://{relative_path}"
|
||||||
|
|
||||||
|
async def retrieve(self, storage_path: str) -> bytes:
|
||||||
|
"""Retrieve file from local storage."""
|
||||||
|
file_path = self._parse_storage_path(storage_path)
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {storage_path}")
|
||||||
|
|
||||||
|
async with aiofiles.open(file_path, "rb") as f:
|
||||||
|
return await f.read()
|
||||||
|
|
||||||
|
async def delete(self, storage_path: str) -> None:
|
||||||
|
"""Delete file from local storage."""
|
||||||
|
file_path = self._parse_storage_path(storage_path)
|
||||||
|
|
||||||
|
if file_path.exists():
|
||||||
|
# Remove file
|
||||||
|
file_path.unlink()
|
||||||
|
|
||||||
|
# Clean up empty parent directories
|
||||||
|
parent = file_path.parent
|
||||||
|
while parent != self.base_dir:
|
||||||
|
try:
|
||||||
|
if parent.exists() and not any(parent.iterdir()):
|
||||||
|
parent.rmdir()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
except OSError:
|
||||||
|
break
|
||||||
|
parent = parent.parent
|
||||||
|
|
||||||
|
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
||||||
|
"""
|
||||||
|
Get download URL for local file.
|
||||||
|
|
||||||
|
For local storage, this returns an API endpoint path.
|
||||||
|
The actual serving is handled by the API layer.
|
||||||
|
"""
|
||||||
|
# Parse the storage path to get the components
|
||||||
|
if storage_path.startswith("local://"):
|
||||||
|
relative_path = storage_path[8:]
|
||||||
|
else:
|
||||||
|
relative_path = storage_path
|
||||||
|
|
||||||
|
# Return the API endpoint for downloading
|
||||||
|
# The file_id is extracted from the path: {workspace_id}/{file_id}/{filename}
|
||||||
|
parts = relative_path.split("/")
|
||||||
|
if len(parts) >= 2:
|
||||||
|
file_id = parts[1] # Second component is file_id
|
||||||
|
return f"/api/workspace/files/{file_id}/download"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid storage path format: {storage_path}")
|
||||||
|
|
||||||
|
|
||||||
|
# Global storage backend instance
|
||||||
|
_workspace_storage: Optional[WorkspaceStorageBackend] = None
|
||||||
|
_storage_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_storage() -> WorkspaceStorageBackend:
|
||||||
|
"""
|
||||||
|
Get the workspace storage backend instance.
|
||||||
|
|
||||||
|
Uses GCS if media_gcs_bucket_name is configured, otherwise uses local storage.
|
||||||
|
"""
|
||||||
|
global _workspace_storage
|
||||||
|
|
||||||
|
if _workspace_storage is None:
|
||||||
|
async with _storage_lock:
|
||||||
|
if _workspace_storage is None:
|
||||||
|
config = Config()
|
||||||
|
|
||||||
|
if config.media_gcs_bucket_name:
|
||||||
|
logger.info(
|
||||||
|
f"Using GCS workspace storage: {config.media_gcs_bucket_name}"
|
||||||
|
)
|
||||||
|
_workspace_storage = GCSWorkspaceStorage(
|
||||||
|
config.media_gcs_bucket_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
storage_dir = (
|
||||||
|
config.workspace_storage_dir
|
||||||
|
if config.workspace_storage_dir
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Using local workspace storage: {storage_dir or 'default'}"
|
||||||
|
)
|
||||||
|
_workspace_storage = LocalWorkspaceStorage(storage_dir)
|
||||||
|
|
||||||
|
return _workspace_storage
|
||||||
|
|
||||||
|
|
||||||
|
async def shutdown_workspace_storage() -> None:
|
||||||
|
"""
|
||||||
|
Properly shutdown the global workspace storage backend.
|
||||||
|
|
||||||
|
Closes aiohttp sessions and other resources for GCS backend.
|
||||||
|
Should be called during application shutdown.
|
||||||
|
"""
|
||||||
|
global _workspace_storage
|
||||||
|
|
||||||
|
if _workspace_storage is not None:
|
||||||
|
async with _storage_lock:
|
||||||
|
if _workspace_storage is not None:
|
||||||
|
if isinstance(_workspace_storage, GCSWorkspaceStorage):
|
||||||
|
await _workspace_storage.close()
|
||||||
|
_workspace_storage = None
|
||||||
|
|
||||||
|
|
||||||
|
def compute_file_checksum(content: bytes) -> str:
|
||||||
|
"""Compute SHA256 checksum of file content."""
|
||||||
|
return hashlib.sha256(content).hexdigest()
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
-- Migrate Claude 3.7 Sonnet to Claude 4.5 Sonnet
|
||||||
|
-- This updates all AgentNode blocks that use the deprecated Claude 3.7 Sonnet model
|
||||||
|
-- Anthropic is retiring claude-3-7-sonnet-20250219 on February 19, 2026
|
||||||
|
|
||||||
|
-- Update AgentNode constant inputs
|
||||||
|
UPDATE "AgentNode"
|
||||||
|
SET "constantInput" = JSONB_SET(
|
||||||
|
"constantInput"::jsonb,
|
||||||
|
'{model}',
|
||||||
|
'"claude-sonnet-4-5-20250929"'::jsonb
|
||||||
|
)
|
||||||
|
WHERE "constantInput"::jsonb->>'model' = 'claude-3-7-sonnet-20250219';
|
||||||
|
|
||||||
|
-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput)
|
||||||
|
UPDATE "AgentNodeExecutionInputOutput"
|
||||||
|
SET "data" = JSONB_SET(
|
||||||
|
"data"::jsonb,
|
||||||
|
'{model}',
|
||||||
|
'"claude-sonnet-4-5-20250929"'::jsonb
|
||||||
|
)
|
||||||
|
WHERE "agentPresetId" IS NOT NULL
|
||||||
|
AND "data"::jsonb->>'model' = 'claude-3-7-sonnet-20250219';
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
-- AlterEnum
|
||||||
|
ALTER TYPE "OnboardingStep" ADD VALUE 'VISIT_COPILOT';
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
-- CreateEnum
|
||||||
|
CREATE TYPE "WorkspaceFileSource" AS ENUM ('UPLOAD', 'EXECUTION', 'COPILOT', 'IMPORT');
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "UserWorkspace" (
|
||||||
|
"id" TEXT NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
"userId" TEXT NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT "UserWorkspace_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "UserWorkspaceFile" (
|
||||||
|
"id" TEXT NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
"workspaceId" TEXT NOT NULL,
|
||||||
|
"name" TEXT NOT NULL,
|
||||||
|
"path" TEXT NOT NULL,
|
||||||
|
"storagePath" TEXT NOT NULL,
|
||||||
|
"mimeType" TEXT NOT NULL,
|
||||||
|
"sizeBytes" BIGINT NOT NULL,
|
||||||
|
"checksum" TEXT,
|
||||||
|
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
"deletedAt" TIMESTAMP(3),
|
||||||
|
"source" "WorkspaceFileSource" NOT NULL DEFAULT 'UPLOAD',
|
||||||
|
"sourceExecId" TEXT,
|
||||||
|
"sourceSessionId" TEXT,
|
||||||
|
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||||
|
|
||||||
|
CONSTRAINT "UserWorkspaceFile_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "UserWorkspace_userId_key" ON "UserWorkspace"("userId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "UserWorkspace_userId_idx" ON "UserWorkspace"("userId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "UserWorkspaceFile_workspaceId_isDeleted_idx" ON "UserWorkspaceFile"("workspaceId", "isDeleted");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "UserWorkspaceFile_workspaceId_path_key" ON "UserWorkspaceFile"("workspaceId", "path");
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "UserWorkspace" ADD CONSTRAINT "UserWorkspace_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "UserWorkspaceFile" ADD CONSTRAINT "UserWorkspaceFile_workspaceId_fkey" FOREIGN KEY ("workspaceId") REFERENCES "UserWorkspace"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to drop the column `source` on the `UserWorkspaceFile` table. All the data in the column will be lost.
|
||||||
|
- You are about to drop the column `sourceExecId` on the `UserWorkspaceFile` table. All the data in the column will be lost.
|
||||||
|
- You are about to drop the column `sourceSessionId` on the `UserWorkspaceFile` table. All the data in the column will be lost.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "UserWorkspaceFile" DROP COLUMN "source",
|
||||||
|
DROP COLUMN "sourceExecId",
|
||||||
|
DROP COLUMN "sourceSessionId";
|
||||||
|
|
||||||
|
-- DropEnum
|
||||||
|
DROP TYPE "WorkspaceFileSource";
|
||||||
@@ -63,6 +63,7 @@ model User {
|
|||||||
IntegrationWebhooks IntegrationWebhook[]
|
IntegrationWebhooks IntegrationWebhook[]
|
||||||
NotificationBatches UserNotificationBatch[]
|
NotificationBatches UserNotificationBatch[]
|
||||||
PendingHumanReviews PendingHumanReview[]
|
PendingHumanReviews PendingHumanReview[]
|
||||||
|
Workspace UserWorkspace?
|
||||||
|
|
||||||
// OAuth Provider relations
|
// OAuth Provider relations
|
||||||
OAuthApplications OAuthApplication[]
|
OAuthApplications OAuthApplication[]
|
||||||
@@ -81,6 +82,7 @@ enum OnboardingStep {
|
|||||||
AGENT_INPUT
|
AGENT_INPUT
|
||||||
CONGRATS
|
CONGRATS
|
||||||
// First Wins
|
// First Wins
|
||||||
|
VISIT_COPILOT
|
||||||
GET_RESULTS
|
GET_RESULTS
|
||||||
MARKETPLACE_VISIT
|
MARKETPLACE_VISIT
|
||||||
MARKETPLACE_ADD_AGENT
|
MARKETPLACE_ADD_AGENT
|
||||||
@@ -136,6 +138,53 @@ model CoPilotUnderstanding {
|
|||||||
@@index([userId])
|
@@index([userId])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
//////////////// USER WORKSPACE TABLES /////////////////
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// User's persistent file storage workspace
|
||||||
|
model UserWorkspace {
|
||||||
|
id String @id @default(uuid())
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
|
userId String @unique
|
||||||
|
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
Files UserWorkspaceFile[]
|
||||||
|
|
||||||
|
@@index([userId])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Individual files in a user's workspace
|
||||||
|
model UserWorkspaceFile {
|
||||||
|
id String @id @default(uuid())
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
|
workspaceId String
|
||||||
|
Workspace UserWorkspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
// File metadata
|
||||||
|
name String // User-visible filename
|
||||||
|
path String // Virtual path (e.g., "/documents/report.pdf")
|
||||||
|
storagePath String // Actual GCS or local storage path
|
||||||
|
mimeType String
|
||||||
|
sizeBytes BigInt
|
||||||
|
checksum String? // SHA256 for integrity
|
||||||
|
|
||||||
|
// File state
|
||||||
|
isDeleted Boolean @default(false)
|
||||||
|
deletedAt DateTime?
|
||||||
|
|
||||||
|
metadata Json @default("{}")
|
||||||
|
|
||||||
|
@@unique([workspaceId, path])
|
||||||
|
@@index([workspaceId, isDeleted])
|
||||||
|
}
|
||||||
|
|
||||||
model BuilderSearchHistory {
|
model BuilderSearchHistory {
|
||||||
id String @id @default(uuid())
|
id String @id @default(uuid())
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
|
|||||||
@@ -57,7 +57,8 @@ class TestDecomposeGoal:
|
|||||||
|
|
||||||
result = await core.decompose_goal("Build a chatbot")
|
result = await core.decompose_goal("Build a chatbot")
|
||||||
|
|
||||||
mock_external.assert_called_once_with("Build a chatbot", "")
|
# library_agents defaults to None
|
||||||
|
mock_external.assert_called_once_with("Build a chatbot", "", None)
|
||||||
assert result == expected_result
|
assert result == expected_result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -74,7 +75,8 @@ class TestDecomposeGoal:
|
|||||||
|
|
||||||
await core.decompose_goal("Build a chatbot", "Use Python")
|
await core.decompose_goal("Build a chatbot", "Use Python")
|
||||||
|
|
||||||
mock_external.assert_called_once_with("Build a chatbot", "Use Python")
|
# library_agents defaults to None
|
||||||
|
mock_external.assert_called_once_with("Build a chatbot", "Use Python", None)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_returns_none_on_service_failure(self):
|
async def test_returns_none_on_service_failure(self):
|
||||||
@@ -109,7 +111,8 @@ class TestGenerateAgent:
|
|||||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||||
result = await core.generate_agent(instructions)
|
result = await core.generate_agent(instructions)
|
||||||
|
|
||||||
mock_external.assert_called_once_with(instructions)
|
# library_agents defaults to None
|
||||||
|
mock_external.assert_called_once_with(instructions, None)
|
||||||
# Result should have id, version, is_active added if not present
|
# Result should have id, version, is_active added if not present
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["name"] == "Test Agent"
|
assert result["name"] == "Test Agent"
|
||||||
@@ -174,7 +177,8 @@ class TestGenerateAgentPatch:
|
|||||||
current_agent = {"nodes": [], "links": []}
|
current_agent = {"nodes": [], "links": []}
|
||||||
result = await core.generate_agent_patch("Add a node", current_agent)
|
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||||
|
|
||||||
mock_external.assert_called_once_with("Add a node", current_agent)
|
# library_agents defaults to None
|
||||||
|
mock_external.assert_called_once_with("Add a node", current_agent, None)
|
||||||
assert result == expected_result
|
assert result == expected_result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -0,0 +1,838 @@
|
|||||||
|
"""
|
||||||
|
Tests for library agent fetching functionality in agent generator.
|
||||||
|
|
||||||
|
This test suite verifies the search-based library agent fetching,
|
||||||
|
including the combination of library and marketplace agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.agent_generator import core
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLibraryAgentsForGeneration:
|
||||||
|
"""Test get_library_agents_for_generation function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetches_agents_with_search_term(self):
|
||||||
|
"""Test that search_term is passed to the library db."""
|
||||||
|
# Create a mock agent with proper attribute values
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.graph_id = "agent-123"
|
||||||
|
mock_agent.graph_version = 1
|
||||||
|
mock_agent.name = "Email Agent"
|
||||||
|
mock_agent.description = "Sends emails"
|
||||||
|
mock_agent.input_schema = {"properties": {}}
|
||||||
|
mock_agent.output_schema = {"properties": {}}
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.agents = [mock_agent]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"list_library_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_list:
|
||||||
|
result = await core.get_library_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query="send email",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify search_term was passed
|
||||||
|
mock_list.assert_called_once_with(
|
||||||
|
user_id="user-123",
|
||||||
|
search_term="send email",
|
||||||
|
page=1,
|
||||||
|
page_size=15,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify result format
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["graph_id"] == "agent-123"
|
||||||
|
assert result[0]["name"] == "Email Agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_excludes_specified_graph_id(self):
|
||||||
|
"""Test that agents with excluded graph_id are filtered out."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.agents = [
|
||||||
|
MagicMock(
|
||||||
|
graph_id="agent-123",
|
||||||
|
graph_version=1,
|
||||||
|
name="Agent 1",
|
||||||
|
description="First agent",
|
||||||
|
input_schema={},
|
||||||
|
output_schema={},
|
||||||
|
),
|
||||||
|
MagicMock(
|
||||||
|
graph_id="agent-456",
|
||||||
|
graph_version=1,
|
||||||
|
name="Agent 2",
|
||||||
|
description="Second agent",
|
||||||
|
input_schema={},
|
||||||
|
output_schema={},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"list_library_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
):
|
||||||
|
result = await core.get_library_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
exclude_graph_id="agent-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the excluded agent is not in results
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["graph_id"] == "agent-456"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_respects_max_results(self):
|
||||||
|
"""Test that max_results parameter limits the page_size."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.agents = []
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"list_library_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_list:
|
||||||
|
await core.get_library_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
max_results=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify page_size was set to max_results
|
||||||
|
mock_list.assert_called_once_with(
|
||||||
|
user_id="user-123",
|
||||||
|
search_term=None,
|
||||||
|
page=1,
|
||||||
|
page_size=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchMarketplaceAgentsForGeneration:
|
||||||
|
"""Test search_marketplace_agents_for_generation function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_searches_marketplace_with_query(self):
|
||||||
|
"""Test that marketplace is searched with the query."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.agents = [
|
||||||
|
MagicMock(
|
||||||
|
agent_name="Public Agent",
|
||||||
|
description="A public agent",
|
||||||
|
sub_heading="Does something useful",
|
||||||
|
creator="creator-1",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# The store_db is dynamically imported, so patch the import path
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.db.get_store_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_search:
|
||||||
|
result = await core.search_marketplace_agents_for_generation(
|
||||||
|
search_query="automation",
|
||||||
|
max_results=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_search.assert_called_once_with(
|
||||||
|
search_query="automation",
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["name"] == "Public Agent"
|
||||||
|
assert result[0]["is_marketplace_agent"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_marketplace_error_gracefully(self):
|
||||||
|
"""Test that marketplace errors don't crash the function."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.db.get_store_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=Exception("Marketplace unavailable"),
|
||||||
|
):
|
||||||
|
result = await core.search_marketplace_agents_for_generation(
|
||||||
|
search_query="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return empty list, not raise exception
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAllRelevantAgentsForGeneration:
|
||||||
|
"""Test get_all_relevant_agents_for_generation function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_combines_library_and_marketplace_agents(self):
|
||||||
|
"""Test that agents from both sources are combined."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "lib-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Library Agent",
|
||||||
|
"description": "From library",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
marketplace_agents = [
|
||||||
|
{
|
||||||
|
"name": "Market Agent",
|
||||||
|
"description": "From marketplace",
|
||||||
|
"sub_heading": "Sub heading",
|
||||||
|
"creator": "creator-1",
|
||||||
|
"is_marketplace_agent": True,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_library_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=library_agents,
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=marketplace_agents,
|
||||||
|
):
|
||||||
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query="test query",
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Library agents should come first
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["name"] == "Library Agent"
|
||||||
|
assert result[1]["name"] == "Market Agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deduplicates_by_name(self):
|
||||||
|
"""Test that marketplace agents with same name as library are excluded."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "lib-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Shared Agent",
|
||||||
|
"description": "From library",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
marketplace_agents = [
|
||||||
|
{
|
||||||
|
"name": "Shared Agent", # Same name, should be deduplicated
|
||||||
|
"description": "From marketplace",
|
||||||
|
"sub_heading": "Sub heading",
|
||||||
|
"creator": "creator-1",
|
||||||
|
"is_marketplace_agent": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Unique Agent",
|
||||||
|
"description": "Only in marketplace",
|
||||||
|
"sub_heading": "Sub heading",
|
||||||
|
"creator": "creator-2",
|
||||||
|
"is_marketplace_agent": True,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_library_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=library_agents,
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=marketplace_agents,
|
||||||
|
):
|
||||||
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query="test",
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shared Agent from marketplace should be excluded
|
||||||
|
assert len(result) == 2
|
||||||
|
names = [a["name"] for a in result]
|
||||||
|
assert "Shared Agent" in names
|
||||||
|
assert "Unique Agent" in names
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_marketplace_when_disabled(self):
|
||||||
|
"""Test that marketplace is not searched when include_marketplace=False."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "lib-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Library Agent",
|
||||||
|
"description": "From library",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_library_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=library_agents,
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_marketplace:
|
||||||
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query="test",
|
||||||
|
include_marketplace=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Marketplace should not be called
|
||||||
|
mock_marketplace.assert_not_called()
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_marketplace_when_no_search_query(self):
|
||||||
|
"""Test that marketplace is not searched without a search query."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "lib-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Library Agent",
|
||||||
|
"description": "From library",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_library_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=library_agents,
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_marketplace:
|
||||||
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query=None, # No search query
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Marketplace should not be called without search query
|
||||||
|
mock_marketplace.assert_not_called()
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractSearchTermsFromSteps:
|
||||||
|
"""Test extract_search_terms_from_steps function."""
|
||||||
|
|
||||||
|
def test_extracts_terms_from_instructions_type(self):
|
||||||
|
"""Test extraction from valid instructions decomposition result."""
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{
|
||||||
|
"description": "Send an email notification",
|
||||||
|
"block_name": "GmailSendBlock",
|
||||||
|
},
|
||||||
|
{"description": "Fetch weather data", "action": "Get weather API"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
assert "Send an email notification" in result
|
||||||
|
assert "GmailSendBlock" in result
|
||||||
|
assert "Fetch weather data" in result
|
||||||
|
assert "Get weather API" in result
|
||||||
|
|
||||||
|
def test_returns_empty_for_non_instructions_type(self):
|
||||||
|
"""Test that non-instructions types return empty list."""
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": [{"question": "What email?"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_deduplicates_terms_case_insensitively(self):
|
||||||
|
"""Test that duplicate terms are removed (case-insensitive)."""
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{"description": "Send Email", "name": "send email"},
|
||||||
|
{"description": "Other task"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
# Should only have one "send email" variant
|
||||||
|
email_terms = [t for t in result if "email" in t.lower()]
|
||||||
|
assert len(email_terms) == 1
|
||||||
|
|
||||||
|
def test_filters_short_terms(self):
|
||||||
|
"""Test that terms with 3 or fewer characters are filtered out."""
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{"description": "ab", "action": "xyz"}, # Both too short
|
||||||
|
{"description": "Valid term here"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
assert "ab" not in result
|
||||||
|
assert "xyz" not in result
|
||||||
|
assert "Valid term here" in result
|
||||||
|
|
||||||
|
def test_handles_empty_steps(self):
|
||||||
|
"""Test handling of empty steps list."""
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnrichLibraryAgentsFromSteps:
|
||||||
|
"""Test enrich_library_agents_from_steps function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enriches_with_additional_agents(self):
|
||||||
|
"""Test that additional agents are found based on steps."""
|
||||||
|
existing_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "existing-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Existing Agent",
|
||||||
|
"description": "Already fetched",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
additional_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "new-456",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Email Agent",
|
||||||
|
"description": "For sending emails",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{"description": "Send email notification"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_all_relevant_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=additional_agents,
|
||||||
|
):
|
||||||
|
result = await core.enrich_library_agents_from_steps(
|
||||||
|
user_id="user-123",
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=existing_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have both existing and new agents
|
||||||
|
assert len(result) == 2
|
||||||
|
names = [a["name"] for a in result]
|
||||||
|
assert "Existing Agent" in names
|
||||||
|
assert "Email Agent" in names
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deduplicates_by_graph_id(self):
|
||||||
|
"""Test that agents with same graph_id are not duplicated."""
|
||||||
|
existing_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Existing Agent",
|
||||||
|
"description": "Already fetched",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Additional search returns same agent
|
||||||
|
additional_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-123", # Same ID
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Existing Agent Copy",
|
||||||
|
"description": "Same agent different name",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [{"description": "Some action"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_all_relevant_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=additional_agents,
|
||||||
|
):
|
||||||
|
result = await core.enrich_library_agents_from_steps(
|
||||||
|
user_id="user-123",
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=existing_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not duplicate
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deduplicates_by_name(self):
|
||||||
|
"""Test that agents with same name are not duplicated."""
|
||||||
|
existing_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Email Agent",
|
||||||
|
"description": "Already fetched",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Additional search returns agent with same name but different ID
|
||||||
|
additional_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-456", # Different ID
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Email Agent", # Same name
|
||||||
|
"description": "Different agent same name",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [{"description": "Send email"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_all_relevant_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=additional_agents,
|
||||||
|
):
|
||||||
|
result = await core.enrich_library_agents_from_steps(
|
||||||
|
user_id="user-123",
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=existing_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not duplicate by name
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].get("graph_id") == "agent-123" # Original kept
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_existing_when_no_steps(self):
|
||||||
|
"""Test that existing agents are returned when no search terms extracted."""
|
||||||
|
existing_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "existing-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Existing Agent",
|
||||||
|
"description": "Already fetched",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "clarifying_questions", # Not instructions type
|
||||||
|
"questions": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await core.enrich_library_agents_from_steps(
|
||||||
|
user_id="user-123",
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=existing_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return existing unchanged
|
||||||
|
assert result == existing_agents
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_limits_search_terms_to_three(self):
|
||||||
|
"""Test that only first 3 search terms are used."""
|
||||||
|
existing_agents = []
|
||||||
|
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{"description": "First action"},
|
||||||
|
{"description": "Second action"},
|
||||||
|
{"description": "Third action"},
|
||||||
|
{"description": "Fourth action"},
|
||||||
|
{"description": "Fifth action"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def mock_get_agents(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_all_relevant_agents_for_generation",
|
||||||
|
side_effect=mock_get_agents,
|
||||||
|
):
|
||||||
|
await core.enrich_library_agents_from_steps(
|
||||||
|
user_id="user-123",
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=existing_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only make 3 calls (limited to first 3 terms)
|
||||||
|
assert call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractUuidsFromText:
|
||||||
|
"""Test extract_uuids_from_text function."""
|
||||||
|
|
||||||
|
def test_extracts_single_uuid(self):
|
||||||
|
"""Test extraction of a single UUID from text."""
|
||||||
|
text = "Use my agent 46631191-e8a8-486f-ad90-84f89738321d for this task"
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "46631191-e8a8-486f-ad90-84f89738321d" in result
|
||||||
|
|
||||||
|
def test_extracts_multiple_uuids(self):
|
||||||
|
"""Test extraction of multiple UUIDs from text."""
|
||||||
|
text = (
|
||||||
|
"Combine agents 11111111-1111-4111-8111-111111111111 "
|
||||||
|
"and 22222222-2222-4222-9222-222222222222"
|
||||||
|
)
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "11111111-1111-4111-8111-111111111111" in result
|
||||||
|
assert "22222222-2222-4222-9222-222222222222" in result
|
||||||
|
|
||||||
|
def test_deduplicates_uuids(self):
|
||||||
|
"""Test that duplicate UUIDs are deduplicated."""
|
||||||
|
text = (
|
||||||
|
"Use 46631191-e8a8-486f-ad90-84f89738321d twice: "
|
||||||
|
"46631191-e8a8-486f-ad90-84f89738321d"
|
||||||
|
)
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_normalizes_to_lowercase(self):
|
||||||
|
"""Test that UUIDs are normalized to lowercase."""
|
||||||
|
text = "Use 46631191-E8A8-486F-AD90-84F89738321D"
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
assert result[0] == "46631191-e8a8-486f-ad90-84f89738321d"
|
||||||
|
|
||||||
|
def test_returns_empty_for_no_uuids(self):
|
||||||
|
"""Test that empty list is returned when no UUIDs found."""
|
||||||
|
text = "Create an email agent that sends notifications"
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_ignores_invalid_uuids(self):
|
||||||
|
"""Test that invalid UUID-like strings are ignored."""
|
||||||
|
text = "Not a valid UUID: 12345678-1234-1234-1234-123456789abc"
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
# UUID v4 requires specific patterns (4 in third group, 8/9/a/b in fourth)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLibraryAgentById:
|
||||||
|
"""Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_agent_when_found_by_graph_id(self):
|
||||||
|
"""Test that agent is returned when found by graph_id."""
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.graph_id = "agent-123"
|
||||||
|
mock_agent.graph_version = 1
|
||||||
|
mock_agent.name = "Test Agent"
|
||||||
|
mock_agent.description = "Test description"
|
||||||
|
mock_agent.input_schema = {"properties": {}}
|
||||||
|
mock_agent.output_schema = {"properties": {}}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent_by_graph_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_agent,
|
||||||
|
):
|
||||||
|
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["graph_id"] == "agent-123"
|
||||||
|
assert result["name"] == "Test Agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_falls_back_to_library_agent_id(self):
|
||||||
|
"""Test that lookup falls back to library agent ID when graph_id not found."""
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.graph_id = "graph-456" # Different from the lookup ID
|
||||||
|
mock_agent.graph_version = 1
|
||||||
|
mock_agent.name = "Library Agent"
|
||||||
|
mock_agent.description = "Found by library ID"
|
||||||
|
mock_agent.input_schema = {"properties": {}}
|
||||||
|
mock_agent.output_schema = {"properties": {}}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent_by_graph_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None, # Not found by graph_id
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_agent, # Found by library ID
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await core.get_library_agent_by_id("user-123", "library-id-123")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["graph_id"] == "graph-456"
|
||||||
|
assert result["name"] == "Library Agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_not_found_by_either_method(self):
|
||||||
|
"""Test that None is returned when agent not found by either method."""
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent_by_graph_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=core.NotFoundError("Not found"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await core.get_library_agent_by_id("user-123", "nonexistent")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_on_exception(self):
|
||||||
|
"""Test that None is returned when exception occurs in both lookups."""
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent_by_graph_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=Exception("Database error"),
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=Exception("Database error"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_alias_works(self):
|
||||||
|
"""Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id."""
|
||||||
|
assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAllRelevantAgentsWithUuids:
|
||||||
|
"""Test UUID extraction in get_all_relevant_agents_for_generation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetches_explicitly_mentioned_agents(self):
|
||||||
|
"""Test that agents mentioned by UUID are fetched directly."""
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.graph_id = "46631191-e8a8-486f-ad90-84f89738321d"
|
||||||
|
mock_agent.graph_version = 1
|
||||||
|
mock_agent.name = "Mentioned Agent"
|
||||||
|
mock_agent.description = "Explicitly mentioned"
|
||||||
|
mock_agent.input_schema = {}
|
||||||
|
mock_agent.output_schema = {}
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.agents = []
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent_by_graph_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_agent,
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"list_library_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d",
|
||||||
|
include_marketplace=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].get("graph_id") == "46631191-e8a8-486f-ad90-84f89738321d"
|
||||||
|
assert result[0].get("name") == "Mentioned Agent"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
@@ -151,15 +151,20 @@ class TestDecomposeGoalExternal:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_decompose_goal_handles_http_error(self):
|
async def test_decompose_goal_handles_http_error(self):
|
||||||
"""Test decomposition handles HTTP errors gracefully."""
|
"""Test decomposition handles HTTP errors gracefully."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 500
|
||||||
mock_client = AsyncMock()
|
mock_client = AsyncMock()
|
||||||
mock_client.post.side_effect = httpx.HTTPStatusError(
|
mock_client.post.side_effect = httpx.HTTPStatusError(
|
||||||
"Server error", request=MagicMock(), response=MagicMock()
|
"Server error", request=MagicMock(), response=mock_response
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch.object(service, "_get_client", return_value=mock_client):
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
result = await service.decompose_goal_external("Build a chatbot")
|
result = await service.decompose_goal_external("Build a chatbot")
|
||||||
|
|
||||||
assert result is None
|
assert result is not None
|
||||||
|
assert result.get("type") == "error"
|
||||||
|
assert result.get("error_type") == "http_error"
|
||||||
|
assert "Server error" in result.get("error", "")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_decompose_goal_handles_request_error(self):
|
async def test_decompose_goal_handles_request_error(self):
|
||||||
@@ -170,7 +175,10 @@ class TestDecomposeGoalExternal:
|
|||||||
with patch.object(service, "_get_client", return_value=mock_client):
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
result = await service.decompose_goal_external("Build a chatbot")
|
result = await service.decompose_goal_external("Build a chatbot")
|
||||||
|
|
||||||
assert result is None
|
assert result is not None
|
||||||
|
assert result.get("type") == "error"
|
||||||
|
assert result.get("error_type") == "connection_error"
|
||||||
|
assert "Connection failed" in result.get("error", "")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_decompose_goal_handles_service_error(self):
|
async def test_decompose_goal_handles_service_error(self):
|
||||||
@@ -179,6 +187,7 @@ class TestDecomposeGoalExternal:
|
|||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": "Internal error",
|
"error": "Internal error",
|
||||||
|
"error_type": "internal_error",
|
||||||
}
|
}
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
@@ -188,7 +197,10 @@ class TestDecomposeGoalExternal:
|
|||||||
with patch.object(service, "_get_client", return_value=mock_client):
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
result = await service.decompose_goal_external("Build a chatbot")
|
result = await service.decompose_goal_external("Build a chatbot")
|
||||||
|
|
||||||
assert result is None
|
assert result is not None
|
||||||
|
assert result.get("type") == "error"
|
||||||
|
assert result.get("error") == "Internal error"
|
||||||
|
assert result.get("error_type") == "internal_error"
|
||||||
|
|
||||||
|
|
||||||
class TestGenerateAgentExternal:
|
class TestGenerateAgentExternal:
|
||||||
@@ -236,7 +248,10 @@ class TestGenerateAgentExternal:
|
|||||||
with patch.object(service, "_get_client", return_value=mock_client):
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
result = await service.generate_agent_external({"steps": []})
|
result = await service.generate_agent_external({"steps": []})
|
||||||
|
|
||||||
assert result is None
|
assert result is not None
|
||||||
|
assert result.get("type") == "error"
|
||||||
|
assert result.get("error_type") == "connection_error"
|
||||||
|
assert "Connection failed" in result.get("error", "")
|
||||||
|
|
||||||
|
|
||||||
class TestGenerateAgentPatchExternal:
|
class TestGenerateAgentPatchExternal:
|
||||||
@@ -418,5 +433,139 @@ class TestGetBlocksExternal:
|
|||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestLibraryAgentsPassthrough:
|
||||||
|
"""Test that library_agents are passed correctly in all requests."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset client singleton before each test."""
|
||||||
|
service._settings = None
|
||||||
|
service._client = None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_passes_library_agents(self):
|
||||||
|
"""Test that library_agents are included in decompose goal payload."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Email Sender",
|
||||||
|
"description": "Sends emails",
|
||||||
|
"input_schema": {"properties": {"to": {"type": "string"}}},
|
||||||
|
"output_schema": {"properties": {"sent": {"type": "boolean"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": ["Step 1"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
await service.decompose_goal_external(
|
||||||
|
"Send an email",
|
||||||
|
library_agents=library_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify library_agents was passed in the payload
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_agent_passes_library_agents(self):
|
||||||
|
"""Test that library_agents are included in generate agent payload."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-456",
|
||||||
|
"graph_version": 2,
|
||||||
|
"name": "Data Fetcher",
|
||||||
|
"description": "Fetches data from API",
|
||||||
|
"input_schema": {"properties": {"url": {"type": "string"}}},
|
||||||
|
"output_schema": {"properties": {"data": {"type": "object"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"agent_json": {"name": "Test Agent", "nodes": []},
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
await service.generate_agent_external(
|
||||||
|
{"steps": ["Step 1"]},
|
||||||
|
library_agents=library_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify library_agents was passed in the payload
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_agent_patch_passes_library_agents(self):
|
||||||
|
"""Test that library_agents are included in patch generation payload."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-789",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Slack Notifier",
|
||||||
|
"description": "Sends Slack messages",
|
||||||
|
"input_schema": {"properties": {"message": {"type": "string"}}},
|
||||||
|
"output_schema": {"properties": {"success": {"type": "boolean"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"agent_json": {"name": "Updated Agent", "nodes": []},
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
await service.generate_agent_patch_external(
|
||||||
|
"Add error handling",
|
||||||
|
{"name": "Original Agent", "nodes": []},
|
||||||
|
library_agents=library_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify library_agents was passed in the payload
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_without_library_agents(self):
|
||||||
|
"""Test that decompose goal works without library_agents."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": ["Step 1"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
await service.decompose_goal_external("Build a workflow")
|
||||||
|
|
||||||
|
# Verify library_agents was NOT passed when not provided
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert "library_agents" not in call_args[1]["json"]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, "-v"])
|
pytest.main([__file__, "-v"])
|
||||||
|
|||||||
@@ -43,19 +43,24 @@ faker = Faker()
|
|||||||
# Constants for data generation limits (reduced for E2E tests)
|
# Constants for data generation limits (reduced for E2E tests)
|
||||||
NUM_USERS = 15
|
NUM_USERS = 15
|
||||||
NUM_AGENT_BLOCKS = 30
|
NUM_AGENT_BLOCKS = 30
|
||||||
MIN_GRAPHS_PER_USER = 15
|
MIN_GRAPHS_PER_USER = 25
|
||||||
MAX_GRAPHS_PER_USER = 15
|
MAX_GRAPHS_PER_USER = 25
|
||||||
MIN_NODES_PER_GRAPH = 3
|
MIN_NODES_PER_GRAPH = 3
|
||||||
MAX_NODES_PER_GRAPH = 6
|
MAX_NODES_PER_GRAPH = 6
|
||||||
MIN_PRESETS_PER_USER = 2
|
MIN_PRESETS_PER_USER = 2
|
||||||
MAX_PRESETS_PER_USER = 3
|
MAX_PRESETS_PER_USER = 3
|
||||||
MIN_AGENTS_PER_USER = 15
|
MIN_AGENTS_PER_USER = 25
|
||||||
MAX_AGENTS_PER_USER = 15
|
MAX_AGENTS_PER_USER = 25
|
||||||
MIN_EXECUTIONS_PER_GRAPH = 2
|
MIN_EXECUTIONS_PER_GRAPH = 2
|
||||||
MAX_EXECUTIONS_PER_GRAPH = 8
|
MAX_EXECUTIONS_PER_GRAPH = 8
|
||||||
MIN_REVIEWS_PER_VERSION = 2
|
MIN_REVIEWS_PER_VERSION = 2
|
||||||
MAX_REVIEWS_PER_VERSION = 5
|
MAX_REVIEWS_PER_VERSION = 5
|
||||||
|
|
||||||
|
# Guaranteed minimums for marketplace tests (deterministic)
|
||||||
|
GUARANTEED_FEATURED_AGENTS = 8
|
||||||
|
GUARANTEED_FEATURED_CREATORS = 5
|
||||||
|
GUARANTEED_TOP_AGENTS = 10
|
||||||
|
|
||||||
|
|
||||||
def get_image():
|
def get_image():
|
||||||
"""Generate a consistent image URL using picsum.photos service."""
|
"""Generate a consistent image URL using picsum.photos service."""
|
||||||
@@ -385,7 +390,7 @@ class TestDataCreator:
|
|||||||
|
|
||||||
library_agents = []
|
library_agents = []
|
||||||
for user in self.users:
|
for user in self.users:
|
||||||
num_agents = 10 # Create exactly 10 agents per user
|
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
|
||||||
|
|
||||||
# Get available graphs for this user
|
# Get available graphs for this user
|
||||||
user_graphs = [
|
user_graphs = [
|
||||||
@@ -507,14 +512,17 @@ class TestDataCreator:
|
|||||||
existing_profiles, min(num_creators, len(existing_profiles))
|
existing_profiles, min(num_creators, len(existing_profiles))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mark about 50% of creators as featured (more for testing)
|
# Guarantee at least GUARANTEED_FEATURED_CREATORS featured creators
|
||||||
num_featured = max(2, int(num_creators * 0.5))
|
num_featured = max(GUARANTEED_FEATURED_CREATORS, int(num_creators * 0.5))
|
||||||
num_featured = min(
|
num_featured = min(
|
||||||
num_featured, len(selected_profiles)
|
num_featured, len(selected_profiles)
|
||||||
) # Don't exceed available profiles
|
) # Don't exceed available profiles
|
||||||
featured_profile_ids = set(
|
featured_profile_ids = set(
|
||||||
random.sample([p.id for p in selected_profiles], num_featured)
|
random.sample([p.id for p in selected_profiles], num_featured)
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
f"🎯 Creating {num_featured} featured creators (min: {GUARANTEED_FEATURED_CREATORS})"
|
||||||
|
)
|
||||||
|
|
||||||
for profile in selected_profiles:
|
for profile in selected_profiles:
|
||||||
try:
|
try:
|
||||||
@@ -545,21 +553,25 @@ class TestDataCreator:
|
|||||||
return profiles
|
return profiles
|
||||||
|
|
||||||
async def create_test_store_submissions(self) -> List[Dict[str, Any]]:
|
async def create_test_store_submissions(self) -> List[Dict[str, Any]]:
|
||||||
"""Create test store submissions using the API function."""
|
"""Create test store submissions using the API function.
|
||||||
|
|
||||||
|
DETERMINISTIC: Guarantees minimum featured agents for E2E tests.
|
||||||
|
"""
|
||||||
print("Creating test store submissions...")
|
print("Creating test store submissions...")
|
||||||
|
|
||||||
submissions = []
|
submissions = []
|
||||||
approved_submissions = []
|
approved_submissions = []
|
||||||
|
featured_count = 0
|
||||||
|
submission_counter = 0
|
||||||
|
|
||||||
# Create a special test submission for test123@gmail.com
|
# Create a special test submission for test123@gmail.com (ALWAYS approved + featured)
|
||||||
test_user = next(
|
test_user = next(
|
||||||
(user for user in self.users if user["email"] == "test123@gmail.com"), None
|
(user for user in self.users if user["email"] == "test123@gmail.com"), None
|
||||||
)
|
)
|
||||||
if test_user:
|
if test_user and self.agent_graphs:
|
||||||
# Special test data for consistent testing
|
|
||||||
test_submission_data = {
|
test_submission_data = {
|
||||||
"user_id": test_user["id"],
|
"user_id": test_user["id"],
|
||||||
"agent_id": self.agent_graphs[0]["id"], # Use first available graph
|
"agent_id": self.agent_graphs[0]["id"],
|
||||||
"agent_version": 1,
|
"agent_version": 1,
|
||||||
"slug": "test-agent-submission",
|
"slug": "test-agent-submission",
|
||||||
"name": "Test Agent Submission",
|
"name": "Test Agent Submission",
|
||||||
@@ -580,37 +592,24 @@ class TestDataCreator:
|
|||||||
submissions.append(test_submission.model_dump())
|
submissions.append(test_submission.model_dump())
|
||||||
print("✅ Created special test store submission for test123@gmail.com")
|
print("✅ Created special test store submission for test123@gmail.com")
|
||||||
|
|
||||||
# Randomly approve, reject, or leave pending the test submission
|
# ALWAYS approve and feature the test submission
|
||||||
if test_submission.store_listing_version_id:
|
if test_submission.store_listing_version_id:
|
||||||
random_value = random.random()
|
approved_submission = await review_store_submission(
|
||||||
if random_value < 0.4: # 40% chance to approve
|
store_listing_version_id=test_submission.store_listing_version_id,
|
||||||
approved_submission = await review_store_submission(
|
is_approved=True,
|
||||||
store_listing_version_id=test_submission.store_listing_version_id,
|
external_comments="Test submission approved",
|
||||||
is_approved=True,
|
internal_comments="Auto-approved test submission",
|
||||||
external_comments="Test submission approved",
|
reviewer_id=test_user["id"],
|
||||||
internal_comments="Auto-approved test submission",
|
)
|
||||||
reviewer_id=test_user["id"],
|
approved_submissions.append(approved_submission.model_dump())
|
||||||
)
|
print("✅ Approved test store submission")
|
||||||
approved_submissions.append(approved_submission.model_dump())
|
|
||||||
print("✅ Approved test store submission")
|
|
||||||
|
|
||||||
# Mark approved submission as featured
|
await prisma.storelistingversion.update(
|
||||||
await prisma.storelistingversion.update(
|
where={"id": test_submission.store_listing_version_id},
|
||||||
where={"id": test_submission.store_listing_version_id},
|
data={"isFeatured": True},
|
||||||
data={"isFeatured": True},
|
)
|
||||||
)
|
featured_count += 1
|
||||||
print("🌟 Marked test agent as FEATURED")
|
print("🌟 Marked test agent as FEATURED")
|
||||||
elif random_value < 0.7: # 30% chance to reject (40% to 70%)
|
|
||||||
await review_store_submission(
|
|
||||||
store_listing_version_id=test_submission.store_listing_version_id,
|
|
||||||
is_approved=False,
|
|
||||||
external_comments="Test submission rejected - needs improvements",
|
|
||||||
internal_comments="Auto-rejected test submission for E2E testing",
|
|
||||||
reviewer_id=test_user["id"],
|
|
||||||
)
|
|
||||||
print("❌ Rejected test store submission")
|
|
||||||
else: # 30% chance to leave pending (70% to 100%)
|
|
||||||
print("⏳ Left test submission pending for review")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating test store submission: {e}")
|
print(f"Error creating test store submission: {e}")
|
||||||
@@ -620,7 +619,6 @@ class TestDataCreator:
|
|||||||
|
|
||||||
# Create regular submissions for all users
|
# Create regular submissions for all users
|
||||||
for user in self.users:
|
for user in self.users:
|
||||||
# Get available graphs for this specific user
|
|
||||||
user_graphs = [
|
user_graphs = [
|
||||||
g for g in self.agent_graphs if g.get("userId") == user["id"]
|
g for g in self.agent_graphs if g.get("userId") == user["id"]
|
||||||
]
|
]
|
||||||
@@ -631,18 +629,17 @@ class TestDataCreator:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Create exactly 4 store submissions per user
|
|
||||||
for submission_index in range(4):
|
for submission_index in range(4):
|
||||||
graph = random.choice(user_graphs)
|
graph = random.choice(user_graphs)
|
||||||
|
submission_counter += 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(
|
print(
|
||||||
f"Creating store submission for user {user['id']} with graph {graph['id']} (owner: {graph.get('userId')})"
|
f"Creating store submission for user {user['id']} with graph {graph['id']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the API function to create store submission with correct parameters
|
|
||||||
submission = await create_store_submission(
|
submission = await create_store_submission(
|
||||||
user_id=user["id"], # Must match graph's userId
|
user_id=user["id"],
|
||||||
agent_id=graph["id"],
|
agent_id=graph["id"],
|
||||||
agent_version=graph.get("version", 1),
|
agent_version=graph.get("version", 1),
|
||||||
slug=faker.slug(),
|
slug=faker.slug(),
|
||||||
@@ -651,22 +648,24 @@ class TestDataCreator:
|
|||||||
video_url=get_video_url() if random.random() < 0.3 else None,
|
video_url=get_video_url() if random.random() < 0.3 else None,
|
||||||
image_urls=[get_image() for _ in range(3)],
|
image_urls=[get_image() for _ in range(3)],
|
||||||
description=faker.text(),
|
description=faker.text(),
|
||||||
categories=[
|
categories=[get_category()],
|
||||||
get_category()
|
|
||||||
], # Single category from predefined list
|
|
||||||
changes_summary="Initial E2E test submission",
|
changes_summary="Initial E2E test submission",
|
||||||
)
|
)
|
||||||
submissions.append(submission.model_dump())
|
submissions.append(submission.model_dump())
|
||||||
print(f"✅ Created store submission: {submission.name}")
|
print(f"✅ Created store submission: {submission.name}")
|
||||||
|
|
||||||
# Randomly approve, reject, or leave pending the submission
|
|
||||||
if submission.store_listing_version_id:
|
if submission.store_listing_version_id:
|
||||||
random_value = random.random()
|
# DETERMINISTIC: First N submissions are always approved
|
||||||
if random_value < 0.4: # 40% chance to approve
|
# First GUARANTEED_FEATURED_AGENTS of those are always featured
|
||||||
try:
|
should_approve = (
|
||||||
# Pick a random user as the reviewer (admin)
|
submission_counter <= GUARANTEED_TOP_AGENTS
|
||||||
reviewer_id = random.choice(self.users)["id"]
|
or random.random() < 0.4
|
||||||
|
)
|
||||||
|
should_feature = featured_count < GUARANTEED_FEATURED_AGENTS
|
||||||
|
|
||||||
|
if should_approve:
|
||||||
|
try:
|
||||||
|
reviewer_id = random.choice(self.users)["id"]
|
||||||
approved_submission = await review_store_submission(
|
approved_submission = await review_store_submission(
|
||||||
store_listing_version_id=submission.store_listing_version_id,
|
store_listing_version_id=submission.store_listing_version_id,
|
||||||
is_approved=True,
|
is_approved=True,
|
||||||
@@ -681,16 +680,7 @@ class TestDataCreator:
|
|||||||
f"✅ Approved store submission: {submission.name}"
|
f"✅ Approved store submission: {submission.name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mark some agents as featured during creation (30% chance)
|
if should_feature:
|
||||||
# More likely for creators and first submissions
|
|
||||||
is_creator = user["id"] in [
|
|
||||||
p.get("userId") for p in self.profiles
|
|
||||||
]
|
|
||||||
feature_chance = (
|
|
||||||
0.5 if is_creator else 0.2
|
|
||||||
) # 50% for creators, 20% for others
|
|
||||||
|
|
||||||
if random.random() < feature_chance:
|
|
||||||
try:
|
try:
|
||||||
await prisma.storelistingversion.update(
|
await prisma.storelistingversion.update(
|
||||||
where={
|
where={
|
||||||
@@ -698,8 +688,25 @@ class TestDataCreator:
|
|||||||
},
|
},
|
||||||
data={"isFeatured": True},
|
data={"isFeatured": True},
|
||||||
)
|
)
|
||||||
|
featured_count += 1
|
||||||
print(
|
print(
|
||||||
f"🌟 Marked agent as FEATURED: {submission.name}"
|
f"🌟 Marked agent as FEATURED ({featured_count}/{GUARANTEED_FEATURED_AGENTS}): {submission.name}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"Warning: Could not mark submission as featured: {e}"
|
||||||
|
)
|
||||||
|
elif random.random() < 0.2:
|
||||||
|
try:
|
||||||
|
await prisma.storelistingversion.update(
|
||||||
|
where={
|
||||||
|
"id": submission.store_listing_version_id
|
||||||
|
},
|
||||||
|
data={"isFeatured": True},
|
||||||
|
)
|
||||||
|
featured_count += 1
|
||||||
|
print(
|
||||||
|
f"🌟 Marked agent as FEATURED (bonus): {submission.name}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(
|
print(
|
||||||
@@ -710,11 +717,9 @@ class TestDataCreator:
|
|||||||
print(
|
print(
|
||||||
f"Warning: Could not approve submission {submission.name}: {e}"
|
f"Warning: Could not approve submission {submission.name}: {e}"
|
||||||
)
|
)
|
||||||
elif random_value < 0.7: # 30% chance to reject (40% to 70%)
|
elif random.random() < 0.5:
|
||||||
try:
|
try:
|
||||||
# Pick a random user as the reviewer (admin)
|
|
||||||
reviewer_id = random.choice(self.users)["id"]
|
reviewer_id = random.choice(self.users)["id"]
|
||||||
|
|
||||||
await review_store_submission(
|
await review_store_submission(
|
||||||
store_listing_version_id=submission.store_listing_version_id,
|
store_listing_version_id=submission.store_listing_version_id,
|
||||||
is_approved=False,
|
is_approved=False,
|
||||||
@@ -729,7 +734,7 @@ class TestDataCreator:
|
|||||||
print(
|
print(
|
||||||
f"Warning: Could not reject submission {submission.name}: {e}"
|
f"Warning: Could not reject submission {submission.name}: {e}"
|
||||||
)
|
)
|
||||||
else: # 30% chance to leave pending (70% to 100%)
|
else:
|
||||||
print(
|
print(
|
||||||
f"⏳ Left submission pending for review: {submission.name}"
|
f"⏳ Left submission pending for review: {submission.name}"
|
||||||
)
|
)
|
||||||
@@ -743,9 +748,13 @@ class TestDataCreator:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
print("\n📊 Store Submissions Summary:")
|
||||||
|
print(f" Created: {len(submissions)}")
|
||||||
|
print(f" Approved: {len(approved_submissions)}")
|
||||||
print(
|
print(
|
||||||
f"Created {len(submissions)} store submissions, approved {len(approved_submissions)}"
|
f" Featured: {featured_count} (guaranteed min: {GUARANTEED_FEATURED_AGENTS})"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store_submissions = submissions
|
self.store_submissions = submissions
|
||||||
return submissions
|
return submissions
|
||||||
|
|
||||||
@@ -825,12 +834,15 @@ class TestDataCreator:
|
|||||||
print(f"✅ Agent blocks available: {len(self.agent_blocks)}")
|
print(f"✅ Agent blocks available: {len(self.agent_blocks)}")
|
||||||
print(f"✅ Agent graphs created: {len(self.agent_graphs)}")
|
print(f"✅ Agent graphs created: {len(self.agent_graphs)}")
|
||||||
print(f"✅ Library agents created: {len(self.library_agents)}")
|
print(f"✅ Library agents created: {len(self.library_agents)}")
|
||||||
print(f"✅ Creator profiles updated: {len(self.profiles)} (some featured)")
|
print(f"✅ Creator profiles updated: {len(self.profiles)}")
|
||||||
print(
|
print(f"✅ Store submissions created: {len(self.store_submissions)}")
|
||||||
f"✅ Store submissions created: {len(self.store_submissions)} (some marked as featured during creation)"
|
|
||||||
)
|
|
||||||
print(f"✅ API keys created: {len(self.api_keys)}")
|
print(f"✅ API keys created: {len(self.api_keys)}")
|
||||||
print(f"✅ Presets created: {len(self.presets)}")
|
print(f"✅ Presets created: {len(self.presets)}")
|
||||||
|
print("\n🎯 Deterministic Guarantees:")
|
||||||
|
print(f" • Featured agents: >= {GUARANTEED_FEATURED_AGENTS}")
|
||||||
|
print(f" • Featured creators: >= {GUARANTEED_FEATURED_CREATORS}")
|
||||||
|
print(f" • Top agents (approved): >= {GUARANTEED_TOP_AGENTS}")
|
||||||
|
print(f" • Library agents per user: >= {MIN_AGENTS_PER_USER}")
|
||||||
print("\n🚀 Your E2E test database is ready to use!")
|
print("\n🚀 Your E2E test database is ready to use!")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -34,3 +34,6 @@ NEXT_PUBLIC_PREVIEW_STEALING_DEV=
|
|||||||
# PostHog Analytics
|
# PostHog Analytics
|
||||||
NEXT_PUBLIC_POSTHOG_KEY=
|
NEXT_PUBLIC_POSTHOG_KEY=
|
||||||
NEXT_PUBLIC_POSTHOG_HOST=https://eu.i.posthog.com
|
NEXT_PUBLIC_POSTHOG_HOST=https://eu.i.posthog.com
|
||||||
|
|
||||||
|
# OpenAI (for voice transcription)
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
|||||||
76
autogpt_platform/frontend/CLAUDE.md
Normal file
76
autogpt_platform/frontend/CLAUDE.md
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
# CLAUDE.md - Frontend
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code when working with the frontend.
|
||||||
|
|
||||||
|
## Essential Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dependencies
|
||||||
|
pnpm i
|
||||||
|
|
||||||
|
# Generate API client from OpenAPI spec
|
||||||
|
pnpm generate:api
|
||||||
|
|
||||||
|
# Start development server
|
||||||
|
pnpm dev
|
||||||
|
|
||||||
|
# Run E2E tests
|
||||||
|
pnpm test
|
||||||
|
|
||||||
|
# Run Storybook for component development
|
||||||
|
pnpm storybook
|
||||||
|
|
||||||
|
# Build production
|
||||||
|
pnpm build
|
||||||
|
|
||||||
|
# Format and lint
|
||||||
|
pnpm format
|
||||||
|
|
||||||
|
# Type checking
|
||||||
|
pnpm types
|
||||||
|
```
|
||||||
|
|
||||||
|
### Code Style
|
||||||
|
|
||||||
|
- Fully capitalize acronyms in symbols, e.g. `graphID`, `useBackendAPI`
|
||||||
|
- Use function declarations (not arrow functions) for components/handlers
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||||
|
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||||
|
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||||
|
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||||
|
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||||
|
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||||
|
- **Icons**: Phosphor Icons only
|
||||||
|
- **Feature Flags**: LaunchDarkly integration
|
||||||
|
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||||
|
- **Testing**: Playwright for E2E, Storybook for component development
|
||||||
|
|
||||||
|
## Environment Configuration
|
||||||
|
|
||||||
|
`.env.default` (defaults) → `.env` (user overrides)
|
||||||
|
|
||||||
|
## Feature Development
|
||||||
|
|
||||||
|
See @CONTRIBUTING.md for complete patterns. Quick reference:
|
||||||
|
|
||||||
|
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||||
|
- Extract component logic into custom hooks grouped by concern, not by component. Each hook should represent a cohesive domain of functionality (e.g., useSearch, useFilters, usePagination) rather than bundling all state into one useComponentState hook.
|
||||||
|
- Put each hook in its own `.ts` file
|
||||||
|
- Put sub-components in local `components/` folder
|
||||||
|
- Component props should be `type Props = { ... }` (not exported) unless it needs to be used outside the component
|
||||||
|
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||||
|
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||||
|
- Never use `src/components/__legacy__/*`
|
||||||
|
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||||
|
- Regenerate with `pnpm generate:api`
|
||||||
|
- Pattern: `use{Method}{Version}{OperationName}`
|
||||||
|
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||||
|
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||||
|
6. **Code conventions**:
|
||||||
|
- Use function declarations (not arrow functions) for components/handlers
|
||||||
|
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
|
||||||
|
- Do not type hook returns, let Typescript infer as much as possible
|
||||||
|
- Never type with `any` unless a variable/attribute can ACTUALLY be of any type
|
||||||
@@ -2,8 +2,9 @@
|
|||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
import { resolveResponse, shouldShowOnboarding } from "@/app/api/helpers";
|
import { resolveResponse, getOnboardingStatus } from "@/app/api/helpers";
|
||||||
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
|
|
||||||
export default function OnboardingPage() {
|
export default function OnboardingPage() {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
@@ -11,10 +12,13 @@ export default function OnboardingPage() {
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
async function redirectToStep() {
|
async function redirectToStep() {
|
||||||
try {
|
try {
|
||||||
// Check if onboarding is enabled
|
// Check if onboarding is enabled (also gets chat flag for redirect)
|
||||||
const isEnabled = await shouldShowOnboarding();
|
const { shouldShowOnboarding, isChatEnabled } =
|
||||||
if (!isEnabled) {
|
await getOnboardingStatus();
|
||||||
router.replace("/");
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
|
if (!shouldShowOnboarding) {
|
||||||
|
router.replace(homepageRoute);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,7 +26,7 @@ export default function OnboardingPage() {
|
|||||||
|
|
||||||
// Handle completed onboarding
|
// Handle completed onboarding
|
||||||
if (onboarding.completedSteps.includes("GET_RESULTS")) {
|
if (onboarding.completedSteps.includes("GET_RESULTS")) {
|
||||||
router.replace("/");
|
router.replace(homepageRoute);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import BackendAPI from "@/lib/autogpt-server-api";
|
import BackendAPI from "@/lib/autogpt-server-api";
|
||||||
import { NextResponse } from "next/server";
|
import { NextResponse } from "next/server";
|
||||||
import { revalidatePath } from "next/cache";
|
import { revalidatePath } from "next/cache";
|
||||||
import { shouldShowOnboarding } from "@/app/api/helpers";
|
import { getOnboardingStatus } from "@/app/api/helpers";
|
||||||
|
|
||||||
// Handle the callback to complete the user session login
|
// Handle the callback to complete the user session login
|
||||||
export async function GET(request: Request) {
|
export async function GET(request: Request) {
|
||||||
@@ -25,11 +26,15 @@ export async function GET(request: Request) {
|
|||||||
const api = new BackendAPI();
|
const api = new BackendAPI();
|
||||||
await api.createUser();
|
await api.createUser();
|
||||||
|
|
||||||
if (await shouldShowOnboarding()) {
|
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||||
|
const { shouldShowOnboarding, isChatEnabled } =
|
||||||
|
await getOnboardingStatus();
|
||||||
|
if (shouldShowOnboarding) {
|
||||||
next = "/onboarding";
|
next = "/onboarding";
|
||||||
revalidatePath("/onboarding", "layout");
|
revalidatePath("/onboarding", "layout");
|
||||||
} else {
|
} else {
|
||||||
revalidatePath("/", "layout");
|
next = getHomepageRoute(isChatEnabled);
|
||||||
|
revalidatePath(next, "layout");
|
||||||
}
|
}
|
||||||
} catch (createUserError) {
|
} catch (createUserError) {
|
||||||
console.error("Error creating user:", createUserError);
|
console.error("Error creating user:", createUserError);
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
||||||
import type { ReactNode } from "react";
|
import type { ReactNode } from "react";
|
||||||
import { useEffect } from "react";
|
|
||||||
import { useCopilotStore } from "../../copilot-page-store";
|
|
||||||
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
|
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
|
||||||
import { LoadingState } from "./components/LoadingState/LoadingState";
|
|
||||||
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
||||||
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
||||||
import { useCopilotShell } from "./useCopilotShell";
|
import { useCopilotShell } from "./useCopilotShell";
|
||||||
@@ -20,38 +18,21 @@ export function CopilotShell({ children }: Props) {
|
|||||||
isMobile,
|
isMobile,
|
||||||
isDrawerOpen,
|
isDrawerOpen,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
isCreatingSession,
|
||||||
isLoggedIn,
|
isLoggedIn,
|
||||||
hasActiveSession,
|
hasActiveSession,
|
||||||
sessions,
|
sessions,
|
||||||
currentSessionId,
|
currentSessionId,
|
||||||
handleSelectSession,
|
|
||||||
handleOpenDrawer,
|
handleOpenDrawer,
|
||||||
handleCloseDrawer,
|
handleCloseDrawer,
|
||||||
handleDrawerOpenChange,
|
handleDrawerOpenChange,
|
||||||
handleNewChat,
|
handleNewChatClick,
|
||||||
|
handleSessionClick,
|
||||||
hasNextPage,
|
hasNextPage,
|
||||||
isFetchingNextPage,
|
isFetchingNextPage,
|
||||||
fetchNextPage,
|
fetchNextPage,
|
||||||
isReadyToShowContent,
|
|
||||||
} = useCopilotShell();
|
} = useCopilotShell();
|
||||||
|
|
||||||
const setNewChatHandler = useCopilotStore((s) => s.setNewChatHandler);
|
|
||||||
const requestNewChat = useCopilotStore((s) => s.requestNewChat);
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function registerNewChatHandler() {
|
|
||||||
setNewChatHandler(handleNewChat);
|
|
||||||
return function cleanup() {
|
|
||||||
setNewChatHandler(null);
|
|
||||||
};
|
|
||||||
},
|
|
||||||
[handleNewChat],
|
|
||||||
);
|
|
||||||
|
|
||||||
function handleNewChatClick() {
|
|
||||||
requestNewChat();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isLoggedIn) {
|
if (!isLoggedIn) {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full items-center justify-center">
|
<div className="flex h-full items-center justify-center">
|
||||||
@@ -72,7 +53,7 @@ export function CopilotShell({ children }: Props) {
|
|||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
hasNextPage={hasNextPage}
|
hasNextPage={hasNextPage}
|
||||||
isFetchingNextPage={isFetchingNextPage}
|
isFetchingNextPage={isFetchingNextPage}
|
||||||
onSelectSession={handleSelectSession}
|
onSelectSession={handleSessionClick}
|
||||||
onFetchNextPage={fetchNextPage}
|
onFetchNextPage={fetchNextPage}
|
||||||
onNewChat={handleNewChatClick}
|
onNewChat={handleNewChatClick}
|
||||||
hasActiveSession={Boolean(hasActiveSession)}
|
hasActiveSession={Boolean(hasActiveSession)}
|
||||||
@@ -82,7 +63,18 @@ export function CopilotShell({ children }: Props) {
|
|||||||
<div className="relative flex min-h-0 flex-1 flex-col">
|
<div className="relative flex min-h-0 flex-1 flex-col">
|
||||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||||
<div className="flex min-h-0 flex-1 flex-col">
|
<div className="flex min-h-0 flex-1 flex-col">
|
||||||
{isReadyToShowContent ? children : <LoadingState />}
|
{isCreatingSession ? (
|
||||||
|
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
||||||
|
<div className="flex flex-col items-center gap-4">
|
||||||
|
<ChatLoader />
|
||||||
|
<Text variant="body" className="text-zinc-500">
|
||||||
|
Creating your chat...
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
children
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -94,7 +86,7 @@ export function CopilotShell({ children }: Props) {
|
|||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
hasNextPage={hasNextPage}
|
hasNextPage={hasNextPage}
|
||||||
isFetchingNextPage={isFetchingNextPage}
|
isFetchingNextPage={isFetchingNextPage}
|
||||||
onSelectSession={handleSelectSession}
|
onSelectSession={handleSessionClick}
|
||||||
onFetchNextPage={fetchNextPage}
|
onFetchNextPage={fetchNextPage}
|
||||||
onNewChat={handleNewChatClick}
|
onNewChat={handleNewChatClick}
|
||||||
onClose={handleCloseDrawer}
|
onClose={handleCloseDrawer}
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
|
||||||
|
|
||||||
export function LoadingState() {
|
|
||||||
return (
|
|
||||||
<div className="flex flex-1 items-center justify-center">
|
|
||||||
<div className="flex flex-col items-center gap-4">
|
|
||||||
<ChatLoader />
|
|
||||||
<Text variant="body" className="text-zinc-500">
|
|
||||||
Loading your chats...
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -3,17 +3,17 @@ import { useState } from "react";
|
|||||||
export function useMobileDrawer() {
|
export function useMobileDrawer() {
|
||||||
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
||||||
|
|
||||||
function handleOpenDrawer() {
|
const handleOpenDrawer = () => {
|
||||||
setIsDrawerOpen(true);
|
setIsDrawerOpen(true);
|
||||||
}
|
};
|
||||||
|
|
||||||
function handleCloseDrawer() {
|
const handleCloseDrawer = () => {
|
||||||
setIsDrawerOpen(false);
|
setIsDrawerOpen(false);
|
||||||
}
|
};
|
||||||
|
|
||||||
function handleDrawerOpenChange(open: boolean) {
|
const handleDrawerOpenChange = (open: boolean) => {
|
||||||
setIsDrawerOpen(open);
|
setIsDrawerOpen(open);
|
||||||
}
|
};
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isDrawerOpen,
|
isDrawerOpen,
|
||||||
|
|||||||
@@ -1,11 +1,6 @@
|
|||||||
import {
|
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
getGetV2ListSessionsQueryKey,
|
|
||||||
useGetV2ListSessions,
|
|
||||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
|
||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||||
import { okData } from "@/app/api/helpers";
|
import { okData } from "@/app/api/helpers";
|
||||||
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
|
|
||||||
const PAGE_SIZE = 50;
|
const PAGE_SIZE = 50;
|
||||||
@@ -16,12 +11,12 @@ export interface UseSessionsPaginationArgs {
|
|||||||
|
|
||||||
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
||||||
const [offset, setOffset] = useState(0);
|
const [offset, setOffset] = useState(0);
|
||||||
|
|
||||||
const [accumulatedSessions, setAccumulatedSessions] = useState<
|
const [accumulatedSessions, setAccumulatedSessions] = useState<
|
||||||
SessionSummaryResponse[]
|
SessionSummaryResponse[]
|
||||||
>([]);
|
>([]);
|
||||||
|
|
||||||
const [totalCount, setTotalCount] = useState<number | null>(null);
|
const [totalCount, setTotalCount] = useState<number | null>(null);
|
||||||
const queryClient = useQueryClient();
|
|
||||||
const onStreamComplete = useChatStore((state) => state.onStreamComplete);
|
|
||||||
|
|
||||||
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
|
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
|
||||||
{ limit: PAGE_SIZE, offset },
|
{ limit: PAGE_SIZE, offset },
|
||||||
@@ -32,38 +27,23 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
useEffect(function refreshOnStreamComplete() {
|
useEffect(() => {
|
||||||
const unsubscribe = onStreamComplete(function handleStreamComplete() {
|
const responseData = okData(data);
|
||||||
setOffset(0);
|
if (responseData) {
|
||||||
|
const newSessions = responseData.sessions;
|
||||||
|
const total = responseData.total;
|
||||||
|
setTotalCount(total);
|
||||||
|
|
||||||
|
if (offset === 0) {
|
||||||
|
setAccumulatedSessions(newSessions);
|
||||||
|
} else {
|
||||||
|
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
|
||||||
|
}
|
||||||
|
} else if (!enabled) {
|
||||||
setAccumulatedSessions([]);
|
setAccumulatedSessions([]);
|
||||||
setTotalCount(null);
|
setTotalCount(null);
|
||||||
queryClient.invalidateQueries({
|
}
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
}, [data, offset, enabled]);
|
||||||
});
|
|
||||||
});
|
|
||||||
return unsubscribe;
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function updateSessionsFromResponse() {
|
|
||||||
const responseData = okData(data);
|
|
||||||
if (responseData) {
|
|
||||||
const newSessions = responseData.sessions;
|
|
||||||
const total = responseData.total;
|
|
||||||
setTotalCount(total);
|
|
||||||
|
|
||||||
if (offset === 0) {
|
|
||||||
setAccumulatedSessions(newSessions);
|
|
||||||
} else {
|
|
||||||
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
|
|
||||||
}
|
|
||||||
} else if (!enabled) {
|
|
||||||
setAccumulatedSessions([]);
|
|
||||||
setTotalCount(null);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[data, offset, enabled],
|
|
||||||
);
|
|
||||||
|
|
||||||
const hasNextPage =
|
const hasNextPage =
|
||||||
totalCount !== null && accumulatedSessions.length < totalCount;
|
totalCount !== null && accumulatedSessions.length < totalCount;
|
||||||
@@ -86,17 +66,17 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
|||||||
}
|
}
|
||||||
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
|
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
|
||||||
|
|
||||||
function fetchNextPage() {
|
const fetchNextPage = () => {
|
||||||
if (hasNextPage && !isFetching) {
|
if (hasNextPage && !isFetching) {
|
||||||
setOffset((prev) => prev + PAGE_SIZE);
|
setOffset((prev) => prev + PAGE_SIZE);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
function reset() {
|
const reset = () => {
|
||||||
|
// Only reset the offset - keep existing sessions visible during refetch
|
||||||
|
// The effect will replace sessions when new data arrives at offset 0
|
||||||
setOffset(0);
|
setOffset(0);
|
||||||
setAccumulatedSessions([]);
|
};
|
||||||
setTotalCount(null);
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
sessions: accumulatedSessions,
|
sessions: accumulatedSessions,
|
||||||
|
|||||||
@@ -104,76 +104,3 @@ export function mergeCurrentSessionIntoList(
|
|||||||
export function getCurrentSessionId(searchParams: URLSearchParams) {
|
export function getCurrentSessionId(searchParams: URLSearchParams) {
|
||||||
return searchParams.get("sessionId");
|
return searchParams.get("sessionId");
|
||||||
}
|
}
|
||||||
|
|
||||||
export function shouldAutoSelectSession(
|
|
||||||
areAllSessionsLoaded: boolean,
|
|
||||||
hasAutoSelectedSession: boolean,
|
|
||||||
paramSessionId: string | null,
|
|
||||||
visibleSessions: SessionSummaryResponse[],
|
|
||||||
accumulatedSessions: SessionSummaryResponse[],
|
|
||||||
isLoading: boolean,
|
|
||||||
totalCount: number | null,
|
|
||||||
) {
|
|
||||||
if (!areAllSessionsLoaded || hasAutoSelectedSession) {
|
|
||||||
return {
|
|
||||||
shouldSelect: false,
|
|
||||||
sessionIdToSelect: null,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (paramSessionId) {
|
|
||||||
return {
|
|
||||||
shouldSelect: false,
|
|
||||||
sessionIdToSelect: null,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (visibleSessions.length > 0) {
|
|
||||||
return {
|
|
||||||
shouldSelect: true,
|
|
||||||
sessionIdToSelect: visibleSessions[0].id,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (accumulatedSessions.length === 0 && !isLoading && totalCount === 0) {
|
|
||||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: true };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (totalCount === 0) {
|
|
||||||
return {
|
|
||||||
shouldSelect: false,
|
|
||||||
sessionIdToSelect: null,
|
|
||||||
shouldCreate: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return { shouldSelect: false, sessionIdToSelect: null, shouldCreate: false };
|
|
||||||
}
|
|
||||||
|
|
||||||
export function checkReadyToShowContent(
|
|
||||||
areAllSessionsLoaded: boolean,
|
|
||||||
paramSessionId: string | null,
|
|
||||||
accumulatedSessions: SessionSummaryResponse[],
|
|
||||||
isCurrentSessionLoading: boolean,
|
|
||||||
currentSessionData: SessionDetailResponse | null | undefined,
|
|
||||||
hasAutoSelectedSession: boolean,
|
|
||||||
) {
|
|
||||||
if (!areAllSessionsLoaded) return false;
|
|
||||||
|
|
||||||
if (paramSessionId) {
|
|
||||||
const sessionFound = accumulatedSessions.some(
|
|
||||||
(s) => s.id === paramSessionId,
|
|
||||||
);
|
|
||||||
return (
|
|
||||||
sessionFound ||
|
|
||||||
(!isCurrentSessionLoading &&
|
|
||||||
currentSessionData !== undefined &&
|
|
||||||
currentSessionData !== null)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return hasAutoSelectedSession;
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,26 +1,22 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
|
getGetV2GetSessionQueryKey,
|
||||||
getGetV2ListSessionsQueryKey,
|
getGetV2ListSessionsQueryKey,
|
||||||
useGetV2GetSession,
|
useGetV2GetSession,
|
||||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
|
||||||
import { okData } from "@/app/api/helpers";
|
import { okData } from "@/app/api/helpers";
|
||||||
|
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
||||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { parseAsString, useQueryState } from "nuqs";
|
|
||||||
import { usePathname, useSearchParams } from "next/navigation";
|
import { usePathname, useSearchParams } from "next/navigation";
|
||||||
import { useEffect, useRef, useState } from "react";
|
import { useRef } from "react";
|
||||||
|
import { useCopilotStore } from "../../copilot-page-store";
|
||||||
|
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
||||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
||||||
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
|
import { getCurrentSessionId } from "./helpers";
|
||||||
import {
|
import { useShellSessionList } from "./useShellSessionList";
|
||||||
checkReadyToShowContent,
|
|
||||||
convertSessionDetailToSummary,
|
|
||||||
filterVisibleSessions,
|
|
||||||
getCurrentSessionId,
|
|
||||||
mergeCurrentSessionIntoList,
|
|
||||||
} from "./helpers";
|
|
||||||
|
|
||||||
export function useCopilotShell() {
|
export function useCopilotShell() {
|
||||||
const pathname = usePathname();
|
const pathname = usePathname();
|
||||||
@@ -31,7 +27,7 @@ export function useCopilotShell() {
|
|||||||
const isMobile =
|
const isMobile =
|
||||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||||
|
|
||||||
const [, setUrlSessionId] = useQueryState("sessionId", parseAsString);
|
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
|
||||||
|
|
||||||
const isOnHomepage = pathname === "/copilot";
|
const isOnHomepage = pathname === "/copilot";
|
||||||
const paramSessionId = searchParams.get("sessionId");
|
const paramSessionId = searchParams.get("sessionId");
|
||||||
@@ -45,123 +41,80 @@ export function useCopilotShell() {
|
|||||||
|
|
||||||
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
|
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
|
||||||
|
|
||||||
const {
|
|
||||||
sessions: accumulatedSessions,
|
|
||||||
isLoading: isSessionsLoading,
|
|
||||||
isFetching: isSessionsFetching,
|
|
||||||
hasNextPage,
|
|
||||||
areAllSessionsLoaded,
|
|
||||||
fetchNextPage,
|
|
||||||
reset: resetPagination,
|
|
||||||
} = useSessionsPagination({
|
|
||||||
enabled: paginationEnabled,
|
|
||||||
});
|
|
||||||
|
|
||||||
const currentSessionId = getCurrentSessionId(searchParams);
|
const currentSessionId = getCurrentSessionId(searchParams);
|
||||||
|
|
||||||
const { data: currentSessionData, isLoading: isCurrentSessionLoading } =
|
const { data: currentSessionData } = useGetV2GetSession(
|
||||||
useGetV2GetSession(currentSessionId || "", {
|
currentSessionId || "",
|
||||||
|
{
|
||||||
query: {
|
query: {
|
||||||
enabled: !!currentSessionId,
|
enabled: !!currentSessionId,
|
||||||
select: okData,
|
select: okData,
|
||||||
},
|
},
|
||||||
});
|
},
|
||||||
|
|
||||||
const [hasAutoSelectedSession, setHasAutoSelectedSession] = useState(false);
|
|
||||||
const hasAutoSelectedRef = useRef(false);
|
|
||||||
const recentlyCreatedSessionsRef = useRef<
|
|
||||||
Map<string, SessionSummaryResponse>
|
|
||||||
>(new Map());
|
|
||||||
|
|
||||||
// Mark as auto-selected when sessionId is in URL
|
|
||||||
useEffect(() => {
|
|
||||||
if (paramSessionId && !hasAutoSelectedRef.current) {
|
|
||||||
hasAutoSelectedRef.current = true;
|
|
||||||
setHasAutoSelectedSession(true);
|
|
||||||
}
|
|
||||||
}, [paramSessionId]);
|
|
||||||
|
|
||||||
// On homepage without sessionId, mark as ready immediately
|
|
||||||
useEffect(() => {
|
|
||||||
if (isOnHomepage && !paramSessionId && !hasAutoSelectedRef.current) {
|
|
||||||
hasAutoSelectedRef.current = true;
|
|
||||||
setHasAutoSelectedSession(true);
|
|
||||||
}
|
|
||||||
}, [isOnHomepage, paramSessionId]);
|
|
||||||
|
|
||||||
// Invalidate sessions list when navigating to homepage (to show newly created sessions)
|
|
||||||
useEffect(() => {
|
|
||||||
if (isOnHomepage && !paramSessionId) {
|
|
||||||
queryClient.invalidateQueries({
|
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [isOnHomepage, paramSessionId, queryClient]);
|
|
||||||
|
|
||||||
// Track newly created sessions to ensure they stay visible even when switching away
|
|
||||||
useEffect(() => {
|
|
||||||
if (currentSessionId && currentSessionData) {
|
|
||||||
const isNewSession =
|
|
||||||
currentSessionData.updated_at === currentSessionData.created_at;
|
|
||||||
const isNotInAccumulated = !accumulatedSessions.some(
|
|
||||||
(s) => s.id === currentSessionId,
|
|
||||||
);
|
|
||||||
if (isNewSession || isNotInAccumulated) {
|
|
||||||
const summary = convertSessionDetailToSummary(currentSessionData);
|
|
||||||
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, [currentSessionId, currentSessionData, accumulatedSessions]);
|
|
||||||
|
|
||||||
// Clean up recently created sessions that are now in the accumulated list
|
|
||||||
useEffect(() => {
|
|
||||||
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
|
|
||||||
if (accumulatedSessions.some((s) => s.id === sessionId)) {
|
|
||||||
recentlyCreatedSessionsRef.current.delete(sessionId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, [accumulatedSessions]);
|
|
||||||
|
|
||||||
// Reset pagination when query becomes disabled
|
|
||||||
const prevPaginationEnabledRef = useRef(paginationEnabled);
|
|
||||||
useEffect(() => {
|
|
||||||
if (prevPaginationEnabledRef.current && !paginationEnabled) {
|
|
||||||
resetPagination();
|
|
||||||
resetAutoSelect();
|
|
||||||
}
|
|
||||||
prevPaginationEnabledRef.current = paginationEnabled;
|
|
||||||
}, [paginationEnabled, resetPagination]);
|
|
||||||
|
|
||||||
const sessions = mergeCurrentSessionIntoList(
|
|
||||||
accumulatedSessions,
|
|
||||||
currentSessionId,
|
|
||||||
currentSessionData,
|
|
||||||
recentlyCreatedSessionsRef.current,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const visibleSessions = filterVisibleSessions(sessions);
|
const {
|
||||||
|
sessions,
|
||||||
|
isLoading,
|
||||||
|
isSessionsFetching,
|
||||||
|
hasNextPage,
|
||||||
|
fetchNextPage,
|
||||||
|
resetPagination,
|
||||||
|
recentlyCreatedSessionsRef,
|
||||||
|
} = useShellSessionList({
|
||||||
|
paginationEnabled,
|
||||||
|
currentSessionId,
|
||||||
|
currentSessionData,
|
||||||
|
isOnHomepage,
|
||||||
|
paramSessionId,
|
||||||
|
});
|
||||||
|
|
||||||
const sidebarSelectedSessionId =
|
const stopStream = useChatStore((s) => s.stopStream);
|
||||||
isOnHomepage && !paramSessionId ? null : currentSessionId;
|
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||||
|
const isStreaming = useCopilotStore((s) => s.isStreaming);
|
||||||
|
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
||||||
|
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
|
||||||
|
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
|
||||||
|
|
||||||
const isReadyToShowContent = isOnHomepage
|
const pendingActionRef = useRef<(() => void) | null>(null);
|
||||||
? true
|
|
||||||
: checkReadyToShowContent(
|
|
||||||
areAllSessionsLoaded,
|
|
||||||
paramSessionId,
|
|
||||||
accumulatedSessions,
|
|
||||||
isCurrentSessionLoading,
|
|
||||||
currentSessionData,
|
|
||||||
hasAutoSelectedSession,
|
|
||||||
);
|
|
||||||
|
|
||||||
function handleSelectSession(sessionId: string) {
|
async function stopCurrentStream() {
|
||||||
|
if (!currentSessionId) return;
|
||||||
|
|
||||||
|
setIsSwitchingSession(true);
|
||||||
|
await new Promise<void>((resolve) => {
|
||||||
|
const unsubscribe = onStreamComplete((completedId) => {
|
||||||
|
if (completedId === currentSessionId) {
|
||||||
|
clearTimeout(timeout);
|
||||||
|
unsubscribe();
|
||||||
|
resolve();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const timeout = setTimeout(() => {
|
||||||
|
unsubscribe();
|
||||||
|
resolve();
|
||||||
|
}, 3000);
|
||||||
|
stopStream(currentSessionId);
|
||||||
|
});
|
||||||
|
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
|
||||||
|
});
|
||||||
|
setIsSwitchingSession(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
function selectSession(sessionId: string) {
|
||||||
|
if (sessionId === currentSessionId) return;
|
||||||
|
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||||
|
});
|
||||||
|
}
|
||||||
setUrlSessionId(sessionId, { shallow: false });
|
setUrlSessionId(sessionId, { shallow: false });
|
||||||
if (isMobile) handleCloseDrawer();
|
if (isMobile) handleCloseDrawer();
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleNewChat() {
|
function startNewChat() {
|
||||||
resetAutoSelect();
|
|
||||||
resetPagination();
|
resetPagination();
|
||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
@@ -170,12 +123,31 @@ export function useCopilotShell() {
|
|||||||
if (isMobile) handleCloseDrawer();
|
if (isMobile) handleCloseDrawer();
|
||||||
}
|
}
|
||||||
|
|
||||||
function resetAutoSelect() {
|
function handleSessionClick(sessionId: string) {
|
||||||
hasAutoSelectedRef.current = false;
|
if (sessionId === currentSessionId) return;
|
||||||
setHasAutoSelectedSession(false);
|
|
||||||
|
if (isStreaming) {
|
||||||
|
pendingActionRef.current = async () => {
|
||||||
|
await stopCurrentStream();
|
||||||
|
selectSession(sessionId);
|
||||||
|
};
|
||||||
|
openInterruptModal(pendingActionRef.current);
|
||||||
|
} else {
|
||||||
|
selectSession(sessionId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
|
function handleNewChatClick() {
|
||||||
|
if (isStreaming) {
|
||||||
|
pendingActionRef.current = async () => {
|
||||||
|
await stopCurrentStream();
|
||||||
|
startNewChat();
|
||||||
|
};
|
||||||
|
openInterruptModal(pendingActionRef.current);
|
||||||
|
} else {
|
||||||
|
startNewChat();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isMobile,
|
isMobile,
|
||||||
@@ -183,17 +155,17 @@ export function useCopilotShell() {
|
|||||||
isLoggedIn,
|
isLoggedIn,
|
||||||
hasActiveSession:
|
hasActiveSession:
|
||||||
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
|
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
|
||||||
isLoading,
|
isLoading: isLoading || isCreatingSession,
|
||||||
sessions: visibleSessions,
|
isCreatingSession,
|
||||||
currentSessionId: sidebarSelectedSessionId,
|
sessions,
|
||||||
handleSelectSession,
|
currentSessionId: urlSessionId,
|
||||||
handleOpenDrawer,
|
handleOpenDrawer,
|
||||||
handleCloseDrawer,
|
handleCloseDrawer,
|
||||||
handleDrawerOpenChange,
|
handleDrawerOpenChange,
|
||||||
handleNewChat,
|
handleNewChatClick,
|
||||||
|
handleSessionClick,
|
||||||
hasNextPage,
|
hasNextPage,
|
||||||
isFetchingNextPage: isSessionsFetching,
|
isFetchingNextPage: isSessionsFetching,
|
||||||
fetchNextPage,
|
fetchNextPage,
|
||||||
isReadyToShowContent,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,113 @@
|
|||||||
|
import { getGetV2ListSessionsQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
|
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||||
|
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||||
|
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
||||||
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
|
import { useEffect, useMemo, useRef } from "react";
|
||||||
|
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
|
||||||
|
import {
|
||||||
|
convertSessionDetailToSummary,
|
||||||
|
filterVisibleSessions,
|
||||||
|
mergeCurrentSessionIntoList,
|
||||||
|
} from "./helpers";
|
||||||
|
|
||||||
|
interface UseShellSessionListArgs {
|
||||||
|
paginationEnabled: boolean;
|
||||||
|
currentSessionId: string | null;
|
||||||
|
currentSessionData: SessionDetailResponse | null | undefined;
|
||||||
|
isOnHomepage: boolean;
|
||||||
|
paramSessionId: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useShellSessionList({
|
||||||
|
paginationEnabled,
|
||||||
|
currentSessionId,
|
||||||
|
currentSessionData,
|
||||||
|
isOnHomepage,
|
||||||
|
paramSessionId,
|
||||||
|
}: UseShellSessionListArgs) {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||||
|
|
||||||
|
const {
|
||||||
|
sessions: accumulatedSessions,
|
||||||
|
isLoading: isSessionsLoading,
|
||||||
|
isFetching: isSessionsFetching,
|
||||||
|
hasNextPage,
|
||||||
|
fetchNextPage,
|
||||||
|
reset: resetPagination,
|
||||||
|
} = useSessionsPagination({
|
||||||
|
enabled: paginationEnabled,
|
||||||
|
});
|
||||||
|
|
||||||
|
const recentlyCreatedSessionsRef = useRef<
|
||||||
|
Map<string, SessionSummaryResponse>
|
||||||
|
>(new Map());
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isOnHomepage && !paramSessionId) {
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [isOnHomepage, paramSessionId, queryClient]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (currentSessionId && currentSessionData) {
|
||||||
|
const isNewSession =
|
||||||
|
currentSessionData.updated_at === currentSessionData.created_at;
|
||||||
|
const isNotInAccumulated = !accumulatedSessions.some(
|
||||||
|
(s) => s.id === currentSessionId,
|
||||||
|
);
|
||||||
|
if (isNewSession || isNotInAccumulated) {
|
||||||
|
const summary = convertSessionDetailToSummary(currentSessionData);
|
||||||
|
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [currentSessionId, currentSessionData, accumulatedSessions]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
|
||||||
|
if (accumulatedSessions.some((s) => s.id === sessionId)) {
|
||||||
|
recentlyCreatedSessionsRef.current.delete(sessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [accumulatedSessions]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const unsubscribe = onStreamComplete(() => {
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
|
});
|
||||||
|
});
|
||||||
|
return unsubscribe;
|
||||||
|
}, [onStreamComplete, queryClient]);
|
||||||
|
|
||||||
|
const sessions = useMemo(
|
||||||
|
() =>
|
||||||
|
mergeCurrentSessionIntoList(
|
||||||
|
accumulatedSessions,
|
||||||
|
currentSessionId,
|
||||||
|
currentSessionData,
|
||||||
|
recentlyCreatedSessionsRef.current,
|
||||||
|
),
|
||||||
|
[accumulatedSessions, currentSessionId, currentSessionData],
|
||||||
|
);
|
||||||
|
|
||||||
|
const visibleSessions = useMemo(
|
||||||
|
() => filterVisibleSessions(sessions),
|
||||||
|
[sessions],
|
||||||
|
);
|
||||||
|
|
||||||
|
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
|
||||||
|
|
||||||
|
return {
|
||||||
|
sessions: visibleSessions,
|
||||||
|
isLoading,
|
||||||
|
isSessionsFetching,
|
||||||
|
hasNextPage,
|
||||||
|
fetchNextPage,
|
||||||
|
resetPagination,
|
||||||
|
recentlyCreatedSessionsRef,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -4,51 +4,53 @@ import { create } from "zustand";
|
|||||||
|
|
||||||
interface CopilotStoreState {
|
interface CopilotStoreState {
|
||||||
isStreaming: boolean;
|
isStreaming: boolean;
|
||||||
isNewChatModalOpen: boolean;
|
isSwitchingSession: boolean;
|
||||||
newChatHandler: (() => void) | null;
|
isCreatingSession: boolean;
|
||||||
|
isInterruptModalOpen: boolean;
|
||||||
|
pendingAction: (() => void) | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface CopilotStoreActions {
|
interface CopilotStoreActions {
|
||||||
setIsStreaming: (isStreaming: boolean) => void;
|
setIsStreaming: (isStreaming: boolean) => void;
|
||||||
setNewChatHandler: (handler: (() => void) | null) => void;
|
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
|
||||||
requestNewChat: () => void;
|
setIsCreatingSession: (isCreating: boolean) => void;
|
||||||
confirmNewChat: () => void;
|
openInterruptModal: (onConfirm: () => void) => void;
|
||||||
cancelNewChat: () => void;
|
confirmInterrupt: () => void;
|
||||||
|
cancelInterrupt: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
type CopilotStore = CopilotStoreState & CopilotStoreActions;
|
type CopilotStore = CopilotStoreState & CopilotStoreActions;
|
||||||
|
|
||||||
export const useCopilotStore = create<CopilotStore>((set, get) => ({
|
export const useCopilotStore = create<CopilotStore>((set, get) => ({
|
||||||
isStreaming: false,
|
isStreaming: false,
|
||||||
isNewChatModalOpen: false,
|
isSwitchingSession: false,
|
||||||
newChatHandler: null,
|
isCreatingSession: false,
|
||||||
|
isInterruptModalOpen: false,
|
||||||
|
pendingAction: null,
|
||||||
|
|
||||||
setIsStreaming(isStreaming) {
|
setIsStreaming(isStreaming) {
|
||||||
set({ isStreaming });
|
set({ isStreaming });
|
||||||
},
|
},
|
||||||
|
|
||||||
setNewChatHandler(handler) {
|
setIsSwitchingSession(isSwitchingSession) {
|
||||||
set({ newChatHandler: handler });
|
set({ isSwitchingSession });
|
||||||
},
|
},
|
||||||
|
|
||||||
requestNewChat() {
|
setIsCreatingSession(isCreatingSession) {
|
||||||
const { isStreaming, newChatHandler } = get();
|
set({ isCreatingSession });
|
||||||
if (isStreaming) {
|
|
||||||
set({ isNewChatModalOpen: true });
|
|
||||||
} else if (newChatHandler) {
|
|
||||||
newChatHandler();
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
|
||||||
confirmNewChat() {
|
openInterruptModal(onConfirm) {
|
||||||
const { newChatHandler } = get();
|
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
|
||||||
set({ isNewChatModalOpen: false });
|
|
||||||
if (newChatHandler) {
|
|
||||||
newChatHandler();
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
|
||||||
cancelNewChat() {
|
confirmInterrupt() {
|
||||||
set({ isNewChatModalOpen: false });
|
const { pendingAction } = get();
|
||||||
|
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||||
|
if (pendingAction) pendingAction();
|
||||||
|
},
|
||||||
|
|
||||||
|
cancelInterrupt() {
|
||||||
|
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|||||||
@@ -1,28 +1,5 @@
|
|||||||
import type { User } from "@supabase/supabase-js";
|
import type { User } from "@supabase/supabase-js";
|
||||||
|
|
||||||
export type PageState =
|
|
||||||
| { type: "welcome" }
|
|
||||||
| { type: "newChat" }
|
|
||||||
| { type: "creating"; prompt: string }
|
|
||||||
| { type: "chat"; sessionId: string; initialPrompt?: string };
|
|
||||||
|
|
||||||
export function getInitialPromptFromState(
|
|
||||||
pageState: PageState,
|
|
||||||
storedInitialPrompt: string | undefined,
|
|
||||||
) {
|
|
||||||
if (storedInitialPrompt) return storedInitialPrompt;
|
|
||||||
if (pageState.type === "creating") return pageState.prompt;
|
|
||||||
if (pageState.type === "chat") return pageState.initialPrompt;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function shouldResetToWelcome(pageState: PageState) {
|
|
||||||
return (
|
|
||||||
pageState.type !== "newChat" &&
|
|
||||||
pageState.type !== "creating" &&
|
|
||||||
pageState.type !== "welcome"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getGreetingName(user?: User | null): string {
|
export function getGreetingName(user?: User | null): string {
|
||||||
if (!user) return "there";
|
if (!user) return "there";
|
||||||
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
||||||
|
|||||||
@@ -1,25 +1,25 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
|
|
||||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
|
||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
import { useCopilotStore } from "./copilot-page-store";
|
import { useCopilotStore } from "./copilot-page-store";
|
||||||
import { useCopilotPage } from "./useCopilotPage";
|
import { useCopilotPage } from "./useCopilotPage";
|
||||||
|
|
||||||
export default function CopilotPage() {
|
export default function CopilotPage() {
|
||||||
const { state, handlers } = useCopilotPage();
|
const { state, handlers } = useCopilotPage();
|
||||||
const confirmNewChat = useCopilotStore((s) => s.confirmNewChat);
|
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||||
|
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||||
|
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||||
const {
|
const {
|
||||||
greetingName,
|
greetingName,
|
||||||
quickActions,
|
quickActions,
|
||||||
isLoading,
|
isLoading,
|
||||||
pageState,
|
hasSession,
|
||||||
isNewChatModalOpen,
|
initialPrompt,
|
||||||
isReady,
|
isReady,
|
||||||
} = state;
|
} = state;
|
||||||
const {
|
const {
|
||||||
@@ -27,20 +27,16 @@ export default function CopilotPage() {
|
|||||||
startChatWithPrompt,
|
startChatWithPrompt,
|
||||||
handleSessionNotFound,
|
handleSessionNotFound,
|
||||||
handleStreamingChange,
|
handleStreamingChange,
|
||||||
handleCancelNewChat,
|
|
||||||
handleNewChatModalOpen,
|
|
||||||
} = handlers;
|
} = handlers;
|
||||||
|
|
||||||
if (!isReady) return null;
|
if (!isReady) return null;
|
||||||
|
|
||||||
if (pageState.type === "chat") {
|
if (hasSession) {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full flex-col">
|
<div className="flex h-full flex-col">
|
||||||
<Chat
|
<Chat
|
||||||
key={pageState.sessionId ?? "welcome"}
|
|
||||||
className="flex-1"
|
className="flex-1"
|
||||||
urlSessionId={pageState.sessionId}
|
initialPrompt={initialPrompt}
|
||||||
initialPrompt={pageState.initialPrompt}
|
|
||||||
onSessionNotFound={handleSessionNotFound}
|
onSessionNotFound={handleSessionNotFound}
|
||||||
onStreamingChange={handleStreamingChange}
|
onStreamingChange={handleStreamingChange}
|
||||||
/>
|
/>
|
||||||
@@ -48,31 +44,33 @@ export default function CopilotPage() {
|
|||||||
title="Interrupt current chat?"
|
title="Interrupt current chat?"
|
||||||
styling={{ maxWidth: 300, width: "100%" }}
|
styling={{ maxWidth: 300, width: "100%" }}
|
||||||
controlled={{
|
controlled={{
|
||||||
isOpen: isNewChatModalOpen,
|
isOpen: isInterruptModalOpen,
|
||||||
set: handleNewChatModalOpen,
|
set: (open) => {
|
||||||
|
if (!open) cancelInterrupt();
|
||||||
|
},
|
||||||
}}
|
}}
|
||||||
onClose={handleCancelNewChat}
|
onClose={cancelInterrupt}
|
||||||
>
|
>
|
||||||
<Dialog.Content>
|
<Dialog.Content>
|
||||||
<div className="flex flex-col gap-4">
|
<div className="flex flex-col gap-4">
|
||||||
<Text variant="body">
|
<Text variant="body">
|
||||||
The current chat response will be interrupted. Are you sure you
|
The current chat response will be interrupted. Are you sure you
|
||||||
want to start a new chat?
|
want to continue?
|
||||||
</Text>
|
</Text>
|
||||||
<Dialog.Footer>
|
<Dialog.Footer>
|
||||||
<Button
|
<Button
|
||||||
type="button"
|
type="button"
|
||||||
variant="outline"
|
variant="outline"
|
||||||
onClick={handleCancelNewChat}
|
onClick={cancelInterrupt}
|
||||||
>
|
>
|
||||||
Cancel
|
Cancel
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
type="button"
|
type="button"
|
||||||
variant="primary"
|
variant="primary"
|
||||||
onClick={confirmNewChat}
|
onClick={confirmInterrupt}
|
||||||
>
|
>
|
||||||
Start new chat
|
Continue
|
||||||
</Button>
|
</Button>
|
||||||
</Dialog.Footer>
|
</Dialog.Footer>
|
||||||
</div>
|
</div>
|
||||||
@@ -82,19 +80,6 @@ export default function CopilotPage() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pageState.type === "newChat" || pageState.type === "creating") {
|
|
||||||
return (
|
|
||||||
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
|
||||||
<div className="flex flex-col items-center gap-4">
|
|
||||||
<ChatLoader />
|
|
||||||
<Text variant="body" className="text-zinc-500">
|
|
||||||
Loading your chats...
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
||||||
<div className="w-full text-center">
|
<div className="w-full text-center">
|
||||||
|
|||||||
@@ -5,79 +5,40 @@ import {
|
|||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
|
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||||
import {
|
import {
|
||||||
Flag,
|
Flag,
|
||||||
type FlagValues,
|
type FlagValues,
|
||||||
useGetFlag,
|
useGetFlag,
|
||||||
} from "@/services/feature-flags/use-get-flag";
|
} from "@/services/feature-flags/use-get-flag";
|
||||||
|
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { useEffect, useReducer } from "react";
|
import { useEffect } from "react";
|
||||||
import { useCopilotStore } from "./copilot-page-store";
|
import { useCopilotStore } from "./copilot-page-store";
|
||||||
import { getGreetingName, getQuickActions, type PageState } from "./helpers";
|
import { getGreetingName, getQuickActions } from "./helpers";
|
||||||
import { useCopilotURLState } from "./useCopilotURLState";
|
import { useCopilotSessionId } from "./useCopilotSessionId";
|
||||||
|
|
||||||
type CopilotState = {
|
|
||||||
pageState: PageState;
|
|
||||||
initialPrompts: Record<string, string>;
|
|
||||||
previousSessionId: string | null;
|
|
||||||
};
|
|
||||||
|
|
||||||
type CopilotAction =
|
|
||||||
| { type: "setPageState"; pageState: PageState }
|
|
||||||
| { type: "setInitialPrompt"; sessionId: string; prompt: string }
|
|
||||||
| { type: "setPreviousSessionId"; sessionId: string | null };
|
|
||||||
|
|
||||||
function isSamePageState(next: PageState, current: PageState) {
|
|
||||||
if (next.type !== current.type) return false;
|
|
||||||
if (next.type === "creating" && current.type === "creating") {
|
|
||||||
return next.prompt === current.prompt;
|
|
||||||
}
|
|
||||||
if (next.type === "chat" && current.type === "chat") {
|
|
||||||
return (
|
|
||||||
next.sessionId === current.sessionId &&
|
|
||||||
next.initialPrompt === current.initialPrompt
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
function copilotReducer(
|
|
||||||
state: CopilotState,
|
|
||||||
action: CopilotAction,
|
|
||||||
): CopilotState {
|
|
||||||
if (action.type === "setPageState") {
|
|
||||||
if (isSamePageState(action.pageState, state.pageState)) return state;
|
|
||||||
return { ...state, pageState: action.pageState };
|
|
||||||
}
|
|
||||||
if (action.type === "setInitialPrompt") {
|
|
||||||
if (state.initialPrompts[action.sessionId] === action.prompt) return state;
|
|
||||||
return {
|
|
||||||
...state,
|
|
||||||
initialPrompts: {
|
|
||||||
...state.initialPrompts,
|
|
||||||
[action.sessionId]: action.prompt,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
if (action.type === "setPreviousSessionId") {
|
|
||||||
if (state.previousSessionId === action.sessionId) return state;
|
|
||||||
return { ...state, previousSessionId: action.sessionId };
|
|
||||||
}
|
|
||||||
return state;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useCopilotPage() {
|
export function useCopilotPage() {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
||||||
const { toast } = useToast();
|
const { toast } = useToast();
|
||||||
|
const { completeStep } = useOnboarding();
|
||||||
|
|
||||||
const isNewChatModalOpen = useCopilotStore((s) => s.isNewChatModalOpen);
|
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
|
||||||
const setIsStreaming = useCopilotStore((s) => s.setIsStreaming);
|
const setIsStreaming = useCopilotStore((s) => s.setIsStreaming);
|
||||||
const cancelNewChat = useCopilotStore((s) => s.cancelNewChat);
|
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
||||||
|
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
||||||
|
|
||||||
|
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
|
||||||
|
useEffect(() => {
|
||||||
|
if (isLoggedIn) {
|
||||||
|
completeStep("VISIT_COPILOT");
|
||||||
|
}
|
||||||
|
}, [completeStep, isLoggedIn]);
|
||||||
|
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
const flags = useFlags<FlagValues>();
|
const flags = useFlags<FlagValues>();
|
||||||
@@ -88,72 +49,27 @@ export function useCopilotPage() {
|
|||||||
const isFlagReady =
|
const isFlagReady =
|
||||||
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
||||||
|
|
||||||
const [state, dispatch] = useReducer(copilotReducer, {
|
|
||||||
pageState: { type: "welcome" },
|
|
||||||
initialPrompts: {},
|
|
||||||
previousSessionId: null,
|
|
||||||
});
|
|
||||||
|
|
||||||
const greetingName = getGreetingName(user);
|
const greetingName = getGreetingName(user);
|
||||||
const quickActions = getQuickActions();
|
const quickActions = getQuickActions();
|
||||||
|
|
||||||
function setPageState(pageState: PageState) {
|
const hasSession = Boolean(urlSessionId);
|
||||||
dispatch({ type: "setPageState", pageState });
|
const initialPrompt = urlSessionId
|
||||||
}
|
? getInitialPrompt(urlSessionId)
|
||||||
|
: undefined;
|
||||||
|
|
||||||
function setInitialPrompt(sessionId: string, prompt: string) {
|
useEffect(() => {
|
||||||
dispatch({ type: "setInitialPrompt", sessionId, prompt });
|
if (!isFlagReady) return;
|
||||||
}
|
if (isChatEnabled === false) {
|
||||||
|
router.replace(homepageRoute);
|
||||||
function setPreviousSessionId(sessionId: string | null) {
|
}
|
||||||
dispatch({ type: "setPreviousSessionId", sessionId });
|
}, [homepageRoute, isChatEnabled, isFlagReady, router]);
|
||||||
}
|
|
||||||
|
|
||||||
const { setUrlSessionId } = useCopilotURLState({
|
|
||||||
pageState: state.pageState,
|
|
||||||
initialPrompts: state.initialPrompts,
|
|
||||||
previousSessionId: state.previousSessionId,
|
|
||||||
setPageState,
|
|
||||||
setInitialPrompt,
|
|
||||||
setPreviousSessionId,
|
|
||||||
});
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function transitionNewChatToWelcome() {
|
|
||||||
if (state.pageState.type === "newChat") {
|
|
||||||
function setWelcomeState() {
|
|
||||||
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
|
|
||||||
}
|
|
||||||
|
|
||||||
const timer = setTimeout(setWelcomeState, 300);
|
|
||||||
|
|
||||||
return function cleanup() {
|
|
||||||
clearTimeout(timer);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[state.pageState.type],
|
|
||||||
);
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function ensureAccess() {
|
|
||||||
if (!isFlagReady) return;
|
|
||||||
if (isChatEnabled === false) {
|
|
||||||
router.replace(homepageRoute);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[homepageRoute, isChatEnabled, isFlagReady, router],
|
|
||||||
);
|
|
||||||
|
|
||||||
async function startChatWithPrompt(prompt: string) {
|
async function startChatWithPrompt(prompt: string) {
|
||||||
if (!prompt?.trim()) return;
|
if (!prompt?.trim()) return;
|
||||||
if (state.pageState.type === "creating") return;
|
if (isCreating) return;
|
||||||
|
|
||||||
const trimmedPrompt = prompt.trim();
|
const trimmedPrompt = prompt.trim();
|
||||||
dispatch({
|
setIsCreating(true);
|
||||||
type: "setPageState",
|
|
||||||
pageState: { type: "creating", prompt: trimmedPrompt },
|
|
||||||
});
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const sessionResponse = await postV2CreateSession({
|
const sessionResponse = await postV2CreateSession({
|
||||||
@@ -165,27 +81,19 @@ export function useCopilotPage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const sessionId = sessionResponse.data.id;
|
const sessionId = sessionResponse.data.id;
|
||||||
|
setInitialPrompt(sessionId, trimmedPrompt);
|
||||||
dispatch({
|
|
||||||
type: "setInitialPrompt",
|
|
||||||
sessionId,
|
|
||||||
prompt: trimmedPrompt,
|
|
||||||
});
|
|
||||||
|
|
||||||
await queryClient.invalidateQueries({
|
await queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
});
|
});
|
||||||
|
|
||||||
await setUrlSessionId(sessionId, { shallow: false });
|
await setUrlSessionId(sessionId, { shallow: true });
|
||||||
dispatch({
|
|
||||||
type: "setPageState",
|
|
||||||
pageState: { type: "chat", sessionId, initialPrompt: trimmedPrompt },
|
|
||||||
});
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("[CopilotPage] Failed to start chat:", error);
|
console.error("[CopilotPage] Failed to start chat:", error);
|
||||||
toast({ title: "Failed to start chat", variant: "destructive" });
|
toast({ title: "Failed to start chat", variant: "destructive" });
|
||||||
Sentry.captureException(error);
|
Sentry.captureException(error);
|
||||||
dispatch({ type: "setPageState", pageState: { type: "welcome" } });
|
} finally {
|
||||||
|
setIsCreating(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,21 +109,13 @@ export function useCopilotPage() {
|
|||||||
setIsStreaming(isStreamingValue);
|
setIsStreaming(isStreamingValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleCancelNewChat() {
|
|
||||||
cancelNewChat();
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleNewChatModalOpen(isOpen: boolean) {
|
|
||||||
if (!isOpen) cancelNewChat();
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
state: {
|
state: {
|
||||||
greetingName,
|
greetingName,
|
||||||
quickActions,
|
quickActions,
|
||||||
isLoading: isUserLoading,
|
isLoading: isUserLoading,
|
||||||
pageState: state.pageState,
|
hasSession,
|
||||||
isNewChatModalOpen,
|
initialPrompt,
|
||||||
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
|
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
|
||||||
},
|
},
|
||||||
handlers: {
|
handlers: {
|
||||||
@@ -223,8 +123,32 @@ export function useCopilotPage() {
|
|||||||
startChatWithPrompt,
|
startChatWithPrompt,
|
||||||
handleSessionNotFound,
|
handleSessionNotFound,
|
||||||
handleStreamingChange,
|
handleStreamingChange,
|
||||||
handleCancelNewChat,
|
|
||||||
handleNewChatModalOpen,
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getInitialPrompt(sessionId: string): string | undefined {
|
||||||
|
try {
|
||||||
|
const prompts = JSON.parse(
|
||||||
|
sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}",
|
||||||
|
);
|
||||||
|
return prompts[sessionId];
|
||||||
|
} catch {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function setInitialPrompt(sessionId: string, prompt: string): void {
|
||||||
|
try {
|
||||||
|
const prompts = JSON.parse(
|
||||||
|
sessionStorage.get(SessionKey.CHAT_INITIAL_PROMPTS) || "{}",
|
||||||
|
);
|
||||||
|
prompts[sessionId] = prompt;
|
||||||
|
sessionStorage.set(
|
||||||
|
SessionKey.CHAT_INITIAL_PROMPTS,
|
||||||
|
JSON.stringify(prompts),
|
||||||
|
);
|
||||||
|
} catch {
|
||||||
|
// Ignore storage errors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,10 @@
|
|||||||
|
import { parseAsString, useQueryState } from "nuqs";
|
||||||
|
|
||||||
|
export function useCopilotSessionId() {
|
||||||
|
const [urlSessionId, setUrlSessionId] = useQueryState(
|
||||||
|
"sessionId",
|
||||||
|
parseAsString,
|
||||||
|
);
|
||||||
|
|
||||||
|
return { urlSessionId, setUrlSessionId };
|
||||||
|
}
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
import { parseAsString, useQueryState } from "nuqs";
|
|
||||||
import { useLayoutEffect } from "react";
|
|
||||||
import {
|
|
||||||
getInitialPromptFromState,
|
|
||||||
type PageState,
|
|
||||||
shouldResetToWelcome,
|
|
||||||
} from "./helpers";
|
|
||||||
|
|
||||||
interface UseCopilotUrlStateArgs {
|
|
||||||
pageState: PageState;
|
|
||||||
initialPrompts: Record<string, string>;
|
|
||||||
previousSessionId: string | null;
|
|
||||||
setPageState: (pageState: PageState) => void;
|
|
||||||
setInitialPrompt: (sessionId: string, prompt: string) => void;
|
|
||||||
setPreviousSessionId: (sessionId: string | null) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useCopilotURLState({
|
|
||||||
pageState,
|
|
||||||
initialPrompts,
|
|
||||||
previousSessionId,
|
|
||||||
setPageState,
|
|
||||||
setInitialPrompt,
|
|
||||||
setPreviousSessionId,
|
|
||||||
}: UseCopilotUrlStateArgs) {
|
|
||||||
const [urlSessionId, setUrlSessionId] = useQueryState(
|
|
||||||
"sessionId",
|
|
||||||
parseAsString,
|
|
||||||
);
|
|
||||||
|
|
||||||
function syncSessionFromUrl() {
|
|
||||||
if (urlSessionId) {
|
|
||||||
if (pageState.type === "chat" && pageState.sessionId === urlSessionId) {
|
|
||||||
setPreviousSessionId(urlSessionId);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const storedInitialPrompt = initialPrompts[urlSessionId];
|
|
||||||
const currentInitialPrompt = getInitialPromptFromState(
|
|
||||||
pageState,
|
|
||||||
storedInitialPrompt,
|
|
||||||
);
|
|
||||||
|
|
||||||
if (currentInitialPrompt) {
|
|
||||||
setInitialPrompt(urlSessionId, currentInitialPrompt);
|
|
||||||
}
|
|
||||||
|
|
||||||
setPageState({
|
|
||||||
type: "chat",
|
|
||||||
sessionId: urlSessionId,
|
|
||||||
initialPrompt: currentInitialPrompt,
|
|
||||||
});
|
|
||||||
setPreviousSessionId(urlSessionId);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const wasInChat = previousSessionId !== null && pageState.type === "chat";
|
|
||||||
setPreviousSessionId(null);
|
|
||||||
if (wasInChat) {
|
|
||||||
setPageState({ type: "newChat" });
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (shouldResetToWelcome(pageState)) {
|
|
||||||
setPageState({ type: "welcome" });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
useLayoutEffect(syncSessionFromUrl, [
|
|
||||||
urlSessionId,
|
|
||||||
pageState.type,
|
|
||||||
previousSessionId,
|
|
||||||
initialPrompts,
|
|
||||||
]);
|
|
||||||
|
|
||||||
return {
|
|
||||||
urlSessionId,
|
|
||||||
setUrlSessionId,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,10 +1,11 @@
|
|||||||
"use server";
|
"use server";
|
||||||
|
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import BackendAPI from "@/lib/autogpt-server-api";
|
import BackendAPI from "@/lib/autogpt-server-api";
|
||||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||||
import { loginFormSchema } from "@/types/auth";
|
import { loginFormSchema } from "@/types/auth";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
import { shouldShowOnboarding } from "../../api/helpers";
|
import { getOnboardingStatus } from "../../api/helpers";
|
||||||
|
|
||||||
export async function login(email: string, password: string) {
|
export async function login(email: string, password: string) {
|
||||||
try {
|
try {
|
||||||
@@ -36,11 +37,15 @@ export async function login(email: string, password: string) {
|
|||||||
const api = new BackendAPI();
|
const api = new BackendAPI();
|
||||||
await api.createUser();
|
await api.createUser();
|
||||||
|
|
||||||
const onboarding = await shouldShowOnboarding();
|
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||||
|
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
|
||||||
|
const next = shouldShowOnboarding
|
||||||
|
? "/onboarding"
|
||||||
|
: getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
success: true,
|
success: true,
|
||||||
onboarding,
|
next,
|
||||||
};
|
};
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
Sentry.captureException(err);
|
Sentry.captureException(err);
|
||||||
|
|||||||
@@ -97,13 +97,8 @@ export function useLoginPage() {
|
|||||||
throw new Error(result.error || "Login failed");
|
throw new Error(result.error || "Login failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nextUrl) {
|
// Prefer URL's next parameter, then use backend-determined route
|
||||||
router.replace(nextUrl);
|
router.replace(nextUrl || result.next || homepageRoute);
|
||||||
} else if (result.onboarding) {
|
|
||||||
router.replace("/onboarding");
|
|
||||||
} else {
|
|
||||||
router.replace(homepageRoute);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
toast({
|
toast({
|
||||||
title:
|
title:
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user