mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-30 09:28:19 -05:00
Compare commits
32 Commits
master
...
feat/sub-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ad8fde75d | ||
|
|
aef705007b | ||
|
|
be7e1ad9b6 | ||
|
|
ce050abff9 | ||
|
|
79eb2889ab | ||
|
|
5bc5e02dcb | ||
|
|
f83366d08d | ||
|
|
350ad3591b | ||
|
|
de0ec3d388 | ||
|
|
7cb1e588b0 | ||
|
|
16ae8ddbe0 | ||
|
|
4b04ae2147 | ||
|
|
de71d6134a | ||
|
|
e6eb8a3f57 | ||
|
|
582c6cad36 | ||
|
|
0d1d275e8d | ||
|
|
dc92a7b520 | ||
|
|
d4047b5439 | ||
|
|
f00678fd1c | ||
|
|
aa175e0f4e | ||
|
|
9a8838c69a | ||
|
|
41beae1122 | ||
|
|
e810f7b0d7 | ||
|
|
9c3822fffe | ||
|
|
c039a2e3ad | ||
|
|
3b822cdaf7 | ||
|
|
a3fe1ede55 | ||
|
|
552d069a9d | ||
|
|
b2eb4831bd | ||
|
|
4cd5da678d | ||
|
|
b94c83aacc | ||
|
|
7668c17d9c |
@@ -29,8 +29,7 @@
|
|||||||
"postCreateCmd": [
|
"postCreateCmd": [
|
||||||
"cd autogpt_platform/autogpt_libs && poetry install",
|
"cd autogpt_platform/autogpt_libs && poetry install",
|
||||||
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
||||||
"cd autogpt_platform/frontend && pnpm install",
|
"cd autogpt_platform/frontend && pnpm install"
|
||||||
"cd docs && pip install -r requirements.txt"
|
|
||||||
],
|
],
|
||||||
"terminalCommand": "code .",
|
"terminalCommand": "code .",
|
||||||
"deleteBranchWithWorktree": false
|
"deleteBranchWithWorktree": false
|
||||||
|
|||||||
6
.github/copilot-instructions.md
vendored
6
.github/copilot-instructions.md
vendored
@@ -160,7 +160,7 @@ pnpm storybook # Start component development server
|
|||||||
|
|
||||||
**Backend Entry Points:**
|
**Backend Entry Points:**
|
||||||
|
|
||||||
- `backend/backend/server/server.py` - FastAPI application setup
|
- `backend/backend/api/rest_api.py` - FastAPI application setup
|
||||||
- `backend/backend/data/` - Database models and user management
|
- `backend/backend/data/` - Database models and user management
|
||||||
- `backend/blocks/` - Agent execution blocks and logic
|
- `backend/blocks/` - Agent execution blocks and logic
|
||||||
|
|
||||||
@@ -219,7 +219,7 @@ Agents are built using a visual block-based system where each block performs a s
|
|||||||
|
|
||||||
### API Development
|
### API Development
|
||||||
|
|
||||||
1. Update routes in `/backend/backend/server/routers/`
|
1. Update routes in `/backend/backend/api/features/`
|
||||||
2. Add/update Pydantic models in same directory
|
2. Add/update Pydantic models in same directory
|
||||||
3. Write tests alongside route files
|
3. Write tests alongside route files
|
||||||
4. For `data/*.py` changes, validate user ID checks
|
4. For `data/*.py` changes, validate user ID checks
|
||||||
@@ -285,7 +285,7 @@ Agents are built using a visual block-based system where each block performs a s
|
|||||||
|
|
||||||
### Security Guidelines
|
### Security Guidelines
|
||||||
|
|
||||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
**Cache Protection Middleware** (`/backend/backend/api/middleware/security.py`):
|
||||||
|
|
||||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -178,4 +178,5 @@ autogpt_platform/backend/settings.py
|
|||||||
*.ign.*
|
*.ign.*
|
||||||
.test-contents
|
.test-contents
|
||||||
.claude/settings.local.json
|
.claude/settings.local.json
|
||||||
|
CLAUDE.local.md
|
||||||
/autogpt_platform/backend/logs
|
/autogpt_platform/backend/logs
|
||||||
|
|||||||
24
AGENTS.md
24
AGENTS.md
@@ -16,7 +16,6 @@ See `docs/content/platform/getting-started.md` for setup instructions.
|
|||||||
- Format Python code with `poetry run format`.
|
- Format Python code with `poetry run format`.
|
||||||
- Format frontend code using `pnpm format`.
|
- Format frontend code using `pnpm format`.
|
||||||
|
|
||||||
|
|
||||||
## Frontend guidelines:
|
## Frontend guidelines:
|
||||||
|
|
||||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||||
@@ -33,14 +32,17 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
|||||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||||
|
|
||||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
||||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
||||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
||||||
- Use function declarations for components, arrow functions only for callbacks
|
- Use function declarations for components, arrow functions only for callbacks
|
||||||
- No barrel files or `index.ts` re-exports
|
- No barrel files or `index.ts` re-exports
|
||||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
|
||||||
- Avoid comments at all times unless the code is very complex
|
- Avoid comments at all times unless the code is very complex
|
||||||
|
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
|
||||||
|
- Do not type hook returns, let Typescript infer as much as possible
|
||||||
|
- Never type with `any`, if not types available use `unknown`
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
@@ -49,22 +51,8 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
|||||||
|
|
||||||
Always run the relevant linters and tests before committing.
|
Always run the relevant linters and tests before committing.
|
||||||
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
||||||
Types:
|
Types: - feat - fix - refactor - ci - dx (developer experience)
|
||||||
- feat
|
Scopes: - platform - platform/library - platform/marketplace - backend - backend/executor - frontend - frontend/library - frontend/marketplace - blocks
|
||||||
- fix
|
|
||||||
- refactor
|
|
||||||
- ci
|
|
||||||
- dx (developer experience)
|
|
||||||
Scopes:
|
|
||||||
- platform
|
|
||||||
- platform/library
|
|
||||||
- platform/marketplace
|
|
||||||
- backend
|
|
||||||
- backend/executor
|
|
||||||
- frontend
|
|
||||||
- frontend/library
|
|
||||||
- frontend/marketplace
|
|
||||||
- blocks
|
|
||||||
|
|
||||||
## Pull requests
|
## Pull requests
|
||||||
|
|
||||||
|
|||||||
@@ -6,152 +6,30 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
|||||||
|
|
||||||
AutoGPT Platform is a monorepo containing:
|
AutoGPT Platform is a monorepo containing:
|
||||||
|
|
||||||
- **Backend** (`/backend`): Python FastAPI server with async support
|
- **Backend** (`backend`): Python FastAPI server with async support
|
||||||
- **Frontend** (`/frontend`): Next.js React application
|
- **Frontend** (`frontend`): Next.js React application
|
||||||
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
|
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||||
|
|
||||||
## Essential Commands
|
## Component Documentation
|
||||||
|
|
||||||
### Backend Development
|
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
|
||||||
|
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
|
||||||
|
|
||||||
```bash
|
## Key Concepts
|
||||||
# Install dependencies
|
|
||||||
cd backend && poetry install
|
|
||||||
|
|
||||||
# Run database migrations
|
|
||||||
poetry run prisma migrate dev
|
|
||||||
|
|
||||||
# Start all services (database, redis, rabbitmq, clamav)
|
|
||||||
docker compose up -d
|
|
||||||
|
|
||||||
# Run the backend server
|
|
||||||
poetry run serve
|
|
||||||
|
|
||||||
# Run tests
|
|
||||||
poetry run test
|
|
||||||
|
|
||||||
# Run specific test
|
|
||||||
poetry run pytest path/to/test_file.py::test_function_name
|
|
||||||
|
|
||||||
# Run block tests (tests that validate all blocks work correctly)
|
|
||||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
|
||||||
|
|
||||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
|
||||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
|
||||||
|
|
||||||
# Lint and format
|
|
||||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
|
||||||
poetry run format # Black + isort
|
|
||||||
poetry run lint # ruff
|
|
||||||
```
|
|
||||||
|
|
||||||
More details can be found in TESTING.md
|
|
||||||
|
|
||||||
#### Creating/Updating Snapshots
|
|
||||||
|
|
||||||
When you first write a test or when the expected output changes:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
poetry run pytest path/to/test.py --snapshot-update
|
|
||||||
```
|
|
||||||
|
|
||||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
|
||||||
|
|
||||||
### Frontend Development
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Install dependencies
|
|
||||||
cd frontend && pnpm i
|
|
||||||
|
|
||||||
# Generate API client from OpenAPI spec
|
|
||||||
pnpm generate:api
|
|
||||||
|
|
||||||
# Start development server
|
|
||||||
pnpm dev
|
|
||||||
|
|
||||||
# Run E2E tests
|
|
||||||
pnpm test
|
|
||||||
|
|
||||||
# Run Storybook for component development
|
|
||||||
pnpm storybook
|
|
||||||
|
|
||||||
# Build production
|
|
||||||
pnpm build
|
|
||||||
|
|
||||||
# Format and lint
|
|
||||||
pnpm format
|
|
||||||
|
|
||||||
# Type checking
|
|
||||||
pnpm types
|
|
||||||
```
|
|
||||||
|
|
||||||
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
|
|
||||||
|
|
||||||
**Key Frontend Conventions:**
|
|
||||||
|
|
||||||
- Separate render logic from data/behavior in components
|
|
||||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
|
||||||
- Use function declarations (not arrow functions) for components/handlers
|
|
||||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
|
||||||
- Only use Phosphor Icons
|
|
||||||
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
|
|
||||||
|
|
||||||
## Architecture Overview
|
|
||||||
|
|
||||||
### Backend Architecture
|
|
||||||
|
|
||||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
|
||||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
|
||||||
- **Queue System**: RabbitMQ for async task processing
|
|
||||||
- **Execution Engine**: Separate executor service processes agent workflows
|
|
||||||
- **Authentication**: JWT-based with Supabase integration
|
|
||||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
|
||||||
|
|
||||||
### Frontend Architecture
|
|
||||||
|
|
||||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
|
||||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
|
||||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
|
||||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
|
||||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
|
||||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
|
||||||
- **Icons**: Phosphor Icons only
|
|
||||||
- **Feature Flags**: LaunchDarkly integration
|
|
||||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
|
||||||
- **Testing**: Playwright for E2E, Storybook for component development
|
|
||||||
|
|
||||||
### Key Concepts
|
|
||||||
|
|
||||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||||
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
|
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||||
3. **Integrations**: OAuth and API connections stored per user
|
3. **Integrations**: OAuth and API connections stored per user
|
||||||
4. **Store**: Marketplace for sharing agent templates
|
4. **Store**: Marketplace for sharing agent templates
|
||||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||||
|
|
||||||
### Testing Approach
|
|
||||||
|
|
||||||
- Backend uses pytest with snapshot testing for API responses
|
|
||||||
- Test files are colocated with source files (`*_test.py`)
|
|
||||||
- Frontend uses Playwright for E2E tests
|
|
||||||
- Component testing via Storybook
|
|
||||||
|
|
||||||
### Database Schema
|
|
||||||
|
|
||||||
Key models (defined in `/backend/schema.prisma`):
|
|
||||||
|
|
||||||
- `User`: Authentication and profile data
|
|
||||||
- `AgentGraph`: Workflow definitions with version control
|
|
||||||
- `AgentGraphExecution`: Execution history and results
|
|
||||||
- `AgentNode`: Individual nodes in a workflow
|
|
||||||
- `StoreListing`: Marketplace listings for sharing agents
|
|
||||||
|
|
||||||
### Environment Configuration
|
### Environment Configuration
|
||||||
|
|
||||||
#### Configuration Files
|
#### Configuration Files
|
||||||
|
|
||||||
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
|
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||||
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
|
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||||
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
|
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||||
|
|
||||||
#### Docker Environment Loading Order
|
#### Docker Environment Loading Order
|
||||||
|
|
||||||
@@ -167,83 +45,12 @@ Key models (defined in `/backend/schema.prisma`):
|
|||||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||||
|
|
||||||
### Common Development Tasks
|
|
||||||
|
|
||||||
**Adding a new block:**
|
|
||||||
|
|
||||||
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
|
||||||
|
|
||||||
- Provider configuration with `ProviderBuilder`
|
|
||||||
- Block schema definition
|
|
||||||
- Authentication (API keys, OAuth, webhooks)
|
|
||||||
- Testing and validation
|
|
||||||
- File organization
|
|
||||||
|
|
||||||
Quick steps:
|
|
||||||
|
|
||||||
1. Create new file in `/backend/backend/blocks/`
|
|
||||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
|
||||||
3. Inherit from `Block` base class
|
|
||||||
4. Define input/output schemas using `BlockSchema`
|
|
||||||
5. Implement async `run` method
|
|
||||||
6. Generate unique block ID using `uuid.uuid4()`
|
|
||||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
|
||||||
|
|
||||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
|
||||||
ex: do the inputs and outputs tie well together?
|
|
||||||
|
|
||||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
|
||||||
|
|
||||||
**Modifying the API:**
|
|
||||||
|
|
||||||
1. Update route in `/backend/backend/server/routers/`
|
|
||||||
2. Add/update Pydantic models in same directory
|
|
||||||
3. Write tests alongside the route file
|
|
||||||
4. Run `poetry run test` to verify
|
|
||||||
|
|
||||||
### Frontend guidelines:
|
|
||||||
|
|
||||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
|
||||||
|
|
||||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
|
||||||
- Add `usePageName.ts` hook for logic
|
|
||||||
- Put sub-components in local `components/` folder
|
|
||||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
|
||||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
|
||||||
- Never use `src/components/__legacy__/*`
|
|
||||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
|
||||||
- Regenerate with `pnpm generate:api`
|
|
||||||
- Pattern: `use{Method}{Version}{OperationName}`
|
|
||||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
|
||||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
|
||||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
|
||||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
|
||||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
|
||||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
|
||||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
|
||||||
- Use function declarations for components, arrow functions only for callbacks
|
|
||||||
- No barrel files or `index.ts` re-exports
|
|
||||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
|
||||||
- Avoid comments at all times unless the code is very complex
|
|
||||||
|
|
||||||
### Security Implementation
|
|
||||||
|
|
||||||
**Cache Protection Middleware:**
|
|
||||||
|
|
||||||
- Located in `/backend/backend/server/middleware/security.py`
|
|
||||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
|
||||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
|
||||||
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
|
|
||||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
|
||||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
|
||||||
- Applied to both main API server and external API applications
|
|
||||||
|
|
||||||
### Creating Pull Requests
|
### Creating Pull Requests
|
||||||
|
|
||||||
- Create the PR aginst the `dev` branch of the repository.
|
- Create the PR against the `dev` branch of the repository.
|
||||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
|
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||||
- Use conventional commit messages (see below)/
|
- Use conventional commit messages (see below)
|
||||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
|
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||||
- Run the github pre-commit hooks to ensure code quality.
|
- Run the github pre-commit hooks to ensure code quality.
|
||||||
|
|
||||||
### Reviewing/Revising Pull Requests
|
### Reviewing/Revising Pull Requests
|
||||||
|
|||||||
170
autogpt_platform/backend/CLAUDE.md
Normal file
170
autogpt_platform/backend/CLAUDE.md
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
# CLAUDE.md - Backend
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code when working with the backend.
|
||||||
|
|
||||||
|
## Essential Commands
|
||||||
|
|
||||||
|
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dependencies
|
||||||
|
poetry install
|
||||||
|
|
||||||
|
# Run database migrations
|
||||||
|
poetry run prisma migrate dev
|
||||||
|
|
||||||
|
# Start all services (database, redis, rabbitmq, clamav)
|
||||||
|
docker compose up -d
|
||||||
|
|
||||||
|
# Run the backend as a whole
|
||||||
|
poetry run app
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
poetry run test
|
||||||
|
|
||||||
|
# Run specific test
|
||||||
|
poetry run pytest path/to/test_file.py::test_function_name
|
||||||
|
|
||||||
|
# Run block tests (tests that validate all blocks work correctly)
|
||||||
|
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||||
|
|
||||||
|
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||||
|
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||||
|
|
||||||
|
# Lint and format
|
||||||
|
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||||
|
poetry run format # Black + isort
|
||||||
|
poetry run lint # ruff
|
||||||
|
```
|
||||||
|
|
||||||
|
More details can be found in @TESTING.md
|
||||||
|
|
||||||
|
### Creating/Updating Snapshots
|
||||||
|
|
||||||
|
When you first write a test or when the expected output changes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
poetry run pytest path/to/test.py --snapshot-update
|
||||||
|
```
|
||||||
|
|
||||||
|
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||||
|
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||||
|
- **Queue System**: RabbitMQ for async task processing
|
||||||
|
- **Execution Engine**: Separate executor service processes agent workflows
|
||||||
|
- **Authentication**: JWT-based with Supabase integration
|
||||||
|
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||||
|
|
||||||
|
## Testing Approach
|
||||||
|
|
||||||
|
- Uses pytest with snapshot testing for API responses
|
||||||
|
- Test files are colocated with source files (`*_test.py`)
|
||||||
|
|
||||||
|
## Database Schema
|
||||||
|
|
||||||
|
Key models (defined in `schema.prisma`):
|
||||||
|
|
||||||
|
- `User`: Authentication and profile data
|
||||||
|
- `AgentGraph`: Workflow definitions with version control
|
||||||
|
- `AgentGraphExecution`: Execution history and results
|
||||||
|
- `AgentNode`: Individual nodes in a workflow
|
||||||
|
- `StoreListing`: Marketplace listings for sharing agents
|
||||||
|
|
||||||
|
## Environment Configuration
|
||||||
|
|
||||||
|
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||||
|
|
||||||
|
## Common Development Tasks
|
||||||
|
|
||||||
|
### Adding a new block
|
||||||
|
|
||||||
|
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||||
|
|
||||||
|
- Provider configuration with `ProviderBuilder`
|
||||||
|
- Block schema definition
|
||||||
|
- Authentication (API keys, OAuth, webhooks)
|
||||||
|
- Testing and validation
|
||||||
|
- File organization
|
||||||
|
|
||||||
|
Quick steps:
|
||||||
|
|
||||||
|
1. Create new file in `backend/blocks/`
|
||||||
|
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||||
|
3. Inherit from `Block` base class
|
||||||
|
4. Define input/output schemas using `BlockSchema`
|
||||||
|
5. Implement async `run` method
|
||||||
|
6. Generate unique block ID using `uuid.uuid4()`
|
||||||
|
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||||
|
|
||||||
|
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||||
|
ex: do the inputs and outputs tie well together?
|
||||||
|
|
||||||
|
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||||
|
|
||||||
|
#### Handling files in blocks with `store_media_file()`
|
||||||
|
|
||||||
|
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||||
|
|
||||||
|
| Format | Use When | Returns |
|
||||||
|
|--------|----------|---------|
|
||||||
|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||||
|
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||||
|
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||||
|
|
||||||
|
**Examples:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# INPUT: Need to process file locally with ffmpeg
|
||||||
|
local_path = await store_media_file(
|
||||||
|
file=input_data.video,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||||
|
|
||||||
|
# INPUT: Need to send to external API like Replicate
|
||||||
|
image_b64 = await store_media_file(
|
||||||
|
file=input_data.image,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_external_api",
|
||||||
|
)
|
||||||
|
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||||
|
|
||||||
|
# OUTPUT: Returning result from block
|
||||||
|
result_url = await store_media_file(
|
||||||
|
file=generated_image_url,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", result_url
|
||||||
|
# In CoPilot: result_url = "workspace://abc123"
|
||||||
|
# In graphs: result_url = "data:image/png;base64,..."
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key points:**
|
||||||
|
|
||||||
|
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||||
|
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||||
|
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||||
|
|
||||||
|
### Modifying the API
|
||||||
|
|
||||||
|
1. Update route in `backend/api/features/`
|
||||||
|
2. Add/update Pydantic models in same directory
|
||||||
|
3. Write tests alongside the route file
|
||||||
|
4. Run `poetry run test` to verify
|
||||||
|
|
||||||
|
## Security Implementation
|
||||||
|
|
||||||
|
### Cache Protection Middleware
|
||||||
|
|
||||||
|
- Located in `backend/api/middleware/security.py`
|
||||||
|
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||||
|
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||||
|
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||||
|
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||||
|
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||||
|
- Applied to both main API server and external API applications
|
||||||
@@ -138,7 +138,7 @@ If the test doesn't need the `user_id` specifically, mocking is not necessary as
|
|||||||
|
|
||||||
#### Using Global Auth Fixtures
|
#### Using Global Auth Fixtures
|
||||||
|
|
||||||
Two global auth fixtures are provided by `backend/server/conftest.py`:
|
Two global auth fixtures are provided by `backend/api/conftest.py`:
|
||||||
|
|
||||||
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
|
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
|
||||||
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
|
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ router = fastapi.APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Taken from backend/server/v2/store/db.py
|
# Taken from backend/api/features/store/db.py
|
||||||
def sanitize_query(query: str | None) -> str | None:
|
def sanitize_query(query: str | None) -> str | None:
|
||||||
if query is None:
|
if query is None:
|
||||||
return query
|
return query
|
||||||
|
|||||||
@@ -1834,6 +1834,11 @@ async def _execute_long_running_tool(
|
|||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
result=error_response.model_dump_json(),
|
result=error_response.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
# Generate LLM continuation so user sees explanation even for errors
|
||||||
|
try:
|
||||||
|
await _generate_llm_continuation(session_id=session_id, user_id=user_id)
|
||||||
|
except Exception as llm_err:
|
||||||
|
logger.warning(f"Failed to generate LLM continuation for error: {llm_err}")
|
||||||
finally:
|
finally:
|
||||||
await _mark_operation_completed(tool_call_id)
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,79 @@
|
|||||||
|
# CoPilot Tools - Future Ideas
|
||||||
|
|
||||||
|
## Multimodal Image Support for CoPilot
|
||||||
|
|
||||||
|
**Problem:** CoPilot uses a vision-capable model but can't "see" workspace images. When a block generates an image and returns `workspace://abc123`, CoPilot can't evaluate it (e.g., checking blog thumbnail quality).
|
||||||
|
|
||||||
|
**Backend Solution:**
|
||||||
|
When preparing messages for the LLM, detect `workspace://` image references and convert them to proper image content blocks:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Before sending to LLM, scan for workspace image references
|
||||||
|
# and inject them as image content parts
|
||||||
|
|
||||||
|
# Example message transformation:
|
||||||
|
# FROM: {"role": "assistant", "content": "Generated image: workspace://abc123"}
|
||||||
|
# TO: {"role": "assistant", "content": [
|
||||||
|
# {"type": "text", "text": "Generated image: workspace://abc123"},
|
||||||
|
# {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
|
||||||
|
# ]}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Where to implement:**
|
||||||
|
- In the chat stream handler before calling the LLM
|
||||||
|
- Or in a message preprocessing step
|
||||||
|
- Need to fetch image from workspace, convert to base64, add as image content
|
||||||
|
|
||||||
|
**Considerations:**
|
||||||
|
- Only do this for image MIME types (image/png, image/jpeg, etc.)
|
||||||
|
- May want a size limit (don't pass 10MB images)
|
||||||
|
- Track which images were "shown" to the AI for frontend indicator
|
||||||
|
- Cost implications - vision API calls are more expensive
|
||||||
|
|
||||||
|
**Frontend Solution:**
|
||||||
|
Show visual indicator on workspace files in chat:
|
||||||
|
- If AI saw the image: normal display
|
||||||
|
- If AI didn't see it: overlay icon saying "AI can't see this image"
|
||||||
|
|
||||||
|
Requires response metadata indicating which `workspace://` refs were passed to the model.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Output Post-Processing Layer for run_block
|
||||||
|
|
||||||
|
**Problem:** Many blocks produce large outputs that:
|
||||||
|
- Consume massive context (100KB base64 image = ~133KB tokens)
|
||||||
|
- Can't fit in conversation
|
||||||
|
- Break things and cause high LLM costs
|
||||||
|
|
||||||
|
**Proposed Solution:** Instead of modifying individual blocks or `store_media_file()`, implement a centralized output processor in `run_block.py` that handles outputs before they're returned to CoPilot.
|
||||||
|
|
||||||
|
**Benefits:**
|
||||||
|
1. **Centralized** - one place to handle all output processing
|
||||||
|
2. **Future-proof** - new blocks automatically get output processing
|
||||||
|
3. **Keeps blocks pure** - they don't need to know about context constraints
|
||||||
|
4. **Handles all large outputs** - not just images
|
||||||
|
|
||||||
|
**Processing Rules:**
|
||||||
|
- Detect base64 data URIs → save to workspace, return `workspace://` reference
|
||||||
|
- Truncate very long strings (>N chars) with truncation note
|
||||||
|
- Summarize large arrays/lists (e.g., "Array with 1000 items, first 5: [...]")
|
||||||
|
- Handle nested large outputs in dicts recursively
|
||||||
|
- Cap total output size
|
||||||
|
|
||||||
|
**Implementation Location:** `run_block.py` after block execution, before returning `BlockOutputResponse`
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```python
|
||||||
|
def _process_outputs_for_context(
|
||||||
|
outputs: dict[str, list[Any]],
|
||||||
|
workspace_manager: WorkspaceManager,
|
||||||
|
max_string_length: int = 10000,
|
||||||
|
max_array_preview: int = 5,
|
||||||
|
) -> dict[str, list[Any]]:
|
||||||
|
"""Process block outputs to prevent context bloat."""
|
||||||
|
processed = {}
|
||||||
|
for name, values in outputs.items():
|
||||||
|
processed[name] = [_process_value(v, workspace_manager) for v in values]
|
||||||
|
return processed
|
||||||
|
```
|
||||||
@@ -18,6 +18,12 @@ from .get_doc_page import GetDocPageTool
|
|||||||
from .run_agent import RunAgentTool
|
from .run_agent import RunAgentTool
|
||||||
from .run_block import RunBlockTool
|
from .run_block import RunBlockTool
|
||||||
from .search_docs import SearchDocsTool
|
from .search_docs import SearchDocsTool
|
||||||
|
from .workspace_files import (
|
||||||
|
DeleteWorkspaceFileTool,
|
||||||
|
ListWorkspaceFilesTool,
|
||||||
|
ReadWorkspaceFileTool,
|
||||||
|
WriteWorkspaceFileTool,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||||
@@ -37,6 +43,11 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
|||||||
"view_agent_output": AgentOutputTool(),
|
"view_agent_output": AgentOutputTool(),
|
||||||
"search_docs": SearchDocsTool(),
|
"search_docs": SearchDocsTool(),
|
||||||
"get_doc_page": GetDocPageTool(),
|
"get_doc_page": GetDocPageTool(),
|
||||||
|
# Workspace tools for CoPilot file operations
|
||||||
|
"list_workspace_files": ListWorkspaceFilesTool(),
|
||||||
|
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||||
|
"write_workspace_file": WriteWorkspaceFileTool(),
|
||||||
|
"delete_workspace_file": DeleteWorkspaceFileTool(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Export individual tool instances for backwards compatibility
|
# Export individual tool instances for backwards compatibility
|
||||||
|
|||||||
@@ -2,27 +2,52 @@
|
|||||||
|
|
||||||
from .core import (
|
from .core import (
|
||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
|
AgentSummary,
|
||||||
|
DecompositionResult,
|
||||||
|
DecompositionStep,
|
||||||
|
LibraryAgentSummary,
|
||||||
|
MarketplaceAgentSummary,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
|
enrich_library_agents_from_steps,
|
||||||
|
extract_search_terms_from_steps,
|
||||||
|
extract_uuids_from_text,
|
||||||
generate_agent,
|
generate_agent,
|
||||||
generate_agent_patch,
|
generate_agent_patch,
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
|
get_all_relevant_agents_for_generation,
|
||||||
|
get_library_agent_by_graph_id,
|
||||||
|
get_library_agent_by_id,
|
||||||
|
get_library_agents_for_generation,
|
||||||
json_to_graph,
|
json_to_graph,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
|
search_marketplace_agents_for_generation,
|
||||||
)
|
)
|
||||||
|
from .errors import get_user_message_for_error
|
||||||
from .service import health_check as check_external_service_health
|
from .service import health_check as check_external_service_health
|
||||||
from .service import is_external_service_configured
|
from .service import is_external_service_configured
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Core functions
|
"AgentGeneratorNotConfiguredError",
|
||||||
|
"AgentSummary",
|
||||||
|
"DecompositionResult",
|
||||||
|
"DecompositionStep",
|
||||||
|
"LibraryAgentSummary",
|
||||||
|
"MarketplaceAgentSummary",
|
||||||
|
"check_external_service_health",
|
||||||
"decompose_goal",
|
"decompose_goal",
|
||||||
|
"enrich_library_agents_from_steps",
|
||||||
|
"extract_search_terms_from_steps",
|
||||||
|
"extract_uuids_from_text",
|
||||||
"generate_agent",
|
"generate_agent",
|
||||||
"generate_agent_patch",
|
"generate_agent_patch",
|
||||||
"save_agent_to_library",
|
|
||||||
"get_agent_as_json",
|
"get_agent_as_json",
|
||||||
"json_to_graph",
|
"get_all_relevant_agents_for_generation",
|
||||||
# Exceptions
|
"get_library_agent_by_graph_id",
|
||||||
"AgentGeneratorNotConfiguredError",
|
"get_library_agent_by_id",
|
||||||
# Service
|
"get_library_agents_for_generation",
|
||||||
|
"get_user_message_for_error",
|
||||||
"is_external_service_configured",
|
"is_external_service_configured",
|
||||||
"check_external_service_health",
|
"json_to_graph",
|
||||||
|
"save_agent_to_library",
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,11 +1,21 @@
|
|||||||
"""Core agent generation functions."""
|
"""Core agent generation functions."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.data.graph import Graph, Link, Node, create_graph
|
from backend.api.features.store import db as store_db
|
||||||
|
from backend.data.graph import (
|
||||||
|
Graph,
|
||||||
|
Link,
|
||||||
|
Node,
|
||||||
|
create_graph,
|
||||||
|
get_graph,
|
||||||
|
get_graph_all_versions,
|
||||||
|
)
|
||||||
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
decompose_goal_external,
|
decompose_goal_external,
|
||||||
@@ -17,6 +27,60 @@ from .service import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LibraryAgentSummary(TypedDict):
|
||||||
|
"""Summary of a library agent for sub-agent composition."""
|
||||||
|
|
||||||
|
graph_id: str
|
||||||
|
graph_version: int
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
output_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class MarketplaceAgentSummary(TypedDict):
|
||||||
|
"""Summary of a marketplace agent for sub-agent composition."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
sub_heading: str
|
||||||
|
creator: str
|
||||||
|
is_marketplace_agent: bool
|
||||||
|
|
||||||
|
|
||||||
|
class DecompositionStep(TypedDict, total=False):
|
||||||
|
"""A single step in decomposed instructions."""
|
||||||
|
|
||||||
|
description: str
|
||||||
|
action: str
|
||||||
|
block_name: str
|
||||||
|
tool: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class DecompositionResult(TypedDict, total=False):
|
||||||
|
"""Result from decompose_goal - can be instructions, questions, or error."""
|
||||||
|
|
||||||
|
type: str # "instructions", "clarifying_questions", "error", etc.
|
||||||
|
steps: list[DecompositionStep]
|
||||||
|
questions: list[dict[str, Any]]
|
||||||
|
error: str
|
||||||
|
error_type: str
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for agent summaries (can be either library or marketplace)
|
||||||
|
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def _to_dict_list(
|
||||||
|
agents: list[AgentSummary] | list[dict[str, Any]] | None,
|
||||||
|
) -> list[dict[str, Any]] | None:
|
||||||
|
"""Convert typed agent summaries to plain dicts for external service calls."""
|
||||||
|
if agents is None:
|
||||||
|
return None
|
||||||
|
return [dict(a) for a in agents]
|
||||||
|
|
||||||
|
|
||||||
class AgentGeneratorNotConfiguredError(Exception):
|
class AgentGeneratorNotConfiguredError(Exception):
|
||||||
"""Raised when the external Agent Generator service is not configured."""
|
"""Raised when the external Agent Generator service is not configured."""
|
||||||
|
|
||||||
@@ -36,15 +100,394 @@ def _check_service_configured() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
_UUID_PATTERN = re.compile(
|
||||||
|
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_uuids_from_text(text: str) -> list[str]:
|
||||||
|
"""Extract all UUID v4 strings from text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text that may contain UUIDs (e.g., user's goal description)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unique UUIDs found in the text (lowercase)
|
||||||
|
"""
|
||||||
|
matches = _UUID_PATTERN.findall(text)
|
||||||
|
return list({m.lower() for m in matches})
|
||||||
|
|
||||||
|
|
||||||
|
async def get_library_agent_by_id(
|
||||||
|
user_id: str, agent_id: str
|
||||||
|
) -> LibraryAgentSummary | None:
|
||||||
|
"""Fetch a specific library agent by its ID (library agent ID or graph_id).
|
||||||
|
|
||||||
|
This function tries multiple lookup strategies:
|
||||||
|
1. First tries to find by graph_id (AgentGraph primary key)
|
||||||
|
2. If not found, tries to find by library agent ID (LibraryAgent primary key)
|
||||||
|
|
||||||
|
This handles both cases:
|
||||||
|
- User provides graph_id (e.g., from AgentExecutorBlock)
|
||||||
|
- User provides library agent ID (e.g., from library URL)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LibraryAgentSummary if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||||
|
return LibraryAgentSummary(
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
graph_version=agent.graph_version,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description,
|
||||||
|
input_schema=agent.input_schema,
|
||||||
|
output_schema=agent.output_schema,
|
||||||
|
)
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||||
|
return LibraryAgentSummary(
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
graph_version=agent.graph_version,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description,
|
||||||
|
input_schema=agent.input_schema,
|
||||||
|
output_schema=agent.output_schema,
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Alias for backward compatibility
|
||||||
|
get_library_agent_by_graph_id = get_library_agent_by_id
|
||||||
|
|
||||||
|
|
||||||
|
async def get_library_agents_for_generation(
|
||||||
|
user_id: str,
|
||||||
|
search_query: str | None = None,
|
||||||
|
exclude_graph_id: str | None = None,
|
||||||
|
max_results: int = 15,
|
||||||
|
) -> list[LibraryAgentSummary]:
|
||||||
|
"""Fetch user's library agents formatted for Agent Generator.
|
||||||
|
|
||||||
|
Uses search-based fetching to return relevant agents instead of all agents.
|
||||||
|
This is more scalable for users with large libraries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
search_query: Optional search term to find relevant agents (user's goal/description)
|
||||||
|
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
|
||||||
|
max_results: Maximum number of agents to return (default 15)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LibraryAgentSummary with schemas for sub-agent composition
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Future enhancement: Add quality filtering based on execution success rate
|
||||||
|
or correctness_score from AgentGraphExecution stats. The current
|
||||||
|
LibraryAgentStatus.ERROR is too aggressive (1 failed run = ERROR).
|
||||||
|
Better approach: filter by success rate (e.g., >50% successful runs)
|
||||||
|
or require at least 1 successful execution.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await library_db.list_library_agents(
|
||||||
|
user_id=user_id,
|
||||||
|
search_term=search_query,
|
||||||
|
page=1,
|
||||||
|
page_size=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[LibraryAgentSummary] = []
|
||||||
|
for agent in response.agents:
|
||||||
|
if exclude_graph_id is not None and agent.graph_id == exclude_graph_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
LibraryAgentSummary(
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
graph_version=agent.graph_version,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description,
|
||||||
|
input_schema=agent.input_schema,
|
||||||
|
output_schema=agent.output_schema,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch library agents: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def search_marketplace_agents_for_generation(
|
||||||
|
search_query: str,
|
||||||
|
max_results: int = 10,
|
||||||
|
) -> list[MarketplaceAgentSummary]:
|
||||||
|
"""Search marketplace agents formatted for Agent Generator.
|
||||||
|
|
||||||
|
Note: This returns basic agent info. Full input/output schemas would require
|
||||||
|
additional graph fetches and is a potential future enhancement.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_query: Search term to find relevant public agents
|
||||||
|
max_results: Maximum number of agents to return (default 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of MarketplaceAgentSummary (without detailed schemas for now)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await store_db.get_store_agents(
|
||||||
|
search_query=search_query,
|
||||||
|
page=1,
|
||||||
|
page_size=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[MarketplaceAgentSummary] = []
|
||||||
|
for agent in response.agents:
|
||||||
|
results.append(
|
||||||
|
MarketplaceAgentSummary(
|
||||||
|
name=agent.agent_name,
|
||||||
|
description=agent.description,
|
||||||
|
sub_heading=agent.sub_heading,
|
||||||
|
creator=agent.creator,
|
||||||
|
is_marketplace_agent=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to search marketplace agents: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def get_all_relevant_agents_for_generation(
|
||||||
|
user_id: str,
|
||||||
|
search_query: str | None = None,
|
||||||
|
exclude_graph_id: str | None = None,
|
||||||
|
include_library: bool = True,
|
||||||
|
include_marketplace: bool = True,
|
||||||
|
max_library_results: int = 15,
|
||||||
|
max_marketplace_results: int = 10,
|
||||||
|
) -> list[AgentSummary]:
|
||||||
|
"""Fetch relevant agents from library and/or marketplace.
|
||||||
|
|
||||||
|
Searches both user's library and marketplace by default.
|
||||||
|
Explicitly mentioned UUIDs in the search query are always looked up.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
search_query: Search term to find relevant agents (user's goal/description)
|
||||||
|
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
|
||||||
|
include_library: Whether to search user's library (default True)
|
||||||
|
include_marketplace: Whether to also search marketplace (default True)
|
||||||
|
max_library_results: Max library agents to return (default 15)
|
||||||
|
max_marketplace_results: Max marketplace agents to return (default 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of AgentSummary, library agents first (with full schemas),
|
||||||
|
then marketplace agents (basic info only)
|
||||||
|
"""
|
||||||
|
agents: list[AgentSummary] = []
|
||||||
|
seen_graph_ids: set[str] = set()
|
||||||
|
|
||||||
|
if search_query:
|
||||||
|
mentioned_uuids = extract_uuids_from_text(search_query)
|
||||||
|
for graph_id in mentioned_uuids:
|
||||||
|
if graph_id == exclude_graph_id:
|
||||||
|
continue
|
||||||
|
agent = await get_library_agent_by_graph_id(user_id, graph_id)
|
||||||
|
if agent and agent["graph_id"] not in seen_graph_ids:
|
||||||
|
agents.append(agent)
|
||||||
|
seen_graph_ids.add(agent["graph_id"])
|
||||||
|
logger.debug(f"Found explicitly mentioned agent: {agent['name']}")
|
||||||
|
|
||||||
|
if include_library:
|
||||||
|
library_agents = await get_library_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=search_query,
|
||||||
|
exclude_graph_id=exclude_graph_id,
|
||||||
|
max_results=max_library_results,
|
||||||
|
)
|
||||||
|
for agent in library_agents:
|
||||||
|
if agent["graph_id"] not in seen_graph_ids:
|
||||||
|
agents.append(agent)
|
||||||
|
seen_graph_ids.add(agent["graph_id"])
|
||||||
|
|
||||||
|
if include_marketplace and search_query:
|
||||||
|
marketplace_agents = await search_marketplace_agents_for_generation(
|
||||||
|
search_query=search_query,
|
||||||
|
max_results=max_marketplace_results,
|
||||||
|
)
|
||||||
|
library_names = {a["name"].lower() for a in agents if a.get("name")}
|
||||||
|
for agent in marketplace_agents:
|
||||||
|
agent_name = agent.get("name")
|
||||||
|
if agent_name and agent_name.lower() not in library_names:
|
||||||
|
agents.append(agent)
|
||||||
|
|
||||||
|
return agents
|
||||||
|
|
||||||
|
|
||||||
|
def extract_search_terms_from_steps(
|
||||||
|
decomposition_result: DecompositionResult | dict[str, Any],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Extract search terms from decomposed instruction steps.
|
||||||
|
|
||||||
|
Analyzes the decomposition result to extract relevant keywords
|
||||||
|
for additional library agent searches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decomposition_result: Result from decompose_goal containing steps
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unique search terms extracted from steps
|
||||||
|
"""
|
||||||
|
search_terms: list[str] = []
|
||||||
|
|
||||||
|
if decomposition_result.get("type") != "instructions":
|
||||||
|
return search_terms
|
||||||
|
|
||||||
|
steps = decomposition_result.get("steps", [])
|
||||||
|
if not steps:
|
||||||
|
return search_terms
|
||||||
|
|
||||||
|
step_keys: list[str] = ["description", "action", "block_name", "tool", "name"]
|
||||||
|
|
||||||
|
for step in steps:
|
||||||
|
for key in step_keys:
|
||||||
|
value = step.get(key) # type: ignore[union-attr]
|
||||||
|
if isinstance(value, str) and len(value) > 3:
|
||||||
|
search_terms.append(value)
|
||||||
|
|
||||||
|
seen: set[str] = set()
|
||||||
|
unique_terms: list[str] = []
|
||||||
|
for term in search_terms:
|
||||||
|
term_lower = term.lower()
|
||||||
|
if term_lower not in seen:
|
||||||
|
seen.add(term_lower)
|
||||||
|
unique_terms.append(term)
|
||||||
|
|
||||||
|
return unique_terms
|
||||||
|
|
||||||
|
|
||||||
|
async def enrich_library_agents_from_steps(
|
||||||
|
user_id: str,
|
||||||
|
decomposition_result: DecompositionResult | dict[str, Any],
|
||||||
|
existing_agents: list[AgentSummary] | list[dict[str, Any]],
|
||||||
|
exclude_graph_id: str | None = None,
|
||||||
|
include_marketplace: bool = True,
|
||||||
|
max_additional_results: int = 10,
|
||||||
|
) -> list[AgentSummary] | list[dict[str, Any]]:
|
||||||
|
"""Enrich library agents list with additional searches based on decomposed steps.
|
||||||
|
|
||||||
|
This implements two-phase search: after decomposition, we search for additional
|
||||||
|
relevant agents based on the specific steps identified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
decomposition_result: Result from decompose_goal containing steps
|
||||||
|
existing_agents: Already fetched library agents from initial search
|
||||||
|
exclude_graph_id: Optional graph ID to exclude
|
||||||
|
include_marketplace: Whether to also search marketplace
|
||||||
|
max_additional_results: Max additional agents per search term (default 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined list of library agents (existing + newly discovered)
|
||||||
|
"""
|
||||||
|
search_terms = extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
if not search_terms:
|
||||||
|
return existing_agents
|
||||||
|
|
||||||
|
existing_ids: set[str] = set()
|
||||||
|
existing_names: set[str] = set()
|
||||||
|
|
||||||
|
for agent in existing_agents:
|
||||||
|
agent_name = agent.get("name", "")
|
||||||
|
if agent_name:
|
||||||
|
existing_names.add(agent_name.lower())
|
||||||
|
graph_id = agent.get("graph_id") # type: ignore[call-overload]
|
||||||
|
if graph_id:
|
||||||
|
existing_ids.add(graph_id)
|
||||||
|
|
||||||
|
all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents)
|
||||||
|
|
||||||
|
for term in search_terms[:3]:
|
||||||
|
try:
|
||||||
|
additional_agents = await get_all_relevant_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=term,
|
||||||
|
exclude_graph_id=exclude_graph_id,
|
||||||
|
include_marketplace=include_marketplace,
|
||||||
|
max_library_results=max_additional_results,
|
||||||
|
max_marketplace_results=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
for agent in additional_agents:
|
||||||
|
agent_name = agent.get("name", "")
|
||||||
|
if not agent_name:
|
||||||
|
continue
|
||||||
|
agent_name_lower = agent_name.lower()
|
||||||
|
|
||||||
|
if agent_name_lower in existing_names:
|
||||||
|
continue
|
||||||
|
|
||||||
|
graph_id = agent.get("graph_id") # type: ignore[call-overload]
|
||||||
|
if graph_id and graph_id in existing_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
all_agents.append(agent)
|
||||||
|
existing_names.add(agent_name_lower)
|
||||||
|
if graph_id:
|
||||||
|
existing_ids.add(graph_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to search for additional agents with term '{term}': {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Enriched library agents: {len(existing_agents)} initial + "
|
||||||
|
f"{len(all_agents) - len(existing_agents)} additional = {len(all_agents)} total"
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_agents
|
||||||
|
|
||||||
|
|
||||||
|
async def decompose_goal(
|
||||||
|
description: str,
|
||||||
|
context: str = "",
|
||||||
|
library_agents: list[AgentSummary] | None = None,
|
||||||
|
) -> DecompositionResult | None:
|
||||||
"""Break down a goal into steps or return clarifying questions.
|
"""Break down a goal into steps or return clarifying questions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
description: Natural language goal description
|
description: Natural language goal description
|
||||||
context: Additional context (e.g., answers to previous questions)
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with either:
|
DecompositionResult with either:
|
||||||
- {"type": "clarifying_questions", "questions": [...]}
|
- {"type": "clarifying_questions", "questions": [...]}
|
||||||
- {"type": "instructions", "steps": [...]}
|
- {"type": "instructions", "steps": [...]}
|
||||||
Or None on error
|
Or None on error
|
||||||
@@ -54,26 +497,41 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
|
|||||||
"""
|
"""
|
||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for decompose_goal")
|
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||||
return await decompose_goal_external(description, context)
|
# Convert typed dicts to plain dicts for external service
|
||||||
|
result = await decompose_goal_external(
|
||||||
|
description, context, _to_dict_list(library_agents)
|
||||||
|
)
|
||||||
|
# Cast the result to DecompositionResult (external service returns dict)
|
||||||
|
return result # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
async def generate_agent(
|
||||||
|
instructions: DecompositionResult | dict[str, Any],
|
||||||
|
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
"""Generate agent JSON from instructions.
|
"""Generate agent JSON from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict or None on error
|
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent")
|
logger.info("Calling external Agent Generator service for generate_agent")
|
||||||
result = await generate_agent_external(instructions)
|
# Convert typed dicts to plain dicts for external service
|
||||||
|
result = await generate_agent_external(
|
||||||
|
dict(instructions), _to_dict_list(library_agents)
|
||||||
|
)
|
||||||
if result:
|
if result:
|
||||||
# Ensure required fields
|
# Check if it's an error response - pass through as-is
|
||||||
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
|
return result
|
||||||
|
# Ensure required fields for successful agent generation
|
||||||
if "id" not in result:
|
if "id" not in result:
|
||||||
result["id"] = str(uuid.uuid4())
|
result["id"] = str(uuid.uuid4())
|
||||||
if "version" not in result:
|
if "version" not in result:
|
||||||
@@ -159,8 +617,6 @@ async def save_agent_to_library(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (created Graph, LibraryAgent)
|
Tuple of (created Graph, LibraryAgent)
|
||||||
"""
|
"""
|
||||||
from backend.data.graph import get_graph_all_versions
|
|
||||||
|
|
||||||
graph = json_to_graph(agent_json)
|
graph = json_to_graph(agent_json)
|
||||||
|
|
||||||
if is_update:
|
if is_update:
|
||||||
@@ -197,25 +653,31 @@ async def save_agent_to_library(
|
|||||||
|
|
||||||
|
|
||||||
async def get_agent_as_json(
|
async def get_agent_as_json(
|
||||||
graph_id: str, user_id: str | None
|
agent_id: str, user_id: str | None
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Fetch an agent and convert to JSON format for editing.
|
"""Fetch an agent and convert to JSON format for editing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: Graph ID or library agent ID
|
agent_id: Graph ID or library agent ID
|
||||||
user_id: User ID
|
user_id: User ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent as JSON dict or None if not found
|
Agent as JSON dict or None if not found
|
||||||
"""
|
"""
|
||||||
from backend.data.graph import get_graph
|
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||||
|
|
||||||
|
if not graph and user_id:
|
||||||
|
try:
|
||||||
|
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
graph = await get_graph(
|
||||||
|
library_agent.graph_id, version=None, user_id=user_id
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Try to get the graph (version=None gets the active version)
|
|
||||||
graph = await get_graph(graph_id, version=None, user_id=user_id)
|
|
||||||
if not graph:
|
if not graph:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Convert to JSON format
|
|
||||||
nodes = []
|
nodes = []
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
nodes.append(
|
nodes.append(
|
||||||
@@ -253,7 +715,9 @@ async def get_agent_as_json(
|
|||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch(
|
async def generate_agent_patch(
|
||||||
update_request: str, current_agent: dict[str, Any]
|
update_request: str,
|
||||||
|
current_agent: dict[str, Any],
|
||||||
|
library_agents: list[AgentSummary] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Update an existing agent using natural language.
|
"""Update an existing agent using natural language.
|
||||||
|
|
||||||
@@ -265,13 +729,18 @@ async def generate_agent_patch(
|
|||||||
Args:
|
Args:
|
||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict, or None on error
|
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||||
|
error dict {"type": "error", ...}, or None on unexpected error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||||
return await generate_agent_patch_external(update_request, current_agent)
|
# Convert typed dicts to plain dicts for external service
|
||||||
|
return await generate_agent_patch_external(
|
||||||
|
update_request, current_agent, _to_dict_list(library_agents)
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,104 @@
|
|||||||
|
"""Error handling utilities for agent generator."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_error_details(details: str) -> str:
|
||||||
|
"""Sanitize error details to remove sensitive information.
|
||||||
|
|
||||||
|
Strips common patterns that could expose internal system info:
|
||||||
|
- File paths (Unix and Windows)
|
||||||
|
- Database connection strings
|
||||||
|
- URLs with credentials
|
||||||
|
- Stack trace internals
|
||||||
|
|
||||||
|
Args:
|
||||||
|
details: Raw error details string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized error details safe for user display
|
||||||
|
"""
|
||||||
|
# Remove file paths (Unix-style)
|
||||||
|
sanitized = re.sub(
|
||||||
|
r"/[a-zA-Z0-9_./\-]+\.(py|js|ts|json|yaml|yml)", "[path]", details
|
||||||
|
)
|
||||||
|
# Remove file paths (Windows-style)
|
||||||
|
sanitized = re.sub(r"[A-Z]:\\[a-zA-Z0-9_\\.\\-]+", "[path]", sanitized)
|
||||||
|
# Remove database URLs
|
||||||
|
sanitized = re.sub(
|
||||||
|
r"(postgres|mysql|mongodb|redis)://[^\s]+", "[database_url]", sanitized
|
||||||
|
)
|
||||||
|
# Remove URLs with credentials
|
||||||
|
sanitized = re.sub(r"https?://[^:]+:[^@]+@[^\s]+", "[url]", sanitized)
|
||||||
|
# Remove line numbers from stack traces
|
||||||
|
sanitized = re.sub(r", line \d+", "", sanitized)
|
||||||
|
# Remove "File" references from stack traces
|
||||||
|
sanitized = re.sub(r'File "[^"]+",?', "", sanitized)
|
||||||
|
|
||||||
|
return sanitized.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_message_for_error(
|
||||||
|
error_type: str,
|
||||||
|
operation: str = "process the request",
|
||||||
|
llm_parse_message: str | None = None,
|
||||||
|
validation_message: str | None = None,
|
||||||
|
error_details: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Get a user-friendly error message based on error type.
|
||||||
|
|
||||||
|
This function maps internal error types to user-friendly messages,
|
||||||
|
providing a consistent experience across different agent operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_type: The error type from the external service
|
||||||
|
(e.g., "llm_parse_error", "timeout", "rate_limit")
|
||||||
|
operation: Description of what operation failed, used in the default
|
||||||
|
message (e.g., "analyze the goal", "generate the agent")
|
||||||
|
llm_parse_message: Custom message for llm_parse_error type
|
||||||
|
validation_message: Custom message for validation_error type
|
||||||
|
error_details: Optional additional details about the error
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User-friendly error message suitable for display to the user
|
||||||
|
"""
|
||||||
|
base_message = ""
|
||||||
|
|
||||||
|
if error_type == "llm_parse_error":
|
||||||
|
base_message = (
|
||||||
|
llm_parse_message
|
||||||
|
or "The AI had trouble processing this request. Please try again."
|
||||||
|
)
|
||||||
|
elif error_type == "validation_error":
|
||||||
|
base_message = (
|
||||||
|
validation_message
|
||||||
|
or "The generated agent failed validation. "
|
||||||
|
"This usually happens when the agent structure doesn't match "
|
||||||
|
"what the platform expects. Please try simplifying your goal "
|
||||||
|
"or breaking it into smaller parts."
|
||||||
|
)
|
||||||
|
elif error_type == "patch_error":
|
||||||
|
base_message = (
|
||||||
|
"Failed to apply the changes. The modification couldn't be "
|
||||||
|
"validated. Please try a different approach or simplify the change."
|
||||||
|
)
|
||||||
|
elif error_type in ("timeout", "llm_timeout"):
|
||||||
|
base_message = (
|
||||||
|
"The request took too long to process. This can happen with "
|
||||||
|
"complex agents. Please try again or simplify your goal."
|
||||||
|
)
|
||||||
|
elif error_type in ("rate_limit", "llm_rate_limit"):
|
||||||
|
base_message = "The service is currently busy. Please try again in a moment."
|
||||||
|
else:
|
||||||
|
base_message = f"Failed to {operation}. Please try again."
|
||||||
|
|
||||||
|
# Add error details if provided (sanitized and truncated)
|
||||||
|
if error_details:
|
||||||
|
# Sanitize to remove sensitive information
|
||||||
|
details = _sanitize_error_details(error_details)
|
||||||
|
# Truncate long error details
|
||||||
|
if len(details) > 200:
|
||||||
|
details = details[:200] + "..."
|
||||||
|
base_message += f"\n\nTechnical details: {details}"
|
||||||
|
|
||||||
|
return base_message
|
||||||
@@ -14,6 +14,70 @@ from backend.util.settings import Settings
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_error_response(
|
||||||
|
error_message: str,
|
||||||
|
error_type: str = "unknown",
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create a standardized error response dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_message: Human-readable error message
|
||||||
|
error_type: Machine-readable error type
|
||||||
|
details: Optional additional error details
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Error dict with type="error" and error details
|
||||||
|
"""
|
||||||
|
response: dict[str, Any] = {
|
||||||
|
"type": "error",
|
||||||
|
"error": error_message,
|
||||||
|
"error_type": error_type,
|
||||||
|
}
|
||||||
|
if details:
|
||||||
|
response["details"] = details
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
|
||||||
|
"""Classify an HTTP error into error_type and message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e: The HTTP status error
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (error_type, error_message)
|
||||||
|
"""
|
||||||
|
status = e.response.status_code
|
||||||
|
if status == 429:
|
||||||
|
return "rate_limit", f"Agent Generator rate limited: {e}"
|
||||||
|
elif status == 503:
|
||||||
|
return "service_unavailable", f"Agent Generator unavailable: {e}"
|
||||||
|
elif status == 504 or status == 408:
|
||||||
|
return "timeout", f"Agent Generator timed out: {e}"
|
||||||
|
else:
|
||||||
|
return "http_error", f"HTTP error calling Agent Generator: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
|
||||||
|
"""Classify a request error into error_type and message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e: The request error
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (error_type, error_message)
|
||||||
|
"""
|
||||||
|
error_str = str(e).lower()
|
||||||
|
if "timeout" in error_str or "timed out" in error_str:
|
||||||
|
return "timeout", f"Agent Generator request timed out: {e}"
|
||||||
|
elif "connect" in error_str:
|
||||||
|
return "connection_error", f"Could not connect to Agent Generator: {e}"
|
||||||
|
else:
|
||||||
|
return "request_error", f"Request error calling Agent Generator: {e}"
|
||||||
|
|
||||||
|
|
||||||
_client: httpx.AsyncClient | None = None
|
_client: httpx.AsyncClient | None = None
|
||||||
_settings: Settings | None = None
|
_settings: Settings | None = None
|
||||||
|
|
||||||
@@ -53,13 +117,16 @@ def _get_client() -> httpx.AsyncClient:
|
|||||||
|
|
||||||
|
|
||||||
async def decompose_goal_external(
|
async def decompose_goal_external(
|
||||||
description: str, context: str = ""
|
description: str,
|
||||||
|
context: str = "",
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to decompose a goal.
|
"""Call the external service to decompose a goal.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
description: Natural language goal description
|
description: Natural language goal description
|
||||||
context: Additional context (e.g., answers to previous questions)
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with either:
|
Dict with either:
|
||||||
@@ -67,7 +134,8 @@ async def decompose_goal_external(
|
|||||||
- {"type": "instructions", "steps": [...]}
|
- {"type": "instructions", "steps": [...]}
|
||||||
- {"type": "unachievable_goal", ...}
|
- {"type": "unachievable_goal", ...}
|
||||||
- {"type": "vague_goal", ...}
|
- {"type": "vague_goal", ...}
|
||||||
Or None on error
|
- {"type": "error", "error": "...", "error_type": "..."} on error
|
||||||
|
Or None on unexpected error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
@@ -76,6 +144,8 @@ async def decompose_goal_external(
|
|||||||
if context:
|
if context:
|
||||||
# The external service uses user_instruction for additional context
|
# The external service uses user_instruction for additional context
|
||||||
payload["user_instruction"] = context
|
payload["user_instruction"] = context
|
||||||
|
if library_agents:
|
||||||
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/decompose-description", json=payload)
|
response = await client.post("/api/decompose-description", json=payload)
|
||||||
@@ -83,8 +153,13 @@ async def decompose_goal_external(
|
|||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
if not data.get("success"):
|
if not data.get("success"):
|
||||||
logger.error(f"External service returned error: {data.get('error')}")
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
return None
|
error_type = data.get("error_type", "unknown")
|
||||||
|
logger.error(
|
||||||
|
f"Agent Generator decomposition failed: {error_msg} "
|
||||||
|
f"(type: {error_type})"
|
||||||
|
)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
# Map the response to the expected format
|
# Map the response to the expected format
|
||||||
response_type = data.get("type")
|
response_type = data.get("type")
|
||||||
@@ -106,88 +181,120 @@ async def decompose_goal_external(
|
|||||||
"type": "vague_goal",
|
"type": "vague_goal",
|
||||||
"suggested_goal": data.get("suggested_goal"),
|
"suggested_goal": data.get("suggested_goal"),
|
||||||
}
|
}
|
||||||
|
elif response_type == "error":
|
||||||
|
# Pass through error from the service
|
||||||
|
return _create_error_response(
|
||||||
|
data.get("error", "Unknown error"),
|
||||||
|
data.get("error_type", "unknown"),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Unknown response type from external service: {response_type}"
|
f"Unknown response type from external service: {response_type}"
|
||||||
)
|
)
|
||||||
return None
|
return _create_error_response(
|
||||||
|
f"Unknown response type from Agent Generator: {response_type}",
|
||||||
|
"invalid_response",
|
||||||
|
)
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
error_type, error_msg = _classify_http_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
logger.error(f"Request error calling external agent generator: {e}")
|
error_type, error_msg = _classify_request_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_external(
|
async def generate_agent_external(
|
||||||
instructions: dict[str, Any]
|
instructions: dict[str, Any],
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate an agent from instructions.
|
"""Call the external service to generate an agent from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict or None on error
|
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {"instructions": instructions}
|
||||||
|
if library_agents:
|
||||||
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post(
|
response = await client.post("/api/generate-agent", json=payload)
|
||||||
"/api/generate-agent", json={"instructions": instructions}
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
if not data.get("success"):
|
if not data.get("success"):
|
||||||
logger.error(f"External service returned error: {data.get('error')}")
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
return None
|
error_type = data.get("error_type", "unknown")
|
||||||
|
logger.error(
|
||||||
|
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
|
||||||
|
)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
return data.get("agent_json")
|
return data.get("agent_json")
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
error_type, error_msg = _classify_http_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
logger.error(f"Request error calling external agent generator: {e}")
|
error_type, error_msg = _classify_request_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch_external(
|
async def generate_agent_patch_external(
|
||||||
update_request: str, current_agent: dict[str, Any]
|
update_request: str,
|
||||||
|
current_agent: dict[str, Any],
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate a patch for an existing agent.
|
"""Call the external service to generate a patch for an existing agent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict, or None on error
|
Updated agent JSON, clarifying questions dict, or error dict on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
try:
|
payload: dict[str, Any] = {
|
||||||
response = await client.post(
|
|
||||||
"/api/update-agent",
|
|
||||||
json={
|
|
||||||
"update_request": update_request,
|
"update_request": update_request,
|
||||||
"current_agent_json": current_agent,
|
"current_agent_json": current_agent,
|
||||||
},
|
}
|
||||||
)
|
if library_agents:
|
||||||
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post("/api/update-agent", json=payload)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
if not data.get("success"):
|
if not data.get("success"):
|
||||||
logger.error(f"External service returned error: {data.get('error')}")
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
return None
|
error_type = data.get("error_type", "unknown")
|
||||||
|
logger.error(
|
||||||
|
f"Agent Generator patch generation failed: {error_msg} "
|
||||||
|
f"(type: {error_type})"
|
||||||
|
)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
# Check if it's clarifying questions
|
# Check if it's clarifying questions
|
||||||
if data.get("type") == "clarifying_questions":
|
if data.get("type") == "clarifying_questions":
|
||||||
@@ -196,18 +303,28 @@ async def generate_agent_patch_external(
|
|||||||
"questions": data.get("questions", []),
|
"questions": data.get("questions", []),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Check if it's an error passed through
|
||||||
|
if data.get("type") == "error":
|
||||||
|
return _create_error_response(
|
||||||
|
data.get("error", "Unknown error"),
|
||||||
|
data.get("error_type", "unknown"),
|
||||||
|
)
|
||||||
|
|
||||||
# Otherwise return the updated agent JSON
|
# Otherwise return the updated agent JSON
|
||||||
return data.get("agent_json")
|
return data.get("agent_json")
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
error_type, error_msg = _classify_http_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
logger.error(f"Request error calling external agent generator: {e}")
|
error_type, error_msg = _classify_request_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
@@ -19,6 +20,86 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
SearchSource = Literal["marketplace", "library"]
|
SearchSource = Literal["marketplace", "library"]
|
||||||
|
|
||||||
|
# UUID v4 pattern for direct agent ID lookup
|
||||||
|
_UUID_PATTERN = re.compile(
|
||||||
|
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(text: str) -> bool:
|
||||||
|
"""Check if text is a valid UUID v4."""
|
||||||
|
return bool(_UUID_PATTERN.match(text.strip()))
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
|
||||||
|
"""Fetch a library agent by ID (library agent ID or graph_id).
|
||||||
|
|
||||||
|
Tries multiple lookup strategies:
|
||||||
|
1. First by graph_id (AgentGraph primary key)
|
||||||
|
2. Then by library agent ID (LibraryAgent primary key)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentInfo if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||||
|
return AgentInfo(
|
||||||
|
id=agent.id,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description or "",
|
||||||
|
source="library",
|
||||||
|
in_library=True,
|
||||||
|
creator=agent.creator_name,
|
||||||
|
status=agent.status.value,
|
||||||
|
can_access_graph=agent.can_access_graph,
|
||||||
|
has_external_trigger=agent.has_external_trigger,
|
||||||
|
new_output=agent.new_output,
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
)
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not fetch library agent by graph_id {agent_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||||
|
return AgentInfo(
|
||||||
|
id=agent.id,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description or "",
|
||||||
|
source="library",
|
||||||
|
in_library=True,
|
||||||
|
creator=agent.creator_name,
|
||||||
|
status=agent.status.value,
|
||||||
|
can_access_graph=agent.can_access_graph,
|
||||||
|
has_external_trigger=agent.has_external_trigger,
|
||||||
|
new_output=agent.new_output,
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def search_agents(
|
async def search_agents(
|
||||||
query: str,
|
query: str,
|
||||||
@@ -70,6 +151,16 @@ async def search_agents(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
else: # library
|
else: # library
|
||||||
|
# If query looks like a UUID, try direct lookup first
|
||||||
|
if _is_uuid(query):
|
||||||
|
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
|
||||||
|
agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type]
|
||||||
|
if agent:
|
||||||
|
agents.append(agent)
|
||||||
|
logger.info(f"Found agent by direct ID lookup: {agent.name}")
|
||||||
|
|
||||||
|
# If no results from UUID lookup, do text search
|
||||||
|
if not agents:
|
||||||
logger.info(f"Searching user library for: {query}")
|
logger.info(f"Searching user library for: {query}")
|
||||||
results = await library_db.list_library_agents(
|
results = await library_db.list_library_agents(
|
||||||
user_id=user_id, # type: ignore[arg-type]
|
user_id=user_id, # type: ignore[arg-type]
|
||||||
|
|||||||
@@ -8,7 +8,10 @@ from backend.api.features.chat.model import ChatSession
|
|||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
|
enrich_library_agents_from_steps,
|
||||||
generate_agent,
|
generate_agent,
|
||||||
|
get_all_relevant_agents_for_generation,
|
||||||
|
get_user_message_for_error,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
)
|
)
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
@@ -102,9 +105,27 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Fetch relevant library and marketplace agents for sub-agent composition
|
||||||
|
library_agents = None
|
||||||
|
if user_id:
|
||||||
|
try:
|
||||||
|
library_agents = await get_all_relevant_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=description, # Use goal as search term
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Log but don't fail - agent generation can work without sub-agents
|
||||||
|
logger.warning(f"Failed to fetch library agents: {e}")
|
||||||
|
|
||||||
# Step 1: Decompose goal into steps
|
# Step 1: Decompose goal into steps
|
||||||
try:
|
try:
|
||||||
decomposition_result = await decompose_goal(description, context)
|
decomposition_result = await decompose_goal(
|
||||||
|
description, context, library_agents
|
||||||
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -117,11 +138,29 @@ class CreateAgentTool(BaseTool):
|
|||||||
|
|
||||||
if decomposition_result is None:
|
if decomposition_result is None:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to analyze the goal. The agent generation service may be unavailable or timed out. Please try again.",
|
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
|
||||||
error="decomposition_failed",
|
error="decomposition_failed",
|
||||||
|
details={"description": description[:100]},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the result is an error from the external service
|
||||||
|
if decomposition_result.get("type") == "error":
|
||||||
|
error_msg = decomposition_result.get("error", "Unknown error")
|
||||||
|
error_type = decomposition_result.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="analyze the goal",
|
||||||
|
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"decomposition_failed:{error_type}",
|
||||||
details={
|
details={
|
||||||
"description": description[:100]
|
"description": description[:100],
|
||||||
}, # Include context for debugging
|
"service_error": error_msg,
|
||||||
|
"error_type": error_type,
|
||||||
|
},
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -171,9 +210,26 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Step 1.5: Enrich library agents with step-based search (two-phase search)
|
||||||
|
# After decomposition, search for additional relevant agents based on the steps
|
||||||
|
if user_id and library_agents is not None:
|
||||||
|
try:
|
||||||
|
library_agents = await enrich_library_agents_from_steps(
|
||||||
|
user_id=user_id,
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=library_agents,
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Log but don't fail - continue with existing agents
|
||||||
|
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||||
|
|
||||||
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
||||||
try:
|
try:
|
||||||
agent_json = await generate_agent(decomposition_result)
|
agent_json = await generate_agent(decomposition_result, library_agents)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -186,11 +242,35 @@ class CreateAgentTool(BaseTool):
|
|||||||
|
|
||||||
if agent_json is None:
|
if agent_json is None:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to generate the agent. The agent generation service may be unavailable or timed out. Please try again.",
|
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
|
||||||
error="generation_failed",
|
error="generation_failed",
|
||||||
|
details={"description": description[:100]},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the result is an error from the external service
|
||||||
|
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
|
||||||
|
error_msg = agent_json.get("error", "Unknown error")
|
||||||
|
error_type = agent_json.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="generate the agent",
|
||||||
|
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
|
||||||
|
validation_message=(
|
||||||
|
"I wasn't able to create a valid agent for this request. "
|
||||||
|
"The generated workflow had some structural issues. "
|
||||||
|
"Please try simplifying your goal or breaking it into smaller steps."
|
||||||
|
),
|
||||||
|
error_details=error_msg if error_type == "validation_error" else None,
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"generation_failed:{error_type}",
|
||||||
details={
|
details={
|
||||||
"description": description[:100]
|
"description": description[:100],
|
||||||
}, # Include context for debugging
|
"service_error": error_msg,
|
||||||
|
"error_type": error_type,
|
||||||
|
},
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -232,7 +312,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
agent_id=created_graph.id,
|
agent_id=created_graph.id,
|
||||||
agent_name=created_graph.name,
|
agent_name=created_graph.name,
|
||||||
library_agent_id=library_agent.id,
|
library_agent_id=library_agent.id,
|
||||||
library_agent_link=f"/library/{library_agent.id}",
|
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ from .agent_generator import (
|
|||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
generate_agent_patch,
|
generate_agent_patch,
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
|
get_all_relevant_agents_for_generation,
|
||||||
|
get_user_message_for_error,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
)
|
)
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
@@ -126,6 +128,22 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
library_agents = None
|
||||||
|
if user_id:
|
||||||
|
try:
|
||||||
|
exclude_id = current_agent.get("id") or agent_id
|
||||||
|
library_agents = await get_all_relevant_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=changes,
|
||||||
|
exclude_graph_id=exclude_id,
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch library agents: {e}")
|
||||||
|
|
||||||
# Build the update request with context
|
# Build the update request with context
|
||||||
update_request = changes
|
update_request = changes
|
||||||
if context:
|
if context:
|
||||||
@@ -133,7 +151,9 @@ class EditAgentTool(BaseTool):
|
|||||||
|
|
||||||
# Step 2: Generate updated agent (external service handles fixing and validation)
|
# Step 2: Generate updated agent (external service handles fixing and validation)
|
||||||
try:
|
try:
|
||||||
result = await generate_agent_patch(update_request, current_agent)
|
result = await generate_agent_patch(
|
||||||
|
update_request, current_agent, library_agents
|
||||||
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -152,6 +172,28 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if the result is an error from the external service
|
||||||
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
|
error_msg = result.get("error", "Unknown error")
|
||||||
|
error_type = result.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="generate the changes",
|
||||||
|
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
|
||||||
|
validation_message="The generated changes failed validation. Please try rephrasing your request.",
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"update_generation_failed:{error_type}",
|
||||||
|
details={
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"changes": changes[:100],
|
||||||
|
"service_error": error_msg,
|
||||||
|
"error_type": error_type,
|
||||||
|
},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Check if LLM returned clarifying questions
|
# Check if LLM returned clarifying questions
|
||||||
if result.get("type") == "clarifying_questions":
|
if result.get("type") == "clarifying_questions":
|
||||||
questions = result.get("questions", [])
|
questions = result.get("questions", [])
|
||||||
@@ -213,7 +255,7 @@ class EditAgentTool(BaseTool):
|
|||||||
agent_id=created_graph.id,
|
agent_id=created_graph.id,
|
||||||
agent_name=created_graph.name,
|
agent_name=created_graph.name,
|
||||||
library_agent_id=library_agent.id,
|
library_agent_id=library_agent.id,
|
||||||
library_agent_link=f"/library/{library_agent.id}",
|
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -28,6 +28,12 @@ class ResponseType(str, Enum):
|
|||||||
BLOCK_OUTPUT = "block_output"
|
BLOCK_OUTPUT = "block_output"
|
||||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||||
DOC_PAGE = "doc_page"
|
DOC_PAGE = "doc_page"
|
||||||
|
# Workspace response types
|
||||||
|
WORKSPACE_FILE_LIST = "workspace_file_list"
|
||||||
|
WORKSPACE_FILE_CONTENT = "workspace_file_content"
|
||||||
|
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
|
||||||
|
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
|
||||||
|
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
|
||||||
# Long-running operation types
|
# Long-running operation types
|
||||||
OPERATION_STARTED = "operation_started"
|
OPERATION_STARTED = "operation_started"
|
||||||
OPERATION_PENDING = "operation_pending"
|
OPERATION_PENDING = "operation_pending"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Tool for executing blocks directly."""
|
"""Tool for executing blocks directly."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -8,6 +9,7 @@ from backend.api.features.chat.model import ChatSession
|
|||||||
from backend.data.block import get_block
|
from backend.data.block import get_block
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
@@ -223,11 +225,48 @@ class RunBlockTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Fetch actual credentials and prepare kwargs for block execution
|
# Get or create user's workspace for CoPilot file operations
|
||||||
# Create execution context with defaults (blocks may require it)
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
|
||||||
|
# Generate synthetic IDs for CoPilot context
|
||||||
|
# Each chat session is treated as its own agent with one continuous run
|
||||||
|
# This means:
|
||||||
|
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
|
||||||
|
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
|
||||||
|
# - node_exec_id = unique per block execution
|
||||||
|
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
||||||
|
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
||||||
|
synthetic_node_id = f"copilot-node-{block_id}"
|
||||||
|
synthetic_node_exec_id = (
|
||||||
|
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create unified execution context with all required fields
|
||||||
|
execution_context = ExecutionContext(
|
||||||
|
# Execution identity
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=synthetic_graph_id,
|
||||||
|
graph_exec_id=synthetic_graph_exec_id,
|
||||||
|
graph_version=1, # Versions are 1-indexed
|
||||||
|
node_id=synthetic_node_id,
|
||||||
|
node_exec_id=synthetic_node_exec_id,
|
||||||
|
# Workspace with session scoping
|
||||||
|
workspace_id=workspace.id,
|
||||||
|
session_id=session.session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare kwargs for block execution
|
||||||
|
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||||
exec_kwargs: dict[str, Any] = {
|
exec_kwargs: dict[str, Any] = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"execution_context": ExecutionContext(),
|
"execution_context": execution_context,
|
||||||
|
# Legacy: individual kwargs for blocks not yet using execution_context
|
||||||
|
"workspace_id": workspace.id,
|
||||||
|
"graph_exec_id": synthetic_graph_exec_id,
|
||||||
|
"node_exec_id": synthetic_node_exec_id,
|
||||||
|
"node_id": synthetic_node_id,
|
||||||
|
"graph_version": 1, # Versions are 1-indexed
|
||||||
|
"graph_id": synthetic_graph_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
for field_name, cred_meta in matched_credentials.items():
|
for field_name, cred_meta in matched_credentials.items():
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from backend.api.features.library import model as library_model
|
|||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
@@ -266,13 +266,14 @@ async def match_user_credentials_to_graph(
|
|||||||
credential_requirements,
|
credential_requirements,
|
||||||
_node_fields,
|
_node_fields,
|
||||||
) in aggregated_creds.items():
|
) in aggregated_creds.items():
|
||||||
# Find first matching credential by provider and type
|
# Find first matching credential by provider, type, and scopes
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
(
|
(
|
||||||
cred
|
cred
|
||||||
for cred in available_creds
|
for cred in available_creds
|
||||||
if cred.provider in credential_requirements.provider
|
if cred.provider in credential_requirements.provider
|
||||||
and cred.type in credential_requirements.supported_types
|
and cred.type in credential_requirements.supported_types
|
||||||
|
and _credential_has_required_scopes(cred, credential_requirements)
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -296,10 +297,17 @@ async def match_user_credentials_to_graph(
|
|||||||
f"{credential_field_name} (validation failed: {e})"
|
f"{credential_field_name} (validation failed: {e})"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Build a helpful error message including scope requirements
|
||||||
|
error_parts = [
|
||||||
|
f"provider in {list(credential_requirements.provider)}",
|
||||||
|
f"type in {list(credential_requirements.supported_types)}",
|
||||||
|
]
|
||||||
|
if credential_requirements.required_scopes:
|
||||||
|
error_parts.append(
|
||||||
|
f"scopes including {list(credential_requirements.required_scopes)}"
|
||||||
|
)
|
||||||
missing_creds.append(
|
missing_creds.append(
|
||||||
f"{credential_field_name} "
|
f"{credential_field_name} (requires {', '.join(error_parts)})"
|
||||||
f"(requires provider in {list(credential_requirements.provider)}, "
|
|
||||||
f"type in {list(credential_requirements.supported_types)})"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -309,6 +317,28 @@ async def match_user_credentials_to_graph(
|
|||||||
return graph_credentials_inputs, missing_creds
|
return graph_credentials_inputs, missing_creds
|
||||||
|
|
||||||
|
|
||||||
|
def _credential_has_required_scopes(
|
||||||
|
credential: Credentials,
|
||||||
|
requirements: CredentialsFieldInfo,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a credential has all the scopes required by the block.
|
||||||
|
|
||||||
|
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||||
|
of the required scopes. For other credential types, returns True (no scope check).
|
||||||
|
"""
|
||||||
|
# Only OAuth2 credentials have scopes to check
|
||||||
|
if credential.type != "oauth2":
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If no scopes are required, any credential matches
|
||||||
|
if not requirements.required_scopes:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check that credential scopes are a superset of required scopes
|
||||||
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
async def check_user_has_required_credentials(
|
async def check_user_has_required_credentials(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
required_credentials: list[CredentialsMetaInput],
|
required_credentials: list[CredentialsMetaInput],
|
||||||
|
|||||||
@@ -0,0 +1,620 @@
|
|||||||
|
"""CoPilot tools for workspace file operations."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.data.workspace import get_or_create_workspace
|
||||||
|
from backend.util.settings import Config
|
||||||
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
from backend.util.workspace import WorkspaceManager
|
||||||
|
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileInfoData(BaseModel):
|
||||||
|
"""Data model for workspace file information (not a response itself)."""
|
||||||
|
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
size_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileListResponse(ToolResponseBase):
|
||||||
|
"""Response containing list of workspace files."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
|
||||||
|
files: list[WorkspaceFileInfoData]
|
||||||
|
total_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileContentResponse(ToolResponseBase):
|
||||||
|
"""Response containing workspace file content (legacy, for small text files)."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
content_base64: str
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
||||||
|
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
size_bytes: int
|
||||||
|
download_url: str
|
||||||
|
preview: str | None = None # First 500 chars for text files
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceWriteResponse(ToolResponseBase):
|
||||||
|
"""Response after writing a file to workspace."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
size_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceDeleteResponse(ToolResponseBase):
|
||||||
|
"""Response after deleting a file from workspace."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
|
||||||
|
file_id: str
|
||||||
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
|
class ListWorkspaceFilesTool(BaseTool):
|
||||||
|
"""Tool for listing files in user's workspace."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "list_workspace_files"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"List files in the user's workspace. "
|
||||||
|
"Returns file names, paths, sizes, and metadata. "
|
||||||
|
"Optionally filter by path prefix."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path_prefix": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional path prefix to filter files "
|
||||||
|
"(e.g., '/documents/' to list only files in documents folder). "
|
||||||
|
"By default, only files from the current session are listed."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of files to return (default 50, max 100)",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 100,
|
||||||
|
},
|
||||||
|
"include_all_sessions": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"If true, list files from all sessions. "
|
||||||
|
"Default is false (only current session's files)."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||||
|
limit = min(kwargs.get("limit", 50), 100)
|
||||||
|
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
files = await manager.list_files(
|
||||||
|
path=path_prefix,
|
||||||
|
limit=limit,
|
||||||
|
include_all_sessions=include_all_sessions,
|
||||||
|
)
|
||||||
|
total = await manager.get_file_count(
|
||||||
|
path=path_prefix,
|
||||||
|
include_all_sessions=include_all_sessions,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_infos = [
|
||||||
|
WorkspaceFileInfoData(
|
||||||
|
file_id=f.id,
|
||||||
|
name=f.name,
|
||||||
|
path=f.path,
|
||||||
|
mime_type=f.mimeType,
|
||||||
|
size_bytes=f.sizeBytes,
|
||||||
|
)
|
||||||
|
for f in files
|
||||||
|
]
|
||||||
|
|
||||||
|
scope_msg = "all sessions" if include_all_sessions else "current session"
|
||||||
|
return WorkspaceFileListResponse(
|
||||||
|
files=file_infos,
|
||||||
|
total_count=total,
|
||||||
|
message=f"Found {len(files)} files in workspace ({scope_msg})",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing workspace files: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to list workspace files: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadWorkspaceFileTool(BaseTool):
|
||||||
|
"""Tool for reading file content from workspace."""
|
||||||
|
|
||||||
|
# Size threshold for returning full content vs metadata+URL
|
||||||
|
# Files larger than this return metadata with download URL to prevent context bloat
|
||||||
|
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||||
|
# Preview size for text files
|
||||||
|
PREVIEW_SIZE = 500
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "read_workspace_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Read a file from the user's workspace. "
|
||||||
|
"Specify either file_id or path to identify the file. "
|
||||||
|
"For small text files, returns content directly. "
|
||||||
|
"For large or binary files, returns metadata and a download URL. "
|
||||||
|
"Paths are scoped to the current session by default. "
|
||||||
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file's unique ID (from list_workspace_files)",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||||
|
"Scoped to current session by default."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"force_download_url": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"If true, always return metadata+URL instead of inline content. "
|
||||||
|
"Default is false (auto-selects based on file size/type)."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [], # At least one must be provided
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _is_text_mime_type(self, mime_type: str) -> bool:
|
||||||
|
"""Check if the MIME type is a text-based type."""
|
||||||
|
text_types = [
|
||||||
|
"text/",
|
||||||
|
"application/json",
|
||||||
|
"application/xml",
|
||||||
|
"application/javascript",
|
||||||
|
"application/x-python",
|
||||||
|
"application/x-sh",
|
||||||
|
]
|
||||||
|
return any(mime_type.startswith(t) for t in text_types)
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id: Optional[str] = kwargs.get("file_id")
|
||||||
|
path: Optional[str] = kwargs.get("path")
|
||||||
|
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||||
|
|
||||||
|
if not file_id and not path:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide either file_id or path",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
# Get file info
|
||||||
|
if file_id:
|
||||||
|
file_info = await manager.get_file_info(file_id)
|
||||||
|
if file_info is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found: {file_id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
target_file_id = file_id
|
||||||
|
else:
|
||||||
|
# path is guaranteed to be non-None here due to the check above
|
||||||
|
assert path is not None
|
||||||
|
file_info = await manager.get_file_info_by_path(path)
|
||||||
|
if file_info is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found at path: {path}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
target_file_id = file_info.id
|
||||||
|
|
||||||
|
# Decide whether to return inline content or metadata+URL
|
||||||
|
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
|
||||||
|
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
||||||
|
|
||||||
|
# Return inline content for small text files (unless force_download_url)
|
||||||
|
if is_small_file and is_text_file and not force_download_url:
|
||||||
|
content = await manager.read_file_by_id(target_file_id)
|
||||||
|
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||||
|
|
||||||
|
return WorkspaceFileContentResponse(
|
||||||
|
file_id=file_info.id,
|
||||||
|
name=file_info.name,
|
||||||
|
path=file_info.path,
|
||||||
|
mime_type=file_info.mimeType,
|
||||||
|
content_base64=content_b64,
|
||||||
|
message=f"Successfully read file: {file_info.name}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return metadata + workspace:// reference for large or binary files
|
||||||
|
# This prevents context bloat (100KB file = ~133KB as base64)
|
||||||
|
# Use workspace:// format so frontend urlTransform can add proxy prefix
|
||||||
|
download_url = f"workspace://{target_file_id}"
|
||||||
|
|
||||||
|
# Generate preview for text files
|
||||||
|
preview: str | None = None
|
||||||
|
if is_text_file:
|
||||||
|
try:
|
||||||
|
content = await manager.read_file_by_id(target_file_id)
|
||||||
|
preview_text = content[: self.PREVIEW_SIZE].decode(
|
||||||
|
"utf-8", errors="replace"
|
||||||
|
)
|
||||||
|
if len(content) > self.PREVIEW_SIZE:
|
||||||
|
preview_text += "..."
|
||||||
|
preview = preview_text
|
||||||
|
except Exception:
|
||||||
|
pass # Preview is optional
|
||||||
|
|
||||||
|
return WorkspaceFileMetadataResponse(
|
||||||
|
file_id=file_info.id,
|
||||||
|
name=file_info.name,
|
||||||
|
path=file_info.path,
|
||||||
|
mime_type=file_info.mimeType,
|
||||||
|
size_bytes=file_info.sizeBytes,
|
||||||
|
download_url=download_url,
|
||||||
|
preview=preview,
|
||||||
|
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading workspace file: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to read workspace file: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteWorkspaceFileTool(BaseTool):
|
||||||
|
"""Tool for writing files to workspace."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "write_workspace_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Write or create a file in the user's workspace. "
|
||||||
|
"Provide the content as a base64-encoded string. "
|
||||||
|
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||||
|
"Files are saved to the current session's folder by default. "
|
||||||
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"filename": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Name for the file (e.g., 'report.pdf')",
|
||||||
|
},
|
||||||
|
"content_base64": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Base64-encoded file content",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional virtual path where to save the file "
|
||||||
|
"(e.g., '/documents/report.pdf'). "
|
||||||
|
"Defaults to '/{filename}'. Scoped to current session."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"mime_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional MIME type of the file. "
|
||||||
|
"Auto-detected from filename if not provided."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"overwrite": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["filename", "content_base64"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
filename: str = kwargs.get("filename", "")
|
||||||
|
content_b64: str = kwargs.get("content_base64", "")
|
||||||
|
path: Optional[str] = kwargs.get("path")
|
||||||
|
mime_type: Optional[str] = kwargs.get("mime_type")
|
||||||
|
overwrite: bool = kwargs.get("overwrite", False)
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide a filename",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not content_b64:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide content_base64",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode content
|
||||||
|
try:
|
||||||
|
content = base64.b64decode(content_b64)
|
||||||
|
except Exception:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Invalid base64-encoded content",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check size
|
||||||
|
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||||
|
if len(content) > max_file_size:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Virus scan
|
||||||
|
await scan_content_safe(content, filename=filename)
|
||||||
|
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
file_record = await manager.write_file(
|
||||||
|
content=content,
|
||||||
|
filename=filename,
|
||||||
|
path=path,
|
||||||
|
mime_type=mime_type,
|
||||||
|
overwrite=overwrite,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkspaceWriteResponse(
|
||||||
|
file_id=file_record.id,
|
||||||
|
name=file_record.name,
|
||||||
|
path=file_record.path,
|
||||||
|
size_bytes=file_record.sizeBytes,
|
||||||
|
message=f"Successfully wrote file: {file_record.name}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error writing workspace file: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to write workspace file: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteWorkspaceFileTool(BaseTool):
|
||||||
|
"""Tool for deleting files from workspace."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "delete_workspace_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Delete a file from the user's workspace. "
|
||||||
|
"Specify either file_id or path to identify the file. "
|
||||||
|
"Paths are scoped to the current session by default. "
|
||||||
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file's unique ID (from list_workspace_files)",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||||
|
"Scoped to current session by default."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [], # At least one must be provided
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id: Optional[str] = kwargs.get("file_id")
|
||||||
|
path: Optional[str] = kwargs.get("path")
|
||||||
|
|
||||||
|
if not file_id and not path:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide either file_id or path",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
# Determine the file_id to delete
|
||||||
|
target_file_id: str
|
||||||
|
if file_id:
|
||||||
|
target_file_id = file_id
|
||||||
|
else:
|
||||||
|
# path is guaranteed to be non-None here due to the check above
|
||||||
|
assert path is not None
|
||||||
|
file_info = await manager.get_file_info_by_path(path)
|
||||||
|
if file_info is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found at path: {path}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
target_file_id = file_info.id
|
||||||
|
|
||||||
|
success = await manager.delete_file(target_file_id)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found: {target_file_id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkspaceDeleteResponse(
|
||||||
|
file_id=target_file_id,
|
||||||
|
success=True,
|
||||||
|
message="File deleted successfully",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to delete workspace file: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
@@ -77,21 +77,32 @@ async def list_library_agents(
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Build search filter if applicable
|
# Build search filter if applicable
|
||||||
|
# Split into words and match ANY word in name or description
|
||||||
if search_term:
|
if search_term:
|
||||||
where_clause["OR"] = [
|
words = [w.strip() for w in search_term.split() if len(w.strip()) >= 3]
|
||||||
|
if words:
|
||||||
|
or_conditions: list[prisma.types.LibraryAgentWhereInput] = []
|
||||||
|
for word in words:
|
||||||
|
or_conditions.append(
|
||||||
{
|
{
|
||||||
"AgentGraph": {
|
"AgentGraph": {
|
||||||
"is": {"name": {"contains": search_term, "mode": "insensitive"}}
|
"is": {"name": {"contains": word, "mode": "insensitive"}}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
|
)
|
||||||
|
or_conditions.append(
|
||||||
{
|
{
|
||||||
"AgentGraph": {
|
"AgentGraph": {
|
||||||
"is": {
|
"is": {
|
||||||
"description": {"contains": search_term, "mode": "insensitive"}
|
"description": {
|
||||||
|
"contains": word,
|
||||||
|
"mode": "insensitive",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
]
|
}
|
||||||
|
)
|
||||||
|
where_clause["OR"] = or_conditions
|
||||||
|
|
||||||
# Determine sorting
|
# Determine sorting
|
||||||
order_by: prisma.types.LibraryAgentOrderByInput | None = None
|
order_by: prisma.types.LibraryAgentOrderByInput | None = None
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
# Workspace API feature module
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
"""
|
||||||
|
Workspace API routes for managing user file storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Annotated
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from backend.data.workspace import get_workspace, get_workspace_file
|
||||||
|
from backend.util.workspace_storage import get_workspace_storage
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_filename_for_header(filename: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize filename for Content-Disposition header to prevent header injection.
|
||||||
|
|
||||||
|
Removes/replaces characters that could break the header or inject new headers.
|
||||||
|
Uses RFC5987 encoding for non-ASCII characters.
|
||||||
|
"""
|
||||||
|
# Remove CR, LF, and null bytes (header injection prevention)
|
||||||
|
sanitized = re.sub(r"[\r\n\x00]", "", filename)
|
||||||
|
# Escape quotes
|
||||||
|
sanitized = sanitized.replace('"', '\\"')
|
||||||
|
# For non-ASCII, use RFC5987 filename* parameter
|
||||||
|
# Check if filename has non-ASCII characters
|
||||||
|
try:
|
||||||
|
sanitized.encode("ascii")
|
||||||
|
return f'attachment; filename="{sanitized}"'
|
||||||
|
except UnicodeEncodeError:
|
||||||
|
# Use RFC5987 encoding for UTF-8 filenames
|
||||||
|
encoded = quote(sanitized, safe="")
|
||||||
|
return f"attachment; filename*=UTF-8''{encoded}"
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = fastapi.APIRouter(
|
||||||
|
dependencies=[fastapi.Security(requires_user)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_streaming_response(content: bytes, file) -> Response:
|
||||||
|
"""Create a streaming response for file content."""
|
||||||
|
return Response(
|
||||||
|
content=content,
|
||||||
|
media_type=file.mimeType,
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||||
|
"Content-Length": str(len(content)),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_file_download_response(file) -> Response:
|
||||||
|
"""
|
||||||
|
Create a download response for a workspace file.
|
||||||
|
|
||||||
|
Handles both local storage (direct streaming) and GCS (signed URL redirect
|
||||||
|
with fallback to streaming).
|
||||||
|
"""
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
|
||||||
|
# For local storage, stream the file directly
|
||||||
|
if file.storagePath.startswith("local://"):
|
||||||
|
content = await storage.retrieve(file.storagePath)
|
||||||
|
return _create_streaming_response(content, file)
|
||||||
|
|
||||||
|
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||||
|
try:
|
||||||
|
url = await storage.get_download_url(file.storagePath, expires_in=300)
|
||||||
|
# If we got back an API path (fallback), stream directly instead
|
||||||
|
if url.startswith("/api/"):
|
||||||
|
content = await storage.retrieve(file.storagePath)
|
||||||
|
return _create_streaming_response(content, file)
|
||||||
|
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||||
|
except Exception as e:
|
||||||
|
# Log the signed URL failure with context
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get signed URL for file {file.id} "
|
||||||
|
f"(storagePath={file.storagePath}): {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# Fall back to streaming directly from GCS
|
||||||
|
try:
|
||||||
|
content = await storage.retrieve(file.storagePath)
|
||||||
|
return _create_streaming_response(content, file)
|
||||||
|
except Exception as fallback_error:
|
||||||
|
logger.error(
|
||||||
|
f"Fallback streaming also failed for file {file.id} "
|
||||||
|
f"(storagePath={file.storagePath}): {fallback_error}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/files/{file_id}/download",
|
||||||
|
summary="Download file by ID",
|
||||||
|
)
|
||||||
|
async def download_file(
|
||||||
|
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||||
|
file_id: str,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Download a file by its ID.
|
||||||
|
|
||||||
|
Returns the file content directly or redirects to a signed URL for GCS.
|
||||||
|
"""
|
||||||
|
workspace = await get_workspace(user_id)
|
||||||
|
if workspace is None:
|
||||||
|
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
|
||||||
|
|
||||||
|
file = await get_workspace_file(file_id, workspace.id)
|
||||||
|
if file is None:
|
||||||
|
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
|
return await _create_file_download_response(file)
|
||||||
@@ -32,6 +32,7 @@ import backend.api.features.postmark.postmark
|
|||||||
import backend.api.features.store.model
|
import backend.api.features.store.model
|
||||||
import backend.api.features.store.routes
|
import backend.api.features.store.routes
|
||||||
import backend.api.features.v1
|
import backend.api.features.v1
|
||||||
|
import backend.api.features.workspace.routes as workspace_routes
|
||||||
import backend.data.block
|
import backend.data.block
|
||||||
import backend.data.db
|
import backend.data.db
|
||||||
import backend.data.graph
|
import backend.data.graph
|
||||||
@@ -52,6 +53,7 @@ from backend.util.exceptions import (
|
|||||||
)
|
)
|
||||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||||
from backend.util.service import UnhealthyServiceError
|
from backend.util.service import UnhealthyServiceError
|
||||||
|
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||||
|
|
||||||
from .external.fastapi_app import external_api
|
from .external.fastapi_app import external_api
|
||||||
from .features.analytics import router as analytics_router
|
from .features.analytics import router as analytics_router
|
||||||
@@ -124,6 +126,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error shutting down cloud storage handler: {e}")
|
logger.warning(f"Error shutting down cloud storage handler: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await shutdown_workspace_storage()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error shutting down workspace storage: {e}")
|
||||||
|
|
||||||
await backend.data.db.disconnect()
|
await backend.data.db.disconnect()
|
||||||
|
|
||||||
|
|
||||||
@@ -315,6 +322,11 @@ app.include_router(
|
|||||||
tags=["v2", "chat"],
|
tags=["v2", "chat"],
|
||||||
prefix="/api/chat",
|
prefix="/api/chat",
|
||||||
)
|
)
|
||||||
|
app.include_router(
|
||||||
|
workspace_routes.router,
|
||||||
|
tags=["workspace"],
|
||||||
|
prefix="/api/workspace",
|
||||||
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.oauth.router,
|
backend.api.features.oauth.router,
|
||||||
tags=["oauth"],
|
tags=["oauth"],
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -117,11 +118,13 @@ class AIImageCustomizerBlock(Block):
|
|||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_output=[
|
test_output=[
|
||||||
("image_url", "https://replicate.delivery/generated-image.jpg"),
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
|
# Use data URI to avoid HTTP requests during tests
|
||||||
"run_model": lambda *args, **kwargs: MediaFileType(
|
"run_model": lambda *args, **kwargs: MediaFileType(
|
||||||
"https://replicate.delivery/generated-image.jpg"
|
"data:image/jpeg;base64,/9j/4AAQSkZJRgABAgAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q=="
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -132,8 +135,7 @@ class AIImageCustomizerBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
@@ -141,10 +143,9 @@ class AIImageCustomizerBlock(Block):
|
|||||||
processed_images = await asyncio.gather(
|
processed_images = await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
store_media_file(
|
store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=img,
|
file=img,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True,
|
return_format="for_external_api", # Get content for Replicate API
|
||||||
)
|
)
|
||||||
for img in input_data.images
|
for img in input_data.images
|
||||||
)
|
)
|
||||||
@@ -158,7 +159,14 @@ class AIImageCustomizerBlock(Block):
|
|||||||
aspect_ratio=input_data.aspect_ratio.value,
|
aspect_ratio=input_data.aspect_ratio.value,
|
||||||
output_format=input_data.output_format.value,
|
output_format=input_data.output_format.value,
|
||||||
)
|
)
|
||||||
yield "image_url", result
|
|
||||||
|
# Store the generated image to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=result,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", stored_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "error", str(e)
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from replicate.client import Client as ReplicateClient
|
|||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -13,6 +14,8 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.file import store_media_file
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
|
|
||||||
class ImageSize(str, Enum):
|
class ImageSize(str, Enum):
|
||||||
@@ -165,11 +168,13 @@ class AIImageGeneratorBlock(Block):
|
|||||||
test_output=[
|
test_output=[
|
||||||
(
|
(
|
||||||
"image_url",
|
"image_url",
|
||||||
"https://replicate.delivery/generated-image.webp",
|
# Test output is a data URI since we now store images
|
||||||
|
lambda x: x.startswith("data:image/"),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
"_run_client": lambda *args, **kwargs: "https://replicate.delivery/generated-image.webp"
|
# Return a data URI directly so store_media_file doesn't need to download
|
||||||
|
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -318,11 +323,24 @@ class AIImageGeneratorBlock(Block):
|
|||||||
style_text = style_map.get(style, "")
|
style_text = style_map.get(style, "")
|
||||||
return f"{style_text} of" if style_text else ""
|
return f"{style_text} of" if style_text else ""
|
||||||
|
|
||||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
url = await self.generate_image(input_data, credentials)
|
url = await self.generate_image(input_data, credentials)
|
||||||
if url:
|
if url:
|
||||||
yield "image_url", url
|
# Store the generated image to the user's workspace/execution folder
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", stored_url
|
||||||
else:
|
else:
|
||||||
yield "error", "Image generation returned an empty result."
|
yield "error", "Image generation returned an empty result."
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -21,7 +22,9 @@ from backend.data.model import (
|
|||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.exceptions import BlockExecutionError
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
@@ -271,7 +274,10 @@ class AIShortformVideoCreatorBlock(Block):
|
|||||||
"voice": Voice.LILY,
|
"voice": Voice.LILY,
|
||||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||||
},
|
},
|
||||||
test_output=("video_url", "https://example.com/video.mp4"),
|
test_output=(
|
||||||
|
"video_url",
|
||||||
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
|
),
|
||||||
test_mock={
|
test_mock={
|
||||||
"create_webhook": lambda *args, **kwargs: (
|
"create_webhook": lambda *args, **kwargs: (
|
||||||
"test_uuid",
|
"test_uuid",
|
||||||
@@ -280,15 +286,21 @@ class AIShortformVideoCreatorBlock(Block):
|
|||||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||||
"check_video_status": lambda *args, **kwargs: {
|
"check_video_status": lambda *args, **kwargs: {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"videoUrl": "https://example.com/video.mp4",
|
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
|
# Use data URI to avoid HTTP requests during tests
|
||||||
|
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Create a new Webhook.site URL
|
# Create a new Webhook.site URL
|
||||||
webhook_token, webhook_url = await self.create_webhook()
|
webhook_token, webhook_url = await self.create_webhook()
|
||||||
@@ -340,7 +352,13 @@ class AIShortformVideoCreatorBlock(Block):
|
|||||||
)
|
)
|
||||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||||
logger.debug(f"Video ready: {video_url}")
|
logger.debug(f"Video ready: {video_url}")
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
|
|
||||||
|
|
||||||
class AIAdMakerVideoCreatorBlock(Block):
|
class AIAdMakerVideoCreatorBlock(Block):
|
||||||
@@ -447,7 +465,10 @@ class AIAdMakerVideoCreatorBlock(Block):
|
|||||||
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
test_output=("video_url", "https://example.com/ad.mp4"),
|
test_output=(
|
||||||
|
"video_url",
|
||||||
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
|
),
|
||||||
test_mock={
|
test_mock={
|
||||||
"create_webhook": lambda *args, **kwargs: (
|
"create_webhook": lambda *args, **kwargs: (
|
||||||
"test_uuid",
|
"test_uuid",
|
||||||
@@ -456,14 +477,21 @@ class AIAdMakerVideoCreatorBlock(Block):
|
|||||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||||
"check_video_status": lambda *args, **kwargs: {
|
"check_video_status": lambda *args, **kwargs: {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"videoUrl": "https://example.com/ad.mp4",
|
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
|
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
webhook_token, webhook_url = await self.create_webhook()
|
webhook_token, webhook_url = await self.create_webhook()
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -531,7 +559,13 @@ class AIAdMakerVideoCreatorBlock(Block):
|
|||||||
raise RuntimeError("Failed to create video: No project ID returned")
|
raise RuntimeError("Failed to create video: No project ID returned")
|
||||||
|
|
||||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
|
|
||||||
|
|
||||||
class AIScreenshotToVideoAdBlock(Block):
|
class AIScreenshotToVideoAdBlock(Block):
|
||||||
@@ -626,7 +660,10 @@ class AIScreenshotToVideoAdBlock(Block):
|
|||||||
"script": "Amazing numbers!",
|
"script": "Amazing numbers!",
|
||||||
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
||||||
},
|
},
|
||||||
test_output=("video_url", "https://example.com/screenshot.mp4"),
|
test_output=(
|
||||||
|
"video_url",
|
||||||
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
|
),
|
||||||
test_mock={
|
test_mock={
|
||||||
"create_webhook": lambda *args, **kwargs: (
|
"create_webhook": lambda *args, **kwargs: (
|
||||||
"test_uuid",
|
"test_uuid",
|
||||||
@@ -635,14 +672,21 @@ class AIScreenshotToVideoAdBlock(Block):
|
|||||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||||
"check_video_status": lambda *args, **kwargs: {
|
"check_video_status": lambda *args, **kwargs: {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"videoUrl": "https://example.com/screenshot.mp4",
|
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
|
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
webhook_token, webhook_url = await self.create_webhook()
|
webhook_token, webhook_url = await self.create_webhook()
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -710,4 +754,10 @@ class AIScreenshotToVideoAdBlock(Block):
|
|||||||
raise RuntimeError("Failed to create video: No project ID returned")
|
raise RuntimeError("Failed to create video: No project ID returned")
|
||||||
|
|
||||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
@@ -17,6 +18,8 @@ from backend.sdk import (
|
|||||||
Requests,
|
Requests,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
|
from backend.util.file import store_media_file
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
from ._config import bannerbear
|
from ._config import bannerbear
|
||||||
|
|
||||||
@@ -135,15 +138,17 @@ class BannerbearTextOverlayBlock(Block):
|
|||||||
},
|
},
|
||||||
test_output=[
|
test_output=[
|
||||||
("success", True),
|
("success", True),
|
||||||
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
("uid", "test-uid-123"),
|
("uid", "test-uid-123"),
|
||||||
("status", "completed"),
|
("status", "completed"),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
|
# Use data URI to avoid HTTP requests during tests
|
||||||
"_make_api_request": lambda *args, **kwargs: {
|
"_make_api_request": lambda *args, **kwargs: {
|
||||||
"uid": "test-uid-123",
|
"uid": "test-uid-123",
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
|
"image_url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAABAAEBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+v//Z",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -177,7 +182,12 @@ class BannerbearTextOverlayBlock(Block):
|
|||||||
raise Exception(error_msg)
|
raise Exception(error_msg)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Build the modifications array
|
# Build the modifications array
|
||||||
modifications = []
|
modifications = []
|
||||||
@@ -234,6 +244,18 @@ class BannerbearTextOverlayBlock(Block):
|
|||||||
|
|
||||||
# Synchronous request - image should be ready
|
# Synchronous request - image should be ready
|
||||||
yield "success", True
|
yield "success", True
|
||||||
yield "image_url", data.get("image_url", "")
|
|
||||||
|
# Store the generated image to workspace for persistence
|
||||||
|
image_url = data.get("image_url", "")
|
||||||
|
if image_url:
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(image_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", stored_url
|
||||||
|
else:
|
||||||
|
yield "image_url", ""
|
||||||
|
|
||||||
yield "uid", data.get("uid", "")
|
yield "uid", data.get("uid", "")
|
||||||
yield "status", data.get("status", "completed")
|
yield "status", data.get("status", "completed")
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.type import MediaFileType, convert
|
from backend.util.type import MediaFileType, convert
|
||||||
@@ -17,10 +18,10 @@ from backend.util.type import MediaFileType, convert
|
|||||||
class FileStoreBlock(Block):
|
class FileStoreBlock(Block):
|
||||||
class Input(BlockSchemaInput):
|
class Input(BlockSchemaInput):
|
||||||
file_in: MediaFileType = SchemaField(
|
file_in: MediaFileType = SchemaField(
|
||||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
description="The file to download and store. Can be a URL (https://...), data URI, or local path."
|
||||||
)
|
)
|
||||||
base_64: bool = SchemaField(
|
base_64: bool = SchemaField(
|
||||||
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
|
description="Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks).",
|
||||||
default=False,
|
default=False,
|
||||||
advanced=True,
|
advanced=True,
|
||||||
title="Produce Base64 Output",
|
title="Produce Base64 Output",
|
||||||
@@ -28,13 +29,18 @@ class FileStoreBlock(Block):
|
|||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
file_out: MediaFileType = SchemaField(
|
file_out: MediaFileType = SchemaField(
|
||||||
description="The relative path to the stored file in the temporary directory."
|
description="Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks."
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
|
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
|
||||||
description="Stores the input file in the temporary directory.",
|
description=(
|
||||||
|
"Downloads and stores a file from a URL, data URI, or local path. "
|
||||||
|
"Use this to fetch images, documents, or other files for processing. "
|
||||||
|
"In CoPilot: saves to workspace (use list_workspace_files to see it). "
|
||||||
|
"In graphs: outputs a data URI to pass to other blocks."
|
||||||
|
),
|
||||||
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
|
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
|
||||||
input_schema=FileStoreBlock.Input,
|
input_schema=FileStoreBlock.Input,
|
||||||
output_schema=FileStoreBlock.Output,
|
output_schema=FileStoreBlock.Output,
|
||||||
@@ -45,15 +51,18 @@ class FileStoreBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
# Determine return format based on user preference
|
||||||
|
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
|
||||||
|
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
|
||||||
|
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
|
||||||
|
|
||||||
yield "file_out", await store_media_file(
|
yield "file_out", await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.file_in,
|
file=input_data.file_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=input_data.base_64,
|
return_format=return_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import APIKeyCredentials, SchemaField
|
from backend.data.model import APIKeyCredentials, SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
@@ -666,8 +667,7 @@ class SendDiscordFileBlock(Block):
|
|||||||
file: MediaFileType,
|
file: MediaFileType,
|
||||||
filename: str,
|
filename: str,
|
||||||
message_content: str,
|
message_content: str,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
intents.guilds = True
|
intents.guilds = True
|
||||||
@@ -731,10 +731,9 @@ class SendDiscordFileBlock(Block):
|
|||||||
# Local file path - read from stored media file
|
# Local file path - read from stored media file
|
||||||
# This would be a path from a previous block's output
|
# This would be a path from a previous block's output
|
||||||
stored_file = await store_media_file(
|
stored_file = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=file,
|
file=file,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True, # Get as data URI
|
return_format="for_external_api", # Get content to send to Discord
|
||||||
)
|
)
|
||||||
# Now process as data URI
|
# Now process as data URI
|
||||||
header, encoded = stored_file.split(",", 1)
|
header, encoded = stored_file.split(",", 1)
|
||||||
@@ -781,8 +780,7 @@ class SendDiscordFileBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
@@ -793,8 +791,7 @@ class SendDiscordFileBlock(Block):
|
|||||||
file=input_data.file,
|
file=input_data.file,
|
||||||
filename=input_data.filename,
|
filename=input_data.filename,
|
||||||
message_content=input_data.message_content,
|
message_content=input_data.message_content,
|
||||||
graph_exec_id=graph_exec_id,
|
execution_context=execution_context,
|
||||||
user_id=user_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "status", result.get("status", "Unknown error")
|
yield "status", result.get("status", "Unknown error")
|
||||||
|
|||||||
@@ -17,8 +17,11 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import ClientResponseError, Requests
|
from backend.util.request import ClientResponseError, Requests
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -64,9 +67,13 @@ class AIVideoGeneratorBlock(Block):
|
|||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
test_output=[("video_url", "https://fal.media/files/example/video.mp4")],
|
test_output=[
|
||||||
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("video_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
"generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4"
|
# Use data URI to avoid HTTP requests during tests
|
||||||
|
"generate_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -208,11 +215,22 @@ class AIVideoGeneratorBlock(Block):
|
|||||||
raise RuntimeError(f"API request failed: {str(e)}")
|
raise RuntimeError(f"API request failed: {str(e)}")
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: FalCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: FalCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
video_url = await self.generate_video(input_data, credentials)
|
video_url = await self.generate_video(input_data, credentials)
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
yield "error", error_message
|
yield "error", error_message
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -121,10 +122,12 @@ class AIImageEditorBlock(Block):
|
|||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_output=[
|
test_output=[
|
||||||
("output_image", "https://replicate.com/output/edited-image.png"),
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("output_image", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
"run_model": lambda *args, **kwargs: "https://replicate.com/output/edited-image.png",
|
# Use data URI to avoid HTTP requests during tests
|
||||||
|
"run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
@@ -134,8 +137,7 @@ class AIImageEditorBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
result = await self.run_model(
|
result = await self.run_model(
|
||||||
@@ -144,20 +146,25 @@ class AIImageEditorBlock(Block):
|
|||||||
prompt=input_data.prompt,
|
prompt=input_data.prompt,
|
||||||
input_image_b64=(
|
input_image_b64=(
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.input_image,
|
file=input_data.input_image,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True,
|
return_format="for_external_api", # Get content for Replicate API
|
||||||
)
|
)
|
||||||
if input_data.input_image
|
if input_data.input_image
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
aspect_ratio=input_data.aspect_ratio.value,
|
aspect_ratio=input_data.aspect_ratio.value,
|
||||||
seed=input_data.seed,
|
seed=input_data.seed,
|
||||||
user_id=user_id,
|
user_id=execution_context.user_id or "",
|
||||||
graph_exec_id=graph_exec_id,
|
graph_exec_id=execution_context.graph_exec_id or "",
|
||||||
)
|
)
|
||||||
yield "output_image", result
|
# Store the generated image to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=result,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "output_image", stored_url
|
||||||
|
|
||||||
async def run_model(
|
async def run_model(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
@@ -95,8 +96,7 @@ def _make_mime_text(
|
|||||||
|
|
||||||
async def create_mime_message(
|
async def create_mime_message(
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
||||||
|
|
||||||
@@ -117,12 +117,12 @@ async def create_mime_message(
|
|||||||
if input_data.attachments:
|
if input_data.attachments:
|
||||||
for attach in input_data.attachments:
|
for attach in input_data.attachments:
|
||||||
local_path = await store_media_file(
|
local_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=attach,
|
file=attach,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
part = MIMEBase("application", "octet-stream")
|
part = MIMEBase("application", "octet-stream")
|
||||||
with open(abs_path, "rb") as f:
|
with open(abs_path, "rb") as f:
|
||||||
part.set_payload(f.read())
|
part.set_payload(f.read())
|
||||||
@@ -582,27 +582,25 @@ class GmailSendBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
result = await self._send_email(
|
result = await self._send_email(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "result", result
|
yield "result", result
|
||||||
|
|
||||||
async def _send_email(
|
async def _send_email(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not input_data.to or not input_data.subject or not input_data.body:
|
if not input_data.to or not input_data.subject or not input_data.body:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"At least one recipient, subject, and body are required for sending an email"
|
"At least one recipient, subject, and body are required for sending an email"
|
||||||
)
|
)
|
||||||
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
|
raw_message = await create_mime_message(input_data, execution_context)
|
||||||
sent_message = await asyncio.to_thread(
|
sent_message = await asyncio.to_thread(
|
||||||
lambda: service.users()
|
lambda: service.users()
|
||||||
.messages()
|
.messages()
|
||||||
@@ -692,30 +690,28 @@ class GmailCreateDraftBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
result = await self._create_draft(
|
result = await self._create_draft(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "result", GmailDraftResult(
|
yield "result", GmailDraftResult(
|
||||||
id=result["id"], message_id=result["message"]["id"], status="draft_created"
|
id=result["id"], message_id=result["message"]["id"], status="draft_created"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _create_draft(
|
async def _create_draft(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not input_data.to or not input_data.subject:
|
if not input_data.to or not input_data.subject:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"At least one recipient and subject are required for creating a draft"
|
"At least one recipient and subject are required for creating a draft"
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
|
raw_message = await create_mime_message(input_data, execution_context)
|
||||||
draft = await asyncio.to_thread(
|
draft = await asyncio.to_thread(
|
||||||
lambda: service.users()
|
lambda: service.users()
|
||||||
.drafts()
|
.drafts()
|
||||||
@@ -1100,7 +1096,7 @@ class GmailGetThreadBlock(GmailBase):
|
|||||||
|
|
||||||
|
|
||||||
async def _build_reply_message(
|
async def _build_reply_message(
|
||||||
service, input_data, graph_exec_id: str, user_id: str
|
service, input_data, execution_context: ExecutionContext
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Builds a reply MIME message for Gmail threads.
|
Builds a reply MIME message for Gmail threads.
|
||||||
@@ -1190,12 +1186,12 @@ async def _build_reply_message(
|
|||||||
# Handle attachments
|
# Handle attachments
|
||||||
for attach in input_data.attachments:
|
for attach in input_data.attachments:
|
||||||
local_path = await store_media_file(
|
local_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=attach,
|
file=attach,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
part = MIMEBase("application", "octet-stream")
|
part = MIMEBase("application", "octet-stream")
|
||||||
with open(abs_path, "rb") as f:
|
with open(abs_path, "rb") as f:
|
||||||
part.set_payload(f.read())
|
part.set_payload(f.read())
|
||||||
@@ -1311,16 +1307,14 @@ class GmailReplyBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
message = await self._reply(
|
message = await self._reply(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "messageId", message["id"]
|
yield "messageId", message["id"]
|
||||||
yield "threadId", message.get("threadId", input_data.threadId)
|
yield "threadId", message.get("threadId", input_data.threadId)
|
||||||
@@ -1343,11 +1337,11 @@ class GmailReplyBlock(GmailBase):
|
|||||||
yield "email", email
|
yield "email", email
|
||||||
|
|
||||||
async def _reply(
|
async def _reply(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# Build the reply message using the shared helper
|
# Build the reply message using the shared helper
|
||||||
raw, thread_id = await _build_reply_message(
|
raw, thread_id = await _build_reply_message(
|
||||||
service, input_data, graph_exec_id, user_id
|
service, input_data, execution_context
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send the message
|
# Send the message
|
||||||
@@ -1441,16 +1435,14 @@ class GmailDraftReplyBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
draft = await self._create_draft_reply(
|
draft = await self._create_draft_reply(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "draftId", draft["id"]
|
yield "draftId", draft["id"]
|
||||||
yield "messageId", draft["message"]["id"]
|
yield "messageId", draft["message"]["id"]
|
||||||
@@ -1458,11 +1450,11 @@ class GmailDraftReplyBlock(GmailBase):
|
|||||||
yield "status", "draft_created"
|
yield "status", "draft_created"
|
||||||
|
|
||||||
async def _create_draft_reply(
|
async def _create_draft_reply(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# Build the reply message using the shared helper
|
# Build the reply message using the shared helper
|
||||||
raw, thread_id = await _build_reply_message(
|
raw, thread_id = await _build_reply_message(
|
||||||
service, input_data, graph_exec_id, user_id
|
service, input_data, execution_context
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create draft with proper thread association
|
# Create draft with proper thread association
|
||||||
@@ -1629,23 +1621,21 @@ class GmailForwardBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
result = await self._forward_message(
|
result = await self._forward_message(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "messageId", result["id"]
|
yield "messageId", result["id"]
|
||||||
yield "threadId", result.get("threadId", "")
|
yield "threadId", result.get("threadId", "")
|
||||||
yield "status", "forwarded"
|
yield "status", "forwarded"
|
||||||
|
|
||||||
async def _forward_message(
|
async def _forward_message(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not input_data.to:
|
if not input_data.to:
|
||||||
raise ValueError("At least one recipient is required for forwarding")
|
raise ValueError("At least one recipient is required for forwarding")
|
||||||
@@ -1727,12 +1717,12 @@ To: {original_to}
|
|||||||
# Add any additional attachments
|
# Add any additional attachments
|
||||||
for attach in input_data.additionalAttachments:
|
for attach in input_data.additionalAttachments:
|
||||||
local_path = await store_media_file(
|
local_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=attach,
|
file=attach,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
part = MIMEBase("application", "octet-stream")
|
part = MIMEBase("application", "octet-stream")
|
||||||
with open(abs_path, "rb") as f:
|
with open(abs_path, "rb") as f:
|
||||||
part.set_payload(f.read())
|
part.set_payload(f.read())
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
@@ -116,10 +117,9 @@ class SendWebRequestBlock(Block):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _prepare_files(
|
async def _prepare_files(
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
files_name: str,
|
files_name: str,
|
||||||
files: list[MediaFileType],
|
files: list[MediaFileType],
|
||||||
user_id: str,
|
|
||||||
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
|
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
|
||||||
"""
|
"""
|
||||||
Prepare files for the request by storing them and reading their content.
|
Prepare files for the request by storing them and reading their content.
|
||||||
@@ -127,11 +127,16 @@ class SendWebRequestBlock(Block):
|
|||||||
(files_name, (filename, BytesIO, mime_type))
|
(files_name, (filename, BytesIO, mime_type))
|
||||||
"""
|
"""
|
||||||
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
if graph_exec_id is None:
|
||||||
|
raise ValueError("graph_exec_id is required for file operations")
|
||||||
|
|
||||||
for media in files:
|
for media in files:
|
||||||
# Normalise to a list so we can repeat the same key
|
# Normalise to a list so we can repeat the same key
|
||||||
rel_path = await store_media_file(
|
rel_path = await store_media_file(
|
||||||
graph_exec_id, media, user_id, return_content=False
|
file=media,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, rel_path)
|
abs_path = get_exec_file_path(graph_exec_id, rel_path)
|
||||||
async with aiofiles.open(abs_path, "rb") as f:
|
async with aiofiles.open(abs_path, "rb") as f:
|
||||||
@@ -143,7 +148,7 @@ class SendWebRequestBlock(Block):
|
|||||||
return files_payload
|
return files_payload
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs
|
self, input_data: Input, *, execution_context: ExecutionContext, **kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# ─── Parse/normalise body ────────────────────────────────────
|
# ─── Parse/normalise body ────────────────────────────────────
|
||||||
body = input_data.body
|
body = input_data.body
|
||||||
@@ -174,7 +179,7 @@ class SendWebRequestBlock(Block):
|
|||||||
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
||||||
if use_files:
|
if use_files:
|
||||||
files_payload = await self._prepare_files(
|
files_payload = await self._prepare_files(
|
||||||
graph_exec_id, input_data.files_name, input_data.files, user_id
|
execution_context, input_data.files_name, input_data.files
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enforce body format rules
|
# Enforce body format rules
|
||||||
@@ -238,9 +243,8 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
credentials: HostScopedCredentials,
|
credentials: HostScopedCredentials,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
||||||
@@ -271,6 +275,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
|||||||
|
|
||||||
# Use parent class run method
|
# Use parent class run method
|
||||||
async for output_name, output_data in super().run(
|
async for output_name, output_data in super().run(
|
||||||
base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs
|
base_input, execution_context=execution_context, **kwargs
|
||||||
):
|
):
|
||||||
yield output_name, output_data
|
yield output_name, output_data
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.mock import MockObject
|
from backend.util.mock import MockObject
|
||||||
@@ -462,18 +463,21 @@ class AgentFileInputBlock(AgentInputBlock):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
if not input_data.value:
|
if not input_data.value:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Determine return format based on user preference
|
||||||
|
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
|
||||||
|
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
|
||||||
|
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
|
||||||
|
|
||||||
yield "result", await store_media_file(
|
yield "result", await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.value,
|
file=input_data.value,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=input_data.base_64,
|
return_format=return_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -115,7 +115,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
|
||||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||||
# AI/ML API models
|
# AI/ML API models
|
||||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||||
@@ -280,9 +279,6 @@ MODEL_METADATA = {
|
|||||||
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
||||||
), # claude-haiku-4-5-20251001
|
), # claude-haiku-4-5-20251001
|
||||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
|
||||||
"anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2
|
|
||||||
), # claude-3-7-sonnet-20250219
|
|
||||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||||
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
||||||
), # claude-3-haiku-20240307
|
), # claude-3-haiku-20240307
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Literal, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
from moviepy.video.fx.Loop import Loop
|
from moviepy.video.fx.Loop import Loop
|
||||||
@@ -13,6 +13,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
@@ -46,18 +47,19 @@ class MediaDurationBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# 1) Store the input media locally
|
# 1) Store the input media locally
|
||||||
local_media_path = await store_media_file(
|
local_media_path = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.media_in,
|
file=input_data.media_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=False,
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
media_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_media_path
|
||||||
)
|
)
|
||||||
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
|
|
||||||
|
|
||||||
# 2) Load the clip
|
# 2) Load the clip
|
||||||
if input_data.is_video:
|
if input_data.is_video:
|
||||||
@@ -88,10 +90,6 @@ class LoopVideoBlock(Block):
|
|||||||
default=None,
|
default=None,
|
||||||
ge=1,
|
ge=1,
|
||||||
)
|
)
|
||||||
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
|
||||||
description="How to return the output video. Either a relative path or base64 data URI.",
|
|
||||||
default="file_path",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
video_out: str = SchemaField(
|
video_out: str = SchemaField(
|
||||||
@@ -111,17 +109,19 @@ class LoopVideoBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
node_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
graph_exec_id: str,
|
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
# 1) Store the input video locally
|
# 1) Store the input video locally
|
||||||
local_video_path = await store_media_file(
|
local_video_path = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.video_in,
|
file=input_data.video_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=False,
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||||
|
|
||||||
@@ -149,12 +149,11 @@ class LoopVideoBlock(Block):
|
|||||||
looped_clip = looped_clip.with_audio(clip.audio)
|
looped_clip = looped_clip.with_audio(clip.audio)
|
||||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||||
|
|
||||||
# Return as data URI
|
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
video_out = await store_media_file(
|
video_out = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=output_filename,
|
file=output_filename,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=input_data.output_return_type == "data_uri",
|
return_format="for_block_output",
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "video_out", video_out
|
yield "video_out", video_out
|
||||||
@@ -177,10 +176,6 @@ class AddAudioToVideoBlock(Block):
|
|||||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||||
default=1.0,
|
default=1.0,
|
||||||
)
|
)
|
||||||
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
|
||||||
description="Return the final output as a relative path or base64 data URI.",
|
|
||||||
default="file_path",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
video_out: MediaFileType = SchemaField(
|
video_out: MediaFileType = SchemaField(
|
||||||
@@ -200,23 +195,24 @@ class AddAudioToVideoBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
node_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
graph_exec_id: str,
|
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
# 1) Store the inputs locally
|
# 1) Store the inputs locally
|
||||||
local_video_path = await store_media_file(
|
local_video_path = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.video_in,
|
file=input_data.video_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=False,
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
local_audio_path = await store_media_file(
|
local_audio_path = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.audio_in,
|
file=input_data.audio_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=False,
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
||||||
@@ -240,12 +236,11 @@ class AddAudioToVideoBlock(Block):
|
|||||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||||
|
|
||||||
# 5) Return either path or data URI
|
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
video_out = await store_media_file(
|
video_out = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=output_filename,
|
file=output_filename,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=input_data.output_return_type == "data_uri",
|
return_format="for_block_output",
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "video_out", video_out
|
yield "video_out", video_out
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -112,8 +113,7 @@ class ScreenshotWebPageBlock(Block):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def take_screenshot(
|
async def take_screenshot(
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
url: str,
|
url: str,
|
||||||
viewport_width: int,
|
viewport_width: int,
|
||||||
viewport_height: int,
|
viewport_height: int,
|
||||||
@@ -155,12 +155,11 @@ class ScreenshotWebPageBlock(Block):
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"image": await store_media_file(
|
"image": await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=MediaFileType(
|
file=MediaFileType(
|
||||||
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
|
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
|
||||||
),
|
),
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True,
|
return_format="for_block_output",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,15 +168,13 @@ class ScreenshotWebPageBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
screenshot_data = await self.take_screenshot(
|
screenshot_data = await self.take_screenshot(
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
graph_exec_id=graph_exec_id,
|
execution_context=execution_context,
|
||||||
user_id=user_id,
|
|
||||||
url=input_data.url,
|
url=input_data.url,
|
||||||
viewport_width=input_data.viewport_width,
|
viewport_width=input_data.viewport_width,
|
||||||
viewport_height=input_data.viewport_height,
|
viewport_height=input_data.viewport_height,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import ContributorDetails, SchemaField
|
from backend.data.model import ContributorDetails, SchemaField
|
||||||
from backend.util.file import get_exec_file_path, store_media_file
|
from backend.util.file import get_exec_file_path, store_media_file
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
@@ -98,7 +99,7 @@ class ReadSpreadsheetBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
import csv
|
import csv
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
@@ -106,14 +107,16 @@ class ReadSpreadsheetBlock(Block):
|
|||||||
# Determine data source - prefer file_input if provided, otherwise use contents
|
# Determine data source - prefer file_input if provided, otherwise use contents
|
||||||
if input_data.file_input:
|
if input_data.file_input:
|
||||||
stored_file_path = await store_media_file(
|
stored_file_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.file_input,
|
file=input_data.file_input,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get full file path
|
# Get full file path
|
||||||
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
file_path = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, stored_file_path
|
||||||
|
)
|
||||||
if not Path(file_path).exists():
|
if not Path(file_path).exists():
|
||||||
raise ValueError(f"File does not exist: {file_path}")
|
raise ValueError(f"File does not exist: {file_path}")
|
||||||
|
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum):
|
|||||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||||
|
|
||||||
# Anthropic
|
# Anthropic
|
||||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_name(self) -> str:
|
def provider_name(self) -> str:
|
||||||
@@ -137,7 +137,7 @@ class StagehandObserveBlock(Block):
|
|||||||
model: StagehandRecommendedLlmModel = SchemaField(
|
model: StagehandRecommendedLlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
description="LLM to use for Stagehand (provider is inferred)",
|
description="LLM to use for Stagehand (provider is inferred)",
|
||||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
model_credentials: AICredentials = AICredentialsField()
|
model_credentials: AICredentials = AICredentialsField()
|
||||||
@@ -230,7 +230,7 @@ class StagehandActBlock(Block):
|
|||||||
model: StagehandRecommendedLlmModel = SchemaField(
|
model: StagehandRecommendedLlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
description="LLM to use for Stagehand (provider is inferred)",
|
description="LLM to use for Stagehand (provider is inferred)",
|
||||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
model_credentials: AICredentials = AICredentialsField()
|
model_credentials: AICredentials = AICredentialsField()
|
||||||
@@ -330,7 +330,7 @@ class StagehandExtractBlock(Block):
|
|||||||
model: StagehandRecommendedLlmModel = SchemaField(
|
model: StagehandRecommendedLlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
description="LLM to use for Stagehand (provider is inferred)",
|
description="LLM to use for Stagehand (provider is inferred)",
|
||||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
model_credentials: AICredentials = AICredentialsField()
|
model_credentials: AICredentials = AICredentialsField()
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -17,7 +18,9 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
@@ -102,7 +105,7 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
test_output=[
|
test_output=[
|
||||||
(
|
(
|
||||||
"video_url",
|
"video_url",
|
||||||
"https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
@@ -110,9 +113,10 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
|
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
|
||||||
"status": "created",
|
"status": "created",
|
||||||
},
|
},
|
||||||
|
# Use data URI to avoid HTTP requests during tests
|
||||||
"get_clip_status": lambda *args, **kwargs: {
|
"get_clip_status": lambda *args, **kwargs: {
|
||||||
"status": "done",
|
"status": "done",
|
||||||
"result_url": "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
|
"result_url": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -138,7 +142,12 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Create the clip
|
# Create the clip
|
||||||
payload = {
|
payload = {
|
||||||
@@ -165,7 +174,14 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
for _ in range(input_data.max_polling_attempts):
|
for _ in range(input_data.max_polling_attempts):
|
||||||
status_response = await self.get_clip_status(credentials.api_key, clip_id)
|
status_response = await self.get_clip_status(credentials.api_key, clip_id)
|
||||||
if status_response["status"] == "done":
|
if status_response["status"] == "done":
|
||||||
yield "video_url", status_response["result_url"]
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
video_url = status_response["result_url"]
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
return
|
return
|
||||||
elif status_response["status"] == "error":
|
elif status_response["status"] == "error":
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from backend.blocks.iteration import StepThroughItemsBlock
|
|||||||
from backend.blocks.llm import AITextSummarizerBlock
|
from backend.blocks.llm import AITextSummarizerBlock
|
||||||
from backend.blocks.text import ExtractTextInformationBlock
|
from backend.blocks.text import ExtractTextInformationBlock
|
||||||
from backend.blocks.xml_parser import XMLParserBlock
|
from backend.blocks.xml_parser import XMLParserBlock
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
@@ -233,9 +234,12 @@ class TestStoreMediaFileSecurity:
|
|||||||
|
|
||||||
with pytest.raises(ValueError, match="File too large"):
|
with pytest.raises(ValueError, match="File too large"):
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id="test",
|
|
||||||
file=MediaFileType(large_data_uri),
|
file=MediaFileType(large_data_uri),
|
||||||
|
execution_context=ExecutionContext(
|
||||||
user_id="test_user",
|
user_id="test_user",
|
||||||
|
graph_exec_id="test",
|
||||||
|
),
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("backend.util.file.Path")
|
@patch("backend.util.file.Path")
|
||||||
@@ -270,9 +274,12 @@ class TestStoreMediaFileSecurity:
|
|||||||
# Should raise an error when directory size exceeds limit
|
# Should raise an error when directory size exceeds limit
|
||||||
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
|
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id="test",
|
|
||||||
file=MediaFileType(
|
file=MediaFileType(
|
||||||
"data:text/plain;base64,dGVzdA=="
|
"data:text/plain;base64,dGVzdA=="
|
||||||
), # Small test file
|
), # Small test file
|
||||||
|
execution_context=ExecutionContext(
|
||||||
user_id="test_user",
|
user_id="test_user",
|
||||||
|
graph_exec_id="test",
|
||||||
|
),
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,10 +11,22 @@ from backend.blocks.http import (
|
|||||||
HttpMethod,
|
HttpMethod,
|
||||||
SendAuthenticatedWebRequestBlock,
|
SendAuthenticatedWebRequestBlock,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import HostScopedCredentials
|
from backend.data.model import HostScopedCredentials
|
||||||
from backend.util.request import Response
|
from backend.util.request import Response
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_context(
|
||||||
|
graph_exec_id: str = "test-exec-id",
|
||||||
|
user_id: str = "test-user-id",
|
||||||
|
) -> ExecutionContext:
|
||||||
|
"""Helper to create test ExecutionContext."""
|
||||||
|
return ExecutionContext(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestHttpBlockWithHostScopedCredentials:
|
class TestHttpBlockWithHostScopedCredentials:
|
||||||
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
||||||
|
|
||||||
@@ -105,8 +117,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=exact_match_credentials,
|
credentials=exact_match_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -161,8 +172,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=wildcard_credentials,
|
credentials=wildcard_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -208,8 +218,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=non_matching_credentials,
|
credentials=non_matching_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -258,8 +267,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=exact_match_credentials,
|
credentials=exact_match_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -318,8 +326,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=auto_discovered_creds, # Execution manager found these
|
credentials=auto_discovered_creds, # Execution manager found these
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -382,8 +389,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=multi_header_creds,
|
credentials=multi_header_creds,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -471,8 +477,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=test_creds,
|
credentials=test_creds,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util import json, text
|
from backend.util import json, text
|
||||||
from backend.util.file import get_exec_file_path, store_media_file
|
from backend.util.file import get_exec_file_path, store_media_file
|
||||||
@@ -444,18 +445,21 @@ class FileReadBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Store the media file properly (handles URLs, data URIs, etc.)
|
# Store the media file properly (handles URLs, data URIs, etc.)
|
||||||
stored_file_path = await store_media_file(
|
stored_file_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.file_input,
|
file=input_data.file_input,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get full file path
|
# Get full file path (graph_exec_id validated by store_media_file above)
|
||||||
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
if not execution_context.graph_exec_id:
|
||||||
|
raise ValueError("execution_context.graph_exec_id is required")
|
||||||
|
file_path = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, stored_file_path
|
||||||
|
)
|
||||||
|
|
||||||
if not Path(file_path).exists():
|
if not Path(file_path).exists():
|
||||||
raise ValueError(f"File does not exist: {file_path}")
|
raise ValueError(f"File does not exist: {file_path}")
|
||||||
|
|||||||
@@ -81,7 +81,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
|||||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
|
||||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||||
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
||||||
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
||||||
|
|||||||
@@ -83,12 +83,29 @@ class ExecutionContext(BaseModel):
|
|||||||
|
|
||||||
model_config = {"extra": "ignore"}
|
model_config = {"extra": "ignore"}
|
||||||
|
|
||||||
|
# Execution identity
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
graph_id: Optional[str] = None
|
||||||
|
graph_exec_id: Optional[str] = None
|
||||||
|
graph_version: Optional[int] = None
|
||||||
|
node_id: Optional[str] = None
|
||||||
|
node_exec_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Safety settings
|
||||||
human_in_the_loop_safe_mode: bool = True
|
human_in_the_loop_safe_mode: bool = True
|
||||||
sensitive_action_safe_mode: bool = False
|
sensitive_action_safe_mode: bool = False
|
||||||
|
|
||||||
|
# User settings
|
||||||
user_timezone: str = "UTC"
|
user_timezone: str = "UTC"
|
||||||
|
|
||||||
|
# Execution hierarchy
|
||||||
root_execution_id: Optional[str] = None
|
root_execution_id: Optional[str] = None
|
||||||
parent_execution_id: Optional[str] = None
|
parent_execution_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Workspace
|
||||||
|
workspace_id: Optional[str] = None
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# -------------------------- Models -------------------------- #
|
# -------------------------- Models -------------------------- #
|
||||||
|
|
||||||
|
|||||||
@@ -666,10 +666,16 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
if not (self.discriminator and self.discriminator_mapping):
|
if not (self.discriminator and self.discriminator_mapping):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider = self.discriminator_mapping[discriminator_value]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{discriminator_value}' is not supported. "
|
||||||
|
"It may have been deprecated. Please update your agent configuration."
|
||||||
|
)
|
||||||
|
|
||||||
return CredentialsFieldInfo(
|
return CredentialsFieldInfo(
|
||||||
credentials_provider=frozenset(
|
credentials_provider=frozenset([provider]),
|
||||||
[self.discriminator_mapping[discriminator_value]]
|
|
||||||
),
|
|
||||||
credentials_types=self.supported_types,
|
credentials_types=self.supported_types,
|
||||||
credentials_scopes=self.required_scopes,
|
credentials_scopes=self.required_scopes,
|
||||||
discriminator=self.discriminator,
|
discriminator=self.discriminator,
|
||||||
|
|||||||
276
autogpt_platform/backend/backend/data/workspace.py
Normal file
276
autogpt_platform/backend/backend/data/workspace.py
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
"""
|
||||||
|
Database CRUD operations for User Workspace.
|
||||||
|
|
||||||
|
This module provides functions for managing user workspaces and workspace files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from prisma.models import UserWorkspace, UserWorkspaceFile
|
||||||
|
from prisma.types import UserWorkspaceFileWhereInput
|
||||||
|
|
||||||
|
from backend.util.json import SafeJson
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
|
||||||
|
"""
|
||||||
|
Get user's workspace, creating one if it doesn't exist.
|
||||||
|
|
||||||
|
Uses upsert to handle race conditions when multiple concurrent requests
|
||||||
|
attempt to create a workspace for the same user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspace instance
|
||||||
|
"""
|
||||||
|
workspace = await UserWorkspace.prisma().upsert(
|
||||||
|
where={"userId": user_id},
|
||||||
|
data={
|
||||||
|
"create": {"userId": user_id},
|
||||||
|
"update": {}, # No updates needed if exists
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return workspace
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
|
||||||
|
"""
|
||||||
|
Get user's workspace if it exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspace instance or None
|
||||||
|
"""
|
||||||
|
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
||||||
|
|
||||||
|
|
||||||
|
async def create_workspace_file(
|
||||||
|
workspace_id: str,
|
||||||
|
file_id: str,
|
||||||
|
name: str,
|
||||||
|
path: str,
|
||||||
|
storage_path: str,
|
||||||
|
mime_type: str,
|
||||||
|
size_bytes: int,
|
||||||
|
checksum: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
) -> UserWorkspaceFile:
|
||||||
|
"""
|
||||||
|
Create a new workspace file record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
file_id: The file ID (same as used in storage path for consistency)
|
||||||
|
name: User-visible filename
|
||||||
|
path: Virtual path (e.g., "/documents/report.pdf")
|
||||||
|
storage_path: Actual storage path (GCS or local)
|
||||||
|
mime_type: MIME type of the file
|
||||||
|
size_bytes: File size in bytes
|
||||||
|
checksum: Optional SHA256 checksum
|
||||||
|
metadata: Optional additional metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created UserWorkspaceFile instance
|
||||||
|
"""
|
||||||
|
# Normalize path to start with /
|
||||||
|
if not path.startswith("/"):
|
||||||
|
path = f"/{path}"
|
||||||
|
|
||||||
|
file = await UserWorkspaceFile.prisma().create(
|
||||||
|
data={
|
||||||
|
"id": file_id,
|
||||||
|
"workspaceId": workspace_id,
|
||||||
|
"name": name,
|
||||||
|
"path": path,
|
||||||
|
"storagePath": storage_path,
|
||||||
|
"mimeType": mime_type,
|
||||||
|
"sizeBytes": size_bytes,
|
||||||
|
"checksum": checksum,
|
||||||
|
"metadata": SafeJson(metadata or {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created workspace file {file.id} at path {path} "
|
||||||
|
f"in workspace {workspace_id}"
|
||||||
|
)
|
||||||
|
return file
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_file(
|
||||||
|
file_id: str,
|
||||||
|
workspace_id: Optional[str] = None,
|
||||||
|
) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Get a workspace file by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file ID
|
||||||
|
workspace_id: Optional workspace ID for validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspaceFile instance or None
|
||||||
|
"""
|
||||||
|
where_clause: dict = {"id": file_id, "isDeleted": False}
|
||||||
|
if workspace_id:
|
||||||
|
where_clause["workspaceId"] = workspace_id
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_file_by_path(
|
||||||
|
workspace_id: str,
|
||||||
|
path: str,
|
||||||
|
) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Get a workspace file by its virtual path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
path: Virtual path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspaceFile instance or None
|
||||||
|
"""
|
||||||
|
# Normalize path
|
||||||
|
if not path.startswith("/"):
|
||||||
|
path = f"/{path}"
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().find_first(
|
||||||
|
where={
|
||||||
|
"workspaceId": workspace_id,
|
||||||
|
"path": path,
|
||||||
|
"isDeleted": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_workspace_files(
|
||||||
|
workspace_id: str,
|
||||||
|
path_prefix: Optional[str] = None,
|
||||||
|
include_deleted: bool = False,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
List files in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
path_prefix: Optional path prefix to filter (e.g., "/documents/")
|
||||||
|
include_deleted: Whether to include soft-deleted files
|
||||||
|
limit: Maximum number of files to return
|
||||||
|
offset: Number of files to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of UserWorkspaceFile instances
|
||||||
|
"""
|
||||||
|
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
|
||||||
|
|
||||||
|
if not include_deleted:
|
||||||
|
where_clause["isDeleted"] = False
|
||||||
|
|
||||||
|
if path_prefix:
|
||||||
|
# Normalize prefix
|
||||||
|
if not path_prefix.startswith("/"):
|
||||||
|
path_prefix = f"/{path_prefix}"
|
||||||
|
where_clause["path"] = {"startswith": path_prefix}
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().find_many(
|
||||||
|
where=where_clause,
|
||||||
|
order={"createdAt": "desc"},
|
||||||
|
take=limit,
|
||||||
|
skip=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def count_workspace_files(
|
||||||
|
workspace_id: str,
|
||||||
|
path_prefix: Optional[str] = None,
|
||||||
|
include_deleted: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Count files in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
path_prefix: Optional path prefix to filter (e.g., "/sessions/abc123/")
|
||||||
|
include_deleted: Whether to include soft-deleted files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of files
|
||||||
|
"""
|
||||||
|
where_clause: dict = {"workspaceId": workspace_id}
|
||||||
|
if not include_deleted:
|
||||||
|
where_clause["isDeleted"] = False
|
||||||
|
|
||||||
|
if path_prefix:
|
||||||
|
# Normalize prefix
|
||||||
|
if not path_prefix.startswith("/"):
|
||||||
|
path_prefix = f"/{path_prefix}"
|
||||||
|
where_clause["path"] = {"startswith": path_prefix}
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().count(where=where_clause)
|
||||||
|
|
||||||
|
|
||||||
|
async def soft_delete_workspace_file(
|
||||||
|
file_id: str,
|
||||||
|
workspace_id: Optional[str] = None,
|
||||||
|
) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Soft-delete a workspace file.
|
||||||
|
|
||||||
|
The path is modified to include a deletion timestamp to free up the original
|
||||||
|
path for new files while preserving the record for potential recovery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file ID
|
||||||
|
workspace_id: Optional workspace ID for validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated UserWorkspaceFile instance or None if not found
|
||||||
|
"""
|
||||||
|
# First verify the file exists and belongs to workspace
|
||||||
|
file = await get_workspace_file(file_id, workspace_id)
|
||||||
|
if file is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
deleted_at = datetime.now(timezone.utc)
|
||||||
|
# Modify path to free up the unique constraint for new files at original path
|
||||||
|
# Format: {original_path}__deleted__{timestamp}
|
||||||
|
deleted_path = f"{file.path}__deleted__{int(deleted_at.timestamp())}"
|
||||||
|
|
||||||
|
updated = await UserWorkspaceFile.prisma().update(
|
||||||
|
where={"id": file_id},
|
||||||
|
data={
|
||||||
|
"isDeleted": True,
|
||||||
|
"deletedAt": deleted_at,
|
||||||
|
"path": deleted_path,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Soft-deleted workspace file {file_id}")
|
||||||
|
return updated
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_total_size(workspace_id: str) -> int:
|
||||||
|
"""
|
||||||
|
Get the total size of all files in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total size in bytes
|
||||||
|
"""
|
||||||
|
files = await list_workspace_files(workspace_id)
|
||||||
|
return sum(file.sizeBytes for file in files)
|
||||||
@@ -236,7 +236,14 @@ async def execute_node(
|
|||||||
input_size = len(input_data_str)
|
input_size = len(input_data_str)
|
||||||
log_metadata.debug("Executed node with input", input=input_data_str)
|
log_metadata.debug("Executed node with input", input=input_data_str)
|
||||||
|
|
||||||
|
# Create node-specific execution context to avoid race conditions
|
||||||
|
# (multiple nodes can execute concurrently and would otherwise mutate shared state)
|
||||||
|
execution_context = execution_context.model_copy(
|
||||||
|
update={"node_id": node_id, "node_exec_id": node_exec_id}
|
||||||
|
)
|
||||||
|
|
||||||
# Inject extra execution arguments for the blocks via kwargs
|
# Inject extra execution arguments for the blocks via kwargs
|
||||||
|
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||||
extra_exec_kwargs: dict = {
|
extra_exec_kwargs: dict = {
|
||||||
"graph_id": graph_id,
|
"graph_id": graph_id,
|
||||||
"graph_version": graph_version,
|
"graph_version": graph_version,
|
||||||
|
|||||||
@@ -892,11 +892,19 @@ async def add_graph_execution(
|
|||||||
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
||||||
|
|
||||||
execution_context = ExecutionContext(
|
execution_context = ExecutionContext(
|
||||||
|
# Execution identity
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_exec_id=graph_exec.id,
|
||||||
|
graph_version=graph_exec.graph_version,
|
||||||
|
# Safety settings
|
||||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||||
|
# User settings
|
||||||
user_timezone=(
|
user_timezone=(
|
||||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||||
),
|
),
|
||||||
|
# Execution hierarchy
|
||||||
root_execution_id=graph_exec.id,
|
root_execution_id=graph_exec.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -348,6 +348,7 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
|||||||
mock_graph_exec.id = "execution-id-123"
|
mock_graph_exec.id = "execution-id-123"
|
||||||
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
||||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||||
|
mock_graph_exec.graph_version = graph_version
|
||||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||||
|
|
||||||
# Mock the queue and event bus
|
# Mock the queue and event bus
|
||||||
@@ -434,6 +435,9 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
|||||||
# Create a second mock execution for the sanity check
|
# Create a second mock execution for the sanity check
|
||||||
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||||
mock_graph_exec_2.id = "execution-id-456"
|
mock_graph_exec_2.id = "execution-id-456"
|
||||||
|
mock_graph_exec_2.node_executions = []
|
||||||
|
mock_graph_exec_2.status = ExecutionStatus.QUEUED
|
||||||
|
mock_graph_exec_2.graph_version = graph_version
|
||||||
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
|
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||||
|
|
||||||
# Reset mocks and set up for second call
|
# Reset mocks and set up for second call
|
||||||
@@ -614,6 +618,7 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
|||||||
mock_graph_exec.id = "execution-id-123"
|
mock_graph_exec.id = "execution-id-123"
|
||||||
mock_graph_exec.node_executions = []
|
mock_graph_exec.node_executions = []
|
||||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||||
|
mock_graph_exec.graph_version = graph_version
|
||||||
|
|
||||||
# Track what's passed to to_graph_execution_entry
|
# Track what's passed to to_graph_execution_entry
|
||||||
captured_kwargs = {}
|
captured_kwargs = {}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import aiohttp
|
|||||||
from gcloud.aio import storage as async_gcs_storage
|
from gcloud.aio import storage as async_gcs_storage
|
||||||
from google.cloud import storage as gcs_storage
|
from google.cloud import storage as gcs_storage
|
||||||
|
|
||||||
|
from backend.util.gcs_utils import download_with_fresh_session, generate_signed_url
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -251,7 +252,7 @@ class CloudStorageHandler:
|
|||||||
f"in_task: {current_task is not None}"
|
f"in_task: {current_task is not None}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse bucket and blob name from path
|
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
|
||||||
parts = path.split("/", 1)
|
parts = path.split("/", 1)
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
raise ValueError(f"Invalid GCS path: {path}")
|
raise ValueError(f"Invalid GCS path: {path}")
|
||||||
@@ -261,50 +262,19 @@ class CloudStorageHandler:
|
|||||||
# Authorization check
|
# Authorization check
|
||||||
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
||||||
|
|
||||||
# Use a fresh client for each download to avoid session issues
|
|
||||||
# This is less efficient but more reliable with the executor's event loop
|
|
||||||
logger.info("[CloudStorage] Creating fresh GCS client for download")
|
|
||||||
|
|
||||||
# Create a new session specifically for this download
|
|
||||||
session = aiohttp.ClientSession(
|
|
||||||
connector=aiohttp.TCPConnector(limit=10, force_close=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
async_client = None
|
|
||||||
try:
|
|
||||||
# Create a new GCS client with the fresh session
|
|
||||||
async_client = async_gcs_storage.Storage(session=session)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
|
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Download content using the fresh client
|
try:
|
||||||
content = await async_client.download(bucket_name, blob_name)
|
content = await download_with_fresh_session(bucket_name, blob_name)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CloudStorage] GCS download successful - size: {len(content)} bytes"
|
f"[CloudStorage] GCS download successful - size: {len(content)} bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clean up
|
|
||||||
await async_client.close()
|
|
||||||
await session.close()
|
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Always try to clean up
|
|
||||||
if async_client is not None:
|
|
||||||
try:
|
|
||||||
await async_client.close()
|
|
||||||
except Exception as cleanup_error:
|
|
||||||
logger.warning(
|
|
||||||
f"[CloudStorage] Error closing GCS client: {cleanup_error}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await session.close()
|
|
||||||
except Exception as cleanup_error:
|
|
||||||
logger.warning(f"[CloudStorage] Error closing session: {cleanup_error}")
|
|
||||||
|
|
||||||
# Log the specific error for debugging
|
# Log the specific error for debugging
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[CloudStorage] GCS download failed - error: {str(e)}, "
|
f"[CloudStorage] GCS download failed - error: {str(e)}, "
|
||||||
@@ -319,10 +289,6 @@ class CloudStorageHandler:
|
|||||||
f"current_task: {current_task}, "
|
f"current_task: {current_task}, "
|
||||||
f"bucket: {bucket_name}, blob: redacted for privacy"
|
f"bucket: {bucket_name}, blob: redacted for privacy"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert gcloud-aio exceptions to standard ones
|
|
||||||
if "404" in str(e) or "Not Found" in str(e):
|
|
||||||
raise FileNotFoundError(f"File not found: gcs://{path}")
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _validate_file_access(
|
def _validate_file_access(
|
||||||
@@ -445,8 +411,7 @@ class CloudStorageHandler:
|
|||||||
graph_exec_id: str | None = None,
|
graph_exec_id: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate signed URL for GCS with authorization."""
|
"""Generate signed URL for GCS with authorization."""
|
||||||
|
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
|
||||||
# Parse bucket and blob name from path
|
|
||||||
parts = path.split("/", 1)
|
parts = path.split("/", 1)
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
raise ValueError(f"Invalid GCS path: {path}")
|
raise ValueError(f"Invalid GCS path: {path}")
|
||||||
@@ -456,21 +421,11 @@ class CloudStorageHandler:
|
|||||||
# Authorization check
|
# Authorization check
|
||||||
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
||||||
|
|
||||||
# Use sync client for signed URLs since gcloud-aio doesn't support them
|
|
||||||
sync_client = self._get_sync_gcs_client()
|
sync_client = self._get_sync_gcs_client()
|
||||||
bucket = sync_client.bucket(bucket_name)
|
return await generate_signed_url(
|
||||||
blob = bucket.blob(blob_name)
|
sync_client, bucket_name, blob_name, expiration_hours * 3600
|
||||||
|
|
||||||
# Generate signed URL asynchronously using sync client
|
|
||||||
url = await asyncio.to_thread(
|
|
||||||
blob.generate_signed_url,
|
|
||||||
version="v4",
|
|
||||||
expiration=datetime.now(timezone.utc) + timedelta(hours=expiration_hours),
|
|
||||||
method="GET",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return url
|
|
||||||
|
|
||||||
async def delete_expired_files(self, provider: str = "gcs") -> int:
|
async def delete_expired_files(self, provider: str = "gcs") -> int:
|
||||||
"""
|
"""
|
||||||
Delete files that have passed their expiration time.
|
Delete files that have passed their expiration time.
|
||||||
|
|||||||
@@ -5,13 +5,26 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Literal
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
from backend.util.settings import Config
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
|
||||||
|
# Return format options for store_media_file
|
||||||
|
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||||
|
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
||||||
|
# - "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
|
||||||
|
MediaReturnFormat = Literal[
|
||||||
|
"for_local_processing", "for_external_api", "for_block_output"
|
||||||
|
]
|
||||||
|
|
||||||
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
|
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
|
||||||
|
|
||||||
# Maximum filename length (conservative limit for most filesystems)
|
# Maximum filename length (conservative limit for most filesystems)
|
||||||
@@ -67,42 +80,56 @@ def clean_exec_files(graph_exec_id: str, file: str = "") -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def store_media_file(
|
async def store_media_file(
|
||||||
graph_exec_id: str,
|
|
||||||
file: MediaFileType,
|
file: MediaFileType,
|
||||||
user_id: str,
|
execution_context: "ExecutionContext",
|
||||||
return_content: bool = False,
|
*,
|
||||||
|
return_format: MediaReturnFormat,
|
||||||
) -> MediaFileType:
|
) -> MediaFileType:
|
||||||
"""
|
"""
|
||||||
Safely handle 'file' (a data URI, a URL, or a local path relative to {temp}/exec_file/{exec_id}),
|
Safely handle 'file' (a data URI, a URL, a workspace:// reference, or a local path
|
||||||
placing or verifying it under:
|
relative to {temp}/exec_file/{exec_id}), placing or verifying it under:
|
||||||
{tempdir}/exec_file/{exec_id}/...
|
{tempdir}/exec_file/{exec_id}/...
|
||||||
|
|
||||||
If 'return_content=True', return a data URI (data:<mime>;base64,<content>).
|
For each MediaFileType input:
|
||||||
Otherwise, returns the file media path relative to the exec_id folder.
|
- Data URI: decode and store locally
|
||||||
|
- URL: download and store locally
|
||||||
|
- workspace:// reference: read from workspace, store locally
|
||||||
|
- Local path: verify it exists in exec_file directory
|
||||||
|
|
||||||
For each MediaFileType type:
|
Return format options:
|
||||||
- Data URI:
|
- "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||||
-> decode and store in a new random file in that folder
|
- "for_external_api": Returns data URI (base64) - use when sending to external APIs
|
||||||
- URL:
|
- "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
|
||||||
-> download and store in that folder
|
|
||||||
- Local path:
|
|
||||||
-> interpret as relative to that folder; verify it exists
|
|
||||||
(no copying, as it's presumably already there).
|
|
||||||
We realpath-check so no symlink or '..' can escape the folder.
|
|
||||||
|
|
||||||
|
:param file: Data URI, URL, workspace://, or local (relative) path.
|
||||||
:param graph_exec_id: The unique ID of the graph execution.
|
:param execution_context: ExecutionContext with user_id, graph_exec_id, workspace_id.
|
||||||
:param file: Data URI, URL, or local (relative) path.
|
:param return_format: What to return: "for_local_processing", "for_external_api", or "for_block_output".
|
||||||
:param return_content: If True, return a data URI of the file content.
|
:return: The requested result based on return_format.
|
||||||
If False, return the *relative* path inside the exec_id folder.
|
|
||||||
:return: The requested result: data URI or relative path of the media.
|
|
||||||
"""
|
"""
|
||||||
|
# Extract values from execution_context
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
user_id = execution_context.user_id
|
||||||
|
|
||||||
|
if not graph_exec_id:
|
||||||
|
raise ValueError("execution_context.graph_exec_id is required")
|
||||||
|
if not user_id:
|
||||||
|
raise ValueError("execution_context.user_id is required")
|
||||||
|
|
||||||
|
# Create workspace_manager if we have workspace_id (with session scoping)
|
||||||
|
# Import here to avoid circular import (file.py → workspace.py → data → blocks → file.py)
|
||||||
|
from backend.util.workspace import WorkspaceManager
|
||||||
|
|
||||||
|
workspace_manager: WorkspaceManager | None = None
|
||||||
|
if execution_context.workspace_id:
|
||||||
|
workspace_manager = WorkspaceManager(
|
||||||
|
user_id, execution_context.workspace_id, execution_context.session_id
|
||||||
|
)
|
||||||
# Build base path
|
# Build base path
|
||||||
base_path = Path(get_exec_file_path(graph_exec_id, ""))
|
base_path = Path(get_exec_file_path(graph_exec_id, ""))
|
||||||
base_path.mkdir(parents=True, exist_ok=True)
|
base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Security fix: Add disk space limits to prevent DoS
|
# Security fix: Add disk space limits to prevent DoS
|
||||||
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB per file
|
MAX_FILE_SIZE_BYTES = Config().max_file_size_mb * 1024 * 1024
|
||||||
MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory
|
MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory
|
||||||
|
|
||||||
# Check total disk usage in base_path
|
# Check total disk usage in base_path
|
||||||
@@ -142,9 +169,57 @@ async def store_media_file(
|
|||||||
"""
|
"""
|
||||||
return str(absolute_path.relative_to(base))
|
return str(absolute_path.relative_to(base))
|
||||||
|
|
||||||
# Check if this is a cloud storage path
|
# Get cloud storage handler for checking cloud paths
|
||||||
cloud_storage = await get_cloud_storage_handler()
|
cloud_storage = await get_cloud_storage_handler()
|
||||||
if cloud_storage.is_cloud_path(file):
|
|
||||||
|
# Track if the input came from workspace (don't re-save it)
|
||||||
|
is_from_workspace = file.startswith("workspace://")
|
||||||
|
|
||||||
|
# Check if this is a workspace file reference
|
||||||
|
if is_from_workspace:
|
||||||
|
if workspace_manager is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Workspace file reference requires workspace context. "
|
||||||
|
"This file type is only available in CoPilot sessions."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse workspace reference
|
||||||
|
# workspace://abc123 - by file ID
|
||||||
|
# workspace:///path/to/file.txt - by virtual path
|
||||||
|
file_ref = file[12:] # Remove "workspace://"
|
||||||
|
|
||||||
|
if file_ref.startswith("/"):
|
||||||
|
# Path reference
|
||||||
|
workspace_content = await workspace_manager.read_file(file_ref)
|
||||||
|
file_info = await workspace_manager.get_file_info_by_path(file_ref)
|
||||||
|
filename = sanitize_filename(
|
||||||
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# ID reference
|
||||||
|
workspace_content = await workspace_manager.read_file_by_id(file_ref)
|
||||||
|
file_info = await workspace_manager.get_file_info(file_ref)
|
||||||
|
filename = sanitize_filename(
|
||||||
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||||
|
except OSError as e:
|
||||||
|
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||||
|
|
||||||
|
# Check file size limit
|
||||||
|
if len(workspace_content) > MAX_FILE_SIZE_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large: {len(workspace_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Virus scan the workspace content before writing locally
|
||||||
|
await scan_content_safe(workspace_content, filename=filename)
|
||||||
|
target_path.write_bytes(workspace_content)
|
||||||
|
|
||||||
|
# Check if this is a cloud storage path
|
||||||
|
elif cloud_storage.is_cloud_path(file):
|
||||||
# Download from cloud storage and store locally
|
# Download from cloud storage and store locally
|
||||||
cloud_content = await cloud_storage.retrieve_file(
|
cloud_content = await cloud_storage.retrieve_file(
|
||||||
file, user_id=user_id, graph_exec_id=graph_exec_id
|
file, user_id=user_id, graph_exec_id=graph_exec_id
|
||||||
@@ -159,9 +234,9 @@ async def store_media_file(
|
|||||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||||
|
|
||||||
# Check file size limit
|
# Check file size limit
|
||||||
if len(cloud_content) > MAX_FILE_SIZE:
|
if len(cloud_content) > MAX_FILE_SIZE_BYTES:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE} bytes"
|
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Virus scan the cloud content before writing locally
|
# Virus scan the cloud content before writing locally
|
||||||
@@ -189,9 +264,9 @@ async def store_media_file(
|
|||||||
content = base64.b64decode(b64_content)
|
content = base64.b64decode(b64_content)
|
||||||
|
|
||||||
# Check file size limit
|
# Check file size limit
|
||||||
if len(content) > MAX_FILE_SIZE:
|
if len(content) > MAX_FILE_SIZE_BYTES:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE} bytes"
|
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Virus scan the base64 content before writing
|
# Virus scan the base64 content before writing
|
||||||
@@ -199,23 +274,31 @@ async def store_media_file(
|
|||||||
target_path.write_bytes(content)
|
target_path.write_bytes(content)
|
||||||
|
|
||||||
elif file.startswith(("http://", "https://")):
|
elif file.startswith(("http://", "https://")):
|
||||||
# URL
|
# URL - download first to get Content-Type header
|
||||||
|
resp = await Requests().get(file)
|
||||||
|
|
||||||
|
# Check file size limit
|
||||||
|
if len(resp.content) > MAX_FILE_SIZE_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract filename from URL path
|
||||||
parsed_url = urlparse(file)
|
parsed_url = urlparse(file)
|
||||||
filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}")
|
filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}")
|
||||||
|
|
||||||
|
# If filename lacks extension, add one from Content-Type header
|
||||||
|
if "." not in filename:
|
||||||
|
content_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
|
||||||
|
if content_type:
|
||||||
|
ext = _extension_from_mime(content_type)
|
||||||
|
filename = f"{filename}{ext}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||||
|
|
||||||
# Download and save
|
|
||||||
resp = await Requests().get(file)
|
|
||||||
|
|
||||||
# Check file size limit
|
|
||||||
if len(resp.content) > MAX_FILE_SIZE:
|
|
||||||
raise ValueError(
|
|
||||||
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE} bytes"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Virus scan the downloaded content before writing
|
# Virus scan the downloaded content before writing
|
||||||
await scan_content_safe(resp.content, filename=filename)
|
await scan_content_safe(resp.content, filename=filename)
|
||||||
target_path.write_bytes(resp.content)
|
target_path.write_bytes(resp.content)
|
||||||
@@ -230,12 +313,44 @@ async def store_media_file(
|
|||||||
if not target_path.is_file():
|
if not target_path.is_file():
|
||||||
raise ValueError(f"Local file does not exist: {target_path}")
|
raise ValueError(f"Local file does not exist: {target_path}")
|
||||||
|
|
||||||
# Return result
|
# Return based on requested format
|
||||||
if return_content:
|
if return_format == "for_local_processing":
|
||||||
return MediaFileType(_file_to_data_uri(target_path))
|
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
||||||
else:
|
# Returns: relative path in exec_file directory (e.g., "image.png")
|
||||||
return MediaFileType(_strip_base_prefix(target_path, base_path))
|
return MediaFileType(_strip_base_prefix(target_path, base_path))
|
||||||
|
|
||||||
|
elif return_format == "for_external_api":
|
||||||
|
# Use when sending content to external APIs that need base64
|
||||||
|
# Returns: data URI (e.g., "data:image/png;base64,iVBORw0...")
|
||||||
|
return MediaFileType(_file_to_data_uri(target_path))
|
||||||
|
|
||||||
|
elif return_format == "for_block_output":
|
||||||
|
# Use when returning output from a block to user/next block
|
||||||
|
# Returns: workspace:// ref (CoPilot) or data URI (graph execution)
|
||||||
|
if workspace_manager is None:
|
||||||
|
# No workspace available (graph execution without CoPilot)
|
||||||
|
# Fallback to data URI so the content can still be used/displayed
|
||||||
|
return MediaFileType(_file_to_data_uri(target_path))
|
||||||
|
|
||||||
|
# Don't re-save if input was already from workspace
|
||||||
|
if is_from_workspace:
|
||||||
|
# Return original workspace reference
|
||||||
|
return MediaFileType(file)
|
||||||
|
|
||||||
|
# Save new content to workspace
|
||||||
|
content = target_path.read_bytes()
|
||||||
|
filename = target_path.name
|
||||||
|
|
||||||
|
file_record = await workspace_manager.write_file(
|
||||||
|
content=content,
|
||||||
|
filename=filename,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
return MediaFileType(f"workspace://{file_record.id}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid return_format: {return_format}")
|
||||||
|
|
||||||
|
|
||||||
def get_dir_size(path: Path) -> int:
|
def get_dir_size(path: Path) -> int:
|
||||||
"""Get total size of directory."""
|
"""Get total size of directory."""
|
||||||
|
|||||||
@@ -7,10 +7,22 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_context(
|
||||||
|
graph_exec_id: str = "test-exec-123",
|
||||||
|
user_id: str = "test-user-123",
|
||||||
|
) -> ExecutionContext:
|
||||||
|
"""Helper to create test ExecutionContext."""
|
||||||
|
return ExecutionContext(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestFileCloudIntegration:
|
class TestFileCloudIntegration:
|
||||||
"""Test cases for cloud storage integration in file utilities."""
|
"""Test cases for cloud storage integration in file utilities."""
|
||||||
|
|
||||||
@@ -70,10 +82,9 @@ class TestFileCloudIntegration:
|
|||||||
mock_path_class.side_effect = path_constructor
|
mock_path_class.side_effect = path_constructor
|
||||||
|
|
||||||
result = await store_media_file(
|
result = await store_media_file(
|
||||||
graph_exec_id,
|
file=MediaFileType(cloud_path),
|
||||||
MediaFileType(cloud_path),
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
"test-user-123",
|
return_format="for_local_processing",
|
||||||
return_content=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify cloud storage operations
|
# Verify cloud storage operations
|
||||||
@@ -144,10 +155,9 @@ class TestFileCloudIntegration:
|
|||||||
mock_path_obj.name = "image.png"
|
mock_path_obj.name = "image.png"
|
||||||
with patch("backend.util.file.Path", return_value=mock_path_obj):
|
with patch("backend.util.file.Path", return_value=mock_path_obj):
|
||||||
result = await store_media_file(
|
result = await store_media_file(
|
||||||
graph_exec_id,
|
file=MediaFileType(cloud_path),
|
||||||
MediaFileType(cloud_path),
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
"test-user-123",
|
return_format="for_external_api",
|
||||||
return_content=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify result is a data URI
|
# Verify result is a data URI
|
||||||
@@ -198,10 +208,9 @@ class TestFileCloudIntegration:
|
|||||||
mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt")
|
mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt")
|
||||||
|
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id,
|
file=MediaFileType(data_uri),
|
||||||
MediaFileType(data_uri),
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
"test-user-123",
|
return_format="for_local_processing",
|
||||||
return_content=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify cloud handler was checked but not used for retrieval
|
# Verify cloud handler was checked but not used for retrieval
|
||||||
@@ -234,5 +243,7 @@ class TestFileCloudIntegration:
|
|||||||
FileNotFoundError, match="File not found in cloud storage"
|
FileNotFoundError, match="File not found in cloud storage"
|
||||||
):
|
):
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id, MediaFileType(cloud_path), "test-user-123"
|
file=MediaFileType(cloud_path),
|
||||||
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|||||||
108
autogpt_platform/backend/backend/util/gcs_utils.py
Normal file
108
autogpt_platform/backend/backend/util/gcs_utils.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""
|
||||||
|
Shared GCS utilities for workspace and cloud storage backends.
|
||||||
|
|
||||||
|
This module provides common functionality for working with Google Cloud Storage,
|
||||||
|
including path parsing, client management, and signed URL generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from gcloud.aio import storage as async_gcs_storage
|
||||||
|
from google.cloud import storage as gcs_storage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_gcs_path(path: str) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Parse a GCS path in the format 'gcs://bucket/blob' to (bucket, blob).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: GCS path string (e.g., "gcs://my-bucket/path/to/file")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (bucket_name, blob_name)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the path format is invalid
|
||||||
|
"""
|
||||||
|
if not path.startswith("gcs://"):
|
||||||
|
raise ValueError(f"Invalid GCS path: {path}")
|
||||||
|
|
||||||
|
path_without_prefix = path[6:] # Remove "gcs://"
|
||||||
|
parts = path_without_prefix.split("/", 1)
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise ValueError(f"Invalid GCS path format: {path}")
|
||||||
|
|
||||||
|
return parts[0], parts[1]
|
||||||
|
|
||||||
|
|
||||||
|
async def download_with_fresh_session(bucket: str, blob: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Download file content using a fresh session.
|
||||||
|
|
||||||
|
This approach avoids event loop issues that can occur when reusing
|
||||||
|
sessions across different async contexts (e.g., in executors).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bucket: GCS bucket name
|
||||||
|
blob: Blob path within the bucket
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content as bytes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the file doesn't exist
|
||||||
|
"""
|
||||||
|
session = aiohttp.ClientSession(
|
||||||
|
connector=aiohttp.TCPConnector(limit=10, force_close=True)
|
||||||
|
)
|
||||||
|
client: async_gcs_storage.Storage | None = None
|
||||||
|
try:
|
||||||
|
client = async_gcs_storage.Storage(session=session)
|
||||||
|
content = await client.download(bucket, blob)
|
||||||
|
return content
|
||||||
|
except Exception as e:
|
||||||
|
if "404" in str(e) or "Not Found" in str(e):
|
||||||
|
raise FileNotFoundError(f"File not found: gcs://{bucket}/{blob}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if client:
|
||||||
|
try:
|
||||||
|
await client.close()
|
||||||
|
except Exception:
|
||||||
|
pass # Best-effort cleanup
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_signed_url(
|
||||||
|
sync_client: gcs_storage.Client,
|
||||||
|
bucket_name: str,
|
||||||
|
blob_name: str,
|
||||||
|
expires_in: int,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a signed URL for temporary access to a GCS file.
|
||||||
|
|
||||||
|
Uses asyncio.to_thread() to run the sync operation without blocking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sync_client: Sync GCS client with service account credentials
|
||||||
|
bucket_name: GCS bucket name
|
||||||
|
blob_name: Blob path within the bucket
|
||||||
|
expires_in: URL expiration time in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Signed URL string
|
||||||
|
"""
|
||||||
|
bucket = sync_client.bucket(bucket_name)
|
||||||
|
blob = bucket.blob(blob_name)
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
blob.generate_signed_url,
|
||||||
|
version="v4",
|
||||||
|
expiration=datetime.now(timezone.utc) + timedelta(seconds=expires_in),
|
||||||
|
method="GET",
|
||||||
|
)
|
||||||
@@ -263,6 +263,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
description="The name of the Google Cloud Storage bucket for media files",
|
description="The name of the Google Cloud Storage bucket for media files",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
workspace_storage_dir: str = Field(
|
||||||
|
default="",
|
||||||
|
description="Local directory for workspace file storage when GCS is not configured. "
|
||||||
|
"If empty, defaults to {app_data}/workspaces. Used for self-hosted deployments.",
|
||||||
|
)
|
||||||
|
|
||||||
reddit_user_agent: str = Field(
|
reddit_user_agent: str = Field(
|
||||||
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
|
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
|
||||||
description="The user agent for the Reddit API",
|
description="The user agent for the Reddit API",
|
||||||
@@ -389,6 +395,13 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
description="Maximum file size in MB for file uploads (1-1024 MB)",
|
description="Maximum file size in MB for file uploads (1-1024 MB)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
max_file_size_mb: int = Field(
|
||||||
|
default=100,
|
||||||
|
ge=1,
|
||||||
|
le=1024,
|
||||||
|
description="Maximum file size in MB for workspace files (1-1024 MB)",
|
||||||
|
)
|
||||||
|
|
||||||
# AutoMod configuration
|
# AutoMod configuration
|
||||||
automod_enabled: bool = Field(
|
automod_enabled: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
|
|||||||
@@ -140,14 +140,29 @@ async def execute_block_test(block: Block):
|
|||||||
setattr(block, mock_name, mock_obj)
|
setattr(block, mock_name, mock_obj)
|
||||||
|
|
||||||
# Populate credentials argument(s)
|
# Populate credentials argument(s)
|
||||||
|
# Generate IDs for execution context
|
||||||
|
graph_id = str(uuid.uuid4())
|
||||||
|
node_id = str(uuid.uuid4())
|
||||||
|
graph_exec_id = str(uuid.uuid4())
|
||||||
|
node_exec_id = str(uuid.uuid4())
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
graph_version = 1 # Default version for tests
|
||||||
|
|
||||||
extra_exec_kwargs: dict = {
|
extra_exec_kwargs: dict = {
|
||||||
"graph_id": str(uuid.uuid4()),
|
"graph_id": graph_id,
|
||||||
"node_id": str(uuid.uuid4()),
|
"node_id": node_id,
|
||||||
"graph_exec_id": str(uuid.uuid4()),
|
"graph_exec_id": graph_exec_id,
|
||||||
"node_exec_id": str(uuid.uuid4()),
|
"node_exec_id": node_exec_id,
|
||||||
"user_id": str(uuid.uuid4()),
|
"user_id": user_id,
|
||||||
"graph_version": 1, # Default version for tests
|
"graph_version": graph_version,
|
||||||
"execution_context": ExecutionContext(),
|
"execution_context": ExecutionContext(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
graph_version=graph_version,
|
||||||
|
node_id=node_id,
|
||||||
|
node_exec_id=node_exec_id,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
input_model = cast(type[BlockSchema], block.input_schema)
|
input_model = cast(type[BlockSchema], block.input_schema)
|
||||||
|
|
||||||
|
|||||||
419
autogpt_platform/backend/backend/util/workspace.py
Normal file
419
autogpt_platform/backend/backend/util/workspace.py
Normal file
@@ -0,0 +1,419 @@
|
|||||||
|
"""
|
||||||
|
WorkspaceManager for managing user workspace file operations.
|
||||||
|
|
||||||
|
This module provides a high-level interface for workspace file operations,
|
||||||
|
combining the storage backend and database layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from prisma.errors import UniqueViolationError
|
||||||
|
from prisma.models import UserWorkspaceFile
|
||||||
|
|
||||||
|
from backend.data.workspace import (
|
||||||
|
count_workspace_files,
|
||||||
|
create_workspace_file,
|
||||||
|
get_workspace_file,
|
||||||
|
get_workspace_file_by_path,
|
||||||
|
list_workspace_files,
|
||||||
|
soft_delete_workspace_file,
|
||||||
|
)
|
||||||
|
from backend.util.settings import Config
|
||||||
|
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceManager:
|
||||||
|
"""
|
||||||
|
Manages workspace file operations.
|
||||||
|
|
||||||
|
Combines storage backend operations with database record management.
|
||||||
|
Supports session-scoped file segmentation where files are stored in
|
||||||
|
session-specific virtual paths: /sessions/{session_id}/{filename}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, user_id: str, workspace_id: str, session_id: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize WorkspaceManager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
session_id: Optional session ID for session-scoped file access
|
||||||
|
"""
|
||||||
|
self.user_id = user_id
|
||||||
|
self.workspace_id = workspace_id
|
||||||
|
self.session_id = session_id
|
||||||
|
# Session path prefix for file isolation
|
||||||
|
self.session_path = f"/sessions/{session_id}" if session_id else ""
|
||||||
|
|
||||||
|
def _resolve_path(self, path: str) -> str:
|
||||||
|
"""
|
||||||
|
Resolve a path, defaulting to session folder if session_id is set.
|
||||||
|
|
||||||
|
Cross-session access is allowed by explicitly using /sessions/other-session-id/...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Virtual path (e.g., "/file.txt" or "/sessions/abc123/file.txt")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved path with session prefix if applicable
|
||||||
|
"""
|
||||||
|
# If path explicitly references a session folder, use it as-is
|
||||||
|
if path.startswith("/sessions/"):
|
||||||
|
return path
|
||||||
|
|
||||||
|
# If we have a session context, prepend session path
|
||||||
|
if self.session_path:
|
||||||
|
# Normalize the path
|
||||||
|
if not path.startswith("/"):
|
||||||
|
path = f"/{path}"
|
||||||
|
return f"{self.session_path}{path}"
|
||||||
|
|
||||||
|
# No session context, use path as-is
|
||||||
|
return path if path.startswith("/") else f"/{path}"
|
||||||
|
|
||||||
|
def _get_effective_path(
|
||||||
|
self, path: Optional[str], include_all_sessions: bool
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get effective path for list/count operations based on session context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Optional path prefix to filter
|
||||||
|
include_all_sessions: If True, don't apply session scoping
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Effective path prefix for database query
|
||||||
|
"""
|
||||||
|
if include_all_sessions:
|
||||||
|
# Normalize path to ensure leading slash (stored paths are normalized)
|
||||||
|
if path is not None and not path.startswith("/"):
|
||||||
|
return f"/{path}"
|
||||||
|
return path
|
||||||
|
elif path is not None:
|
||||||
|
# Resolve the provided path with session scoping
|
||||||
|
return self._resolve_path(path)
|
||||||
|
elif self.session_path:
|
||||||
|
# Default to session folder with trailing slash to prevent prefix collisions
|
||||||
|
# e.g., "/sessions/abc" should not match "/sessions/abc123"
|
||||||
|
return self.session_path.rstrip("/") + "/"
|
||||||
|
else:
|
||||||
|
# No session context, use path as-is
|
||||||
|
return path
|
||||||
|
|
||||||
|
async def read_file(self, path: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Read file from workspace by virtual path.
|
||||||
|
|
||||||
|
When session_id is set, paths are resolved relative to the session folder
|
||||||
|
unless they explicitly reference /sessions/...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Virtual path (e.g., "/documents/report.pdf")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content as bytes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If file doesn't exist
|
||||||
|
"""
|
||||||
|
resolved_path = self._resolve_path(path)
|
||||||
|
file = await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||||
|
if file is None:
|
||||||
|
raise FileNotFoundError(f"File not found at path: {resolved_path}")
|
||||||
|
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
return await storage.retrieve(file.storagePath)
|
||||||
|
|
||||||
|
async def read_file_by_id(self, file_id: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Read file from workspace by file ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content as bytes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If file doesn't exist
|
||||||
|
"""
|
||||||
|
file = await get_workspace_file(file_id, self.workspace_id)
|
||||||
|
if file is None:
|
||||||
|
raise FileNotFoundError(f"File not found: {file_id}")
|
||||||
|
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
return await storage.retrieve(file.storagePath)
|
||||||
|
|
||||||
|
async def write_file(
|
||||||
|
self,
|
||||||
|
content: bytes,
|
||||||
|
filename: str,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
mime_type: Optional[str] = None,
|
||||||
|
overwrite: bool = False,
|
||||||
|
) -> UserWorkspaceFile:
|
||||||
|
"""
|
||||||
|
Write file to workspace.
|
||||||
|
|
||||||
|
When session_id is set, files are written to /sessions/{session_id}/...
|
||||||
|
by default. Use explicit /sessions/... paths for cross-session access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: File content as bytes
|
||||||
|
filename: Filename for the file
|
||||||
|
path: Virtual path (defaults to "/{filename}", session-scoped if session_id set)
|
||||||
|
mime_type: MIME type (auto-detected if not provided)
|
||||||
|
overwrite: Whether to overwrite existing file at path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created UserWorkspaceFile instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If file exceeds size limit or path already exists
|
||||||
|
"""
|
||||||
|
# Enforce file size limit
|
||||||
|
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||||
|
if len(content) > max_file_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large: {len(content)} bytes exceeds "
|
||||||
|
f"{Config().max_file_size_mb}MB limit"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine path with session scoping
|
||||||
|
if path is None:
|
||||||
|
path = f"/{filename}"
|
||||||
|
elif not path.startswith("/"):
|
||||||
|
path = f"/{path}"
|
||||||
|
|
||||||
|
# Resolve path with session prefix
|
||||||
|
path = self._resolve_path(path)
|
||||||
|
|
||||||
|
# Check if file exists at path (only error for non-overwrite case)
|
||||||
|
# For overwrite=True, we let the write proceed and handle via UniqueViolationError
|
||||||
|
# This ensures the new file is written to storage BEFORE the old one is deleted,
|
||||||
|
# preventing data loss if the new write fails
|
||||||
|
if not overwrite:
|
||||||
|
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
||||||
|
if existing is not None:
|
||||||
|
raise ValueError(f"File already exists at path: {path}")
|
||||||
|
|
||||||
|
# Auto-detect MIME type if not provided
|
||||||
|
if mime_type is None:
|
||||||
|
mime_type, _ = mimetypes.guess_type(filename)
|
||||||
|
mime_type = mime_type or "application/octet-stream"
|
||||||
|
|
||||||
|
# Compute checksum
|
||||||
|
checksum = compute_file_checksum(content)
|
||||||
|
|
||||||
|
# Generate unique file ID for storage
|
||||||
|
file_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Store file in storage backend
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
storage_path = await storage.store(
|
||||||
|
workspace_id=self.workspace_id,
|
||||||
|
file_id=file_id,
|
||||||
|
filename=filename,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create database record - handle race condition where another request
|
||||||
|
# created a file at the same path between our check and create
|
||||||
|
try:
|
||||||
|
file = await create_workspace_file(
|
||||||
|
workspace_id=self.workspace_id,
|
||||||
|
file_id=file_id,
|
||||||
|
name=filename,
|
||||||
|
path=path,
|
||||||
|
storage_path=storage_path,
|
||||||
|
mime_type=mime_type,
|
||||||
|
size_bytes=len(content),
|
||||||
|
checksum=checksum,
|
||||||
|
)
|
||||||
|
except UniqueViolationError:
|
||||||
|
# Race condition: another request created a file at this path
|
||||||
|
if overwrite:
|
||||||
|
# Re-fetch and delete the conflicting file, then retry
|
||||||
|
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
||||||
|
if existing:
|
||||||
|
await self.delete_file(existing.id)
|
||||||
|
# Retry the create - if this also fails, clean up storage file
|
||||||
|
try:
|
||||||
|
file = await create_workspace_file(
|
||||||
|
workspace_id=self.workspace_id,
|
||||||
|
file_id=file_id,
|
||||||
|
name=filename,
|
||||||
|
path=path,
|
||||||
|
storage_path=storage_path,
|
||||||
|
mime_type=mime_type,
|
||||||
|
size_bytes=len(content),
|
||||||
|
checksum=checksum,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Clean up orphaned storage file on retry failure
|
||||||
|
try:
|
||||||
|
await storage.delete(storage_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# Clean up the orphaned storage file before raising
|
||||||
|
try:
|
||||||
|
await storage.delete(storage_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||||
|
raise ValueError(f"File already exists at path: {path}")
|
||||||
|
except Exception:
|
||||||
|
# Any other database error (connection, validation, etc.) - clean up storage
|
||||||
|
try:
|
||||||
|
await storage.delete(storage_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Wrote file {file.id} ({filename}) to workspace {self.workspace_id} "
|
||||||
|
f"at path {path}, size={len(content)} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
return file
|
||||||
|
|
||||||
|
async def list_files(
|
||||||
|
self,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
include_all_sessions: bool = False,
|
||||||
|
) -> list[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
List files in workspace.
|
||||||
|
|
||||||
|
When session_id is set and include_all_sessions is False (default),
|
||||||
|
only files in the current session's folder are listed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Optional path prefix to filter (e.g., "/documents/")
|
||||||
|
limit: Maximum number of files to return
|
||||||
|
offset: Number of files to skip
|
||||||
|
include_all_sessions: If True, list files from all sessions.
|
||||||
|
If False (default), only list current session's files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of UserWorkspaceFile instances
|
||||||
|
"""
|
||||||
|
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||||
|
|
||||||
|
return await list_workspace_files(
|
||||||
|
workspace_id=self.workspace_id,
|
||||||
|
path_prefix=effective_path,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_file(self, file_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a file (soft-delete).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False if not found
|
||||||
|
"""
|
||||||
|
file = await get_workspace_file(file_id, self.workspace_id)
|
||||||
|
if file is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Delete from storage
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
try:
|
||||||
|
await storage.delete(file.storagePath)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete file from storage: {e}")
|
||||||
|
# Continue with database soft-delete even if storage delete fails
|
||||||
|
|
||||||
|
# Soft-delete database record
|
||||||
|
result = await soft_delete_workspace_file(file_id, self.workspace_id)
|
||||||
|
return result is not None
|
||||||
|
|
||||||
|
async def get_download_url(self, file_id: str, expires_in: int = 3600) -> str:
|
||||||
|
"""
|
||||||
|
Get download URL for a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file's ID
|
||||||
|
expires_in: URL expiration in seconds (default 1 hour)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Download URL (signed URL for GCS, API endpoint for local)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If file doesn't exist
|
||||||
|
"""
|
||||||
|
file = await get_workspace_file(file_id, self.workspace_id)
|
||||||
|
if file is None:
|
||||||
|
raise FileNotFoundError(f"File not found: {file_id}")
|
||||||
|
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
return await storage.get_download_url(file.storagePath, expires_in)
|
||||||
|
|
||||||
|
async def get_file_info(self, file_id: str) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Get file metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspaceFile instance or None
|
||||||
|
"""
|
||||||
|
return await get_workspace_file(file_id, self.workspace_id)
|
||||||
|
|
||||||
|
async def get_file_info_by_path(self, path: str) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Get file metadata by path.
|
||||||
|
|
||||||
|
When session_id is set, paths are resolved relative to the session folder
|
||||||
|
unless they explicitly reference /sessions/...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Virtual path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspaceFile instance or None
|
||||||
|
"""
|
||||||
|
resolved_path = self._resolve_path(path)
|
||||||
|
return await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||||
|
|
||||||
|
async def get_file_count(
|
||||||
|
self,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
include_all_sessions: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Get number of files in workspace.
|
||||||
|
|
||||||
|
When session_id is set and include_all_sessions is False (default),
|
||||||
|
only counts files in the current session's folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Optional path prefix to filter (e.g., "/documents/")
|
||||||
|
include_all_sessions: If True, count all files in workspace.
|
||||||
|
If False (default), only count current session's files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of files
|
||||||
|
"""
|
||||||
|
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||||
|
|
||||||
|
return await count_workspace_files(
|
||||||
|
self.workspace_id, path_prefix=effective_path
|
||||||
|
)
|
||||||
398
autogpt_platform/backend/backend/util/workspace_storage.py
Normal file
398
autogpt_platform/backend/backend/util/workspace_storage.py
Normal file
@@ -0,0 +1,398 @@
|
|||||||
|
"""
|
||||||
|
Workspace storage backend abstraction for supporting both cloud and local deployments.
|
||||||
|
|
||||||
|
This module provides a unified interface for storing workspace files, with implementations
|
||||||
|
for Google Cloud Storage (cloud deployments) and local filesystem (self-hosted deployments).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import aiohttp
|
||||||
|
from gcloud.aio import storage as async_gcs_storage
|
||||||
|
from google.cloud import storage as gcs_storage
|
||||||
|
|
||||||
|
from backend.util.data import get_data_path
|
||||||
|
from backend.util.gcs_utils import (
|
||||||
|
download_with_fresh_session,
|
||||||
|
generate_signed_url,
|
||||||
|
parse_gcs_path,
|
||||||
|
)
|
||||||
|
from backend.util.settings import Config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceStorageBackend(ABC):
|
||||||
|
"""Abstract interface for workspace file storage."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def store(
|
||||||
|
self,
|
||||||
|
workspace_id: str,
|
||||||
|
file_id: str,
|
||||||
|
filename: str,
|
||||||
|
content: bytes,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Store file content, return storage path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
file_id: Unique file ID for storage
|
||||||
|
filename: Original filename
|
||||||
|
content: File content as bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage path string (cloud path or local path)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def retrieve(self, storage_path: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Retrieve file content from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path: The storage path returned from store()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content as bytes
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete(self, storage_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete file from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path: The storage path to delete
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
||||||
|
"""
|
||||||
|
Get URL for downloading the file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path: The storage path
|
||||||
|
expires_in: URL expiration time in seconds (default 1 hour)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Download URL (signed URL for GCS, direct API path for local)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GCSWorkspaceStorage(WorkspaceStorageBackend):
|
||||||
|
"""Google Cloud Storage implementation for workspace storage."""
|
||||||
|
|
||||||
|
def __init__(self, bucket_name: str):
|
||||||
|
self.bucket_name = bucket_name
|
||||||
|
self._async_client: Optional[async_gcs_storage.Storage] = None
|
||||||
|
self._sync_client: Optional[gcs_storage.Client] = None
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
|
||||||
|
async def _get_async_client(self) -> async_gcs_storage.Storage:
|
||||||
|
"""Get or create async GCS client."""
|
||||||
|
if self._async_client is None:
|
||||||
|
self._session = aiohttp.ClientSession(
|
||||||
|
connector=aiohttp.TCPConnector(limit=100, force_close=False)
|
||||||
|
)
|
||||||
|
self._async_client = async_gcs_storage.Storage(session=self._session)
|
||||||
|
return self._async_client
|
||||||
|
|
||||||
|
def _get_sync_client(self) -> gcs_storage.Client:
|
||||||
|
"""Get or create sync GCS client (for signed URLs)."""
|
||||||
|
if self._sync_client is None:
|
||||||
|
self._sync_client = gcs_storage.Client()
|
||||||
|
return self._sync_client
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close all client connections."""
|
||||||
|
if self._async_client is not None:
|
||||||
|
try:
|
||||||
|
await self._async_client.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing GCS client: {e}")
|
||||||
|
self._async_client = None
|
||||||
|
|
||||||
|
if self._session is not None:
|
||||||
|
try:
|
||||||
|
await self._session.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing session: {e}")
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
def _build_blob_name(self, workspace_id: str, file_id: str, filename: str) -> str:
|
||||||
|
"""Build the blob path for workspace files."""
|
||||||
|
return f"workspaces/{workspace_id}/{file_id}/{filename}"
|
||||||
|
|
||||||
|
async def store(
|
||||||
|
self,
|
||||||
|
workspace_id: str,
|
||||||
|
file_id: str,
|
||||||
|
filename: str,
|
||||||
|
content: bytes,
|
||||||
|
) -> str:
|
||||||
|
"""Store file in GCS."""
|
||||||
|
client = await self._get_async_client()
|
||||||
|
blob_name = self._build_blob_name(workspace_id, file_id, filename)
|
||||||
|
|
||||||
|
# Upload with metadata
|
||||||
|
upload_time = datetime.now(timezone.utc)
|
||||||
|
await client.upload(
|
||||||
|
self.bucket_name,
|
||||||
|
blob_name,
|
||||||
|
content,
|
||||||
|
metadata={
|
||||||
|
"uploaded_at": upload_time.isoformat(),
|
||||||
|
"workspace_id": workspace_id,
|
||||||
|
"file_id": file_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"gcs://{self.bucket_name}/{blob_name}"
|
||||||
|
|
||||||
|
async def retrieve(self, storage_path: str) -> bytes:
|
||||||
|
"""Retrieve file from GCS."""
|
||||||
|
bucket_name, blob_name = parse_gcs_path(storage_path)
|
||||||
|
return await download_with_fresh_session(bucket_name, blob_name)
|
||||||
|
|
||||||
|
async def delete(self, storage_path: str) -> None:
|
||||||
|
"""Delete file from GCS."""
|
||||||
|
bucket_name, blob_name = parse_gcs_path(storage_path)
|
||||||
|
client = await self._get_async_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await client.delete(bucket_name, blob_name)
|
||||||
|
except Exception as e:
|
||||||
|
if "404" not in str(e) and "Not Found" not in str(e):
|
||||||
|
raise
|
||||||
|
# File already deleted, that's fine
|
||||||
|
|
||||||
|
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
||||||
|
"""
|
||||||
|
Generate download URL for GCS file.
|
||||||
|
|
||||||
|
Attempts to generate a signed URL if running with service account credentials.
|
||||||
|
Falls back to an API proxy endpoint if signed URL generation fails
|
||||||
|
(e.g., when running locally with user OAuth credentials).
|
||||||
|
"""
|
||||||
|
bucket_name, blob_name = parse_gcs_path(storage_path)
|
||||||
|
|
||||||
|
# Extract file_id from blob_name for fallback: workspaces/{workspace_id}/{file_id}/{filename}
|
||||||
|
blob_parts = blob_name.split("/")
|
||||||
|
file_id = blob_parts[2] if len(blob_parts) >= 3 else None
|
||||||
|
|
||||||
|
# Try to generate signed URL (requires service account credentials)
|
||||||
|
try:
|
||||||
|
sync_client = self._get_sync_client()
|
||||||
|
return await generate_signed_url(
|
||||||
|
sync_client, bucket_name, blob_name, expires_in
|
||||||
|
)
|
||||||
|
except AttributeError as e:
|
||||||
|
# Signed URL generation requires service account with private key.
|
||||||
|
# When running with user OAuth credentials, fall back to API proxy.
|
||||||
|
if "private key" in str(e) and file_id:
|
||||||
|
logger.debug(
|
||||||
|
"Cannot generate signed URL (no service account credentials), "
|
||||||
|
"falling back to API proxy endpoint"
|
||||||
|
)
|
||||||
|
return f"/api/workspace/files/{file_id}/download"
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class LocalWorkspaceStorage(WorkspaceStorageBackend):
|
||||||
|
"""Local filesystem implementation for workspace storage (self-hosted deployments)."""
|
||||||
|
|
||||||
|
def __init__(self, base_dir: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize local storage backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_dir: Base directory for workspace storage.
|
||||||
|
If None, defaults to {app_data}/workspaces
|
||||||
|
"""
|
||||||
|
if base_dir:
|
||||||
|
self.base_dir = Path(base_dir)
|
||||||
|
else:
|
||||||
|
self.base_dir = Path(get_data_path()) / "workspaces"
|
||||||
|
|
||||||
|
# Ensure base directory exists
|
||||||
|
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def _build_file_path(self, workspace_id: str, file_id: str, filename: str) -> Path:
|
||||||
|
"""Build the local file path with path traversal protection."""
|
||||||
|
# Import here to avoid circular import
|
||||||
|
# (file.py imports workspace.py which imports workspace_storage.py)
|
||||||
|
from backend.util.file import sanitize_filename
|
||||||
|
|
||||||
|
# Sanitize filename to prevent path traversal (removes / and \ among others)
|
||||||
|
safe_filename = sanitize_filename(filename)
|
||||||
|
file_path = (self.base_dir / workspace_id / file_id / safe_filename).resolve()
|
||||||
|
|
||||||
|
# Verify the resolved path is still under base_dir
|
||||||
|
if not file_path.is_relative_to(self.base_dir.resolve()):
|
||||||
|
raise ValueError("Invalid filename: path traversal detected")
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
def _parse_storage_path(self, storage_path: str) -> Path:
|
||||||
|
"""Parse local storage path to filesystem path."""
|
||||||
|
if storage_path.startswith("local://"):
|
||||||
|
relative_path = storage_path[8:] # Remove "local://"
|
||||||
|
else:
|
||||||
|
relative_path = storage_path
|
||||||
|
|
||||||
|
full_path = (self.base_dir / relative_path).resolve()
|
||||||
|
|
||||||
|
# Security check: ensure path is under base_dir
|
||||||
|
# Use is_relative_to() for robust path containment check
|
||||||
|
# (handles case-insensitive filesystems and edge cases)
|
||||||
|
if not full_path.is_relative_to(self.base_dir.resolve()):
|
||||||
|
raise ValueError("Invalid storage path: path traversal detected")
|
||||||
|
|
||||||
|
return full_path
|
||||||
|
|
||||||
|
async def store(
|
||||||
|
self,
|
||||||
|
workspace_id: str,
|
||||||
|
file_id: str,
|
||||||
|
filename: str,
|
||||||
|
content: bytes,
|
||||||
|
) -> str:
|
||||||
|
"""Store file locally."""
|
||||||
|
file_path = self._build_file_path(workspace_id, file_id, filename)
|
||||||
|
|
||||||
|
# Create parent directories
|
||||||
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Write file asynchronously
|
||||||
|
async with aiofiles.open(file_path, "wb") as f:
|
||||||
|
await f.write(content)
|
||||||
|
|
||||||
|
# Return relative path as storage path
|
||||||
|
relative_path = file_path.relative_to(self.base_dir)
|
||||||
|
return f"local://{relative_path}"
|
||||||
|
|
||||||
|
async def retrieve(self, storage_path: str) -> bytes:
|
||||||
|
"""Retrieve file from local storage."""
|
||||||
|
file_path = self._parse_storage_path(storage_path)
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {storage_path}")
|
||||||
|
|
||||||
|
async with aiofiles.open(file_path, "rb") as f:
|
||||||
|
return await f.read()
|
||||||
|
|
||||||
|
async def delete(self, storage_path: str) -> None:
|
||||||
|
"""Delete file from local storage."""
|
||||||
|
file_path = self._parse_storage_path(storage_path)
|
||||||
|
|
||||||
|
if file_path.exists():
|
||||||
|
# Remove file
|
||||||
|
file_path.unlink()
|
||||||
|
|
||||||
|
# Clean up empty parent directories
|
||||||
|
parent = file_path.parent
|
||||||
|
while parent != self.base_dir:
|
||||||
|
try:
|
||||||
|
if parent.exists() and not any(parent.iterdir()):
|
||||||
|
parent.rmdir()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
except OSError:
|
||||||
|
break
|
||||||
|
parent = parent.parent
|
||||||
|
|
||||||
|
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
||||||
|
"""
|
||||||
|
Get download URL for local file.
|
||||||
|
|
||||||
|
For local storage, this returns an API endpoint path.
|
||||||
|
The actual serving is handled by the API layer.
|
||||||
|
"""
|
||||||
|
# Parse the storage path to get the components
|
||||||
|
if storage_path.startswith("local://"):
|
||||||
|
relative_path = storage_path[8:]
|
||||||
|
else:
|
||||||
|
relative_path = storage_path
|
||||||
|
|
||||||
|
# Return the API endpoint for downloading
|
||||||
|
# The file_id is extracted from the path: {workspace_id}/{file_id}/{filename}
|
||||||
|
parts = relative_path.split("/")
|
||||||
|
if len(parts) >= 2:
|
||||||
|
file_id = parts[1] # Second component is file_id
|
||||||
|
return f"/api/workspace/files/{file_id}/download"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid storage path format: {storage_path}")
|
||||||
|
|
||||||
|
|
||||||
|
# Global storage backend instance
|
||||||
|
_workspace_storage: Optional[WorkspaceStorageBackend] = None
|
||||||
|
_storage_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_storage() -> WorkspaceStorageBackend:
|
||||||
|
"""
|
||||||
|
Get the workspace storage backend instance.
|
||||||
|
|
||||||
|
Uses GCS if media_gcs_bucket_name is configured, otherwise uses local storage.
|
||||||
|
"""
|
||||||
|
global _workspace_storage
|
||||||
|
|
||||||
|
if _workspace_storage is None:
|
||||||
|
async with _storage_lock:
|
||||||
|
if _workspace_storage is None:
|
||||||
|
config = Config()
|
||||||
|
|
||||||
|
if config.media_gcs_bucket_name:
|
||||||
|
logger.info(
|
||||||
|
f"Using GCS workspace storage: {config.media_gcs_bucket_name}"
|
||||||
|
)
|
||||||
|
_workspace_storage = GCSWorkspaceStorage(
|
||||||
|
config.media_gcs_bucket_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
storage_dir = (
|
||||||
|
config.workspace_storage_dir
|
||||||
|
if config.workspace_storage_dir
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Using local workspace storage: {storage_dir or 'default'}"
|
||||||
|
)
|
||||||
|
_workspace_storage = LocalWorkspaceStorage(storage_dir)
|
||||||
|
|
||||||
|
return _workspace_storage
|
||||||
|
|
||||||
|
|
||||||
|
async def shutdown_workspace_storage() -> None:
|
||||||
|
"""
|
||||||
|
Properly shutdown the global workspace storage backend.
|
||||||
|
|
||||||
|
Closes aiohttp sessions and other resources for GCS backend.
|
||||||
|
Should be called during application shutdown.
|
||||||
|
"""
|
||||||
|
global _workspace_storage
|
||||||
|
|
||||||
|
if _workspace_storage is not None:
|
||||||
|
async with _storage_lock:
|
||||||
|
if _workspace_storage is not None:
|
||||||
|
if isinstance(_workspace_storage, GCSWorkspaceStorage):
|
||||||
|
await _workspace_storage.close()
|
||||||
|
_workspace_storage = None
|
||||||
|
|
||||||
|
|
||||||
|
def compute_file_checksum(content: bytes) -> str:
|
||||||
|
"""Compute SHA256 checksum of file content."""
|
||||||
|
return hashlib.sha256(content).hexdigest()
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
-- Migrate Claude 3.7 Sonnet to Claude 4.5 Sonnet
|
||||||
|
-- This updates all AgentNode blocks that use the deprecated Claude 3.7 Sonnet model
|
||||||
|
-- Anthropic is retiring claude-3-7-sonnet-20250219 on February 19, 2026
|
||||||
|
|
||||||
|
-- Update AgentNode constant inputs
|
||||||
|
UPDATE "AgentNode"
|
||||||
|
SET "constantInput" = JSONB_SET(
|
||||||
|
"constantInput"::jsonb,
|
||||||
|
'{model}',
|
||||||
|
'"claude-sonnet-4-5-20250929"'::jsonb
|
||||||
|
)
|
||||||
|
WHERE "constantInput"::jsonb->>'model' = 'claude-3-7-sonnet-20250219';
|
||||||
|
|
||||||
|
-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput)
|
||||||
|
UPDATE "AgentNodeExecutionInputOutput"
|
||||||
|
SET "data" = JSONB_SET(
|
||||||
|
"data"::jsonb,
|
||||||
|
'{model}',
|
||||||
|
'"claude-sonnet-4-5-20250929"'::jsonb
|
||||||
|
)
|
||||||
|
WHERE "agentPresetId" IS NOT NULL
|
||||||
|
AND "data"::jsonb->>'model' = 'claude-3-7-sonnet-20250219';
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
-- CreateEnum
|
||||||
|
CREATE TYPE "WorkspaceFileSource" AS ENUM ('UPLOAD', 'EXECUTION', 'COPILOT', 'IMPORT');
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "UserWorkspace" (
|
||||||
|
"id" TEXT NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
"userId" TEXT NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT "UserWorkspace_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "UserWorkspaceFile" (
|
||||||
|
"id" TEXT NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||||
|
"workspaceId" TEXT NOT NULL,
|
||||||
|
"name" TEXT NOT NULL,
|
||||||
|
"path" TEXT NOT NULL,
|
||||||
|
"storagePath" TEXT NOT NULL,
|
||||||
|
"mimeType" TEXT NOT NULL,
|
||||||
|
"sizeBytes" BIGINT NOT NULL,
|
||||||
|
"checksum" TEXT,
|
||||||
|
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
"deletedAt" TIMESTAMP(3),
|
||||||
|
"source" "WorkspaceFileSource" NOT NULL DEFAULT 'UPLOAD',
|
||||||
|
"sourceExecId" TEXT,
|
||||||
|
"sourceSessionId" TEXT,
|
||||||
|
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||||
|
|
||||||
|
CONSTRAINT "UserWorkspaceFile_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "UserWorkspace_userId_key" ON "UserWorkspace"("userId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "UserWorkspace_userId_idx" ON "UserWorkspace"("userId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "UserWorkspaceFile_workspaceId_isDeleted_idx" ON "UserWorkspaceFile"("workspaceId", "isDeleted");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "UserWorkspaceFile_workspaceId_path_key" ON "UserWorkspaceFile"("workspaceId", "path");
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "UserWorkspace" ADD CONSTRAINT "UserWorkspace_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "UserWorkspaceFile" ADD CONSTRAINT "UserWorkspaceFile_workspaceId_fkey" FOREIGN KEY ("workspaceId") REFERENCES "UserWorkspace"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to drop the column `source` on the `UserWorkspaceFile` table. All the data in the column will be lost.
|
||||||
|
- You are about to drop the column `sourceExecId` on the `UserWorkspaceFile` table. All the data in the column will be lost.
|
||||||
|
- You are about to drop the column `sourceSessionId` on the `UserWorkspaceFile` table. All the data in the column will be lost.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "UserWorkspaceFile" DROP COLUMN "source",
|
||||||
|
DROP COLUMN "sourceExecId",
|
||||||
|
DROP COLUMN "sourceSessionId";
|
||||||
|
|
||||||
|
-- DropEnum
|
||||||
|
DROP TYPE "WorkspaceFileSource";
|
||||||
@@ -63,6 +63,7 @@ model User {
|
|||||||
IntegrationWebhooks IntegrationWebhook[]
|
IntegrationWebhooks IntegrationWebhook[]
|
||||||
NotificationBatches UserNotificationBatch[]
|
NotificationBatches UserNotificationBatch[]
|
||||||
PendingHumanReviews PendingHumanReview[]
|
PendingHumanReviews PendingHumanReview[]
|
||||||
|
Workspace UserWorkspace?
|
||||||
|
|
||||||
// OAuth Provider relations
|
// OAuth Provider relations
|
||||||
OAuthApplications OAuthApplication[]
|
OAuthApplications OAuthApplication[]
|
||||||
@@ -137,6 +138,53 @@ model CoPilotUnderstanding {
|
|||||||
@@index([userId])
|
@@index([userId])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
//////////////// USER WORKSPACE TABLES /////////////////
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// User's persistent file storage workspace
|
||||||
|
model UserWorkspace {
|
||||||
|
id String @id @default(uuid())
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
|
userId String @unique
|
||||||
|
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
Files UserWorkspaceFile[]
|
||||||
|
|
||||||
|
@@index([userId])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Individual files in a user's workspace
|
||||||
|
model UserWorkspaceFile {
|
||||||
|
id String @id @default(uuid())
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
|
||||||
|
workspaceId String
|
||||||
|
Workspace UserWorkspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
// File metadata
|
||||||
|
name String // User-visible filename
|
||||||
|
path String // Virtual path (e.g., "/documents/report.pdf")
|
||||||
|
storagePath String // Actual GCS or local storage path
|
||||||
|
mimeType String
|
||||||
|
sizeBytes BigInt
|
||||||
|
checksum String? // SHA256 for integrity
|
||||||
|
|
||||||
|
// File state
|
||||||
|
isDeleted Boolean @default(false)
|
||||||
|
deletedAt DateTime?
|
||||||
|
|
||||||
|
metadata Json @default("{}")
|
||||||
|
|
||||||
|
@@unique([workspaceId, path])
|
||||||
|
@@index([workspaceId, isDeleted])
|
||||||
|
}
|
||||||
|
|
||||||
model BuilderSearchHistory {
|
model BuilderSearchHistory {
|
||||||
id String @id @default(uuid())
|
id String @id @default(uuid())
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
|
|||||||
@@ -57,7 +57,8 @@ class TestDecomposeGoal:
|
|||||||
|
|
||||||
result = await core.decompose_goal("Build a chatbot")
|
result = await core.decompose_goal("Build a chatbot")
|
||||||
|
|
||||||
mock_external.assert_called_once_with("Build a chatbot", "")
|
# library_agents defaults to None
|
||||||
|
mock_external.assert_called_once_with("Build a chatbot", "", None)
|
||||||
assert result == expected_result
|
assert result == expected_result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -74,7 +75,8 @@ class TestDecomposeGoal:
|
|||||||
|
|
||||||
await core.decompose_goal("Build a chatbot", "Use Python")
|
await core.decompose_goal("Build a chatbot", "Use Python")
|
||||||
|
|
||||||
mock_external.assert_called_once_with("Build a chatbot", "Use Python")
|
# library_agents defaults to None
|
||||||
|
mock_external.assert_called_once_with("Build a chatbot", "Use Python", None)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_returns_none_on_service_failure(self):
|
async def test_returns_none_on_service_failure(self):
|
||||||
@@ -109,7 +111,8 @@ class TestGenerateAgent:
|
|||||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||||
result = await core.generate_agent(instructions)
|
result = await core.generate_agent(instructions)
|
||||||
|
|
||||||
mock_external.assert_called_once_with(instructions)
|
# library_agents defaults to None
|
||||||
|
mock_external.assert_called_once_with(instructions, None)
|
||||||
# Result should have id, version, is_active added if not present
|
# Result should have id, version, is_active added if not present
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["name"] == "Test Agent"
|
assert result["name"] == "Test Agent"
|
||||||
@@ -174,7 +177,8 @@ class TestGenerateAgentPatch:
|
|||||||
current_agent = {"nodes": [], "links": []}
|
current_agent = {"nodes": [], "links": []}
|
||||||
result = await core.generate_agent_patch("Add a node", current_agent)
|
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||||
|
|
||||||
mock_external.assert_called_once_with("Add a node", current_agent)
|
# library_agents defaults to None
|
||||||
|
mock_external.assert_called_once_with("Add a node", current_agent, None)
|
||||||
assert result == expected_result
|
assert result == expected_result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -0,0 +1,838 @@
|
|||||||
|
"""
|
||||||
|
Tests for library agent fetching functionality in agent generator.
|
||||||
|
|
||||||
|
This test suite verifies the search-based library agent fetching,
|
||||||
|
including the combination of library and marketplace agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.agent_generator import core
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLibraryAgentsForGeneration:
|
||||||
|
"""Test get_library_agents_for_generation function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetches_agents_with_search_term(self):
|
||||||
|
"""Test that search_term is passed to the library db."""
|
||||||
|
# Create a mock agent with proper attribute values
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.graph_id = "agent-123"
|
||||||
|
mock_agent.graph_version = 1
|
||||||
|
mock_agent.name = "Email Agent"
|
||||||
|
mock_agent.description = "Sends emails"
|
||||||
|
mock_agent.input_schema = {"properties": {}}
|
||||||
|
mock_agent.output_schema = {"properties": {}}
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.agents = [mock_agent]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"list_library_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_list:
|
||||||
|
result = await core.get_library_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query="send email",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify search_term was passed
|
||||||
|
mock_list.assert_called_once_with(
|
||||||
|
user_id="user-123",
|
||||||
|
search_term="send email",
|
||||||
|
page=1,
|
||||||
|
page_size=15,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify result format
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["graph_id"] == "agent-123"
|
||||||
|
assert result[0]["name"] == "Email Agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_excludes_specified_graph_id(self):
|
||||||
|
"""Test that agents with excluded graph_id are filtered out."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.agents = [
|
||||||
|
MagicMock(
|
||||||
|
graph_id="agent-123",
|
||||||
|
graph_version=1,
|
||||||
|
name="Agent 1",
|
||||||
|
description="First agent",
|
||||||
|
input_schema={},
|
||||||
|
output_schema={},
|
||||||
|
),
|
||||||
|
MagicMock(
|
||||||
|
graph_id="agent-456",
|
||||||
|
graph_version=1,
|
||||||
|
name="Agent 2",
|
||||||
|
description="Second agent",
|
||||||
|
input_schema={},
|
||||||
|
output_schema={},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"list_library_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
):
|
||||||
|
result = await core.get_library_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
exclude_graph_id="agent-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the excluded agent is not in results
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["graph_id"] == "agent-456"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_respects_max_results(self):
|
||||||
|
"""Test that max_results parameter limits the page_size."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.agents = []
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"list_library_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_list:
|
||||||
|
await core.get_library_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
max_results=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify page_size was set to max_results
|
||||||
|
mock_list.assert_called_once_with(
|
||||||
|
user_id="user-123",
|
||||||
|
search_term=None,
|
||||||
|
page=1,
|
||||||
|
page_size=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchMarketplaceAgentsForGeneration:
|
||||||
|
"""Test search_marketplace_agents_for_generation function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_searches_marketplace_with_query(self):
|
||||||
|
"""Test that marketplace is searched with the query."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.agents = [
|
||||||
|
MagicMock(
|
||||||
|
agent_name="Public Agent",
|
||||||
|
description="A public agent",
|
||||||
|
sub_heading="Does something useful",
|
||||||
|
creator="creator-1",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# The store_db is dynamically imported, so patch the import path
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.db.get_store_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_search:
|
||||||
|
result = await core.search_marketplace_agents_for_generation(
|
||||||
|
search_query="automation",
|
||||||
|
max_results=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_search.assert_called_once_with(
|
||||||
|
search_query="automation",
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["name"] == "Public Agent"
|
||||||
|
assert result[0]["is_marketplace_agent"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_marketplace_error_gracefully(self):
|
||||||
|
"""Test that marketplace errors don't crash the function."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.db.get_store_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=Exception("Marketplace unavailable"),
|
||||||
|
):
|
||||||
|
result = await core.search_marketplace_agents_for_generation(
|
||||||
|
search_query="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return empty list, not raise exception
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAllRelevantAgentsForGeneration:
|
||||||
|
"""Test get_all_relevant_agents_for_generation function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_combines_library_and_marketplace_agents(self):
|
||||||
|
"""Test that agents from both sources are combined."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "lib-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Library Agent",
|
||||||
|
"description": "From library",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
marketplace_agents = [
|
||||||
|
{
|
||||||
|
"name": "Market Agent",
|
||||||
|
"description": "From marketplace",
|
||||||
|
"sub_heading": "Sub heading",
|
||||||
|
"creator": "creator-1",
|
||||||
|
"is_marketplace_agent": True,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_library_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=library_agents,
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=marketplace_agents,
|
||||||
|
):
|
||||||
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query="test query",
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Library agents should come first
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["name"] == "Library Agent"
|
||||||
|
assert result[1]["name"] == "Market Agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deduplicates_by_name(self):
|
||||||
|
"""Test that marketplace agents with same name as library are excluded."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "lib-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Shared Agent",
|
||||||
|
"description": "From library",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
marketplace_agents = [
|
||||||
|
{
|
||||||
|
"name": "Shared Agent", # Same name, should be deduplicated
|
||||||
|
"description": "From marketplace",
|
||||||
|
"sub_heading": "Sub heading",
|
||||||
|
"creator": "creator-1",
|
||||||
|
"is_marketplace_agent": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Unique Agent",
|
||||||
|
"description": "Only in marketplace",
|
||||||
|
"sub_heading": "Sub heading",
|
||||||
|
"creator": "creator-2",
|
||||||
|
"is_marketplace_agent": True,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_library_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=library_agents,
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=marketplace_agents,
|
||||||
|
):
|
||||||
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query="test",
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shared Agent from marketplace should be excluded
|
||||||
|
assert len(result) == 2
|
||||||
|
names = [a["name"] for a in result]
|
||||||
|
assert "Shared Agent" in names
|
||||||
|
assert "Unique Agent" in names
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_marketplace_when_disabled(self):
|
||||||
|
"""Test that marketplace is not searched when include_marketplace=False."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "lib-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Library Agent",
|
||||||
|
"description": "From library",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_library_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=library_agents,
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_marketplace:
|
||||||
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query="test",
|
||||||
|
include_marketplace=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Marketplace should not be called
|
||||||
|
mock_marketplace.assert_not_called()
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_marketplace_when_no_search_query(self):
|
||||||
|
"""Test that marketplace is not searched without a search query."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "lib-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Library Agent",
|
||||||
|
"description": "From library",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_library_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=library_agents,
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_marketplace:
|
||||||
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query=None, # No search query
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Marketplace should not be called without search query
|
||||||
|
mock_marketplace.assert_not_called()
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractSearchTermsFromSteps:
|
||||||
|
"""Test extract_search_terms_from_steps function."""
|
||||||
|
|
||||||
|
def test_extracts_terms_from_instructions_type(self):
|
||||||
|
"""Test extraction from valid instructions decomposition result."""
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{
|
||||||
|
"description": "Send an email notification",
|
||||||
|
"block_name": "GmailSendBlock",
|
||||||
|
},
|
||||||
|
{"description": "Fetch weather data", "action": "Get weather API"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
assert "Send an email notification" in result
|
||||||
|
assert "GmailSendBlock" in result
|
||||||
|
assert "Fetch weather data" in result
|
||||||
|
assert "Get weather API" in result
|
||||||
|
|
||||||
|
def test_returns_empty_for_non_instructions_type(self):
|
||||||
|
"""Test that non-instructions types return empty list."""
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": [{"question": "What email?"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_deduplicates_terms_case_insensitively(self):
|
||||||
|
"""Test that duplicate terms are removed (case-insensitive)."""
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{"description": "Send Email", "name": "send email"},
|
||||||
|
{"description": "Other task"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
# Should only have one "send email" variant
|
||||||
|
email_terms = [t for t in result if "email" in t.lower()]
|
||||||
|
assert len(email_terms) == 1
|
||||||
|
|
||||||
|
def test_filters_short_terms(self):
|
||||||
|
"""Test that terms with 3 or fewer characters are filtered out."""
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{"description": "ab", "action": "xyz"}, # Both too short
|
||||||
|
{"description": "Valid term here"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
assert "ab" not in result
|
||||||
|
assert "xyz" not in result
|
||||||
|
assert "Valid term here" in result
|
||||||
|
|
||||||
|
def test_handles_empty_steps(self):
|
||||||
|
"""Test handling of empty steps list."""
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnrichLibraryAgentsFromSteps:
|
||||||
|
"""Test enrich_library_agents_from_steps function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enriches_with_additional_agents(self):
|
||||||
|
"""Test that additional agents are found based on steps."""
|
||||||
|
existing_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "existing-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Existing Agent",
|
||||||
|
"description": "Already fetched",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
additional_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "new-456",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Email Agent",
|
||||||
|
"description": "For sending emails",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{"description": "Send email notification"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_all_relevant_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=additional_agents,
|
||||||
|
):
|
||||||
|
result = await core.enrich_library_agents_from_steps(
|
||||||
|
user_id="user-123",
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=existing_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have both existing and new agents
|
||||||
|
assert len(result) == 2
|
||||||
|
names = [a["name"] for a in result]
|
||||||
|
assert "Existing Agent" in names
|
||||||
|
assert "Email Agent" in names
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deduplicates_by_graph_id(self):
|
||||||
|
"""Test that agents with same graph_id are not duplicated."""
|
||||||
|
existing_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Existing Agent",
|
||||||
|
"description": "Already fetched",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Additional search returns same agent
|
||||||
|
additional_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-123", # Same ID
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Existing Agent Copy",
|
||||||
|
"description": "Same agent different name",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [{"description": "Some action"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_all_relevant_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=additional_agents,
|
||||||
|
):
|
||||||
|
result = await core.enrich_library_agents_from_steps(
|
||||||
|
user_id="user-123",
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=existing_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not duplicate
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deduplicates_by_name(self):
|
||||||
|
"""Test that agents with same name are not duplicated."""
|
||||||
|
existing_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Email Agent",
|
||||||
|
"description": "Already fetched",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Additional search returns agent with same name but different ID
|
||||||
|
additional_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-456", # Different ID
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Email Agent", # Same name
|
||||||
|
"description": "Different agent same name",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [{"description": "Send email"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_all_relevant_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=additional_agents,
|
||||||
|
):
|
||||||
|
result = await core.enrich_library_agents_from_steps(
|
||||||
|
user_id="user-123",
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=existing_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not duplicate by name
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].get("graph_id") == "agent-123" # Original kept
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_existing_when_no_steps(self):
|
||||||
|
"""Test that existing agents are returned when no search terms extracted."""
|
||||||
|
existing_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "existing-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Existing Agent",
|
||||||
|
"description": "Already fetched",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "clarifying_questions", # Not instructions type
|
||||||
|
"questions": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await core.enrich_library_agents_from_steps(
|
||||||
|
user_id="user-123",
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=existing_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return existing unchanged
|
||||||
|
assert result == existing_agents
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_limits_search_terms_to_three(self):
|
||||||
|
"""Test that only first 3 search terms are used."""
|
||||||
|
existing_agents = []
|
||||||
|
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{"description": "First action"},
|
||||||
|
{"description": "Second action"},
|
||||||
|
{"description": "Third action"},
|
||||||
|
{"description": "Fourth action"},
|
||||||
|
{"description": "Fifth action"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def mock_get_agents(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_all_relevant_agents_for_generation",
|
||||||
|
side_effect=mock_get_agents,
|
||||||
|
):
|
||||||
|
await core.enrich_library_agents_from_steps(
|
||||||
|
user_id="user-123",
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=existing_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only make 3 calls (limited to first 3 terms)
|
||||||
|
assert call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractUuidsFromText:
|
||||||
|
"""Test extract_uuids_from_text function."""
|
||||||
|
|
||||||
|
def test_extracts_single_uuid(self):
|
||||||
|
"""Test extraction of a single UUID from text."""
|
||||||
|
text = "Use my agent 46631191-e8a8-486f-ad90-84f89738321d for this task"
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "46631191-e8a8-486f-ad90-84f89738321d" in result
|
||||||
|
|
||||||
|
def test_extracts_multiple_uuids(self):
|
||||||
|
"""Test extraction of multiple UUIDs from text."""
|
||||||
|
text = (
|
||||||
|
"Combine agents 11111111-1111-4111-8111-111111111111 "
|
||||||
|
"and 22222222-2222-4222-9222-222222222222"
|
||||||
|
)
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "11111111-1111-4111-8111-111111111111" in result
|
||||||
|
assert "22222222-2222-4222-9222-222222222222" in result
|
||||||
|
|
||||||
|
def test_deduplicates_uuids(self):
|
||||||
|
"""Test that duplicate UUIDs are deduplicated."""
|
||||||
|
text = (
|
||||||
|
"Use 46631191-e8a8-486f-ad90-84f89738321d twice: "
|
||||||
|
"46631191-e8a8-486f-ad90-84f89738321d"
|
||||||
|
)
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_normalizes_to_lowercase(self):
|
||||||
|
"""Test that UUIDs are normalized to lowercase."""
|
||||||
|
text = "Use 46631191-E8A8-486F-AD90-84F89738321D"
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
assert result[0] == "46631191-e8a8-486f-ad90-84f89738321d"
|
||||||
|
|
||||||
|
def test_returns_empty_for_no_uuids(self):
|
||||||
|
"""Test that empty list is returned when no UUIDs found."""
|
||||||
|
text = "Create an email agent that sends notifications"
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_ignores_invalid_uuids(self):
|
||||||
|
"""Test that invalid UUID-like strings are ignored."""
|
||||||
|
text = "Not a valid UUID: 12345678-1234-1234-1234-123456789abc"
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
# UUID v4 requires specific patterns (4 in third group, 8/9/a/b in fourth)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLibraryAgentById:
|
||||||
|
"""Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_agent_when_found_by_graph_id(self):
|
||||||
|
"""Test that agent is returned when found by graph_id."""
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.graph_id = "agent-123"
|
||||||
|
mock_agent.graph_version = 1
|
||||||
|
mock_agent.name = "Test Agent"
|
||||||
|
mock_agent.description = "Test description"
|
||||||
|
mock_agent.input_schema = {"properties": {}}
|
||||||
|
mock_agent.output_schema = {"properties": {}}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent_by_graph_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_agent,
|
||||||
|
):
|
||||||
|
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["graph_id"] == "agent-123"
|
||||||
|
assert result["name"] == "Test Agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_falls_back_to_library_agent_id(self):
|
||||||
|
"""Test that lookup falls back to library agent ID when graph_id not found."""
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.graph_id = "graph-456" # Different from the lookup ID
|
||||||
|
mock_agent.graph_version = 1
|
||||||
|
mock_agent.name = "Library Agent"
|
||||||
|
mock_agent.description = "Found by library ID"
|
||||||
|
mock_agent.input_schema = {"properties": {}}
|
||||||
|
mock_agent.output_schema = {"properties": {}}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent_by_graph_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None, # Not found by graph_id
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_agent, # Found by library ID
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await core.get_library_agent_by_id("user-123", "library-id-123")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["graph_id"] == "graph-456"
|
||||||
|
assert result["name"] == "Library Agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_not_found_by_either_method(self):
|
||||||
|
"""Test that None is returned when agent not found by either method."""
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent_by_graph_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=core.NotFoundError("Not found"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await core.get_library_agent_by_id("user-123", "nonexistent")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_on_exception(self):
|
||||||
|
"""Test that None is returned when exception occurs in both lookups."""
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent_by_graph_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=Exception("Database error"),
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=Exception("Database error"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_alias_works(self):
|
||||||
|
"""Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id."""
|
||||||
|
assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAllRelevantAgentsWithUuids:
|
||||||
|
"""Test UUID extraction in get_all_relevant_agents_for_generation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetches_explicitly_mentioned_agents(self):
|
||||||
|
"""Test that agents mentioned by UUID are fetched directly."""
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.graph_id = "46631191-e8a8-486f-ad90-84f89738321d"
|
||||||
|
mock_agent.graph_version = 1
|
||||||
|
mock_agent.name = "Mentioned Agent"
|
||||||
|
mock_agent.description = "Explicitly mentioned"
|
||||||
|
mock_agent.input_schema = {}
|
||||||
|
mock_agent.output_schema = {}
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.agents = []
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent_by_graph_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_agent,
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"list_library_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d",
|
||||||
|
include_marketplace=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].get("graph_id") == "46631191-e8a8-486f-ad90-84f89738321d"
|
||||||
|
assert result[0].get("name") == "Mentioned Agent"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
@@ -151,15 +151,20 @@ class TestDecomposeGoalExternal:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_decompose_goal_handles_http_error(self):
|
async def test_decompose_goal_handles_http_error(self):
|
||||||
"""Test decomposition handles HTTP errors gracefully."""
|
"""Test decomposition handles HTTP errors gracefully."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 500
|
||||||
mock_client = AsyncMock()
|
mock_client = AsyncMock()
|
||||||
mock_client.post.side_effect = httpx.HTTPStatusError(
|
mock_client.post.side_effect = httpx.HTTPStatusError(
|
||||||
"Server error", request=MagicMock(), response=MagicMock()
|
"Server error", request=MagicMock(), response=mock_response
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch.object(service, "_get_client", return_value=mock_client):
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
result = await service.decompose_goal_external("Build a chatbot")
|
result = await service.decompose_goal_external("Build a chatbot")
|
||||||
|
|
||||||
assert result is None
|
assert result is not None
|
||||||
|
assert result.get("type") == "error"
|
||||||
|
assert result.get("error_type") == "http_error"
|
||||||
|
assert "Server error" in result.get("error", "")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_decompose_goal_handles_request_error(self):
|
async def test_decompose_goal_handles_request_error(self):
|
||||||
@@ -170,7 +175,10 @@ class TestDecomposeGoalExternal:
|
|||||||
with patch.object(service, "_get_client", return_value=mock_client):
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
result = await service.decompose_goal_external("Build a chatbot")
|
result = await service.decompose_goal_external("Build a chatbot")
|
||||||
|
|
||||||
assert result is None
|
assert result is not None
|
||||||
|
assert result.get("type") == "error"
|
||||||
|
assert result.get("error_type") == "connection_error"
|
||||||
|
assert "Connection failed" in result.get("error", "")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_decompose_goal_handles_service_error(self):
|
async def test_decompose_goal_handles_service_error(self):
|
||||||
@@ -179,6 +187,7 @@ class TestDecomposeGoalExternal:
|
|||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": "Internal error",
|
"error": "Internal error",
|
||||||
|
"error_type": "internal_error",
|
||||||
}
|
}
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
@@ -188,7 +197,10 @@ class TestDecomposeGoalExternal:
|
|||||||
with patch.object(service, "_get_client", return_value=mock_client):
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
result = await service.decompose_goal_external("Build a chatbot")
|
result = await service.decompose_goal_external("Build a chatbot")
|
||||||
|
|
||||||
assert result is None
|
assert result is not None
|
||||||
|
assert result.get("type") == "error"
|
||||||
|
assert result.get("error") == "Internal error"
|
||||||
|
assert result.get("error_type") == "internal_error"
|
||||||
|
|
||||||
|
|
||||||
class TestGenerateAgentExternal:
|
class TestGenerateAgentExternal:
|
||||||
@@ -236,7 +248,10 @@ class TestGenerateAgentExternal:
|
|||||||
with patch.object(service, "_get_client", return_value=mock_client):
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
result = await service.generate_agent_external({"steps": []})
|
result = await service.generate_agent_external({"steps": []})
|
||||||
|
|
||||||
assert result is None
|
assert result is not None
|
||||||
|
assert result.get("type") == "error"
|
||||||
|
assert result.get("error_type") == "connection_error"
|
||||||
|
assert "Connection failed" in result.get("error", "")
|
||||||
|
|
||||||
|
|
||||||
class TestGenerateAgentPatchExternal:
|
class TestGenerateAgentPatchExternal:
|
||||||
@@ -418,5 +433,139 @@ class TestGetBlocksExternal:
|
|||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestLibraryAgentsPassthrough:
|
||||||
|
"""Test that library_agents are passed correctly in all requests."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset client singleton before each test."""
|
||||||
|
service._settings = None
|
||||||
|
service._client = None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_passes_library_agents(self):
|
||||||
|
"""Test that library_agents are included in decompose goal payload."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Email Sender",
|
||||||
|
"description": "Sends emails",
|
||||||
|
"input_schema": {"properties": {"to": {"type": "string"}}},
|
||||||
|
"output_schema": {"properties": {"sent": {"type": "boolean"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": ["Step 1"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
await service.decompose_goal_external(
|
||||||
|
"Send an email",
|
||||||
|
library_agents=library_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify library_agents was passed in the payload
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_agent_passes_library_agents(self):
|
||||||
|
"""Test that library_agents are included in generate agent payload."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-456",
|
||||||
|
"graph_version": 2,
|
||||||
|
"name": "Data Fetcher",
|
||||||
|
"description": "Fetches data from API",
|
||||||
|
"input_schema": {"properties": {"url": {"type": "string"}}},
|
||||||
|
"output_schema": {"properties": {"data": {"type": "object"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"agent_json": {"name": "Test Agent", "nodes": []},
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
await service.generate_agent_external(
|
||||||
|
{"steps": ["Step 1"]},
|
||||||
|
library_agents=library_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify library_agents was passed in the payload
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_agent_patch_passes_library_agents(self):
|
||||||
|
"""Test that library_agents are included in patch generation payload."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-789",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Slack Notifier",
|
||||||
|
"description": "Sends Slack messages",
|
||||||
|
"input_schema": {"properties": {"message": {"type": "string"}}},
|
||||||
|
"output_schema": {"properties": {"success": {"type": "boolean"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"agent_json": {"name": "Updated Agent", "nodes": []},
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
await service.generate_agent_patch_external(
|
||||||
|
"Add error handling",
|
||||||
|
{"name": "Original Agent", "nodes": []},
|
||||||
|
library_agents=library_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify library_agents was passed in the payload
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_without_library_agents(self):
|
||||||
|
"""Test that decompose goal works without library_agents."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": ["Step 1"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
await service.decompose_goal_external("Build a workflow")
|
||||||
|
|
||||||
|
# Verify library_agents was NOT passed when not provided
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert "library_agents" not in call_args[1]["json"]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, "-v"])
|
pytest.main([__file__, "-v"])
|
||||||
|
|||||||
@@ -43,19 +43,24 @@ faker = Faker()
|
|||||||
# Constants for data generation limits (reduced for E2E tests)
|
# Constants for data generation limits (reduced for E2E tests)
|
||||||
NUM_USERS = 15
|
NUM_USERS = 15
|
||||||
NUM_AGENT_BLOCKS = 30
|
NUM_AGENT_BLOCKS = 30
|
||||||
MIN_GRAPHS_PER_USER = 15
|
MIN_GRAPHS_PER_USER = 25
|
||||||
MAX_GRAPHS_PER_USER = 15
|
MAX_GRAPHS_PER_USER = 25
|
||||||
MIN_NODES_PER_GRAPH = 3
|
MIN_NODES_PER_GRAPH = 3
|
||||||
MAX_NODES_PER_GRAPH = 6
|
MAX_NODES_PER_GRAPH = 6
|
||||||
MIN_PRESETS_PER_USER = 2
|
MIN_PRESETS_PER_USER = 2
|
||||||
MAX_PRESETS_PER_USER = 3
|
MAX_PRESETS_PER_USER = 3
|
||||||
MIN_AGENTS_PER_USER = 15
|
MIN_AGENTS_PER_USER = 25
|
||||||
MAX_AGENTS_PER_USER = 15
|
MAX_AGENTS_PER_USER = 25
|
||||||
MIN_EXECUTIONS_PER_GRAPH = 2
|
MIN_EXECUTIONS_PER_GRAPH = 2
|
||||||
MAX_EXECUTIONS_PER_GRAPH = 8
|
MAX_EXECUTIONS_PER_GRAPH = 8
|
||||||
MIN_REVIEWS_PER_VERSION = 2
|
MIN_REVIEWS_PER_VERSION = 2
|
||||||
MAX_REVIEWS_PER_VERSION = 5
|
MAX_REVIEWS_PER_VERSION = 5
|
||||||
|
|
||||||
|
# Guaranteed minimums for marketplace tests (deterministic)
|
||||||
|
GUARANTEED_FEATURED_AGENTS = 8
|
||||||
|
GUARANTEED_FEATURED_CREATORS = 5
|
||||||
|
GUARANTEED_TOP_AGENTS = 10
|
||||||
|
|
||||||
|
|
||||||
def get_image():
|
def get_image():
|
||||||
"""Generate a consistent image URL using picsum.photos service."""
|
"""Generate a consistent image URL using picsum.photos service."""
|
||||||
@@ -385,7 +390,7 @@ class TestDataCreator:
|
|||||||
|
|
||||||
library_agents = []
|
library_agents = []
|
||||||
for user in self.users:
|
for user in self.users:
|
||||||
num_agents = 10 # Create exactly 10 agents per user
|
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
|
||||||
|
|
||||||
# Get available graphs for this user
|
# Get available graphs for this user
|
||||||
user_graphs = [
|
user_graphs = [
|
||||||
@@ -507,14 +512,17 @@ class TestDataCreator:
|
|||||||
existing_profiles, min(num_creators, len(existing_profiles))
|
existing_profiles, min(num_creators, len(existing_profiles))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mark about 50% of creators as featured (more for testing)
|
# Guarantee at least GUARANTEED_FEATURED_CREATORS featured creators
|
||||||
num_featured = max(2, int(num_creators * 0.5))
|
num_featured = max(GUARANTEED_FEATURED_CREATORS, int(num_creators * 0.5))
|
||||||
num_featured = min(
|
num_featured = min(
|
||||||
num_featured, len(selected_profiles)
|
num_featured, len(selected_profiles)
|
||||||
) # Don't exceed available profiles
|
) # Don't exceed available profiles
|
||||||
featured_profile_ids = set(
|
featured_profile_ids = set(
|
||||||
random.sample([p.id for p in selected_profiles], num_featured)
|
random.sample([p.id for p in selected_profiles], num_featured)
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
f"🎯 Creating {num_featured} featured creators (min: {GUARANTEED_FEATURED_CREATORS})"
|
||||||
|
)
|
||||||
|
|
||||||
for profile in selected_profiles:
|
for profile in selected_profiles:
|
||||||
try:
|
try:
|
||||||
@@ -545,21 +553,25 @@ class TestDataCreator:
|
|||||||
return profiles
|
return profiles
|
||||||
|
|
||||||
async def create_test_store_submissions(self) -> List[Dict[str, Any]]:
|
async def create_test_store_submissions(self) -> List[Dict[str, Any]]:
|
||||||
"""Create test store submissions using the API function."""
|
"""Create test store submissions using the API function.
|
||||||
|
|
||||||
|
DETERMINISTIC: Guarantees minimum featured agents for E2E tests.
|
||||||
|
"""
|
||||||
print("Creating test store submissions...")
|
print("Creating test store submissions...")
|
||||||
|
|
||||||
submissions = []
|
submissions = []
|
||||||
approved_submissions = []
|
approved_submissions = []
|
||||||
|
featured_count = 0
|
||||||
|
submission_counter = 0
|
||||||
|
|
||||||
# Create a special test submission for test123@gmail.com
|
# Create a special test submission for test123@gmail.com (ALWAYS approved + featured)
|
||||||
test_user = next(
|
test_user = next(
|
||||||
(user for user in self.users if user["email"] == "test123@gmail.com"), None
|
(user for user in self.users if user["email"] == "test123@gmail.com"), None
|
||||||
)
|
)
|
||||||
if test_user:
|
if test_user and self.agent_graphs:
|
||||||
# Special test data for consistent testing
|
|
||||||
test_submission_data = {
|
test_submission_data = {
|
||||||
"user_id": test_user["id"],
|
"user_id": test_user["id"],
|
||||||
"agent_id": self.agent_graphs[0]["id"], # Use first available graph
|
"agent_id": self.agent_graphs[0]["id"],
|
||||||
"agent_version": 1,
|
"agent_version": 1,
|
||||||
"slug": "test-agent-submission",
|
"slug": "test-agent-submission",
|
||||||
"name": "Test Agent Submission",
|
"name": "Test Agent Submission",
|
||||||
@@ -580,10 +592,8 @@ class TestDataCreator:
|
|||||||
submissions.append(test_submission.model_dump())
|
submissions.append(test_submission.model_dump())
|
||||||
print("✅ Created special test store submission for test123@gmail.com")
|
print("✅ Created special test store submission for test123@gmail.com")
|
||||||
|
|
||||||
# Randomly approve, reject, or leave pending the test submission
|
# ALWAYS approve and feature the test submission
|
||||||
if test_submission.store_listing_version_id:
|
if test_submission.store_listing_version_id:
|
||||||
random_value = random.random()
|
|
||||||
if random_value < 0.4: # 40% chance to approve
|
|
||||||
approved_submission = await review_store_submission(
|
approved_submission = await review_store_submission(
|
||||||
store_listing_version_id=test_submission.store_listing_version_id,
|
store_listing_version_id=test_submission.store_listing_version_id,
|
||||||
is_approved=True,
|
is_approved=True,
|
||||||
@@ -594,23 +604,12 @@ class TestDataCreator:
|
|||||||
approved_submissions.append(approved_submission.model_dump())
|
approved_submissions.append(approved_submission.model_dump())
|
||||||
print("✅ Approved test store submission")
|
print("✅ Approved test store submission")
|
||||||
|
|
||||||
# Mark approved submission as featured
|
|
||||||
await prisma.storelistingversion.update(
|
await prisma.storelistingversion.update(
|
||||||
where={"id": test_submission.store_listing_version_id},
|
where={"id": test_submission.store_listing_version_id},
|
||||||
data={"isFeatured": True},
|
data={"isFeatured": True},
|
||||||
)
|
)
|
||||||
|
featured_count += 1
|
||||||
print("🌟 Marked test agent as FEATURED")
|
print("🌟 Marked test agent as FEATURED")
|
||||||
elif random_value < 0.7: # 30% chance to reject (40% to 70%)
|
|
||||||
await review_store_submission(
|
|
||||||
store_listing_version_id=test_submission.store_listing_version_id,
|
|
||||||
is_approved=False,
|
|
||||||
external_comments="Test submission rejected - needs improvements",
|
|
||||||
internal_comments="Auto-rejected test submission for E2E testing",
|
|
||||||
reviewer_id=test_user["id"],
|
|
||||||
)
|
|
||||||
print("❌ Rejected test store submission")
|
|
||||||
else: # 30% chance to leave pending (70% to 100%)
|
|
||||||
print("⏳ Left test submission pending for review")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating test store submission: {e}")
|
print(f"Error creating test store submission: {e}")
|
||||||
@@ -620,7 +619,6 @@ class TestDataCreator:
|
|||||||
|
|
||||||
# Create regular submissions for all users
|
# Create regular submissions for all users
|
||||||
for user in self.users:
|
for user in self.users:
|
||||||
# Get available graphs for this specific user
|
|
||||||
user_graphs = [
|
user_graphs = [
|
||||||
g for g in self.agent_graphs if g.get("userId") == user["id"]
|
g for g in self.agent_graphs if g.get("userId") == user["id"]
|
||||||
]
|
]
|
||||||
@@ -631,18 +629,17 @@ class TestDataCreator:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Create exactly 4 store submissions per user
|
|
||||||
for submission_index in range(4):
|
for submission_index in range(4):
|
||||||
graph = random.choice(user_graphs)
|
graph = random.choice(user_graphs)
|
||||||
|
submission_counter += 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(
|
print(
|
||||||
f"Creating store submission for user {user['id']} with graph {graph['id']} (owner: {graph.get('userId')})"
|
f"Creating store submission for user {user['id']} with graph {graph['id']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the API function to create store submission with correct parameters
|
|
||||||
submission = await create_store_submission(
|
submission = await create_store_submission(
|
||||||
user_id=user["id"], # Must match graph's userId
|
user_id=user["id"],
|
||||||
agent_id=graph["id"],
|
agent_id=graph["id"],
|
||||||
agent_version=graph.get("version", 1),
|
agent_version=graph.get("version", 1),
|
||||||
slug=faker.slug(),
|
slug=faker.slug(),
|
||||||
@@ -651,22 +648,24 @@ class TestDataCreator:
|
|||||||
video_url=get_video_url() if random.random() < 0.3 else None,
|
video_url=get_video_url() if random.random() < 0.3 else None,
|
||||||
image_urls=[get_image() for _ in range(3)],
|
image_urls=[get_image() for _ in range(3)],
|
||||||
description=faker.text(),
|
description=faker.text(),
|
||||||
categories=[
|
categories=[get_category()],
|
||||||
get_category()
|
|
||||||
], # Single category from predefined list
|
|
||||||
changes_summary="Initial E2E test submission",
|
changes_summary="Initial E2E test submission",
|
||||||
)
|
)
|
||||||
submissions.append(submission.model_dump())
|
submissions.append(submission.model_dump())
|
||||||
print(f"✅ Created store submission: {submission.name}")
|
print(f"✅ Created store submission: {submission.name}")
|
||||||
|
|
||||||
# Randomly approve, reject, or leave pending the submission
|
|
||||||
if submission.store_listing_version_id:
|
if submission.store_listing_version_id:
|
||||||
random_value = random.random()
|
# DETERMINISTIC: First N submissions are always approved
|
||||||
if random_value < 0.4: # 40% chance to approve
|
# First GUARANTEED_FEATURED_AGENTS of those are always featured
|
||||||
try:
|
should_approve = (
|
||||||
# Pick a random user as the reviewer (admin)
|
submission_counter <= GUARANTEED_TOP_AGENTS
|
||||||
reviewer_id = random.choice(self.users)["id"]
|
or random.random() < 0.4
|
||||||
|
)
|
||||||
|
should_feature = featured_count < GUARANTEED_FEATURED_AGENTS
|
||||||
|
|
||||||
|
if should_approve:
|
||||||
|
try:
|
||||||
|
reviewer_id = random.choice(self.users)["id"]
|
||||||
approved_submission = await review_store_submission(
|
approved_submission = await review_store_submission(
|
||||||
store_listing_version_id=submission.store_listing_version_id,
|
store_listing_version_id=submission.store_listing_version_id,
|
||||||
is_approved=True,
|
is_approved=True,
|
||||||
@@ -681,16 +680,7 @@ class TestDataCreator:
|
|||||||
f"✅ Approved store submission: {submission.name}"
|
f"✅ Approved store submission: {submission.name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mark some agents as featured during creation (30% chance)
|
if should_feature:
|
||||||
# More likely for creators and first submissions
|
|
||||||
is_creator = user["id"] in [
|
|
||||||
p.get("userId") for p in self.profiles
|
|
||||||
]
|
|
||||||
feature_chance = (
|
|
||||||
0.5 if is_creator else 0.2
|
|
||||||
) # 50% for creators, 20% for others
|
|
||||||
|
|
||||||
if random.random() < feature_chance:
|
|
||||||
try:
|
try:
|
||||||
await prisma.storelistingversion.update(
|
await prisma.storelistingversion.update(
|
||||||
where={
|
where={
|
||||||
@@ -698,8 +688,25 @@ class TestDataCreator:
|
|||||||
},
|
},
|
||||||
data={"isFeatured": True},
|
data={"isFeatured": True},
|
||||||
)
|
)
|
||||||
|
featured_count += 1
|
||||||
print(
|
print(
|
||||||
f"🌟 Marked agent as FEATURED: {submission.name}"
|
f"🌟 Marked agent as FEATURED ({featured_count}/{GUARANTEED_FEATURED_AGENTS}): {submission.name}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"Warning: Could not mark submission as featured: {e}"
|
||||||
|
)
|
||||||
|
elif random.random() < 0.2:
|
||||||
|
try:
|
||||||
|
await prisma.storelistingversion.update(
|
||||||
|
where={
|
||||||
|
"id": submission.store_listing_version_id
|
||||||
|
},
|
||||||
|
data={"isFeatured": True},
|
||||||
|
)
|
||||||
|
featured_count += 1
|
||||||
|
print(
|
||||||
|
f"🌟 Marked agent as FEATURED (bonus): {submission.name}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(
|
print(
|
||||||
@@ -710,11 +717,9 @@ class TestDataCreator:
|
|||||||
print(
|
print(
|
||||||
f"Warning: Could not approve submission {submission.name}: {e}"
|
f"Warning: Could not approve submission {submission.name}: {e}"
|
||||||
)
|
)
|
||||||
elif random_value < 0.7: # 30% chance to reject (40% to 70%)
|
elif random.random() < 0.5:
|
||||||
try:
|
try:
|
||||||
# Pick a random user as the reviewer (admin)
|
|
||||||
reviewer_id = random.choice(self.users)["id"]
|
reviewer_id = random.choice(self.users)["id"]
|
||||||
|
|
||||||
await review_store_submission(
|
await review_store_submission(
|
||||||
store_listing_version_id=submission.store_listing_version_id,
|
store_listing_version_id=submission.store_listing_version_id,
|
||||||
is_approved=False,
|
is_approved=False,
|
||||||
@@ -729,7 +734,7 @@ class TestDataCreator:
|
|||||||
print(
|
print(
|
||||||
f"Warning: Could not reject submission {submission.name}: {e}"
|
f"Warning: Could not reject submission {submission.name}: {e}"
|
||||||
)
|
)
|
||||||
else: # 30% chance to leave pending (70% to 100%)
|
else:
|
||||||
print(
|
print(
|
||||||
f"⏳ Left submission pending for review: {submission.name}"
|
f"⏳ Left submission pending for review: {submission.name}"
|
||||||
)
|
)
|
||||||
@@ -743,9 +748,13 @@ class TestDataCreator:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
print("\n📊 Store Submissions Summary:")
|
||||||
|
print(f" Created: {len(submissions)}")
|
||||||
|
print(f" Approved: {len(approved_submissions)}")
|
||||||
print(
|
print(
|
||||||
f"Created {len(submissions)} store submissions, approved {len(approved_submissions)}"
|
f" Featured: {featured_count} (guaranteed min: {GUARANTEED_FEATURED_AGENTS})"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store_submissions = submissions
|
self.store_submissions = submissions
|
||||||
return submissions
|
return submissions
|
||||||
|
|
||||||
@@ -825,12 +834,15 @@ class TestDataCreator:
|
|||||||
print(f"✅ Agent blocks available: {len(self.agent_blocks)}")
|
print(f"✅ Agent blocks available: {len(self.agent_blocks)}")
|
||||||
print(f"✅ Agent graphs created: {len(self.agent_graphs)}")
|
print(f"✅ Agent graphs created: {len(self.agent_graphs)}")
|
||||||
print(f"✅ Library agents created: {len(self.library_agents)}")
|
print(f"✅ Library agents created: {len(self.library_agents)}")
|
||||||
print(f"✅ Creator profiles updated: {len(self.profiles)} (some featured)")
|
print(f"✅ Creator profiles updated: {len(self.profiles)}")
|
||||||
print(
|
print(f"✅ Store submissions created: {len(self.store_submissions)}")
|
||||||
f"✅ Store submissions created: {len(self.store_submissions)} (some marked as featured during creation)"
|
|
||||||
)
|
|
||||||
print(f"✅ API keys created: {len(self.api_keys)}")
|
print(f"✅ API keys created: {len(self.api_keys)}")
|
||||||
print(f"✅ Presets created: {len(self.presets)}")
|
print(f"✅ Presets created: {len(self.presets)}")
|
||||||
|
print("\n🎯 Deterministic Guarantees:")
|
||||||
|
print(f" • Featured agents: >= {GUARANTEED_FEATURED_AGENTS}")
|
||||||
|
print(f" • Featured creators: >= {GUARANTEED_FEATURED_CREATORS}")
|
||||||
|
print(f" • Top agents (approved): >= {GUARANTEED_TOP_AGENTS}")
|
||||||
|
print(f" • Library agents per user: >= {MIN_AGENTS_PER_USER}")
|
||||||
print("\n🚀 Your E2E test database is ready to use!")
|
print("\n🚀 Your E2E test database is ready to use!")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -34,3 +34,6 @@ NEXT_PUBLIC_PREVIEW_STEALING_DEV=
|
|||||||
# PostHog Analytics
|
# PostHog Analytics
|
||||||
NEXT_PUBLIC_POSTHOG_KEY=
|
NEXT_PUBLIC_POSTHOG_KEY=
|
||||||
NEXT_PUBLIC_POSTHOG_HOST=https://eu.i.posthog.com
|
NEXT_PUBLIC_POSTHOG_HOST=https://eu.i.posthog.com
|
||||||
|
|
||||||
|
# OpenAI (for voice transcription)
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
|||||||
76
autogpt_platform/frontend/CLAUDE.md
Normal file
76
autogpt_platform/frontend/CLAUDE.md
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
# CLAUDE.md - Frontend
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code when working with the frontend.
|
||||||
|
|
||||||
|
## Essential Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dependencies
|
||||||
|
pnpm i
|
||||||
|
|
||||||
|
# Generate API client from OpenAPI spec
|
||||||
|
pnpm generate:api
|
||||||
|
|
||||||
|
# Start development server
|
||||||
|
pnpm dev
|
||||||
|
|
||||||
|
# Run E2E tests
|
||||||
|
pnpm test
|
||||||
|
|
||||||
|
# Run Storybook for component development
|
||||||
|
pnpm storybook
|
||||||
|
|
||||||
|
# Build production
|
||||||
|
pnpm build
|
||||||
|
|
||||||
|
# Format and lint
|
||||||
|
pnpm format
|
||||||
|
|
||||||
|
# Type checking
|
||||||
|
pnpm types
|
||||||
|
```
|
||||||
|
|
||||||
|
### Code Style
|
||||||
|
|
||||||
|
- Fully capitalize acronyms in symbols, e.g. `graphID`, `useBackendAPI`
|
||||||
|
- Use function declarations (not arrow functions) for components/handlers
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||||
|
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||||
|
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||||
|
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||||
|
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||||
|
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||||
|
- **Icons**: Phosphor Icons only
|
||||||
|
- **Feature Flags**: LaunchDarkly integration
|
||||||
|
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||||
|
- **Testing**: Playwright for E2E, Storybook for component development
|
||||||
|
|
||||||
|
## Environment Configuration
|
||||||
|
|
||||||
|
`.env.default` (defaults) → `.env` (user overrides)
|
||||||
|
|
||||||
|
## Feature Development
|
||||||
|
|
||||||
|
See @CONTRIBUTING.md for complete patterns. Quick reference:
|
||||||
|
|
||||||
|
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||||
|
- Extract component logic into custom hooks grouped by concern, not by component. Each hook should represent a cohesive domain of functionality (e.g., useSearch, useFilters, usePagination) rather than bundling all state into one useComponentState hook.
|
||||||
|
- Put each hook in its own `.ts` file
|
||||||
|
- Put sub-components in local `components/` folder
|
||||||
|
- Component props should be `type Props = { ... }` (not exported) unless it needs to be used outside the component
|
||||||
|
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||||
|
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||||
|
- Never use `src/components/__legacy__/*`
|
||||||
|
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||||
|
- Regenerate with `pnpm generate:api`
|
||||||
|
- Pattern: `use{Method}{Version}{OperationName}`
|
||||||
|
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||||
|
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||||
|
6. **Code conventions**:
|
||||||
|
- Use function declarations (not arrow functions) for components/handlers
|
||||||
|
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
|
||||||
|
- Do not type hook returns, let Typescript infer as much as possible
|
||||||
|
- Never type with `any` unless a variable/attribute can ACTUALLY be of any type
|
||||||
@@ -73,9 +73,9 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const reset = () => {
|
const reset = () => {
|
||||||
|
// Only reset the offset - keep existing sessions visible during refetch
|
||||||
|
// The effect will replace sessions when new data arrives at offset 0
|
||||||
setOffset(0);
|
setOffset(0);
|
||||||
setAccumulatedSessions([]);
|
|
||||||
setTotalCount(null);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -5912,6 +5912,40 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/api/workspace/files/{file_id}/download": {
|
||||||
|
"get": {
|
||||||
|
"tags": ["workspace"],
|
||||||
|
"summary": "Download file by ID",
|
||||||
|
"description": "Download a file by its ID.\n\nReturns the file content directly or redirects to a signed URL for GCS.",
|
||||||
|
"operationId": "getWorkspaceDownload file by id",
|
||||||
|
"security": [{ "HTTPBearerJWT": [] }],
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "file_id",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"schema": { "type": "string", "title": "File Id" }
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful Response",
|
||||||
|
"content": { "application/json": { "schema": {} } }
|
||||||
|
},
|
||||||
|
"401": {
|
||||||
|
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Validation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/health": {
|
"/health": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": ["health"],
|
"tags": ["health"],
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import {
|
import {
|
||||||
ApiError,
|
ApiError,
|
||||||
|
getServerAuthToken,
|
||||||
makeAuthenticatedFileUpload,
|
makeAuthenticatedFileUpload,
|
||||||
makeAuthenticatedRequest,
|
makeAuthenticatedRequest,
|
||||||
} from "@/lib/autogpt-server-api/helpers";
|
} from "@/lib/autogpt-server-api/helpers";
|
||||||
@@ -15,6 +16,69 @@ function buildBackendUrl(path: string[], queryString: string): string {
|
|||||||
return `${environment.getAGPTServerBaseUrl()}/${backendPath}${queryString}`;
|
return `${environment.getAGPTServerBaseUrl()}/${backendPath}${queryString}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if this is a workspace file download request that needs binary response handling.
|
||||||
|
*/
|
||||||
|
function isWorkspaceDownloadRequest(path: string[]): boolean {
|
||||||
|
// Match pattern: api/workspace/files/{id}/download (5 segments)
|
||||||
|
return (
|
||||||
|
path.length == 5 &&
|
||||||
|
path[0] === "api" &&
|
||||||
|
path[1] === "workspace" &&
|
||||||
|
path[2] === "files" &&
|
||||||
|
path[path.length - 1] === "download"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle workspace file download requests with proper binary response streaming.
|
||||||
|
*/
|
||||||
|
async function handleWorkspaceDownload(
|
||||||
|
req: NextRequest,
|
||||||
|
backendUrl: string,
|
||||||
|
): Promise<NextResponse> {
|
||||||
|
const token = await getServerAuthToken();
|
||||||
|
|
||||||
|
const headers: Record<string, string> = {};
|
||||||
|
if (token && token !== "no-token-found") {
|
||||||
|
headers["Authorization"] = `Bearer ${token}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await fetch(backendUrl, {
|
||||||
|
method: "GET",
|
||||||
|
headers,
|
||||||
|
redirect: "follow", // Follow redirects to signed URLs
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
return NextResponse.json(
|
||||||
|
{ error: `Failed to download file: ${response.statusText}` },
|
||||||
|
{ status: response.status },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the content type from the backend response
|
||||||
|
const contentType =
|
||||||
|
response.headers.get("Content-Type") || "application/octet-stream";
|
||||||
|
const contentDisposition = response.headers.get("Content-Disposition");
|
||||||
|
|
||||||
|
// Stream the response body
|
||||||
|
const responseHeaders: Record<string, string> = {
|
||||||
|
"Content-Type": contentType,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (contentDisposition) {
|
||||||
|
responseHeaders["Content-Disposition"] = contentDisposition;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the binary content
|
||||||
|
const arrayBuffer = await response.arrayBuffer();
|
||||||
|
return new NextResponse(arrayBuffer, {
|
||||||
|
status: 200,
|
||||||
|
headers: responseHeaders,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
async function handleJsonRequest(
|
async function handleJsonRequest(
|
||||||
req: NextRequest,
|
req: NextRequest,
|
||||||
method: string,
|
method: string,
|
||||||
@@ -180,6 +244,11 @@ async function handler(
|
|||||||
};
|
};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
// Handle workspace file downloads separately (binary response)
|
||||||
|
if (method === "GET" && isWorkspaceDownloadRequest(path)) {
|
||||||
|
return await handleWorkspaceDownload(req, backendUrl);
|
||||||
|
}
|
||||||
|
|
||||||
if (method === "GET" || method === "DELETE") {
|
if (method === "GET" || method === "DELETE") {
|
||||||
responseBody = await handleGetDeleteRequest(method, backendUrl, req);
|
responseBody = await handleGetDeleteRequest(method, backendUrl, req);
|
||||||
} else if (contentType?.includes("application/json")) {
|
} else if (contentType?.includes("application/json")) {
|
||||||
|
|||||||
77
autogpt_platform/frontend/src/app/api/transcribe/route.ts
Normal file
77
autogpt_platform/frontend/src/app/api/transcribe/route.ts
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
||||||
|
import { NextRequest, NextResponse } from "next/server";
|
||||||
|
|
||||||
|
const WHISPER_API_URL = "https://api.openai.com/v1/audio/transcriptions";
|
||||||
|
const MAX_FILE_SIZE = 25 * 1024 * 1024; // 25MB - Whisper's limit
|
||||||
|
|
||||||
|
function getExtensionFromMimeType(mimeType: string): string {
|
||||||
|
const subtype = mimeType.split("/")[1]?.split(";")[0];
|
||||||
|
return subtype || "webm";
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function POST(request: NextRequest) {
|
||||||
|
const token = await getServerAuthToken();
|
||||||
|
|
||||||
|
if (!token || token === "no-token-found") {
|
||||||
|
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
|
||||||
|
}
|
||||||
|
|
||||||
|
const apiKey = process.env.OPENAI_API_KEY;
|
||||||
|
|
||||||
|
if (!apiKey) {
|
||||||
|
return NextResponse.json(
|
||||||
|
{ error: "OpenAI API key not configured" },
|
||||||
|
{ status: 401 },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const formData = await request.formData();
|
||||||
|
const audioFile = formData.get("audio");
|
||||||
|
|
||||||
|
if (!audioFile || !(audioFile instanceof Blob)) {
|
||||||
|
return NextResponse.json(
|
||||||
|
{ error: "No audio file provided" },
|
||||||
|
{ status: 400 },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (audioFile.size > MAX_FILE_SIZE) {
|
||||||
|
return NextResponse.json(
|
||||||
|
{ error: "File too large. Maximum size is 25MB." },
|
||||||
|
{ status: 413 },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const ext = getExtensionFromMimeType(audioFile.type);
|
||||||
|
const whisperFormData = new FormData();
|
||||||
|
whisperFormData.append("file", audioFile, `recording.${ext}`);
|
||||||
|
whisperFormData.append("model", "whisper-1");
|
||||||
|
|
||||||
|
const response = await fetch(WHISPER_API_URL, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${apiKey}`,
|
||||||
|
},
|
||||||
|
body: whisperFormData,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorData = await response.json().catch(() => ({}));
|
||||||
|
console.error("Whisper API error:", errorData);
|
||||||
|
return NextResponse.json(
|
||||||
|
{ error: errorData.error?.message || "Transcription failed" },
|
||||||
|
{ status: response.status },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = await response.json();
|
||||||
|
return NextResponse.json({ text: result.text });
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Transcription error:", error);
|
||||||
|
return NextResponse.json(
|
||||||
|
{ error: "Failed to process audio" },
|
||||||
|
{ status: 500 },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,7 +1,14 @@
|
|||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { ArrowUpIcon, StopIcon } from "@phosphor-icons/react";
|
import {
|
||||||
|
ArrowUpIcon,
|
||||||
|
CircleNotchIcon,
|
||||||
|
MicrophoneIcon,
|
||||||
|
StopIcon,
|
||||||
|
} from "@phosphor-icons/react";
|
||||||
|
import { RecordingIndicator } from "./components/RecordingIndicator";
|
||||||
import { useChatInput } from "./useChatInput";
|
import { useChatInput } from "./useChatInput";
|
||||||
|
import { useVoiceRecording } from "./useVoiceRecording";
|
||||||
|
|
||||||
export interface Props {
|
export interface Props {
|
||||||
onSend: (message: string) => void;
|
onSend: (message: string) => void;
|
||||||
@@ -21,22 +28,49 @@ export function ChatInput({
|
|||||||
className,
|
className,
|
||||||
}: Props) {
|
}: Props) {
|
||||||
const inputId = "chat-input";
|
const inputId = "chat-input";
|
||||||
const { value, handleKeyDown, handleSubmit, handleChange, hasMultipleLines } =
|
const {
|
||||||
useChatInput({
|
value,
|
||||||
|
setValue,
|
||||||
|
handleKeyDown: baseHandleKeyDown,
|
||||||
|
handleSubmit,
|
||||||
|
handleChange,
|
||||||
|
hasMultipleLines,
|
||||||
|
} = useChatInput({
|
||||||
onSend,
|
onSend,
|
||||||
disabled: disabled || isStreaming,
|
disabled: disabled || isStreaming,
|
||||||
maxRows: 4,
|
maxRows: 4,
|
||||||
inputId,
|
inputId,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const {
|
||||||
|
isRecording,
|
||||||
|
isTranscribing,
|
||||||
|
elapsedTime,
|
||||||
|
toggleRecording,
|
||||||
|
handleKeyDown,
|
||||||
|
showMicButton,
|
||||||
|
isInputDisabled,
|
||||||
|
audioStream,
|
||||||
|
} = useVoiceRecording({
|
||||||
|
setValue,
|
||||||
|
disabled: disabled || isStreaming,
|
||||||
|
isStreaming,
|
||||||
|
value,
|
||||||
|
baseHandleKeyDown,
|
||||||
|
inputId,
|
||||||
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<form onSubmit={handleSubmit} className={cn("relative flex-1", className)}>
|
<form onSubmit={handleSubmit} className={cn("relative flex-1", className)}>
|
||||||
<div className="relative">
|
<div className="relative">
|
||||||
<div
|
<div
|
||||||
id={`${inputId}-wrapper`}
|
id={`${inputId}-wrapper`}
|
||||||
className={cn(
|
className={cn(
|
||||||
"relative overflow-hidden border border-neutral-200 bg-white shadow-sm",
|
"relative overflow-hidden border bg-white shadow-sm",
|
||||||
"focus-within:border-zinc-400 focus-within:ring-1 focus-within:ring-zinc-400",
|
"focus-within:ring-1",
|
||||||
|
isRecording
|
||||||
|
? "border-red-400 focus-within:border-red-400 focus-within:ring-red-400"
|
||||||
|
: "border-neutral-200 focus-within:border-zinc-400 focus-within:ring-zinc-400",
|
||||||
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
|
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
@@ -46,22 +80,67 @@ export function ChatInput({
|
|||||||
value={value}
|
value={value}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
onKeyDown={handleKeyDown}
|
onKeyDown={handleKeyDown}
|
||||||
placeholder={placeholder}
|
placeholder={
|
||||||
disabled={disabled || isStreaming}
|
isTranscribing
|
||||||
|
? "Transcribing..."
|
||||||
|
: isRecording
|
||||||
|
? ""
|
||||||
|
: placeholder
|
||||||
|
}
|
||||||
|
disabled={isInputDisabled}
|
||||||
rows={1}
|
rows={1}
|
||||||
className={cn(
|
className={cn(
|
||||||
"w-full resize-none overflow-y-auto border-0 bg-transparent text-[1rem] leading-6 text-black",
|
"w-full resize-none overflow-y-auto border-0 bg-transparent text-[1rem] leading-6 text-black",
|
||||||
"placeholder:text-zinc-400",
|
"placeholder:text-zinc-400",
|
||||||
"focus:outline-none focus:ring-0",
|
"focus:outline-none focus:ring-0",
|
||||||
"disabled:text-zinc-500",
|
"disabled:text-zinc-500",
|
||||||
hasMultipleLines ? "pb-6 pl-4 pr-4 pt-2" : "pb-4 pl-4 pr-14 pt-4",
|
hasMultipleLines
|
||||||
|
? "pb-6 pl-4 pr-4 pt-2"
|
||||||
|
: showMicButton
|
||||||
|
? "pb-4 pl-14 pr-14 pt-4"
|
||||||
|
: "pb-4 pl-4 pr-14 pt-4",
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
|
{isRecording && !value && (
|
||||||
|
<div className="pointer-events-none absolute inset-0 flex items-center justify-center">
|
||||||
|
<RecordingIndicator
|
||||||
|
elapsedTime={elapsedTime}
|
||||||
|
audioStream={audioStream}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
<span id="chat-input-hint" className="sr-only">
|
<span id="chat-input-hint" className="sr-only">
|
||||||
Press Enter to send, Shift+Enter for new line
|
Press Enter to send, Shift+Enter for new line, Space to record voice
|
||||||
</span>
|
</span>
|
||||||
|
|
||||||
|
{showMicButton && (
|
||||||
|
<div className="absolute bottom-[7px] left-2 flex items-center gap-1">
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="icon"
|
||||||
|
size="icon"
|
||||||
|
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
||||||
|
onClick={toggleRecording}
|
||||||
|
disabled={disabled || isTranscribing}
|
||||||
|
className={cn(
|
||||||
|
isRecording
|
||||||
|
? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600"
|
||||||
|
: isTranscribing
|
||||||
|
? "border-zinc-300 bg-zinc-100 text-zinc-400"
|
||||||
|
: "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{isTranscribing ? (
|
||||||
|
<CircleNotchIcon className="h-4 w-4 animate-spin" />
|
||||||
|
) : (
|
||||||
|
<MicrophoneIcon className="h-4 w-4" weight="bold" />
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="absolute bottom-[7px] right-2 flex items-center gap-1">
|
||||||
{isStreaming ? (
|
{isStreaming ? (
|
||||||
<Button
|
<Button
|
||||||
type="button"
|
type="button"
|
||||||
@@ -69,7 +148,7 @@ export function ChatInput({
|
|||||||
size="icon"
|
size="icon"
|
||||||
aria-label="Stop generating"
|
aria-label="Stop generating"
|
||||||
onClick={onStop}
|
onClick={onStop}
|
||||||
className="absolute bottom-[7px] right-2 border-red-600 bg-red-600 text-white hover:border-red-800 hover:bg-red-800"
|
className="border-red-600 bg-red-600 text-white hover:border-red-800 hover:bg-red-800"
|
||||||
>
|
>
|
||||||
<StopIcon className="h-4 w-4" weight="bold" />
|
<StopIcon className="h-4 w-4" weight="bold" />
|
||||||
</Button>
|
</Button>
|
||||||
@@ -80,15 +159,16 @@ export function ChatInput({
|
|||||||
size="icon"
|
size="icon"
|
||||||
aria-label="Send message"
|
aria-label="Send message"
|
||||||
className={cn(
|
className={cn(
|
||||||
"absolute bottom-[7px] right-2 border-zinc-800 bg-zinc-800 text-white hover:border-zinc-900 hover:bg-zinc-900",
|
"border-zinc-800 bg-zinc-800 text-white hover:border-zinc-900 hover:bg-zinc-900",
|
||||||
(disabled || !value.trim()) && "opacity-20",
|
(disabled || !value.trim() || isRecording) && "opacity-20",
|
||||||
)}
|
)}
|
||||||
disabled={disabled || !value.trim()}
|
disabled={disabled || !value.trim() || isRecording}
|
||||||
>
|
>
|
||||||
<ArrowUpIcon className="h-4 w-4" weight="bold" />
|
<ArrowUpIcon className="h-4 w-4" weight="bold" />
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
</div>
|
||||||
</form>
|
</form>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,142 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useEffect, useRef, useState } from "react";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
stream: MediaStream | null;
|
||||||
|
barCount?: number;
|
||||||
|
barWidth?: number;
|
||||||
|
barGap?: number;
|
||||||
|
barColor?: string;
|
||||||
|
minBarHeight?: number;
|
||||||
|
maxBarHeight?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function AudioWaveform({
|
||||||
|
stream,
|
||||||
|
barCount = 24,
|
||||||
|
barWidth = 3,
|
||||||
|
barGap = 2,
|
||||||
|
barColor = "#ef4444", // red-500
|
||||||
|
minBarHeight = 4,
|
||||||
|
maxBarHeight = 32,
|
||||||
|
}: Props) {
|
||||||
|
const [bars, setBars] = useState<number[]>(() =>
|
||||||
|
Array(barCount).fill(minBarHeight),
|
||||||
|
);
|
||||||
|
const analyserRef = useRef<AnalyserNode | null>(null);
|
||||||
|
const audioContextRef = useRef<AudioContext | null>(null);
|
||||||
|
const sourceRef = useRef<MediaStreamAudioSourceNode | null>(null);
|
||||||
|
const animationRef = useRef<number | null>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!stream) {
|
||||||
|
setBars(Array(barCount).fill(minBarHeight));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create audio context and analyser
|
||||||
|
const audioContext = new AudioContext();
|
||||||
|
const analyser = audioContext.createAnalyser();
|
||||||
|
analyser.fftSize = 512;
|
||||||
|
analyser.smoothingTimeConstant = 0.8;
|
||||||
|
|
||||||
|
// Connect the stream to the analyser
|
||||||
|
const source = audioContext.createMediaStreamSource(stream);
|
||||||
|
source.connect(analyser);
|
||||||
|
|
||||||
|
audioContextRef.current = audioContext;
|
||||||
|
analyserRef.current = analyser;
|
||||||
|
sourceRef.current = source;
|
||||||
|
|
||||||
|
const timeData = new Uint8Array(analyser.frequencyBinCount);
|
||||||
|
|
||||||
|
const updateBars = () => {
|
||||||
|
if (!analyserRef.current) return;
|
||||||
|
|
||||||
|
analyserRef.current.getByteTimeDomainData(timeData);
|
||||||
|
|
||||||
|
// Distribute time-domain data across bars
|
||||||
|
// This shows waveform amplitude, making all bars respond to audio
|
||||||
|
const newBars: number[] = [];
|
||||||
|
const samplesPerBar = timeData.length / barCount;
|
||||||
|
|
||||||
|
for (let i = 0; i < barCount; i++) {
|
||||||
|
// Sample waveform data for this bar
|
||||||
|
let maxAmplitude = 0;
|
||||||
|
const startIdx = Math.floor(i * samplesPerBar);
|
||||||
|
const endIdx = Math.floor((i + 1) * samplesPerBar);
|
||||||
|
|
||||||
|
for (let j = startIdx; j < endIdx && j < timeData.length; j++) {
|
||||||
|
// Convert to amplitude (distance from center 128)
|
||||||
|
const amplitude = Math.abs(timeData[j] - 128);
|
||||||
|
maxAmplitude = Math.max(maxAmplitude, amplitude);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map amplitude (0-128) to bar height
|
||||||
|
const normalized = (maxAmplitude / 128) * 255;
|
||||||
|
const height =
|
||||||
|
minBarHeight + (normalized / 255) * (maxBarHeight - minBarHeight);
|
||||||
|
newBars.push(height);
|
||||||
|
}
|
||||||
|
|
||||||
|
setBars(newBars);
|
||||||
|
animationRef.current = requestAnimationFrame(updateBars);
|
||||||
|
};
|
||||||
|
|
||||||
|
updateBars();
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
if (animationRef.current) {
|
||||||
|
cancelAnimationFrame(animationRef.current);
|
||||||
|
}
|
||||||
|
if (sourceRef.current) {
|
||||||
|
sourceRef.current.disconnect();
|
||||||
|
}
|
||||||
|
if (audioContextRef.current) {
|
||||||
|
audioContextRef.current.close();
|
||||||
|
}
|
||||||
|
analyserRef.current = null;
|
||||||
|
audioContextRef.current = null;
|
||||||
|
sourceRef.current = null;
|
||||||
|
};
|
||||||
|
}, [stream, barCount, minBarHeight, maxBarHeight]);
|
||||||
|
|
||||||
|
const totalWidth = barCount * barWidth + (barCount - 1) * barGap;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className="flex items-center justify-center"
|
||||||
|
style={{
|
||||||
|
width: totalWidth,
|
||||||
|
height: maxBarHeight,
|
||||||
|
gap: barGap,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{bars.map((height, i) => {
|
||||||
|
const barHeight = Math.max(minBarHeight, height);
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
key={i}
|
||||||
|
className="relative"
|
||||||
|
style={{
|
||||||
|
width: barWidth,
|
||||||
|
height: maxBarHeight,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
className="absolute left-0 rounded-full transition-[height] duration-75"
|
||||||
|
style={{
|
||||||
|
width: barWidth,
|
||||||
|
height: barHeight,
|
||||||
|
top: "50%",
|
||||||
|
transform: "translateY(-50%)",
|
||||||
|
backgroundColor: barColor,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
import { formatElapsedTime } from "../helpers";
|
||||||
|
import { AudioWaveform } from "./AudioWaveform";
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
elapsedTime: number;
|
||||||
|
audioStream: MediaStream | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function RecordingIndicator({ elapsedTime, audioStream }: Props) {
|
||||||
|
return (
|
||||||
|
<div className="flex items-center gap-3">
|
||||||
|
<AudioWaveform
|
||||||
|
stream={audioStream}
|
||||||
|
barCount={20}
|
||||||
|
barWidth={3}
|
||||||
|
barGap={2}
|
||||||
|
barColor="#ef4444"
|
||||||
|
minBarHeight={4}
|
||||||
|
maxBarHeight={24}
|
||||||
|
/>
|
||||||
|
<span className="min-w-[3ch] text-sm font-medium text-red-500">
|
||||||
|
{formatElapsedTime(elapsedTime)}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
export function formatElapsedTime(ms: number): string {
|
||||||
|
const seconds = Math.floor(ms / 1000);
|
||||||
|
const minutes = Math.floor(seconds / 60);
|
||||||
|
const remainingSeconds = seconds % 60;
|
||||||
|
return `${minutes}:${remainingSeconds.toString().padStart(2, "0")}`;
|
||||||
|
}
|
||||||
@@ -6,7 +6,7 @@ import {
|
|||||||
useState,
|
useState,
|
||||||
} from "react";
|
} from "react";
|
||||||
|
|
||||||
interface UseChatInputArgs {
|
interface Args {
|
||||||
onSend: (message: string) => void;
|
onSend: (message: string) => void;
|
||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
maxRows?: number;
|
maxRows?: number;
|
||||||
@@ -18,7 +18,7 @@ export function useChatInput({
|
|||||||
disabled = false,
|
disabled = false,
|
||||||
maxRows = 5,
|
maxRows = 5,
|
||||||
inputId = "chat-input",
|
inputId = "chat-input",
|
||||||
}: UseChatInputArgs) {
|
}: Args) {
|
||||||
const [value, setValue] = useState("");
|
const [value, setValue] = useState("");
|
||||||
const [hasMultipleLines, setHasMultipleLines] = useState(false);
|
const [hasMultipleLines, setHasMultipleLines] = useState(false);
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,251 @@
|
|||||||
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
|
import React, {
|
||||||
|
KeyboardEvent,
|
||||||
|
useCallback,
|
||||||
|
useEffect,
|
||||||
|
useRef,
|
||||||
|
useState,
|
||||||
|
} from "react";
|
||||||
|
|
||||||
|
const MAX_RECORDING_DURATION = 2 * 60 * 1000; // 2 minutes in ms
|
||||||
|
|
||||||
|
interface Args {
|
||||||
|
setValue: React.Dispatch<React.SetStateAction<string>>;
|
||||||
|
disabled?: boolean;
|
||||||
|
isStreaming?: boolean;
|
||||||
|
value: string;
|
||||||
|
baseHandleKeyDown: (event: KeyboardEvent<HTMLTextAreaElement>) => void;
|
||||||
|
inputId?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useVoiceRecording({
|
||||||
|
setValue,
|
||||||
|
disabled = false,
|
||||||
|
isStreaming = false,
|
||||||
|
value,
|
||||||
|
baseHandleKeyDown,
|
||||||
|
inputId,
|
||||||
|
}: Args) {
|
||||||
|
const [isRecording, setIsRecording] = useState(false);
|
||||||
|
const [isTranscribing, setIsTranscribing] = useState(false);
|
||||||
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
const [elapsedTime, setElapsedTime] = useState(0);
|
||||||
|
|
||||||
|
const mediaRecorderRef = useRef<MediaRecorder | null>(null);
|
||||||
|
const chunksRef = useRef<Blob[]>([]);
|
||||||
|
const timerRef = useRef<NodeJS.Timeout | null>(null);
|
||||||
|
const startTimeRef = useRef<number>(0);
|
||||||
|
const streamRef = useRef<MediaStream | null>(null);
|
||||||
|
const isRecordingRef = useRef(false);
|
||||||
|
|
||||||
|
const isSupported =
|
||||||
|
typeof window !== "undefined" &&
|
||||||
|
!!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
|
||||||
|
|
||||||
|
const clearTimer = useCallback(() => {
|
||||||
|
if (timerRef.current) {
|
||||||
|
clearInterval(timerRef.current);
|
||||||
|
timerRef.current = null;
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const cleanup = useCallback(() => {
|
||||||
|
clearTimer();
|
||||||
|
if (streamRef.current) {
|
||||||
|
streamRef.current.getTracks().forEach((track) => track.stop());
|
||||||
|
streamRef.current = null;
|
||||||
|
}
|
||||||
|
mediaRecorderRef.current = null;
|
||||||
|
chunksRef.current = [];
|
||||||
|
setElapsedTime(0);
|
||||||
|
}, [clearTimer]);
|
||||||
|
|
||||||
|
const handleTranscription = useCallback(
|
||||||
|
(text: string) => {
|
||||||
|
setValue((prev) => {
|
||||||
|
const trimmedPrev = prev.trim();
|
||||||
|
if (trimmedPrev) {
|
||||||
|
return `${trimmedPrev} ${text}`;
|
||||||
|
}
|
||||||
|
return text;
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[setValue],
|
||||||
|
);
|
||||||
|
|
||||||
|
const transcribeAudio = useCallback(
|
||||||
|
async (audioBlob: Blob) => {
|
||||||
|
setIsTranscribing(true);
|
||||||
|
setError(null);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const formData = new FormData();
|
||||||
|
formData.append("audio", audioBlob);
|
||||||
|
|
||||||
|
const response = await fetch("/api/transcribe", {
|
||||||
|
method: "POST",
|
||||||
|
body: formData,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const data = await response.json().catch(() => ({}));
|
||||||
|
throw new Error(data.error || "Transcription failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
if (data.text) {
|
||||||
|
handleTranscription(data.text);
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
const message =
|
||||||
|
err instanceof Error ? err.message : "Transcription failed";
|
||||||
|
setError(message);
|
||||||
|
console.error("Transcription error:", err);
|
||||||
|
} finally {
|
||||||
|
setIsTranscribing(false);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[handleTranscription, inputId],
|
||||||
|
);
|
||||||
|
|
||||||
|
const stopRecording = useCallback(() => {
|
||||||
|
if (mediaRecorderRef.current && isRecordingRef.current) {
|
||||||
|
mediaRecorderRef.current.stop();
|
||||||
|
isRecordingRef.current = false;
|
||||||
|
setIsRecording(false);
|
||||||
|
clearTimer();
|
||||||
|
}
|
||||||
|
}, [clearTimer]);
|
||||||
|
|
||||||
|
const startRecording = useCallback(async () => {
|
||||||
|
if (disabled || isRecordingRef.current || isTranscribing) return;
|
||||||
|
|
||||||
|
setError(null);
|
||||||
|
chunksRef.current = [];
|
||||||
|
|
||||||
|
try {
|
||||||
|
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||||
|
streamRef.current = stream;
|
||||||
|
|
||||||
|
const mediaRecorder = new MediaRecorder(stream, {
|
||||||
|
mimeType: MediaRecorder.isTypeSupported("audio/webm")
|
||||||
|
? "audio/webm"
|
||||||
|
: "audio/mp4",
|
||||||
|
});
|
||||||
|
|
||||||
|
mediaRecorderRef.current = mediaRecorder;
|
||||||
|
|
||||||
|
mediaRecorder.ondataavailable = (event) => {
|
||||||
|
if (event.data.size > 0) {
|
||||||
|
chunksRef.current.push(event.data);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
mediaRecorder.onstop = async () => {
|
||||||
|
const audioBlob = new Blob(chunksRef.current, {
|
||||||
|
type: mediaRecorder.mimeType,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Cleanup stream
|
||||||
|
if (streamRef.current) {
|
||||||
|
streamRef.current.getTracks().forEach((track) => track.stop());
|
||||||
|
streamRef.current = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (audioBlob.size > 0) {
|
||||||
|
await transcribeAudio(audioBlob);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
mediaRecorder.start(1000); // Collect data every second
|
||||||
|
isRecordingRef.current = true;
|
||||||
|
setIsRecording(true);
|
||||||
|
startTimeRef.current = Date.now();
|
||||||
|
|
||||||
|
// Start elapsed time timer
|
||||||
|
timerRef.current = setInterval(() => {
|
||||||
|
const elapsed = Date.now() - startTimeRef.current;
|
||||||
|
setElapsedTime(elapsed);
|
||||||
|
|
||||||
|
// Auto-stop at max duration
|
||||||
|
if (elapsed >= MAX_RECORDING_DURATION) {
|
||||||
|
stopRecording();
|
||||||
|
}
|
||||||
|
}, 100);
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Failed to start recording:", err);
|
||||||
|
if (err instanceof DOMException && err.name === "NotAllowedError") {
|
||||||
|
setError("Microphone permission denied");
|
||||||
|
} else {
|
||||||
|
setError("Failed to access microphone");
|
||||||
|
}
|
||||||
|
cleanup();
|
||||||
|
}
|
||||||
|
}, [disabled, isTranscribing, stopRecording, transcribeAudio, cleanup]);
|
||||||
|
|
||||||
|
const toggleRecording = useCallback(() => {
|
||||||
|
if (isRecording) {
|
||||||
|
stopRecording();
|
||||||
|
} else {
|
||||||
|
startRecording();
|
||||||
|
}
|
||||||
|
}, [isRecording, startRecording, stopRecording]);
|
||||||
|
|
||||||
|
const { toast } = useToast();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (error) {
|
||||||
|
toast({
|
||||||
|
title: "Voice recording failed",
|
||||||
|
description: error,
|
||||||
|
variant: "destructive",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [error, toast]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!isTranscribing && inputId) {
|
||||||
|
const inputElement = document.getElementById(inputId);
|
||||||
|
if (inputElement) {
|
||||||
|
inputElement.focus();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [isTranscribing, inputId]);
|
||||||
|
|
||||||
|
const handleKeyDown = useCallback(
|
||||||
|
(event: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||||
|
if (event.key === " " && !value.trim() && !isTranscribing) {
|
||||||
|
event.preventDefault();
|
||||||
|
toggleRecording();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
baseHandleKeyDown(event);
|
||||||
|
},
|
||||||
|
[value, isTranscribing, toggleRecording, baseHandleKeyDown],
|
||||||
|
);
|
||||||
|
|
||||||
|
const showMicButton = isSupported && !isStreaming;
|
||||||
|
const isInputDisabled = disabled || isStreaming || isTranscribing;
|
||||||
|
|
||||||
|
// Cleanup on unmount
|
||||||
|
useEffect(() => {
|
||||||
|
return () => {
|
||||||
|
cleanup();
|
||||||
|
};
|
||||||
|
}, [cleanup]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
isRecording,
|
||||||
|
isTranscribing,
|
||||||
|
error,
|
||||||
|
elapsedTime,
|
||||||
|
startRecording,
|
||||||
|
stopRecording,
|
||||||
|
toggleRecording,
|
||||||
|
isSupported,
|
||||||
|
handleKeyDown,
|
||||||
|
showMicButton,
|
||||||
|
isInputDisabled,
|
||||||
|
audioStream: streamRef.current,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
|
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
import { EyeSlash } from "@phosphor-icons/react";
|
||||||
import React from "react";
|
import React from "react";
|
||||||
import ReactMarkdown from "react-markdown";
|
import ReactMarkdown from "react-markdown";
|
||||||
import remarkGfm from "remark-gfm";
|
import remarkGfm from "remark-gfm";
|
||||||
@@ -29,12 +31,88 @@ interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {
|
|||||||
type?: string;
|
type?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a workspace:// URL to a proxy URL that routes through Next.js to the backend.
|
||||||
|
* workspace://abc123 -> /api/proxy/api/workspace/files/abc123/download
|
||||||
|
*
|
||||||
|
* Uses the generated API URL helper and routes through the Next.js proxy
|
||||||
|
* which handles authentication and proper backend routing.
|
||||||
|
*/
|
||||||
|
/**
|
||||||
|
* URL transformer for ReactMarkdown.
|
||||||
|
* Converts workspace:// URLs to proxy URLs that route through Next.js to the backend.
|
||||||
|
* workspace://abc123 -> /api/proxy/api/workspace/files/abc123/download
|
||||||
|
*
|
||||||
|
* This is needed because ReactMarkdown sanitizes URLs and only allows
|
||||||
|
* http, https, mailto, and tel protocols by default.
|
||||||
|
*/
|
||||||
|
function resolveWorkspaceUrl(src: string): string {
|
||||||
|
if (src.startsWith("workspace://")) {
|
||||||
|
const fileId = src.replace("workspace://", "");
|
||||||
|
// Use the generated API URL helper to get the correct path
|
||||||
|
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
|
||||||
|
// Route through the Next.js proxy (same pattern as customMutator for client-side)
|
||||||
|
return `/api/proxy${apiPath}`;
|
||||||
|
}
|
||||||
|
return src;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if the image URL is a workspace file (AI cannot see these yet).
|
||||||
|
* After URL transformation, workspace files have URLs like /api/proxy/api/workspace/files/...
|
||||||
|
*/
|
||||||
|
function isWorkspaceImage(src: string | undefined): boolean {
|
||||||
|
return src?.includes("/workspace/files/") ?? false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Custom image component that shows an indicator when the AI cannot see the image.
|
||||||
|
* Note: src is already transformed by urlTransform, so workspace:// is now /api/workspace/...
|
||||||
|
*/
|
||||||
|
function MarkdownImage(props: Record<string, unknown>) {
|
||||||
|
const src = props.src as string | undefined;
|
||||||
|
const alt = props.alt as string | undefined;
|
||||||
|
|
||||||
|
const aiCannotSee = isWorkspaceImage(src);
|
||||||
|
|
||||||
|
// If no src, show a placeholder
|
||||||
|
if (!src) {
|
||||||
|
return (
|
||||||
|
<span className="my-2 inline-block rounded border border-amber-200 bg-amber-50 px-2 py-1 text-sm text-amber-700">
|
||||||
|
[Image: {alt || "missing src"}]
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<span className="relative my-2 inline-block">
|
||||||
|
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||||
|
<img
|
||||||
|
src={src}
|
||||||
|
alt={alt || "Image"}
|
||||||
|
className="h-auto max-w-full rounded-md border border-zinc-200"
|
||||||
|
loading="lazy"
|
||||||
|
/>
|
||||||
|
{aiCannotSee && (
|
||||||
|
<span
|
||||||
|
className="absolute bottom-2 right-2 flex items-center gap-1 rounded bg-black/70 px-2 py-1 text-xs text-white"
|
||||||
|
title="The AI cannot see this image"
|
||||||
|
>
|
||||||
|
<EyeSlash size={14} />
|
||||||
|
<span>AI cannot see this image</span>
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
||||||
return (
|
return (
|
||||||
<div className={cn("markdown-content", className)}>
|
<div className={cn("markdown-content", className)}>
|
||||||
<ReactMarkdown
|
<ReactMarkdown
|
||||||
skipHtml={true}
|
skipHtml={true}
|
||||||
remarkPlugins={[remarkGfm]}
|
remarkPlugins={[remarkGfm]}
|
||||||
|
urlTransform={resolveWorkspaceUrl}
|
||||||
components={{
|
components={{
|
||||||
code: ({ children, className, ...props }: CodeProps) => {
|
code: ({ children, className, ...props }: CodeProps) => {
|
||||||
const isInline = !className?.includes("language-");
|
const isInline = !className?.includes("language-");
|
||||||
@@ -206,6 +284,9 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
|||||||
{children}
|
{children}
|
||||||
</td>
|
</td>
|
||||||
),
|
),
|
||||||
|
img: ({ src, alt, ...props }) => (
|
||||||
|
<MarkdownImage src={src} alt={alt} {...props} />
|
||||||
|
),
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{content}
|
{content}
|
||||||
|
|||||||
@@ -30,13 +30,94 @@ export function getErrorMessage(result: unknown): string {
|
|||||||
}
|
}
|
||||||
if (typeof result === "object" && result !== null) {
|
if (typeof result === "object" && result !== null) {
|
||||||
const response = result as Record<string, unknown>;
|
const response = result as Record<string, unknown>;
|
||||||
if (response.error) return stripInternalReasoning(String(response.error));
|
|
||||||
if (response.message)
|
if (response.message)
|
||||||
return stripInternalReasoning(String(response.message));
|
return stripInternalReasoning(String(response.message));
|
||||||
|
if (response.error) return stripInternalReasoning(String(response.error));
|
||||||
}
|
}
|
||||||
return "An error occurred";
|
return "An error occurred";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a value is a workspace file reference.
|
||||||
|
*/
|
||||||
|
function isWorkspaceRef(value: unknown): value is string {
|
||||||
|
return typeof value === "string" && value.startsWith("workspace://");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a workspace reference appears to be an image based on common patterns.
|
||||||
|
* Since workspace refs don't have extensions, we check the context or assume image
|
||||||
|
* for certain block types.
|
||||||
|
*
|
||||||
|
* TODO: Replace keyword matching with MIME type encoded in workspace ref.
|
||||||
|
* e.g., workspace://abc123#image/png or workspace://abc123#video/mp4
|
||||||
|
* This would let frontend render correctly without fragile keyword matching.
|
||||||
|
*/
|
||||||
|
function isLikelyImageRef(value: string, outputKey?: string): boolean {
|
||||||
|
if (!isWorkspaceRef(value)) return false;
|
||||||
|
|
||||||
|
// Check output key name for video-related hints (these are NOT images)
|
||||||
|
const videoKeywords = ["video", "mp4", "mov", "avi", "webm", "movie", "clip"];
|
||||||
|
if (outputKey) {
|
||||||
|
const lowerKey = outputKey.toLowerCase();
|
||||||
|
if (videoKeywords.some((kw) => lowerKey.includes(kw))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check output key name for image-related hints
|
||||||
|
const imageKeywords = [
|
||||||
|
"image",
|
||||||
|
"img",
|
||||||
|
"photo",
|
||||||
|
"picture",
|
||||||
|
"thumbnail",
|
||||||
|
"avatar",
|
||||||
|
"icon",
|
||||||
|
"screenshot",
|
||||||
|
];
|
||||||
|
if (outputKey) {
|
||||||
|
const lowerKey = outputKey.toLowerCase();
|
||||||
|
if (imageKeywords.some((kw) => lowerKey.includes(kw))) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to treating workspace refs as potential images
|
||||||
|
// since that's the most common case for generated content
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Format a single output value, converting workspace refs to markdown images.
|
||||||
|
*/
|
||||||
|
function formatOutputValue(value: unknown, outputKey?: string): string {
|
||||||
|
if (isWorkspaceRef(value) && isLikelyImageRef(value, outputKey)) {
|
||||||
|
// Format as markdown image
|
||||||
|
return ``;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof value === "string") {
|
||||||
|
// Check for data URIs (images)
|
||||||
|
if (value.startsWith("data:image/")) {
|
||||||
|
return ``;
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Array.isArray(value)) {
|
||||||
|
return value
|
||||||
|
.map((item, idx) => formatOutputValue(item, `${outputKey}_${idx}`))
|
||||||
|
.join("\n\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof value === "object" && value !== null) {
|
||||||
|
return JSON.stringify(value, null, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
return String(value);
|
||||||
|
}
|
||||||
|
|
||||||
function getToolCompletionPhrase(toolName: string): string {
|
function getToolCompletionPhrase(toolName: string): string {
|
||||||
const toolCompletionPhrases: Record<string, string> = {
|
const toolCompletionPhrases: Record<string, string> = {
|
||||||
add_understanding: "Updated your business information",
|
add_understanding: "Updated your business information",
|
||||||
@@ -127,10 +208,26 @@ export function formatToolResponse(result: unknown, toolName: string): string {
|
|||||||
|
|
||||||
case "block_output":
|
case "block_output":
|
||||||
const blockName = (response.block_name as string) || "Block";
|
const blockName = (response.block_name as string) || "Block";
|
||||||
const outputs = response.outputs as Record<string, unknown> | undefined;
|
const outputs = response.outputs as Record<string, unknown[]> | undefined;
|
||||||
if (outputs && Object.keys(outputs).length > 0) {
|
if (outputs && Object.keys(outputs).length > 0) {
|
||||||
const outputKeys = Object.keys(outputs);
|
const formattedOutputs: string[] = [];
|
||||||
return `${blockName} executed successfully. Outputs: ${outputKeys.join(", ")}`;
|
|
||||||
|
for (const [key, values] of Object.entries(outputs)) {
|
||||||
|
if (!Array.isArray(values) || values.length === 0) continue;
|
||||||
|
|
||||||
|
// Format each value in the output array
|
||||||
|
for (const value of values) {
|
||||||
|
const formatted = formatOutputValue(value, key);
|
||||||
|
if (formatted) {
|
||||||
|
formattedOutputs.push(formatted);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (formattedOutputs.length > 0) {
|
||||||
|
return `${blockName} executed successfully.\n\n${formattedOutputs.join("\n\n")}`;
|
||||||
|
}
|
||||||
|
return `${blockName} executed successfully.`;
|
||||||
}
|
}
|
||||||
return `${blockName} executed successfully.`;
|
return `${blockName} executed successfully.`;
|
||||||
|
|
||||||
@@ -266,8 +363,8 @@ export function formatToolResponse(result: unknown, toolName: string): string {
|
|||||||
|
|
||||||
case "error":
|
case "error":
|
||||||
const errorMsg =
|
const errorMsg =
|
||||||
(response.error as string) || response.message || "An error occurred";
|
(response.message as string) || response.error || "An error occurred";
|
||||||
return `Error: ${errorMsg}`;
|
return stripInternalReasoning(String(errorMsg));
|
||||||
|
|
||||||
case "no_results":
|
case "no_results":
|
||||||
const suggestions = (response.suggestions as string[]) || [];
|
const suggestions = (response.suggestions as string[]) || [];
|
||||||
|
|||||||
@@ -516,7 +516,7 @@ export type GraphValidationErrorResponse = {
|
|||||||
|
|
||||||
/* *** LIBRARY *** */
|
/* *** LIBRARY *** */
|
||||||
|
|
||||||
/* Mirror of backend/server/v2/library/model.py:LibraryAgent */
|
/* Mirror of backend/api/features/library/model.py:LibraryAgent */
|
||||||
export type LibraryAgent = {
|
export type LibraryAgent = {
|
||||||
id: LibraryAgentID;
|
id: LibraryAgentID;
|
||||||
graph_id: GraphID;
|
graph_id: GraphID;
|
||||||
@@ -616,7 +616,7 @@ export enum LibraryAgentSortEnum {
|
|||||||
|
|
||||||
/* *** CREDENTIALS *** */
|
/* *** CREDENTIALS *** */
|
||||||
|
|
||||||
/* Mirror of backend/server/integrations/router.py:CredentialsMetaResponse */
|
/* Mirror of backend/api/features/integrations/router.py:CredentialsMetaResponse */
|
||||||
export type CredentialsMetaResponse = {
|
export type CredentialsMetaResponse = {
|
||||||
id: string;
|
id: string;
|
||||||
provider: CredentialsProviderName;
|
provider: CredentialsProviderName;
|
||||||
@@ -628,13 +628,13 @@ export type CredentialsMetaResponse = {
|
|||||||
is_system?: boolean;
|
is_system?: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Mirror of backend/server/integrations/router.py:CredentialsDeletionResponse */
|
/* Mirror of backend/api/features/integrations/router.py:CredentialsDeletionResponse */
|
||||||
export type CredentialsDeleteResponse = {
|
export type CredentialsDeleteResponse = {
|
||||||
deleted: true;
|
deleted: true;
|
||||||
revoked: boolean | null;
|
revoked: boolean | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Mirror of backend/server/integrations/router.py:CredentialsDeletionNeedsConfirmationResponse */
|
/* Mirror of backend/api/features/integrations/router.py:CredentialsDeletionNeedsConfirmationResponse */
|
||||||
export type CredentialsDeleteNeedConfirmationResponse = {
|
export type CredentialsDeleteNeedConfirmationResponse = {
|
||||||
deleted: false;
|
deleted: false;
|
||||||
need_confirmation: true;
|
need_confirmation: true;
|
||||||
@@ -888,7 +888,7 @@ export type Schedule = {
|
|||||||
|
|
||||||
export type ScheduleID = Brand<string, "ScheduleID">;
|
export type ScheduleID = Brand<string, "ScheduleID">;
|
||||||
|
|
||||||
/* Mirror of backend/server/routers/v1.py:ScheduleCreationRequest */
|
/* Mirror of backend/api/features/v1.py:ScheduleCreationRequest */
|
||||||
export type ScheduleCreatable = {
|
export type ScheduleCreatable = {
|
||||||
graph_id: GraphID;
|
graph_id: GraphID;
|
||||||
graph_version: number;
|
graph_version: number;
|
||||||
|
|||||||
@@ -59,12 +59,13 @@ test.describe("Library", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
test("pagination works correctly", async ({ page }, testInfo) => {
|
test("pagination works correctly", async ({ page }, testInfo) => {
|
||||||
test.setTimeout(testInfo.timeout * 3); // Increase timeout for pagination operations
|
test.setTimeout(testInfo.timeout * 3);
|
||||||
await page.goto("/library");
|
await page.goto("/library");
|
||||||
|
|
||||||
|
const PAGE_SIZE = 20;
|
||||||
const paginationResult = await libraryPage.testPagination();
|
const paginationResult = await libraryPage.testPagination();
|
||||||
|
|
||||||
if (paginationResult.initialCount >= 10) {
|
if (paginationResult.initialCount >= PAGE_SIZE) {
|
||||||
expect(paginationResult.finalCount).toBeGreaterThanOrEqual(
|
expect(paginationResult.finalCount).toBeGreaterThanOrEqual(
|
||||||
paginationResult.initialCount,
|
paginationResult.initialCount,
|
||||||
);
|
);
|
||||||
@@ -133,7 +134,10 @@ test.describe("Library", () => {
|
|||||||
test.expect(clearedSearchValue).toBe("");
|
test.expect(clearedSearchValue).toBe("");
|
||||||
});
|
});
|
||||||
|
|
||||||
test("pagination while searching works correctly", async ({ page }) => {
|
test("pagination while searching works correctly", async ({
|
||||||
|
page,
|
||||||
|
}, testInfo) => {
|
||||||
|
test.setTimeout(testInfo.timeout * 3);
|
||||||
await page.goto("/library");
|
await page.goto("/library");
|
||||||
|
|
||||||
const allAgents = await libraryPage.getAgents();
|
const allAgents = await libraryPage.getAgents();
|
||||||
@@ -152,9 +156,10 @@ test.describe("Library", () => {
|
|||||||
);
|
);
|
||||||
expect(matchingResults.length).toEqual(initialSearchResults.length);
|
expect(matchingResults.length).toEqual(initialSearchResults.length);
|
||||||
|
|
||||||
|
const PAGE_SIZE = 20;
|
||||||
const searchPaginationResult = await libraryPage.testPagination();
|
const searchPaginationResult = await libraryPage.testPagination();
|
||||||
|
|
||||||
if (searchPaginationResult.initialCount >= 10) {
|
if (searchPaginationResult.initialCount >= PAGE_SIZE) {
|
||||||
expect(searchPaginationResult.finalCount).toBeGreaterThanOrEqual(
|
expect(searchPaginationResult.finalCount).toBeGreaterThanOrEqual(
|
||||||
searchPaginationResult.initialCount,
|
searchPaginationResult.initialCount,
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -69,9 +69,12 @@ test.describe("Marketplace Creator Page – Basic Functionality", () => {
|
|||||||
await marketplacePage.getFirstCreatorProfile(page);
|
await marketplacePage.getFirstCreatorProfile(page);
|
||||||
await firstCreatorProfile.click();
|
await firstCreatorProfile.click();
|
||||||
await page.waitForURL("**/marketplace/creator/**");
|
await page.waitForURL("**/marketplace/creator/**");
|
||||||
|
await page.waitForLoadState("networkidle").catch(() => {});
|
||||||
|
|
||||||
const firstAgent = page
|
const firstAgent = page
|
||||||
.locator('[data-testid="store-card"]:visible')
|
.locator('[data-testid="store-card"]:visible')
|
||||||
.first();
|
.first();
|
||||||
|
await firstAgent.waitFor({ state: "visible", timeout: 30000 });
|
||||||
|
|
||||||
await firstAgent.click();
|
await firstAgent.click();
|
||||||
await page.waitForURL("**/marketplace/agent/**");
|
await page.waitForURL("**/marketplace/agent/**");
|
||||||
|
|||||||
@@ -77,7 +77,6 @@ test.describe("Marketplace – Basic Functionality", () => {
|
|||||||
|
|
||||||
const firstFeaturedAgent =
|
const firstFeaturedAgent =
|
||||||
await marketplacePage.getFirstFeaturedAgent(page);
|
await marketplacePage.getFirstFeaturedAgent(page);
|
||||||
await firstFeaturedAgent.waitFor({ state: "visible" });
|
|
||||||
await firstFeaturedAgent.click();
|
await firstFeaturedAgent.click();
|
||||||
await page.waitForURL("**/marketplace/agent/**");
|
await page.waitForURL("**/marketplace/agent/**");
|
||||||
await matchesUrl(page, /\/marketplace\/agent\/.+/);
|
await matchesUrl(page, /\/marketplace\/agent\/.+/);
|
||||||
@@ -116,7 +115,15 @@ test.describe("Marketplace – Basic Functionality", () => {
|
|||||||
const searchTerm = page.getByText("DummyInput").first();
|
const searchTerm = page.getByText("DummyInput").first();
|
||||||
await isVisible(searchTerm);
|
await isVisible(searchTerm);
|
||||||
|
|
||||||
await page.waitForTimeout(10000);
|
await page.waitForLoadState("networkidle").catch(() => {});
|
||||||
|
|
||||||
|
await page
|
||||||
|
.waitForFunction(
|
||||||
|
() =>
|
||||||
|
document.querySelectorAll('[data-testid="store-card"]').length > 0,
|
||||||
|
{ timeout: 15000 },
|
||||||
|
)
|
||||||
|
.catch(() => console.log("No search results appeared within timeout"));
|
||||||
|
|
||||||
const results = await marketplacePage.getSearchResultsCount(page);
|
const results = await marketplacePage.getSearchResultsCount(page);
|
||||||
expect(results).toBeGreaterThan(0);
|
expect(results).toBeGreaterThan(0);
|
||||||
|
|||||||
@@ -300,21 +300,27 @@ export class LibraryPage extends BasePage {
|
|||||||
async scrollToLoadMore(): Promise<void> {
|
async scrollToLoadMore(): Promise<void> {
|
||||||
console.log(`scrolling to load more agents`);
|
console.log(`scrolling to load more agents`);
|
||||||
|
|
||||||
// Get initial agent count
|
const initialCount = await this.getAgentCountByListLength();
|
||||||
const initialCount = await this.getAgentCount();
|
console.log(`Initial agent count (DOM cards): ${initialCount}`);
|
||||||
console.log(`Initial agent count: ${initialCount}`);
|
|
||||||
|
|
||||||
// Scroll down to trigger pagination
|
|
||||||
await this.scrollToBottom();
|
await this.scrollToBottom();
|
||||||
|
|
||||||
// Wait for potential new agents to load
|
await this.page
|
||||||
await this.page.waitForTimeout(2000);
|
.waitForLoadState("networkidle", { timeout: 10000 })
|
||||||
|
.catch(() => console.log("Network idle timeout, continuing..."));
|
||||||
|
|
||||||
// Check if more agents loaded
|
await this.page
|
||||||
const newCount = await this.getAgentCount();
|
.waitForFunction(
|
||||||
console.log(`New agent count after scroll: ${newCount}`);
|
(prevCount) =>
|
||||||
|
document.querySelectorAll('[data-testid="library-agent-card"]')
|
||||||
|
.length > prevCount,
|
||||||
|
initialCount,
|
||||||
|
{ timeout: 5000 },
|
||||||
|
)
|
||||||
|
.catch(() => {});
|
||||||
|
|
||||||
return;
|
const newCount = await this.getAgentCountByListLength();
|
||||||
|
console.log(`New agent count after scroll (DOM cards): ${newCount}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
async testPagination(): Promise<{
|
async testPagination(): Promise<{
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ export class MarketplacePage extends BasePage {
|
|||||||
|
|
||||||
async goto(page: Page) {
|
async goto(page: Page) {
|
||||||
await page.goto("/marketplace");
|
await page.goto("/marketplace");
|
||||||
|
await page.waitForLoadState("networkidle").catch(() => {});
|
||||||
}
|
}
|
||||||
|
|
||||||
async getMarketplaceTitle(page: Page) {
|
async getMarketplaceTitle(page: Page) {
|
||||||
@@ -109,16 +110,24 @@ export class MarketplacePage extends BasePage {
|
|||||||
|
|
||||||
async getFirstFeaturedAgent(page: Page) {
|
async getFirstFeaturedAgent(page: Page) {
|
||||||
const { getId } = getSelectors(page);
|
const { getId } = getSelectors(page);
|
||||||
return getId("featured-store-card").first();
|
const card = getId("featured-store-card").first();
|
||||||
|
await card.waitFor({ state: "visible", timeout: 30000 });
|
||||||
|
return card;
|
||||||
}
|
}
|
||||||
|
|
||||||
async getFirstTopAgent() {
|
async getFirstTopAgent() {
|
||||||
return this.page.locator('[data-testid="store-card"]:visible').first();
|
const card = this.page
|
||||||
|
.locator('[data-testid="store-card"]:visible')
|
||||||
|
.first();
|
||||||
|
await card.waitFor({ state: "visible", timeout: 30000 });
|
||||||
|
return card;
|
||||||
}
|
}
|
||||||
|
|
||||||
async getFirstCreatorProfile(page: Page) {
|
async getFirstCreatorProfile(page: Page) {
|
||||||
const { getId } = getSelectors(page);
|
const { getId } = getSelectors(page);
|
||||||
return getId("creator-card").first();
|
const card = getId("creator-card").first();
|
||||||
|
await card.waitFor({ state: "visible", timeout: 30000 });
|
||||||
|
return card;
|
||||||
}
|
}
|
||||||
|
|
||||||
async getSearchResultsCount(page: Page) {
|
async getSearchResultsCount(page: Page) {
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
|||||||
| [Block Installation](block-integrations/basic.md#block-installation) | Given a code string, this block allows the verification and installation of a block code into the system |
|
| [Block Installation](block-integrations/basic.md#block-installation) | Given a code string, this block allows the verification and installation of a block code into the system |
|
||||||
| [Concatenate Lists](block-integrations/basic.md#concatenate-lists) | Concatenates multiple lists into a single list |
|
| [Concatenate Lists](block-integrations/basic.md#concatenate-lists) | Concatenates multiple lists into a single list |
|
||||||
| [Dictionary Is Empty](block-integrations/basic.md#dictionary-is-empty) | Checks if a dictionary is empty |
|
| [Dictionary Is Empty](block-integrations/basic.md#dictionary-is-empty) | Checks if a dictionary is empty |
|
||||||
| [File Store](block-integrations/basic.md#file-store) | Stores the input file in the temporary directory |
|
| [File Store](block-integrations/basic.md#file-store) | Downloads and stores a file from a URL, data URI, or local path |
|
||||||
| [Find In Dictionary](block-integrations/basic.md#find-in-dictionary) | A block that looks up a value in a dictionary, list, or object by key or index and returns the corresponding value |
|
| [Find In Dictionary](block-integrations/basic.md#find-in-dictionary) | A block that looks up a value in a dictionary, list, or object by key or index and returns the corresponding value |
|
||||||
| [Find In List](block-integrations/basic.md#find-in-list) | Finds the index of the value in the list |
|
| [Find In List](block-integrations/basic.md#find-in-list) | Finds the index of the value in the list |
|
||||||
| [Get All Memories](block-integrations/basic.md#get-all-memories) | Retrieve all memories from Mem0 with optional conversation filtering |
|
| [Get All Memories](block-integrations/basic.md#get-all-memories) | Retrieve all memories from Mem0 with optional conversation filtering |
|
||||||
|
|||||||
@@ -709,7 +709,7 @@ This is useful for conditional logic where you need to verify if data was return
|
|||||||
## File Store
|
## File Store
|
||||||
|
|
||||||
### What it is
|
### What it is
|
||||||
Stores the input file in the temporary directory.
|
Downloads and stores a file from a URL, data URI, or local path. Use this to fetch images, documents, or other files for processing. In CoPilot: saves to workspace (use list_workspace_files to see it). In graphs: outputs a data URI to pass to other blocks.
|
||||||
|
|
||||||
### How it works
|
### How it works
|
||||||
<!-- MANUAL: how_it_works -->
|
<!-- MANUAL: how_it_works -->
|
||||||
@@ -722,15 +722,15 @@ The block outputs a file path that other blocks can use to access the stored fil
|
|||||||
|
|
||||||
| Input | Description | Type | Required |
|
| Input | Description | Type | Required |
|
||||||
|-------|-------------|------|----------|
|
|-------|-------------|------|----------|
|
||||||
| file_in | The file to store in the temporary directory, it can be a URL, data URI, or local path. | str (file) | Yes |
|
| file_in | The file to download and store. Can be a URL (https://...), data URI, or local path. | str (file) | Yes |
|
||||||
| base_64 | Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks). | bool | No |
|
| base_64 | Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks). | bool | No |
|
||||||
|
|
||||||
### Outputs
|
### Outputs
|
||||||
|
|
||||||
| Output | Description | Type |
|
| Output | Description | Type |
|
||||||
|--------|-------------|------|
|
|--------|-------------|------|
|
||||||
| error | Error message if the operation failed | str |
|
| error | Error message if the operation failed | str |
|
||||||
| file_out | The relative path to the stored file in the temporary directory. | str (file) |
|
| file_out | Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks. | str (file) |
|
||||||
|
|
||||||
### Possible use case
|
### Possible use case
|
||||||
<!-- MANUAL: use_case -->
|
<!-- MANUAL: use_case -->
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ The result routes data to yes_output or no_output, enabling intelligent branchin
|
|||||||
| condition | A plaintext English description of the condition to evaluate | str | Yes |
|
| condition | A plaintext English description of the condition to evaluate | str | Yes |
|
||||||
| yes_value | (Optional) Value to output if the condition is true. If not provided, input_value will be used. | Yes Value | No |
|
| yes_value | (Optional) Value to output if the condition is true. If not provided, input_value will be used. | Yes Value | No |
|
||||||
| no_value | (Optional) Value to output if the condition is false. If not provided, input_value will be used. | No Value | No |
|
| no_value | (Optional) Value to output if the condition is false. If not provided, input_value will be used. | No Value | No |
|
||||||
| model | The language model to use for evaluating the condition. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "gpt-3.5-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-3-7-sonnet-20250219" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-3-pro-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No |
|
| model | The language model to use for evaluating the condition. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "gpt-3.5-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-3-pro-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No |
|
||||||
|
|
||||||
### Outputs
|
### Outputs
|
||||||
|
|
||||||
@@ -103,7 +103,7 @@ The block sends the entire conversation history to the chosen LLM, including sys
|
|||||||
|-------|-------------|------|----------|
|
|-------|-------------|------|----------|
|
||||||
| prompt | The prompt to send to the language model. | str | No |
|
| prompt | The prompt to send to the language model. | str | No |
|
||||||
| messages | List of messages in the conversation. | List[Any] | Yes |
|
| messages | List of messages in the conversation. | List[Any] | Yes |
|
||||||
| model | The language model to use for the conversation. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "gpt-3.5-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-3-7-sonnet-20250219" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-3-pro-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No |
|
| model | The language model to use for the conversation. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "gpt-3.5-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-3-pro-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No |
|
||||||
| max_tokens | The maximum number of tokens to generate in the chat completion. | int | No |
|
| max_tokens | The maximum number of tokens to generate in the chat completion. | int | No |
|
||||||
| ollama_host | Ollama host for local models | str | No |
|
| ollama_host | Ollama host for local models | str | No |
|
||||||
|
|
||||||
@@ -257,7 +257,7 @@ The block formulates a prompt based on the given focus or source data, sends it
|
|||||||
|-------|-------------|------|----------|
|
|-------|-------------|------|----------|
|
||||||
| focus | The focus of the list to generate. | str | No |
|
| focus | The focus of the list to generate. | str | No |
|
||||||
| source_data | The data to generate the list from. | str | No |
|
| source_data | The data to generate the list from. | str | No |
|
||||||
| model | The language model to use for generating the list. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "gpt-3.5-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-3-7-sonnet-20250219" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-3-pro-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No |
|
| model | The language model to use for generating the list. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "gpt-3.5-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-3-pro-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No |
|
||||||
| max_retries | Maximum number of retries for generating a valid list. | int | No |
|
| max_retries | Maximum number of retries for generating a valid list. | int | No |
|
||||||
| force_json_output | Whether to force the LLM to produce a JSON-only response. This can increase the block's reliability, but may also reduce the quality of the response because it prohibits the LLM from reasoning before providing its JSON response. | bool | No |
|
| force_json_output | Whether to force the LLM to produce a JSON-only response. This can increase the block's reliability, but may also reduce the quality of the response because it prohibits the LLM from reasoning before providing its JSON response. | bool | No |
|
||||||
| max_tokens | The maximum number of tokens to generate in the chat completion. | int | No |
|
| max_tokens | The maximum number of tokens to generate in the chat completion. | int | No |
|
||||||
@@ -424,7 +424,7 @@ The block sends the input prompt to a chosen LLM, along with any system prompts
|
|||||||
| prompt | The prompt to send to the language model. | str | Yes |
|
| prompt | The prompt to send to the language model. | str | Yes |
|
||||||
| expected_format | Expected format of the response. If provided, the response will be validated against this format. The keys should be the expected fields in the response, and the values should be the description of the field. | Dict[str, str] | Yes |
|
| expected_format | Expected format of the response. If provided, the response will be validated against this format. The keys should be the expected fields in the response, and the values should be the description of the field. | Dict[str, str] | Yes |
|
||||||
| list_result | Whether the response should be a list of objects in the expected format. | bool | No |
|
| list_result | Whether the response should be a list of objects in the expected format. | bool | No |
|
||||||
| model | The language model to use for answering the prompt. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "gpt-3.5-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-3-7-sonnet-20250219" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-3-pro-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No |
|
| model | The language model to use for answering the prompt. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "gpt-3.5-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-3-pro-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No |
|
||||||
| force_json_output | Whether to force the LLM to produce a JSON-only response. This can increase the block's reliability, but may also reduce the quality of the response because it prohibits the LLM from reasoning before providing its JSON response. | bool | No |
|
| force_json_output | Whether to force the LLM to produce a JSON-only response. This can increase the block's reliability, but may also reduce the quality of the response because it prohibits the LLM from reasoning before providing its JSON response. | bool | No |
|
||||||
| sys_prompt | The system prompt to provide additional context to the model. | str | No |
|
| sys_prompt | The system prompt to provide additional context to the model. | str | No |
|
||||||
| conversation_history | The conversation history to provide context for the prompt. | List[Dict[str, Any]] | No |
|
| conversation_history | The conversation history to provide context for the prompt. | List[Dict[str, Any]] | No |
|
||||||
@@ -464,7 +464,7 @@ The block sends the input prompt to a chosen LLM, processes the response, and re
|
|||||||
| Input | Description | Type | Required |
|
| Input | Description | Type | Required |
|
||||||
|-------|-------------|------|----------|
|
|-------|-------------|------|----------|
|
||||||
| prompt | The prompt to send to the language model. You can use any of the {keys} from Prompt Values to fill in the prompt with values from the prompt values dictionary by putting them in curly braces. | str | Yes |
|
| prompt | The prompt to send to the language model. You can use any of the {keys} from Prompt Values to fill in the prompt with values from the prompt values dictionary by putting them in curly braces. | str | Yes |
|
||||||
| model | The language model to use for answering the prompt. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "gpt-3.5-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-3-7-sonnet-20250219" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-3-pro-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No |
|
| model | The language model to use for answering the prompt. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "gpt-3.5-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-3-pro-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No |
|
||||||
| sys_prompt | The system prompt to provide additional context to the model. | str | No |
|
| sys_prompt | The system prompt to provide additional context to the model. | str | No |
|
||||||
| retry | Number of times to retry the LLM call if the response does not match the expected format. | int | No |
|
| retry | Number of times to retry the LLM call if the response does not match the expected format. | int | No |
|
||||||
| prompt_values | Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}. | Dict[str, str] | No |
|
| prompt_values | Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}. | Dict[str, str] | No |
|
||||||
@@ -501,7 +501,7 @@ The block splits the input text into smaller chunks, sends each chunk to an LLM
|
|||||||
| Input | Description | Type | Required |
|
| Input | Description | Type | Required |
|
||||||
|-------|-------------|------|----------|
|
|-------|-------------|------|----------|
|
||||||
| text | The text to summarize. | str | Yes |
|
| text | The text to summarize. | str | Yes |
|
||||||
| model | The language model to use for summarizing the text. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "gpt-3.5-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-3-7-sonnet-20250219" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-3-pro-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No |
|
| model | The language model to use for summarizing the text. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "gpt-3.5-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-3-pro-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No |
|
||||||
| focus | The topic to focus on in the summary | str | No |
|
| focus | The topic to focus on in the summary | str | No |
|
||||||
| style | The style of the summary to generate. | "concise" \| "detailed" \| "bullet points" \| "numbered list" | No |
|
| style | The style of the summary to generate. | "concise" \| "detailed" \| "bullet points" \| "numbered list" | No |
|
||||||
| max_tokens | The maximum number of tokens to generate in the chat completion. | int | No |
|
| max_tokens | The maximum number of tokens to generate in the chat completion. | int | No |
|
||||||
@@ -763,7 +763,7 @@ Configure agent_mode_max_iterations to control loop behavior: 0 for single decis
|
|||||||
| Input | Description | Type | Required |
|
| Input | Description | Type | Required |
|
||||||
|-------|-------------|------|----------|
|
|-------|-------------|------|----------|
|
||||||
| prompt | The prompt to send to the language model. | str | Yes |
|
| prompt | The prompt to send to the language model. | str | Yes |
|
||||||
| model | The language model to use for answering the prompt. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "gpt-3.5-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-3-7-sonnet-20250219" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-3-pro-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No |
|
| model | The language model to use for answering the prompt. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "gpt-3.5-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-3-pro-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No |
|
||||||
| multiple_tool_calls | Whether to allow multiple tool calls in a single response. | bool | No |
|
| multiple_tool_calls | Whether to allow multiple tool calls in a single response. | bool | No |
|
||||||
| sys_prompt | The system prompt to provide additional context to the model. | str | No |
|
| sys_prompt | The system prompt to provide additional context to the model. | str | No |
|
||||||
| conversation_history | The conversation history to provide context for the prompt. | List[Dict[str, Any]] | No |
|
| conversation_history | The conversation history to provide context for the prompt. | List[Dict[str, Any]] | No |
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ Block to attach an audio file to a video file using moviepy.
|
|||||||
<!-- MANUAL: how_it_works -->
|
<!-- MANUAL: how_it_works -->
|
||||||
This block combines a video file with an audio file using the moviepy library. The audio track is attached to the video, optionally with volume adjustment via the volume parameter (1.0 = original volume).
|
This block combines a video file with an audio file using the moviepy library. The audio track is attached to the video, optionally with volume adjustment via the volume parameter (1.0 = original volume).
|
||||||
|
|
||||||
Input files can be URLs, data URIs, or local paths. The output can be returned as either a file path or base64 data URI.
|
Input files can be URLs, data URIs, or local paths. The output format is automatically determined: `workspace://` URLs in CoPilot, data URIs in graph executions.
|
||||||
<!-- END MANUAL -->
|
<!-- END MANUAL -->
|
||||||
|
|
||||||
### Inputs
|
### Inputs
|
||||||
@@ -22,7 +22,6 @@ Input files can be URLs, data URIs, or local paths. The output can be returned a
|
|||||||
| video_in | Video input (URL, data URI, or local path). | str (file) | Yes |
|
| video_in | Video input (URL, data URI, or local path). | str (file) | Yes |
|
||||||
| audio_in | Audio input (URL, data URI, or local path). | str (file) | Yes |
|
| audio_in | Audio input (URL, data URI, or local path). | str (file) | Yes |
|
||||||
| volume | Volume scale for the newly attached audio track (1.0 = original). | float | No |
|
| volume | Volume scale for the newly attached audio track (1.0 = original). | float | No |
|
||||||
| output_return_type | Return the final output as a relative path or base64 data URI. | "file_path" \| "data_uri" | No |
|
|
||||||
|
|
||||||
### Outputs
|
### Outputs
|
||||||
|
|
||||||
@@ -51,7 +50,7 @@ Block to loop a video to a given duration or number of repeats.
|
|||||||
<!-- MANUAL: how_it_works -->
|
<!-- MANUAL: how_it_works -->
|
||||||
This block extends a video by repeating it to reach a target duration or number of loops. Set duration to specify the total length in seconds, or use n_loops to repeat the video a specific number of times.
|
This block extends a video by repeating it to reach a target duration or number of loops. Set duration to specify the total length in seconds, or use n_loops to repeat the video a specific number of times.
|
||||||
|
|
||||||
The looped video is seamlessly concatenated and can be output as a file path or base64 data URI.
|
The looped video is seamlessly concatenated. The output format is automatically determined: `workspace://` URLs in CoPilot, data URIs in graph executions.
|
||||||
<!-- END MANUAL -->
|
<!-- END MANUAL -->
|
||||||
|
|
||||||
### Inputs
|
### Inputs
|
||||||
@@ -61,7 +60,6 @@ The looped video is seamlessly concatenated and can be output as a file path or
|
|||||||
| video_in | The input video (can be a URL, data URI, or local path). | str (file) | Yes |
|
| video_in | The input video (can be a URL, data URI, or local path). | str (file) | Yes |
|
||||||
| duration | Target duration (in seconds) to loop the video to. If omitted, defaults to no looping. | float | No |
|
| duration | Target duration (in seconds) to loop the video to. If omitted, defaults to no looping. | float | No |
|
||||||
| n_loops | Number of times to repeat the video. If omitted, defaults to 1 (no repeat). | int | No |
|
| n_loops | Number of times to repeat the video. If omitted, defaults to 1 (no repeat). | int | No |
|
||||||
| output_return_type | How to return the output video. Either a relative path or base64 data URI. | "file_path" \| "data_uri" | No |
|
|
||||||
|
|
||||||
### Outputs
|
### Outputs
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ Configure timeouts for DOM settlement and page loading. Variables can be passed
|
|||||||
| Input | Description | Type | Required |
|
| Input | Description | Type | Required |
|
||||||
|-------|-------------|------|----------|
|
|-------|-------------|------|----------|
|
||||||
| browserbase_project_id | Browserbase project ID (required if using Browserbase) | str | Yes |
|
| browserbase_project_id | Browserbase project ID (required if using Browserbase) | str | Yes |
|
||||||
| model | LLM to use for Stagehand (provider is inferred) | "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "claude-3-7-sonnet-20250219" | No |
|
| model | LLM to use for Stagehand (provider is inferred) | "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "claude-sonnet-4-5-20250929" | No |
|
||||||
| url | URL to navigate to. | str | Yes |
|
| url | URL to navigate to. | str | Yes |
|
||||||
| action | Action to perform. Suggested actions are: click, fill, type, press, scroll, select from dropdown. For multi-step actions, add an entry for each step. | List[str] | Yes |
|
| action | Action to perform. Suggested actions are: click, fill, type, press, scroll, select from dropdown. For multi-step actions, add an entry for each step. | List[str] | Yes |
|
||||||
| variables | Variables to use in the action. Variables contains data you want the action to use. | Dict[str, str] | No |
|
| variables | Variables to use in the action. Variables contains data you want the action to use. | Dict[str, str] | No |
|
||||||
@@ -65,7 +65,7 @@ Supports searching within iframes and configurable timeouts for dynamic content
|
|||||||
| Input | Description | Type | Required |
|
| Input | Description | Type | Required |
|
||||||
|-------|-------------|------|----------|
|
|-------|-------------|------|----------|
|
||||||
| browserbase_project_id | Browserbase project ID (required if using Browserbase) | str | Yes |
|
| browserbase_project_id | Browserbase project ID (required if using Browserbase) | str | Yes |
|
||||||
| model | LLM to use for Stagehand (provider is inferred) | "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "claude-3-7-sonnet-20250219" | No |
|
| model | LLM to use for Stagehand (provider is inferred) | "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "claude-sonnet-4-5-20250929" | No |
|
||||||
| url | URL to navigate to. | str | Yes |
|
| url | URL to navigate to. | str | Yes |
|
||||||
| instruction | Natural language description of elements or actions to discover. | str | Yes |
|
| instruction | Natural language description of elements or actions to discover. | str | Yes |
|
||||||
| iframes | Whether to search within iframes. If True, Stagehand will search for actions within iframes. | bool | No |
|
| iframes | Whether to search within iframes. If True, Stagehand will search for actions within iframes. | bool | No |
|
||||||
@@ -106,7 +106,7 @@ Use this to explore a page's interactive elements before building automated work
|
|||||||
| Input | Description | Type | Required |
|
| Input | Description | Type | Required |
|
||||||
|-------|-------------|------|----------|
|
|-------|-------------|------|----------|
|
||||||
| browserbase_project_id | Browserbase project ID (required if using Browserbase) | str | Yes |
|
| browserbase_project_id | Browserbase project ID (required if using Browserbase) | str | Yes |
|
||||||
| model | LLM to use for Stagehand (provider is inferred) | "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "claude-3-7-sonnet-20250219" | No |
|
| model | LLM to use for Stagehand (provider is inferred) | "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "claude-sonnet-4-5-20250929" | No |
|
||||||
| url | URL to navigate to. | str | Yes |
|
| url | URL to navigate to. | str | Yes |
|
||||||
| instruction | Natural language description of elements or actions to discover. | str | Yes |
|
| instruction | Natural language description of elements or actions to discover. | str | Yes |
|
||||||
| iframes | Whether to search within iframes. If True, Stagehand will search for actions within iframes. | bool | No |
|
| iframes | Whether to search within iframes. If True, Stagehand will search for actions within iframes. | bool | No |
|
||||||
|
|||||||
@@ -277,6 +277,50 @@ async def run(
|
|||||||
token = credentials.api_key.get_secret_value()
|
token = credentials.api_key.get_secret_value()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Handling Files
|
||||||
|
|
||||||
|
When your block works with files (images, videos, documents), use `store_media_file()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.util.file import store_media_file
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# PROCESSING: Need local file path for tools like ffmpeg, MoviePy, PIL
|
||||||
|
local_path = await store_media_file(
|
||||||
|
file=input_data.video,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
# EXTERNAL API: Need base64 content for APIs like Replicate, OpenAI
|
||||||
|
image_b64 = await store_media_file(
|
||||||
|
file=input_data.image,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_external_api",
|
||||||
|
)
|
||||||
|
|
||||||
|
# OUTPUT: Return to user/next block (auto-adapts to context)
|
||||||
|
result = await store_media_file(
|
||||||
|
file=generated_url,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output", # workspace:// in CoPilot, data URI in graphs
|
||||||
|
)
|
||||||
|
yield "image_url", result
|
||||||
|
```
|
||||||
|
|
||||||
|
**Return format options:**
|
||||||
|
- `"for_local_processing"` - Local file path for processing tools
|
||||||
|
- `"for_external_api"` - Data URI for external APIs needing base64
|
||||||
|
- `"for_block_output"` - **Always use for outputs** - automatically picks best format
|
||||||
|
|
||||||
## Testing Your Block
|
## Testing Your Block
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ This document focuses on the **API Integration OAuth flow** used for connecting
|
|||||||
### 2. Backend API Trust Boundary
|
### 2. Backend API Trust Boundary
|
||||||
- **Location**: Server-side FastAPI application
|
- **Location**: Server-side FastAPI application
|
||||||
- **Components**:
|
- **Components**:
|
||||||
- Integration router (`/backend/backend/server/integrations/router.py`)
|
- Integration router (`/backend/backend/api/features/integrations/router.py`)
|
||||||
- OAuth handlers (`/backend/backend/integrations/oauth/`)
|
- OAuth handlers (`/backend/backend/integrations/oauth/`)
|
||||||
- Credentials store (`/backend/backend/integrations/credentials_store.py`)
|
- Credentials store (`/backend/backend/integrations/credentials_store.py`)
|
||||||
- **Trust Level**: Trusted - server-controlled environment
|
- **Trust Level**: Trusted - server-controlled environment
|
||||||
|
|||||||
@@ -111,6 +111,71 @@ Follow these steps to create and test a new block:
|
|||||||
- `graph_exec_id`: The ID of the execution of the agent. This changes every time the agent has a new "run"
|
- `graph_exec_id`: The ID of the execution of the agent. This changes every time the agent has a new "run"
|
||||||
- `node_exec_id`: The ID of the execution of the node. This changes every time the node is executed
|
- `node_exec_id`: The ID of the execution of the node. This changes every time the node is executed
|
||||||
- `node_id`: The ID of the node that is being executed. It changes every version of the graph, but not every time the node is executed.
|
- `node_id`: The ID of the node that is being executed. It changes every version of the graph, but not every time the node is executed.
|
||||||
|
- `execution_context`: An `ExecutionContext` object containing user_id, graph_exec_id, workspace_id, and session_id. Required for file handling.
|
||||||
|
|
||||||
|
### Handling Files in Blocks
|
||||||
|
|
||||||
|
When your block needs to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. This function handles downloading, validation, virus scanning, and storage.
|
||||||
|
|
||||||
|
**Import:**
|
||||||
|
```python
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.util.file import store_media_file
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
```
|
||||||
|
|
||||||
|
**The `return_format` parameter determines what you get back:**
|
||||||
|
|
||||||
|
| Format | Use When | Returns |
|
||||||
|
|--------|----------|---------|
|
||||||
|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||||
|
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||||
|
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||||
|
|
||||||
|
**Examples:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# PROCESSING: Need to work with file locally (ffmpeg, MoviePy, PIL)
|
||||||
|
local_path = await store_media_file(
|
||||||
|
file=input_data.video,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
# local_path = "video.mp4" - use with Path, ffmpeg, subprocess, etc.
|
||||||
|
full_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
|
|
||||||
|
# EXTERNAL API: Need to send content to an API like Replicate
|
||||||
|
image_b64 = await store_media_file(
|
||||||
|
file=input_data.image,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_external_api",
|
||||||
|
)
|
||||||
|
# image_b64 = "data:image/png;base64,iVBORw0..." - send to external API
|
||||||
|
|
||||||
|
# OUTPUT: Returning result from block to user/next block
|
||||||
|
result_url = await store_media_file(
|
||||||
|
file=generated_image_url,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", result_url
|
||||||
|
# In CoPilot: result_url = "workspace://abc123" (persistent, context-efficient)
|
||||||
|
# In graphs: result_url = "data:image/png;base64,..." (for next block/display)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key points:**
|
||||||
|
|
||||||
|
- `for_block_output` is the **only** format that auto-adapts to execution context
|
||||||
|
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||||
|
- Never manually check for `workspace_id` - let `for_block_output` handle the logic
|
||||||
|
- The function handles URLs, data URIs, `workspace://` references, and local paths as input
|
||||||
|
|
||||||
### Field Types
|
### Field Types
|
||||||
|
|
||||||
|
|||||||
@@ -246,7 +246,7 @@ If you encounter any issues, verify that:
|
|||||||
```bash
|
```bash
|
||||||
ollama pull llama3.2
|
ollama pull llama3.2
|
||||||
```
|
```
|
||||||
- If using a custom model, ensure it's added to the model list in `backend/server/model.py`
|
- If using a custom model, ensure it's added to the model list in `backend/api/model.py`
|
||||||
|
|
||||||
#### Docker Issues
|
#### Docker Issues
|
||||||
- Ensure Docker daemon is running:
|
- Ensure Docker daemon is running:
|
||||||
|
|||||||
Reference in New Issue
Block a user