mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-05 20:35:10 -05:00
Compare commits
56 Commits
release/v0
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
85b6520710 | ||
|
|
bfa942e032 | ||
|
|
11256076d8 | ||
|
|
3ca2387631 | ||
|
|
ed07f02738 | ||
|
|
b121030c94 | ||
|
|
c22c18374d | ||
|
|
e40233a3ac | ||
|
|
3ae5eabf9d | ||
|
|
a077ba9f03 | ||
|
|
5401d54eaa | ||
|
|
5ac89d7c0b | ||
|
|
4f908d5cb3 | ||
|
|
c1aa684743 | ||
|
|
7e5b84cc5c | ||
|
|
09cb313211 | ||
|
|
c026485023 | ||
|
|
1eabc60484 | ||
|
|
f4bf492f24 | ||
|
|
81e48c00a4 | ||
|
|
7dc53071e8 | ||
|
|
4878665c66 | ||
|
|
678ddde751 | ||
|
|
aef6f57cfd | ||
|
|
14cee1670a | ||
|
|
d81d1ce024 | ||
|
|
2dd341c369 | ||
|
|
f7350c797a | ||
|
|
1081590384 | ||
|
|
7e37de8e30 | ||
|
|
2abbb7fbc8 | ||
|
|
7ee94d986c | ||
|
|
05b60db554 | ||
|
|
18a1661fa3 | ||
|
|
b72521daa9 | ||
|
|
cc4839bedb | ||
|
|
dbbff04616 | ||
|
|
350ad3591b | ||
|
|
e6438b9a76 | ||
|
|
de0ec3d388 | ||
|
|
e10ff8d37f | ||
|
|
7cb1e588b0 | ||
|
|
582c6cad36 | ||
|
|
3b822cdaf7 | ||
|
|
b2eb4831bd | ||
|
|
4cd5da678d | ||
|
|
9538992eaf | ||
|
|
b94c83aacc | ||
|
|
7668c17d9c | ||
|
|
27b72062f2 | ||
|
|
e0dfae5732 | ||
|
|
9a79a8d257 | ||
|
|
7df867d645 | ||
|
|
a9bf08748b | ||
|
|
d855f79874 | ||
|
|
dac99694fe |
@@ -29,8 +29,7 @@
|
|||||||
"postCreateCmd": [
|
"postCreateCmd": [
|
||||||
"cd autogpt_platform/autogpt_libs && poetry install",
|
"cd autogpt_platform/autogpt_libs && poetry install",
|
||||||
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
||||||
"cd autogpt_platform/frontend && pnpm install",
|
"cd autogpt_platform/frontend && pnpm install"
|
||||||
"cd docs && pip install -r requirements.txt"
|
|
||||||
],
|
],
|
||||||
"terminalCommand": "code .",
|
"terminalCommand": "code .",
|
||||||
"deleteBranchWithWorktree": false
|
"deleteBranchWithWorktree": false
|
||||||
|
|||||||
6
.github/copilot-instructions.md
vendored
6
.github/copilot-instructions.md
vendored
@@ -160,7 +160,7 @@ pnpm storybook # Start component development server
|
|||||||
|
|
||||||
**Backend Entry Points:**
|
**Backend Entry Points:**
|
||||||
|
|
||||||
- `backend/backend/server/server.py` - FastAPI application setup
|
- `backend/backend/api/rest_api.py` - FastAPI application setup
|
||||||
- `backend/backend/data/` - Database models and user management
|
- `backend/backend/data/` - Database models and user management
|
||||||
- `backend/blocks/` - Agent execution blocks and logic
|
- `backend/blocks/` - Agent execution blocks and logic
|
||||||
|
|
||||||
@@ -219,7 +219,7 @@ Agents are built using a visual block-based system where each block performs a s
|
|||||||
|
|
||||||
### API Development
|
### API Development
|
||||||
|
|
||||||
1. Update routes in `/backend/backend/server/routers/`
|
1. Update routes in `/backend/backend/api/features/`
|
||||||
2. Add/update Pydantic models in same directory
|
2. Add/update Pydantic models in same directory
|
||||||
3. Write tests alongside route files
|
3. Write tests alongside route files
|
||||||
4. For `data/*.py` changes, validate user ID checks
|
4. For `data/*.py` changes, validate user ID checks
|
||||||
@@ -285,7 +285,7 @@ Agents are built using a visual block-based system where each block performs a s
|
|||||||
|
|
||||||
### Security Guidelines
|
### Security Guidelines
|
||||||
|
|
||||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
**Cache Protection Middleware** (`/backend/backend/api/middleware/security.py`):
|
||||||
|
|
||||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -178,4 +178,6 @@ autogpt_platform/backend/settings.py
|
|||||||
*.ign.*
|
*.ign.*
|
||||||
.test-contents
|
.test-contents
|
||||||
.claude/settings.local.json
|
.claude/settings.local.json
|
||||||
|
CLAUDE.local.md
|
||||||
/autogpt_platform/backend/logs
|
/autogpt_platform/backend/logs
|
||||||
|
.next
|
||||||
24
AGENTS.md
24
AGENTS.md
@@ -16,7 +16,6 @@ See `docs/content/platform/getting-started.md` for setup instructions.
|
|||||||
- Format Python code with `poetry run format`.
|
- Format Python code with `poetry run format`.
|
||||||
- Format frontend code using `pnpm format`.
|
- Format frontend code using `pnpm format`.
|
||||||
|
|
||||||
|
|
||||||
## Frontend guidelines:
|
## Frontend guidelines:
|
||||||
|
|
||||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||||
@@ -33,14 +32,17 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
|||||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||||
|
|
||||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
||||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
||||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
||||||
- Use function declarations for components, arrow functions only for callbacks
|
- Use function declarations for components, arrow functions only for callbacks
|
||||||
- No barrel files or `index.ts` re-exports
|
- No barrel files or `index.ts` re-exports
|
||||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
|
||||||
- Avoid comments at all times unless the code is very complex
|
- Avoid comments at all times unless the code is very complex
|
||||||
|
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
|
||||||
|
- Do not type hook returns, let Typescript infer as much as possible
|
||||||
|
- Never type with `any`, if not types available use `unknown`
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
@@ -49,22 +51,8 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
|||||||
|
|
||||||
Always run the relevant linters and tests before committing.
|
Always run the relevant linters and tests before committing.
|
||||||
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
||||||
Types:
|
Types: - feat - fix - refactor - ci - dx (developer experience)
|
||||||
- feat
|
Scopes: - platform - platform/library - platform/marketplace - backend - backend/executor - frontend - frontend/library - frontend/marketplace - blocks
|
||||||
- fix
|
|
||||||
- refactor
|
|
||||||
- ci
|
|
||||||
- dx (developer experience)
|
|
||||||
Scopes:
|
|
||||||
- platform
|
|
||||||
- platform/library
|
|
||||||
- platform/marketplace
|
|
||||||
- backend
|
|
||||||
- backend/executor
|
|
||||||
- frontend
|
|
||||||
- frontend/library
|
|
||||||
- frontend/marketplace
|
|
||||||
- blocks
|
|
||||||
|
|
||||||
## Pull requests
|
## Pull requests
|
||||||
|
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ Before proceeding with the installation, ensure your system meets the following
|
|||||||
### Updated Setup Instructions:
|
### Updated Setup Instructions:
|
||||||
We've moved to a fully maintained and regularly updated documentation site.
|
We've moved to a fully maintained and regularly updated documentation site.
|
||||||
|
|
||||||
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
|
👉 [Follow the official self-hosting guide here](https://agpt.co/docs/platform/getting-started/getting-started)
|
||||||
|
|
||||||
|
|
||||||
This tutorial assumes you have Docker, VSCode, git and npm installed.
|
This tutorial assumes you have Docker, VSCode, git and npm installed.
|
||||||
|
|||||||
@@ -6,152 +6,30 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
|||||||
|
|
||||||
AutoGPT Platform is a monorepo containing:
|
AutoGPT Platform is a monorepo containing:
|
||||||
|
|
||||||
- **Backend** (`/backend`): Python FastAPI server with async support
|
- **Backend** (`backend`): Python FastAPI server with async support
|
||||||
- **Frontend** (`/frontend`): Next.js React application
|
- **Frontend** (`frontend`): Next.js React application
|
||||||
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
|
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||||
|
|
||||||
## Essential Commands
|
## Component Documentation
|
||||||
|
|
||||||
### Backend Development
|
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
|
||||||
|
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
|
||||||
|
|
||||||
```bash
|
## Key Concepts
|
||||||
# Install dependencies
|
|
||||||
cd backend && poetry install
|
|
||||||
|
|
||||||
# Run database migrations
|
|
||||||
poetry run prisma migrate dev
|
|
||||||
|
|
||||||
# Start all services (database, redis, rabbitmq, clamav)
|
|
||||||
docker compose up -d
|
|
||||||
|
|
||||||
# Run the backend server
|
|
||||||
poetry run serve
|
|
||||||
|
|
||||||
# Run tests
|
|
||||||
poetry run test
|
|
||||||
|
|
||||||
# Run specific test
|
|
||||||
poetry run pytest path/to/test_file.py::test_function_name
|
|
||||||
|
|
||||||
# Run block tests (tests that validate all blocks work correctly)
|
|
||||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
|
||||||
|
|
||||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
|
||||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
|
||||||
|
|
||||||
# Lint and format
|
|
||||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
|
||||||
poetry run format # Black + isort
|
|
||||||
poetry run lint # ruff
|
|
||||||
```
|
|
||||||
|
|
||||||
More details can be found in TESTING.md
|
|
||||||
|
|
||||||
#### Creating/Updating Snapshots
|
|
||||||
|
|
||||||
When you first write a test or when the expected output changes:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
poetry run pytest path/to/test.py --snapshot-update
|
|
||||||
```
|
|
||||||
|
|
||||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
|
||||||
|
|
||||||
### Frontend Development
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Install dependencies
|
|
||||||
cd frontend && pnpm i
|
|
||||||
|
|
||||||
# Generate API client from OpenAPI spec
|
|
||||||
pnpm generate:api
|
|
||||||
|
|
||||||
# Start development server
|
|
||||||
pnpm dev
|
|
||||||
|
|
||||||
# Run E2E tests
|
|
||||||
pnpm test
|
|
||||||
|
|
||||||
# Run Storybook for component development
|
|
||||||
pnpm storybook
|
|
||||||
|
|
||||||
# Build production
|
|
||||||
pnpm build
|
|
||||||
|
|
||||||
# Format and lint
|
|
||||||
pnpm format
|
|
||||||
|
|
||||||
# Type checking
|
|
||||||
pnpm types
|
|
||||||
```
|
|
||||||
|
|
||||||
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
|
|
||||||
|
|
||||||
**Key Frontend Conventions:**
|
|
||||||
|
|
||||||
- Separate render logic from data/behavior in components
|
|
||||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
|
||||||
- Use function declarations (not arrow functions) for components/handlers
|
|
||||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
|
||||||
- Only use Phosphor Icons
|
|
||||||
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
|
|
||||||
|
|
||||||
## Architecture Overview
|
|
||||||
|
|
||||||
### Backend Architecture
|
|
||||||
|
|
||||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
|
||||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
|
||||||
- **Queue System**: RabbitMQ for async task processing
|
|
||||||
- **Execution Engine**: Separate executor service processes agent workflows
|
|
||||||
- **Authentication**: JWT-based with Supabase integration
|
|
||||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
|
||||||
|
|
||||||
### Frontend Architecture
|
|
||||||
|
|
||||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
|
||||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
|
||||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
|
||||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
|
||||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
|
||||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
|
||||||
- **Icons**: Phosphor Icons only
|
|
||||||
- **Feature Flags**: LaunchDarkly integration
|
|
||||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
|
||||||
- **Testing**: Playwright for E2E, Storybook for component development
|
|
||||||
|
|
||||||
### Key Concepts
|
|
||||||
|
|
||||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||||
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
|
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||||
3. **Integrations**: OAuth and API connections stored per user
|
3. **Integrations**: OAuth and API connections stored per user
|
||||||
4. **Store**: Marketplace for sharing agent templates
|
4. **Store**: Marketplace for sharing agent templates
|
||||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||||
|
|
||||||
### Testing Approach
|
|
||||||
|
|
||||||
- Backend uses pytest with snapshot testing for API responses
|
|
||||||
- Test files are colocated with source files (`*_test.py`)
|
|
||||||
- Frontend uses Playwright for E2E tests
|
|
||||||
- Component testing via Storybook
|
|
||||||
|
|
||||||
### Database Schema
|
|
||||||
|
|
||||||
Key models (defined in `/backend/schema.prisma`):
|
|
||||||
|
|
||||||
- `User`: Authentication and profile data
|
|
||||||
- `AgentGraph`: Workflow definitions with version control
|
|
||||||
- `AgentGraphExecution`: Execution history and results
|
|
||||||
- `AgentNode`: Individual nodes in a workflow
|
|
||||||
- `StoreListing`: Marketplace listings for sharing agents
|
|
||||||
|
|
||||||
### Environment Configuration
|
### Environment Configuration
|
||||||
|
|
||||||
#### Configuration Files
|
#### Configuration Files
|
||||||
|
|
||||||
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
|
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||||
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
|
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||||
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
|
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||||
|
|
||||||
#### Docker Environment Loading Order
|
#### Docker Environment Loading Order
|
||||||
|
|
||||||
@@ -167,83 +45,12 @@ Key models (defined in `/backend/schema.prisma`):
|
|||||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||||
|
|
||||||
### Common Development Tasks
|
|
||||||
|
|
||||||
**Adding a new block:**
|
|
||||||
|
|
||||||
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
|
||||||
|
|
||||||
- Provider configuration with `ProviderBuilder`
|
|
||||||
- Block schema definition
|
|
||||||
- Authentication (API keys, OAuth, webhooks)
|
|
||||||
- Testing and validation
|
|
||||||
- File organization
|
|
||||||
|
|
||||||
Quick steps:
|
|
||||||
|
|
||||||
1. Create new file in `/backend/backend/blocks/`
|
|
||||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
|
||||||
3. Inherit from `Block` base class
|
|
||||||
4. Define input/output schemas using `BlockSchema`
|
|
||||||
5. Implement async `run` method
|
|
||||||
6. Generate unique block ID using `uuid.uuid4()`
|
|
||||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
|
||||||
|
|
||||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
|
||||||
ex: do the inputs and outputs tie well together?
|
|
||||||
|
|
||||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
|
||||||
|
|
||||||
**Modifying the API:**
|
|
||||||
|
|
||||||
1. Update route in `/backend/backend/server/routers/`
|
|
||||||
2. Add/update Pydantic models in same directory
|
|
||||||
3. Write tests alongside the route file
|
|
||||||
4. Run `poetry run test` to verify
|
|
||||||
|
|
||||||
### Frontend guidelines:
|
|
||||||
|
|
||||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
|
||||||
|
|
||||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
|
||||||
- Add `usePageName.ts` hook for logic
|
|
||||||
- Put sub-components in local `components/` folder
|
|
||||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
|
||||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
|
||||||
- Never use `src/components/__legacy__/*`
|
|
||||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
|
||||||
- Regenerate with `pnpm generate:api`
|
|
||||||
- Pattern: `use{Method}{Version}{OperationName}`
|
|
||||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
|
||||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
|
||||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
|
||||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
|
||||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
|
||||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
|
||||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
|
||||||
- Use function declarations for components, arrow functions only for callbacks
|
|
||||||
- No barrel files or `index.ts` re-exports
|
|
||||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
|
||||||
- Avoid comments at all times unless the code is very complex
|
|
||||||
|
|
||||||
### Security Implementation
|
|
||||||
|
|
||||||
**Cache Protection Middleware:**
|
|
||||||
|
|
||||||
- Located in `/backend/backend/server/middleware/security.py`
|
|
||||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
|
||||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
|
||||||
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
|
|
||||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
|
||||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
|
||||||
- Applied to both main API server and external API applications
|
|
||||||
|
|
||||||
### Creating Pull Requests
|
### Creating Pull Requests
|
||||||
|
|
||||||
- Create the PR aginst the `dev` branch of the repository.
|
- Create the PR against the `dev` branch of the repository.
|
||||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
|
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||||
- Use conventional commit messages (see below)/
|
- Use conventional commit messages (see below)
|
||||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
|
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||||
- Run the github pre-commit hooks to ensure code quality.
|
- Run the github pre-commit hooks to ensure code quality.
|
||||||
|
|
||||||
### Reviewing/Revising Pull Requests
|
### Reviewing/Revising Pull Requests
|
||||||
|
|||||||
@@ -152,6 +152,7 @@ REPLICATE_API_KEY=
|
|||||||
REVID_API_KEY=
|
REVID_API_KEY=
|
||||||
SCREENSHOTONE_API_KEY=
|
SCREENSHOTONE_API_KEY=
|
||||||
UNREAL_SPEECH_API_KEY=
|
UNREAL_SPEECH_API_KEY=
|
||||||
|
ELEVENLABS_API_KEY=
|
||||||
|
|
||||||
# Data & Search Services
|
# Data & Search Services
|
||||||
E2B_API_KEY=
|
E2B_API_KEY=
|
||||||
|
|||||||
3
autogpt_platform/backend/.gitignore
vendored
3
autogpt_platform/backend/.gitignore
vendored
@@ -19,3 +19,6 @@ load-tests/*.json
|
|||||||
load-tests/*.log
|
load-tests/*.log
|
||||||
load-tests/node_modules/*
|
load-tests/node_modules/*
|
||||||
migrations/*/rollback*.sql
|
migrations/*/rollback*.sql
|
||||||
|
|
||||||
|
# Workspace files
|
||||||
|
workspaces/
|
||||||
|
|||||||
170
autogpt_platform/backend/CLAUDE.md
Normal file
170
autogpt_platform/backend/CLAUDE.md
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
# CLAUDE.md - Backend
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code when working with the backend.
|
||||||
|
|
||||||
|
## Essential Commands
|
||||||
|
|
||||||
|
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dependencies
|
||||||
|
poetry install
|
||||||
|
|
||||||
|
# Run database migrations
|
||||||
|
poetry run prisma migrate dev
|
||||||
|
|
||||||
|
# Start all services (database, redis, rabbitmq, clamav)
|
||||||
|
docker compose up -d
|
||||||
|
|
||||||
|
# Run the backend as a whole
|
||||||
|
poetry run app
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
poetry run test
|
||||||
|
|
||||||
|
# Run specific test
|
||||||
|
poetry run pytest path/to/test_file.py::test_function_name
|
||||||
|
|
||||||
|
# Run block tests (tests that validate all blocks work correctly)
|
||||||
|
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||||
|
|
||||||
|
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||||
|
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||||
|
|
||||||
|
# Lint and format
|
||||||
|
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||||
|
poetry run format # Black + isort
|
||||||
|
poetry run lint # ruff
|
||||||
|
```
|
||||||
|
|
||||||
|
More details can be found in @TESTING.md
|
||||||
|
|
||||||
|
### Creating/Updating Snapshots
|
||||||
|
|
||||||
|
When you first write a test or when the expected output changes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
poetry run pytest path/to/test.py --snapshot-update
|
||||||
|
```
|
||||||
|
|
||||||
|
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||||
|
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||||
|
- **Queue System**: RabbitMQ for async task processing
|
||||||
|
- **Execution Engine**: Separate executor service processes agent workflows
|
||||||
|
- **Authentication**: JWT-based with Supabase integration
|
||||||
|
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||||
|
|
||||||
|
## Testing Approach
|
||||||
|
|
||||||
|
- Uses pytest with snapshot testing for API responses
|
||||||
|
- Test files are colocated with source files (`*_test.py`)
|
||||||
|
|
||||||
|
## Database Schema
|
||||||
|
|
||||||
|
Key models (defined in `schema.prisma`):
|
||||||
|
|
||||||
|
- `User`: Authentication and profile data
|
||||||
|
- `AgentGraph`: Workflow definitions with version control
|
||||||
|
- `AgentGraphExecution`: Execution history and results
|
||||||
|
- `AgentNode`: Individual nodes in a workflow
|
||||||
|
- `StoreListing`: Marketplace listings for sharing agents
|
||||||
|
|
||||||
|
## Environment Configuration
|
||||||
|
|
||||||
|
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||||
|
|
||||||
|
## Common Development Tasks
|
||||||
|
|
||||||
|
### Adding a new block
|
||||||
|
|
||||||
|
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||||
|
|
||||||
|
- Provider configuration with `ProviderBuilder`
|
||||||
|
- Block schema definition
|
||||||
|
- Authentication (API keys, OAuth, webhooks)
|
||||||
|
- Testing and validation
|
||||||
|
- File organization
|
||||||
|
|
||||||
|
Quick steps:
|
||||||
|
|
||||||
|
1. Create new file in `backend/blocks/`
|
||||||
|
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||||
|
3. Inherit from `Block` base class
|
||||||
|
4. Define input/output schemas using `BlockSchema`
|
||||||
|
5. Implement async `run` method
|
||||||
|
6. Generate unique block ID using `uuid.uuid4()`
|
||||||
|
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||||
|
|
||||||
|
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||||
|
ex: do the inputs and outputs tie well together?
|
||||||
|
|
||||||
|
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||||
|
|
||||||
|
#### Handling files in blocks with `store_media_file()`
|
||||||
|
|
||||||
|
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||||
|
|
||||||
|
| Format | Use When | Returns |
|
||||||
|
|--------|----------|---------|
|
||||||
|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||||
|
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||||
|
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||||
|
|
||||||
|
**Examples:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# INPUT: Need to process file locally with ffmpeg
|
||||||
|
local_path = await store_media_file(
|
||||||
|
file=input_data.video,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||||
|
|
||||||
|
# INPUT: Need to send to external API like Replicate
|
||||||
|
image_b64 = await store_media_file(
|
||||||
|
file=input_data.image,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_external_api",
|
||||||
|
)
|
||||||
|
# image_b64 = "..." - send to API
|
||||||
|
|
||||||
|
# OUTPUT: Returning result from block
|
||||||
|
result_url = await store_media_file(
|
||||||
|
file=generated_image_url,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", result_url
|
||||||
|
# In CoPilot: result_url = "workspace://abc123"
|
||||||
|
# In graphs: result_url = "data:image/png;base64,..."
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key points:**
|
||||||
|
|
||||||
|
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||||
|
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||||
|
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||||
|
|
||||||
|
### Modifying the API
|
||||||
|
|
||||||
|
1. Update route in `backend/api/features/`
|
||||||
|
2. Add/update Pydantic models in same directory
|
||||||
|
3. Write tests alongside the route file
|
||||||
|
4. Run `poetry run test` to verify
|
||||||
|
|
||||||
|
## Security Implementation
|
||||||
|
|
||||||
|
### Cache Protection Middleware
|
||||||
|
|
||||||
|
- Located in `backend/api/middleware/security.py`
|
||||||
|
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||||
|
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||||
|
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||||
|
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||||
|
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||||
|
- Applied to both main API server and external API applications
|
||||||
@@ -62,10 +62,12 @@ ENV POETRY_HOME=/opt/poetry \
|
|||||||
DEBIAN_FRONTEND=noninteractive
|
DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
# Install Python without upgrading system-managed packages
|
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
python3.13 \
|
python3.13 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
|
ffmpeg \
|
||||||
|
imagemagick \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy only necessary files from builder
|
# Copy only necessary files from builder
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ If the test doesn't need the `user_id` specifically, mocking is not necessary as
|
|||||||
|
|
||||||
#### Using Global Auth Fixtures
|
#### Using Global Auth Fixtures
|
||||||
|
|
||||||
Two global auth fixtures are provided by `backend/server/conftest.py`:
|
Two global auth fixtures are provided by `backend/api/conftest.py`:
|
||||||
|
|
||||||
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
|
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
|
||||||
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
|
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ router = fastapi.APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Taken from backend/server/v2/store/db.py
|
# Taken from backend/api/features/store/db.py
|
||||||
def sanitize_query(query: str | None) -> str | None:
|
def sanitize_query(query: str | None) -> str | None:
|
||||||
if query is None:
|
if query is None:
|
||||||
return query
|
return query
|
||||||
|
|||||||
@@ -0,0 +1,368 @@
|
|||||||
|
"""Redis Streams consumer for operation completion messages.
|
||||||
|
|
||||||
|
This module provides a consumer (ChatCompletionConsumer) that listens for
|
||||||
|
completion notifications (OperationCompleteMessage) from external services
|
||||||
|
(like Agent Generator) and triggers the appropriate stream registry and
|
||||||
|
chat service updates via process_operation_success/process_operation_failure.
|
||||||
|
|
||||||
|
Why Redis Streams instead of RabbitMQ?
|
||||||
|
--------------------------------------
|
||||||
|
While the project typically uses RabbitMQ for async task queues (e.g., execution
|
||||||
|
queue), Redis Streams was chosen for chat completion notifications because:
|
||||||
|
|
||||||
|
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
|
||||||
|
Streams (via stream_registry) for message persistence and replay. Using Redis
|
||||||
|
Streams for completion notifications keeps all chat streaming infrastructure
|
||||||
|
in one system, simplifying operations and reducing cross-system coordination.
|
||||||
|
|
||||||
|
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
|
||||||
|
allowing consumers to replay missed messages after reconnection. This aligns
|
||||||
|
with the SSE reconnection pattern where clients can resume from last_message_id.
|
||||||
|
|
||||||
|
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
|
||||||
|
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
|
||||||
|
recovering from dead consumers - ideal for the completion callback pattern.
|
||||||
|
|
||||||
|
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
|
||||||
|
stream_registry) provides lower latency than an additional RabbitMQ hop.
|
||||||
|
|
||||||
|
5. **Atomicity with Task State**: Completion processing often needs to update
|
||||||
|
task metadata stored in Redis. Keeping both in Redis enables simpler
|
||||||
|
transactional semantics without distributed coordination.
|
||||||
|
|
||||||
|
The consumer uses Redis Streams with consumer groups for reliable message
|
||||||
|
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
|
||||||
|
stale pending messages from dead consumers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
from prisma import Prisma
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from redis.exceptions import ResponseError
|
||||||
|
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
|
|
||||||
|
from . import stream_registry
|
||||||
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
|
from .config import ChatConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
config = ChatConfig()
|
||||||
|
|
||||||
|
|
||||||
|
class OperationCompleteMessage(BaseModel):
|
||||||
|
"""Message format for operation completion notifications."""
|
||||||
|
|
||||||
|
operation_id: str
|
||||||
|
task_id: str
|
||||||
|
success: bool
|
||||||
|
result: dict | str | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionConsumer:
|
||||||
|
"""Consumer for chat operation completion messages from Redis Streams.
|
||||||
|
|
||||||
|
This consumer initializes its own Prisma client in start() to ensure
|
||||||
|
database operations work correctly within this async context.
|
||||||
|
|
||||||
|
Uses Redis consumer groups to allow multiple platform pods to consume
|
||||||
|
messages reliably with automatic redelivery on failure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._consumer_task: asyncio.Task | None = None
|
||||||
|
self._running = False
|
||||||
|
self._prisma: Prisma | None = None
|
||||||
|
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the completion consumer."""
|
||||||
|
if self._running:
|
||||||
|
logger.warning("Completion consumer already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create consumer group if it doesn't exist
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
await redis.xgroup_create(
|
||||||
|
config.stream_completion_name,
|
||||||
|
config.stream_consumer_group,
|
||||||
|
id="0",
|
||||||
|
mkstream=True,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Created consumer group '{config.stream_consumer_group}' "
|
||||||
|
f"on stream '{config.stream_completion_name}'"
|
||||||
|
)
|
||||||
|
except ResponseError as e:
|
||||||
|
if "BUSYGROUP" in str(e):
|
||||||
|
logger.debug(
|
||||||
|
f"Consumer group '{config.stream_consumer_group}' already exists"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||||
|
logger.info(
|
||||||
|
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _ensure_prisma(self) -> Prisma:
|
||||||
|
"""Lazily initialize Prisma client on first use."""
|
||||||
|
if self._prisma is None:
|
||||||
|
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||||
|
self._prisma = Prisma(datasource={"url": database_url})
|
||||||
|
await self._prisma.connect()
|
||||||
|
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
||||||
|
return self._prisma
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the completion consumer."""
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
if self._consumer_task:
|
||||||
|
self._consumer_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._consumer_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._consumer_task = None
|
||||||
|
|
||||||
|
if self._prisma:
|
||||||
|
await self._prisma.disconnect()
|
||||||
|
self._prisma = None
|
||||||
|
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
||||||
|
|
||||||
|
logger.info("Chat completion consumer stopped")
|
||||||
|
|
||||||
|
async def _consume_messages(self) -> None:
|
||||||
|
"""Main message consumption loop with retry logic."""
|
||||||
|
max_retries = 10
|
||||||
|
retry_delay = 5 # seconds
|
||||||
|
retry_count = 0
|
||||||
|
block_timeout = 5000 # milliseconds
|
||||||
|
|
||||||
|
while self._running and retry_count < max_retries:
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
|
||||||
|
# Reset retry count on successful connection
|
||||||
|
retry_count = 0
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
# First, claim any stale pending messages from dead consumers
|
||||||
|
# Redis does NOT auto-redeliver pending messages; we must explicitly
|
||||||
|
# claim them using XAUTOCLAIM
|
||||||
|
try:
|
||||||
|
claimed_result = await redis.xautoclaim(
|
||||||
|
name=config.stream_completion_name,
|
||||||
|
groupname=config.stream_consumer_group,
|
||||||
|
consumername=self._consumer_name,
|
||||||
|
min_idle_time=config.stream_claim_min_idle_ms,
|
||||||
|
start_id="0-0",
|
||||||
|
count=10,
|
||||||
|
)
|
||||||
|
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
|
||||||
|
if claimed_result and len(claimed_result) >= 2:
|
||||||
|
claimed_entries = claimed_result[1]
|
||||||
|
if claimed_entries:
|
||||||
|
logger.info(
|
||||||
|
f"Claimed {len(claimed_entries)} stale pending messages"
|
||||||
|
)
|
||||||
|
for entry_id, data in claimed_entries:
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
await self._process_entry(redis, entry_id, data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
|
||||||
|
|
||||||
|
# Read new messages from the stream
|
||||||
|
messages = await redis.xreadgroup(
|
||||||
|
groupname=config.stream_consumer_group,
|
||||||
|
consumername=self._consumer_name,
|
||||||
|
streams={config.stream_completion_name: ">"},
|
||||||
|
block=block_timeout,
|
||||||
|
count=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for stream_name, entries in messages:
|
||||||
|
for entry_id, data in entries:
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
await self._process_entry(redis, entry_id, data)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Consumer cancelled")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
retry_count += 1
|
||||||
|
logger.error(
|
||||||
|
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
if self._running and retry_count < max_retries:
|
||||||
|
await asyncio.sleep(retry_delay)
|
||||||
|
else:
|
||||||
|
logger.error("Max retries reached, stopping consumer")
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _process_entry(
|
||||||
|
self, redis: Any, entry_id: str, data: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Process a single stream entry and acknowledge it on success.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis: Redis client connection
|
||||||
|
entry_id: The stream entry ID
|
||||||
|
data: The entry data dict
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Handle the message
|
||||||
|
message_data = data.get("data")
|
||||||
|
if message_data:
|
||||||
|
await self._handle_message(
|
||||||
|
message_data.encode()
|
||||||
|
if isinstance(message_data, str)
|
||||||
|
else message_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Acknowledge the message after successful processing
|
||||||
|
await redis.xack(
|
||||||
|
config.stream_completion_name,
|
||||||
|
config.stream_consumer_group,
|
||||||
|
entry_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error processing completion message {entry_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# Message remains in pending state and will be claimed by
|
||||||
|
# XAUTOCLAIM after min_idle_time expires
|
||||||
|
|
||||||
|
async def _handle_message(self, body: bytes) -> None:
|
||||||
|
"""Handle a completion message using our own Prisma client."""
|
||||||
|
try:
|
||||||
|
data = orjson.loads(body)
|
||||||
|
message = OperationCompleteMessage(**data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse completion message: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
||||||
|
f"(task_id={message.task_id}, success={message.success})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find task in registry
|
||||||
|
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
||||||
|
if task is None:
|
||||||
|
task = await stream_registry.get_task(message.task_id)
|
||||||
|
|
||||||
|
if task is None:
|
||||||
|
logger.warning(
|
||||||
|
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
||||||
|
f"(task_id={message.task_id})"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
||||||
|
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Guard against empty task fields
|
||||||
|
if not task.task_id or not task.session_id or not task.tool_call_id:
|
||||||
|
logger.error(
|
||||||
|
f"[COMPLETION] Task has empty critical fields! "
|
||||||
|
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
||||||
|
f"tool_call_id={task.tool_call_id!r}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if message.success:
|
||||||
|
await self._handle_success(task, message)
|
||||||
|
else:
|
||||||
|
await self._handle_failure(task, message)
|
||||||
|
|
||||||
|
async def _handle_success(
|
||||||
|
self,
|
||||||
|
task: stream_registry.ActiveTask,
|
||||||
|
message: OperationCompleteMessage,
|
||||||
|
) -> None:
|
||||||
|
"""Handle successful operation completion."""
|
||||||
|
prisma = await self._ensure_prisma()
|
||||||
|
await process_operation_success(task, message.result, prisma)
|
||||||
|
|
||||||
|
async def _handle_failure(
|
||||||
|
self,
|
||||||
|
task: stream_registry.ActiveTask,
|
||||||
|
message: OperationCompleteMessage,
|
||||||
|
) -> None:
|
||||||
|
"""Handle failed operation completion."""
|
||||||
|
prisma = await self._ensure_prisma()
|
||||||
|
await process_operation_failure(task, message.error, prisma)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level consumer instance
|
||||||
|
_consumer: ChatCompletionConsumer | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def start_completion_consumer() -> None:
|
||||||
|
"""Start the global completion consumer."""
|
||||||
|
global _consumer
|
||||||
|
if _consumer is None:
|
||||||
|
_consumer = ChatCompletionConsumer()
|
||||||
|
await _consumer.start()
|
||||||
|
|
||||||
|
|
||||||
|
async def stop_completion_consumer() -> None:
|
||||||
|
"""Stop the global completion consumer."""
|
||||||
|
global _consumer
|
||||||
|
if _consumer:
|
||||||
|
await _consumer.stop()
|
||||||
|
_consumer = None
|
||||||
|
|
||||||
|
|
||||||
|
async def publish_operation_complete(
|
||||||
|
operation_id: str,
|
||||||
|
task_id: str,
|
||||||
|
success: bool,
|
||||||
|
result: dict | str | None = None,
|
||||||
|
error: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Publish an operation completion message to Redis Streams.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_id: The operation ID that completed.
|
||||||
|
task_id: The task ID associated with the operation.
|
||||||
|
success: Whether the operation succeeded.
|
||||||
|
result: The result data (for success).
|
||||||
|
error: The error message (for failure).
|
||||||
|
"""
|
||||||
|
message = OperationCompleteMessage(
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
success=success,
|
||||||
|
result=result,
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
|
||||||
|
redis = await get_redis_async()
|
||||||
|
await redis.xadd(
|
||||||
|
config.stream_completion_name,
|
||||||
|
{"data": message.model_dump_json()},
|
||||||
|
maxlen=config.stream_max_length,
|
||||||
|
)
|
||||||
|
logger.info(f"Published completion for operation {operation_id}")
|
||||||
@@ -0,0 +1,344 @@
|
|||||||
|
"""Shared completion handling for operation success and failure.
|
||||||
|
|
||||||
|
This module provides common logic for handling operation completion from both:
|
||||||
|
- The Redis Streams consumer (completion_consumer.py)
|
||||||
|
- The HTTP webhook endpoint (routes.py)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
from prisma import Prisma
|
||||||
|
|
||||||
|
from . import service as chat_service
|
||||||
|
from . import stream_registry
|
||||||
|
from .response_model import StreamError, StreamToolOutputAvailable
|
||||||
|
from .tools.models import ErrorResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Tools that produce agent_json that needs to be saved to library
|
||||||
|
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
||||||
|
|
||||||
|
# Keys that should be stripped from agent_json when returning in error responses
|
||||||
|
SENSITIVE_KEYS = frozenset(
|
||||||
|
{
|
||||||
|
"api_key",
|
||||||
|
"apikey",
|
||||||
|
"api_secret",
|
||||||
|
"password",
|
||||||
|
"secret",
|
||||||
|
"credentials",
|
||||||
|
"credential",
|
||||||
|
"token",
|
||||||
|
"access_token",
|
||||||
|
"refresh_token",
|
||||||
|
"private_key",
|
||||||
|
"privatekey",
|
||||||
|
"auth",
|
||||||
|
"authorization",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_agent_json(obj: Any) -> Any:
|
||||||
|
"""Recursively sanitize agent_json by removing sensitive keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object to sanitize (dict, list, or primitive)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized copy with sensitive keys removed/redacted
|
||||||
|
"""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {
|
||||||
|
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
|
||||||
|
for k, v in obj.items()
|
||||||
|
}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [_sanitize_agent_json(item) for item in obj]
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class ToolMessageUpdateError(Exception):
|
||||||
|
"""Raised when updating a tool message in the database fails."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_tool_message(
|
||||||
|
session_id: str,
|
||||||
|
tool_call_id: str,
|
||||||
|
content: str,
|
||||||
|
prisma_client: Prisma | None,
|
||||||
|
) -> None:
|
||||||
|
"""Update tool message in database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session ID
|
||||||
|
tool_call_id: The tool call ID to update
|
||||||
|
content: The new content for the message
|
||||||
|
prisma_client: Optional Prisma client. If None, uses chat_service.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ToolMessageUpdateError: If the database update fails. The caller should
|
||||||
|
handle this to avoid marking the task as completed with inconsistent state.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if prisma_client:
|
||||||
|
# Use provided Prisma client (for consumer with its own connection)
|
||||||
|
updated_count = await prisma_client.chatmessage.update_many(
|
||||||
|
where={
|
||||||
|
"sessionId": session_id,
|
||||||
|
"toolCallId": tool_call_id,
|
||||||
|
},
|
||||||
|
data={"content": content},
|
||||||
|
)
|
||||||
|
# Check if any rows were updated - 0 means message not found
|
||||||
|
if updated_count == 0:
|
||||||
|
raise ToolMessageUpdateError(
|
||||||
|
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use service function (for webhook endpoint)
|
||||||
|
await chat_service._update_pending_operation(
|
||||||
|
session_id=session_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
result=content,
|
||||||
|
)
|
||||||
|
except ToolMessageUpdateError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
||||||
|
raise ToolMessageUpdateError(
|
||||||
|
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
|
||||||
|
"""Serialize result to JSON string with sensible defaults.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: The result to serialize. Can be a dict, list, string,
|
||||||
|
number, boolean, or None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON string representation of the result. Returns '{"status": "completed"}'
|
||||||
|
only when result is explicitly None.
|
||||||
|
"""
|
||||||
|
if isinstance(result, str):
|
||||||
|
return result
|
||||||
|
if result is None:
|
||||||
|
return '{"status": "completed"}'
|
||||||
|
return orjson.dumps(result).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
async def _save_agent_from_result(
|
||||||
|
result: dict[str, Any],
|
||||||
|
user_id: str | None,
|
||||||
|
tool_name: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Save agent to library if result contains agent_json.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: The result dict that may contain agent_json
|
||||||
|
user_id: The user ID to save the agent for
|
||||||
|
tool_name: The tool name (create_agent or edit_agent)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated result dict with saved agent details, or original result if no agent_json
|
||||||
|
"""
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
|
||||||
|
return result
|
||||||
|
|
||||||
|
agent_json = result.get("agent_json")
|
||||||
|
if not agent_json:
|
||||||
|
logger.warning(
|
||||||
|
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .tools.agent_generator import save_agent_to_library
|
||||||
|
|
||||||
|
is_update = tool_name == "edit_agent"
|
||||||
|
created_graph, library_agent = await save_agent_to_library(
|
||||||
|
agent_json, user_id, is_update=is_update
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
||||||
|
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return a response similar to AgentSavedResponse
|
||||||
|
return {
|
||||||
|
"type": "agent_saved",
|
||||||
|
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
||||||
|
"agent_id": created_graph.id,
|
||||||
|
"agent_name": created_graph.name,
|
||||||
|
"library_agent_id": library_agent.id,
|
||||||
|
"library_agent_link": f"/library/agents/{library_agent.id}",
|
||||||
|
"agent_page_link": f"/build?flowID={created_graph.id}",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[COMPLETION] Failed to save agent to library: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# Return error but don't fail the whole operation
|
||||||
|
# Sanitize agent_json to remove sensitive keys before returning
|
||||||
|
return {
|
||||||
|
"type": "error",
|
||||||
|
"message": f"Agent was generated but failed to save: {str(e)}",
|
||||||
|
"error": str(e),
|
||||||
|
"agent_json": _sanitize_agent_json(agent_json),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def process_operation_success(
|
||||||
|
task: stream_registry.ActiveTask,
|
||||||
|
result: dict | str | None,
|
||||||
|
prisma_client: Prisma | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle successful operation completion.
|
||||||
|
|
||||||
|
Publishes the result to the stream registry, updates the database,
|
||||||
|
generates LLM continuation, and marks the task as completed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The active task that completed
|
||||||
|
result: The result data from the operation
|
||||||
|
prisma_client: Optional Prisma client for database operations.
|
||||||
|
If None, uses chat_service._update_pending_operation instead.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ToolMessageUpdateError: If the database update fails. The task will be
|
||||||
|
marked as failed instead of completed to avoid inconsistent state.
|
||||||
|
"""
|
||||||
|
# For agent generation tools, save the agent to library
|
||||||
|
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
||||||
|
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
||||||
|
|
||||||
|
# Serialize result for output (only substitute default when result is exactly None)
|
||||||
|
result_output = result if result is not None else {"status": "completed"}
|
||||||
|
output_str = (
|
||||||
|
result_output
|
||||||
|
if isinstance(result_output, str)
|
||||||
|
else orjson.dumps(result_output).decode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Publish result to stream registry
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task.task_id,
|
||||||
|
StreamToolOutputAvailable(
|
||||||
|
toolCallId=task.tool_call_id,
|
||||||
|
toolName=task.tool_name,
|
||||||
|
output=output_str,
|
||||||
|
success=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update pending operation in database
|
||||||
|
# If this fails, we must not continue to mark the task as completed
|
||||||
|
result_str = serialize_result(result)
|
||||||
|
try:
|
||||||
|
await _update_tool_message(
|
||||||
|
session_id=task.session_id,
|
||||||
|
tool_call_id=task.tool_call_id,
|
||||||
|
content=result_str,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
)
|
||||||
|
except ToolMessageUpdateError:
|
||||||
|
# DB update failed - mark task as failed to avoid inconsistent state
|
||||||
|
logger.error(
|
||||||
|
f"[COMPLETION] DB update failed for task {task.task_id}, "
|
||||||
|
"marking as failed instead of completed"
|
||||||
|
)
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task.task_id,
|
||||||
|
StreamError(errorText="Failed to save operation result to database"),
|
||||||
|
)
|
||||||
|
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Generate LLM continuation with streaming
|
||||||
|
try:
|
||||||
|
await chat_service._generate_llm_continuation_with_streaming(
|
||||||
|
session_id=task.session_id,
|
||||||
|
user_id=task.user_id,
|
||||||
|
task_id=task.task_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark task as completed and release Redis lock
|
||||||
|
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||||
|
try:
|
||||||
|
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def process_operation_failure(
|
||||||
|
task: stream_registry.ActiveTask,
|
||||||
|
error: str | None,
|
||||||
|
prisma_client: Prisma | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle failed operation completion.
|
||||||
|
|
||||||
|
Publishes the error to the stream registry, updates the database with
|
||||||
|
the error response, and marks the task as failed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The active task that failed
|
||||||
|
error: The error message from the operation
|
||||||
|
prisma_client: Optional Prisma client for database operations.
|
||||||
|
If None, uses chat_service._update_pending_operation instead.
|
||||||
|
"""
|
||||||
|
error_msg = error or "Operation failed"
|
||||||
|
|
||||||
|
# Publish error to stream registry
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task.task_id,
|
||||||
|
StreamError(errorText=error_msg),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update pending operation with error
|
||||||
|
# If this fails, we still continue to mark the task as failed
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message=error_msg,
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await _update_tool_message(
|
||||||
|
session_id=task.session_id,
|
||||||
|
tool_call_id=task.tool_call_id,
|
||||||
|
content=error_response.model_dump_json(),
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
)
|
||||||
|
except ToolMessageUpdateError:
|
||||||
|
# DB update failed - log but continue with cleanup
|
||||||
|
logger.error(
|
||||||
|
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
|
||||||
|
"continuing with cleanup"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark task as failed and release Redis lock
|
||||||
|
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||||
|
try:
|
||||||
|
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||||
|
|
||||||
|
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
||||||
@@ -11,7 +11,7 @@ class ChatConfig(BaseSettings):
|
|||||||
|
|
||||||
# OpenAI API Configuration
|
# OpenAI API Configuration
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="anthropic/claude-opus-4.5", description="Default model to use"
|
default="anthropic/claude-opus-4.6", description="Default model to use"
|
||||||
)
|
)
|
||||||
title_model: str = Field(
|
title_model: str = Field(
|
||||||
default="openai/gpt-4o-mini",
|
default="openai/gpt-4o-mini",
|
||||||
@@ -44,6 +44,48 @@ class ChatConfig(BaseSettings):
|
|||||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Stream registry configuration for SSE reconnection
|
||||||
|
stream_ttl: int = Field(
|
||||||
|
default=3600,
|
||||||
|
description="TTL in seconds for stream data in Redis (1 hour)",
|
||||||
|
)
|
||||||
|
stream_max_length: int = Field(
|
||||||
|
default=10000,
|
||||||
|
description="Maximum number of messages to store per stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redis Streams configuration for completion consumer
|
||||||
|
stream_completion_name: str = Field(
|
||||||
|
default="chat:completions",
|
||||||
|
description="Redis Stream name for operation completions",
|
||||||
|
)
|
||||||
|
stream_consumer_group: str = Field(
|
||||||
|
default="chat_consumers",
|
||||||
|
description="Consumer group name for completion stream",
|
||||||
|
)
|
||||||
|
stream_claim_min_idle_ms: int = Field(
|
||||||
|
default=60000,
|
||||||
|
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redis key prefixes for stream registry
|
||||||
|
task_meta_prefix: str = Field(
|
||||||
|
default="chat:task:meta:",
|
||||||
|
description="Prefix for task metadata hash keys",
|
||||||
|
)
|
||||||
|
task_stream_prefix: str = Field(
|
||||||
|
default="chat:stream:",
|
||||||
|
description="Prefix for task message stream keys",
|
||||||
|
)
|
||||||
|
task_op_prefix: str = Field(
|
||||||
|
default="chat:task:op:",
|
||||||
|
description="Prefix for operation ID to task ID mapping keys",
|
||||||
|
)
|
||||||
|
internal_api_key: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
# Langfuse Prompt Management Configuration
|
# Langfuse Prompt Management Configuration
|
||||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||||
langfuse_prompt_name: str = Field(
|
langfuse_prompt_name: str = Field(
|
||||||
@@ -82,6 +124,14 @@ class ChatConfig(BaseSettings):
|
|||||||
v = "https://openrouter.ai/api/v1"
|
v = "https://openrouter.ai/api/v1"
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@field_validator("internal_api_key", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def get_internal_api_key(cls, v):
|
||||||
|
"""Get internal API key from environment if not provided."""
|
||||||
|
if v is None:
|
||||||
|
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||||
|
return v
|
||||||
|
|
||||||
# Prompt paths for different contexts
|
# Prompt paths for different contexts
|
||||||
PROMPT_PATHS: dict[str, str] = {
|
PROMPT_PATHS: dict[str, str] = {
|
||||||
"default": "prompts/chat_system.md",
|
"default": "prompts/chat_system.md",
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ class StreamStart(StreamBaseResponse):
|
|||||||
|
|
||||||
type: ResponseType = ResponseType.START
|
type: ResponseType = ResponseType.START
|
||||||
messageId: str = Field(..., description="Unique message ID")
|
messageId: str = Field(..., description="Unique message ID")
|
||||||
|
taskId: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StreamFinish(StreamBaseResponse):
|
class StreamFinish(StreamBaseResponse):
|
||||||
|
|||||||
@@ -1,19 +1,23 @@
|
|||||||
"""Chat API routes for chat session management and streaming via SSE."""
|
"""Chat API routes for chat session management and streaming via SSE."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import uuid as uuid_module
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from autogpt_libs import auth
|
from autogpt_libs import auth
|
||||||
from fastapi import APIRouter, Depends, Query, Security
|
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
from . import service as chat_service
|
from . import service as chat_service
|
||||||
|
from . import stream_registry
|
||||||
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||||
|
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -55,6 +59,15 @@ class CreateSessionResponse(BaseModel):
|
|||||||
user_id: str | None
|
user_id: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveStreamInfo(BaseModel):
|
||||||
|
"""Information about an active stream for reconnection."""
|
||||||
|
|
||||||
|
task_id: str
|
||||||
|
last_message_id: str # Redis Stream message ID for resumption
|
||||||
|
operation_id: str # Operation ID for completion tracking
|
||||||
|
tool_name: str # Name of the tool being executed
|
||||||
|
|
||||||
|
|
||||||
class SessionDetailResponse(BaseModel):
|
class SessionDetailResponse(BaseModel):
|
||||||
"""Response model providing complete details for a chat session, including messages."""
|
"""Response model providing complete details for a chat session, including messages."""
|
||||||
|
|
||||||
@@ -63,6 +76,7 @@ class SessionDetailResponse(BaseModel):
|
|||||||
updated_at: str
|
updated_at: str
|
||||||
user_id: str | None
|
user_id: str | None
|
||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
|
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||||
|
|
||||||
|
|
||||||
class SessionSummaryResponse(BaseModel):
|
class SessionSummaryResponse(BaseModel):
|
||||||
@@ -81,6 +95,14 @@ class ListSessionsResponse(BaseModel):
|
|||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class OperationCompleteRequest(BaseModel):
|
||||||
|
"""Request model for external completion webhook."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
result: dict | str | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
# ========== Routes ==========
|
# ========== Routes ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -166,13 +188,14 @@ async def get_session(
|
|||||||
Retrieve the details of a specific chat session.
|
Retrieve the details of a specific chat session.
|
||||||
|
|
||||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||||
|
If there's an active stream for this session, returns the task_id for reconnection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The unique identifier for the desired chat session.
|
session_id: The unique identifier for the desired chat session.
|
||||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SessionDetailResponse: Details for the requested session, or None if not found.
|
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
@@ -180,10 +203,27 @@ async def get_session(
|
|||||||
raise NotFoundError(f"Session {session_id} not found.")
|
raise NotFoundError(f"Session {session_id} not found.")
|
||||||
|
|
||||||
messages = [message.model_dump() for message in session.messages]
|
messages = [message.model_dump() for message in session.messages]
|
||||||
logger.info(
|
|
||||||
f"Returning session {session_id}: "
|
# Check if there's an active stream for this session
|
||||||
f"message_count={len(messages)}, "
|
active_stream_info = None
|
||||||
f"roles={[m.get('role') for m in messages]}"
|
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||||
|
session_id, user_id
|
||||||
|
)
|
||||||
|
if active_task:
|
||||||
|
# Filter out the in-progress assistant message from the session response.
|
||||||
|
# The client will receive the complete assistant response through the SSE
|
||||||
|
# stream replay instead, preventing duplicate content.
|
||||||
|
if messages and messages[-1].get("role") == "assistant":
|
||||||
|
messages = messages[:-1]
|
||||||
|
|
||||||
|
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
||||||
|
# Since we filtered out the cached assistant message, the client needs
|
||||||
|
# the full stream to reconstruct the response.
|
||||||
|
active_stream_info = ActiveStreamInfo(
|
||||||
|
task_id=active_task.task_id,
|
||||||
|
last_message_id="0-0",
|
||||||
|
operation_id=active_task.operation_id,
|
||||||
|
tool_name=active_task.tool_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
return SessionDetailResponse(
|
return SessionDetailResponse(
|
||||||
@@ -192,6 +232,7 @@ async def get_session(
|
|||||||
updated_at=session.updated_at.isoformat(),
|
updated_at=session.updated_at.isoformat(),
|
||||||
user_id=session.user_id or None,
|
user_id=session.user_id or None,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
active_stream=active_stream_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -211,19 +252,42 @@ async def stream_chat_post(
|
|||||||
- Tool call UI elements (if invoked)
|
- Tool call UI elements (if invoked)
|
||||||
- Tool execution results
|
- Tool execution results
|
||||||
|
|
||||||
|
The AI generation runs in a background task that continues even if the client disconnects.
|
||||||
|
All chunks are written to Redis for reconnection support. If the client disconnects,
|
||||||
|
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier to associate with the streamed messages.
|
session_id: The chat session identifier to associate with the streamed messages.
|
||||||
request: Request body containing message, is_user_message, and optional context.
|
request: Request body containing message, is_user_message, and optional context.
|
||||||
user_id: Optional authenticated user ID.
|
user_id: Optional authenticated user ID.
|
||||||
Returns:
|
Returns:
|
||||||
StreamingResponse: SSE-formatted response chunks.
|
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
||||||
|
containing the task_id for reconnection.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
# Create a task in the stream registry for reconnection support
|
||||||
chunk_count = 0
|
task_id = str(uuid_module.uuid4())
|
||||||
first_chunk_type: str | None = None
|
operation_id = str(uuid_module.uuid4())
|
||||||
|
await stream_registry.create_task(
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
||||||
|
tool_name="chat",
|
||||||
|
operation_id=operation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Background task that runs the AI generation independently of SSE connection
|
||||||
|
async def run_ai_generation():
|
||||||
|
try:
|
||||||
|
# Emit a start event with task_id for reconnection
|
||||||
|
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||||
|
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||||
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
session_id,
|
session_id,
|
||||||
request.message,
|
request.message,
|
||||||
@@ -232,27 +296,67 @@ async def stream_chat_post(
|
|||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
context=request.context,
|
context=request.context,
|
||||||
):
|
):
|
||||||
if chunk_count < 3:
|
# Write to Redis (subscribers will receive via XREAD)
|
||||||
logger.info(
|
await stream_registry.publish_chunk(task_id, chunk)
|
||||||
"Chat stream chunk",
|
|
||||||
extra={
|
# Mark task as completed
|
||||||
"session_id": session_id,
|
await stream_registry.mark_task_completed(task_id, "completed")
|
||||||
"chunk_type": str(chunk.type),
|
except Exception as e:
|
||||||
},
|
logger.error(
|
||||||
|
f"Error in background AI generation for session {session_id}: {e}"
|
||||||
)
|
)
|
||||||
if not first_chunk_type:
|
await stream_registry.mark_task_completed(task_id, "failed")
|
||||||
first_chunk_type = str(chunk.type)
|
|
||||||
chunk_count += 1
|
# Start the AI generation in a background task
|
||||||
|
bg_task = asyncio.create_task(run_ai_generation())
|
||||||
|
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||||
|
|
||||||
|
# SSE endpoint that subscribes to the task's stream
|
||||||
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
subscriber_queue = None
|
||||||
|
try:
|
||||||
|
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||||
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
|
task_id=task_id,
|
||||||
|
user_id=user_id,
|
||||||
|
last_message_id="0-0", # Get all messages from the beginning
|
||||||
|
)
|
||||||
|
|
||||||
|
if subscriber_queue is None:
|
||||||
|
yield StreamFinish().to_sse()
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# Read from the subscriber queue and yield to SSE
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
yield chunk.to_sse()
|
yield chunk.to_sse()
|
||||||
logger.info(
|
|
||||||
"Chat stream completed",
|
# Check for finish signal
|
||||||
extra={
|
if isinstance(chunk, StreamFinish):
|
||||||
"session_id": session_id,
|
break
|
||||||
"chunk_count": chunk_count,
|
except asyncio.TimeoutError:
|
||||||
"first_chunk_type": first_chunk_type,
|
# Send heartbeat to keep connection alive
|
||||||
},
|
yield StreamHeartbeat().to_sse()
|
||||||
|
|
||||||
|
except GeneratorExit:
|
||||||
|
pass # Client disconnected - background task continues
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
||||||
|
finally:
|
||||||
|
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||||
|
if subscriber_queue is not None:
|
||||||
|
try:
|
||||||
|
await stream_registry.unsubscribe_from_task(
|
||||||
|
task_id, subscriber_queue
|
||||||
)
|
)
|
||||||
# AI SDK protocol termination
|
except Exception as unsub_err:
|
||||||
|
logger.error(
|
||||||
|
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
@@ -366,6 +470,251 @@ async def session_assign_user(
|
|||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Task Streaming (SSE Reconnection) ==========
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/tasks/{task_id}/stream",
|
||||||
|
)
|
||||||
|
async def stream_task(
|
||||||
|
task_id: str,
|
||||||
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
|
last_message_id: str = Query(
|
||||||
|
default="0-0",
|
||||||
|
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Reconnect to a long-running task's SSE stream.
|
||||||
|
|
||||||
|
When a long-running operation (like agent generation) starts, the client
|
||||||
|
receives a task_id. If the connection drops, the client can reconnect
|
||||||
|
using this endpoint to resume receiving updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The task ID from the operation_started response.
|
||||||
|
user_id: Authenticated user ID for ownership validation.
|
||||||
|
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
||||||
|
"""
|
||||||
|
# Check task existence and expiry before subscribing
|
||||||
|
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
||||||
|
|
||||||
|
if error_code == "TASK_EXPIRED":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=410,
|
||||||
|
detail={
|
||||||
|
"code": "TASK_EXPIRED",
|
||||||
|
"message": "This operation has expired. Please try again.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if error_code == "TASK_NOT_FOUND":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"code": "TASK_NOT_FOUND",
|
||||||
|
"message": f"Task {task_id} not found.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate ownership if task has an owner
|
||||||
|
if task and task.user_id and user_id != task.user_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={
|
||||||
|
"code": "ACCESS_DENIED",
|
||||||
|
"message": "You do not have access to this task.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get subscriber queue from stream registry
|
||||||
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
|
task_id=task_id,
|
||||||
|
user_id=user_id,
|
||||||
|
last_message_id=last_message_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if subscriber_queue is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"code": "TASK_NOT_FOUND",
|
||||||
|
"message": f"Task {task_id} not found or access denied.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Wait for next chunk with timeout for heartbeats
|
||||||
|
chunk = await asyncio.wait_for(
|
||||||
|
subscriber_queue.get(), timeout=heartbeat_interval
|
||||||
|
)
|
||||||
|
yield chunk.to_sse()
|
||||||
|
|
||||||
|
# Check for finish signal
|
||||||
|
if isinstance(chunk, StreamFinish):
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Send heartbeat to keep connection alive
|
||||||
|
yield StreamHeartbeat().to_sse()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
# Unsubscribe when client disconnects or stream ends
|
||||||
|
try:
|
||||||
|
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||||
|
except Exception as unsub_err:
|
||||||
|
logger.error(
|
||||||
|
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"x-vercel-ai-ui-message-stream": "v1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/tasks/{task_id}",
|
||||||
|
)
|
||||||
|
async def get_task_status(
|
||||||
|
task_id: str,
|
||||||
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Get the status of a long-running task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The task ID to check.
|
||||||
|
user_id: Authenticated user ID for ownership validation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Task status including task_id, status, tool_name, and operation_id.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If task_id is not found or user doesn't have access.
|
||||||
|
"""
|
||||||
|
task = await stream_registry.get_task(task_id)
|
||||||
|
|
||||||
|
if task is None:
|
||||||
|
raise NotFoundError(f"Task {task_id} not found.")
|
||||||
|
|
||||||
|
# Validate ownership - if task has an owner, requester must match
|
||||||
|
if task.user_id and user_id != task.user_id:
|
||||||
|
raise NotFoundError(f"Task {task_id} not found.")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task_id": task.task_id,
|
||||||
|
"session_id": task.session_id,
|
||||||
|
"status": task.status,
|
||||||
|
"tool_name": task.tool_name,
|
||||||
|
"operation_id": task.operation_id,
|
||||||
|
"created_at": task.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== External Completion Webhook ==========
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/operations/{operation_id}/complete",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
async def complete_operation(
|
||||||
|
operation_id: str,
|
||||||
|
request: OperationCompleteRequest,
|
||||||
|
x_api_key: str | None = Header(default=None),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
External completion webhook for long-running operations.
|
||||||
|
|
||||||
|
Called by Agent Generator (or other services) when an operation completes.
|
||||||
|
This triggers the stream registry to publish completion and continue LLM generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_id: The operation ID to complete.
|
||||||
|
request: Completion payload with success status and result/error.
|
||||||
|
x_api_key: Internal API key for authentication.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Status of the completion.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If API key is invalid or operation not found.
|
||||||
|
"""
|
||||||
|
# Validate internal API key - reject if not configured or invalid
|
||||||
|
if not config.internal_api_key:
|
||||||
|
logger.error(
|
||||||
|
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail="Webhook not available: internal API key not configured",
|
||||||
|
)
|
||||||
|
if x_api_key != config.internal_api_key:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||||
|
|
||||||
|
# Find task by operation_id
|
||||||
|
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||||
|
if task is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Operation {operation_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Received completion webhook for operation {operation_id} "
|
||||||
|
f"(task_id={task.task_id}, success={request.success})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.success:
|
||||||
|
await process_operation_success(task, request.result)
|
||||||
|
else:
|
||||||
|
await process_operation_failure(task, request.error)
|
||||||
|
|
||||||
|
return {"status": "ok", "task_id": task.task_id}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Configuration ==========
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/ttl", status_code=200)
|
||||||
|
async def get_ttl_config() -> dict:
|
||||||
|
"""
|
||||||
|
Get the stream TTL configuration.
|
||||||
|
|
||||||
|
Returns the Time-To-Live settings for chat streams, which determines
|
||||||
|
how long clients can reconnect to an active stream.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: TTL configuration with seconds and milliseconds values.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"stream_ttl_seconds": config.stream_ttl,
|
||||||
|
"stream_ttl_ms": config.stream_ttl * 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ========== Health Check ==========
|
# ========== Health Check ==========
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,704 @@
|
|||||||
|
"""Stream registry for managing reconnectable SSE streams.
|
||||||
|
|
||||||
|
This module provides a registry for tracking active streaming tasks and their
|
||||||
|
messages. It uses Redis for all state management (no in-memory state), making
|
||||||
|
pods stateless and horizontally scalable.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
- Redis Stream: Persists all messages for replay and real-time delivery
|
||||||
|
- Redis Hash: Task metadata (status, session_id, etc.)
|
||||||
|
|
||||||
|
Subscribers:
|
||||||
|
1. Replay missed messages from Redis Stream (XREAD)
|
||||||
|
2. Listen for live updates via blocking XREAD
|
||||||
|
3. No in-memory state required on the subscribing pod
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
|
|
||||||
|
from .config import ChatConfig
|
||||||
|
from .response_model import StreamBaseResponse, StreamError, StreamFinish
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
config = ChatConfig()
|
||||||
|
|
||||||
|
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
||||||
|
_local_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
|
# Track listener tasks per subscriber queue for cleanup
|
||||||
|
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
|
||||||
|
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
|
||||||
|
|
||||||
|
# Timeout for putting chunks into subscriber queues (seconds)
|
||||||
|
# If the queue is full and doesn't drain within this time, send an overflow error
|
||||||
|
QUEUE_PUT_TIMEOUT = 5.0
|
||||||
|
|
||||||
|
# Lua script for atomic compare-and-swap status update (idempotent completion)
|
||||||
|
# Returns 1 if status was updated, 0 if already completed/failed
|
||||||
|
COMPLETE_TASK_SCRIPT = """
|
||||||
|
local current = redis.call("HGET", KEYS[1], "status")
|
||||||
|
if current == "running" then
|
||||||
|
redis.call("HSET", KEYS[1], "status", ARGV[1])
|
||||||
|
return 1
|
||||||
|
end
|
||||||
|
return 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActiveTask:
|
||||||
|
"""Represents an active streaming task (metadata only, no in-memory queues)."""
|
||||||
|
|
||||||
|
task_id: str
|
||||||
|
session_id: str
|
||||||
|
user_id: str | None
|
||||||
|
tool_call_id: str
|
||||||
|
tool_name: str
|
||||||
|
operation_id: str
|
||||||
|
status: Literal["running", "completed", "failed"] = "running"
|
||||||
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
asyncio_task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_task_meta_key(task_id: str) -> str:
|
||||||
|
"""Get Redis key for task metadata."""
|
||||||
|
return f"{config.task_meta_prefix}{task_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_task_stream_key(task_id: str) -> str:
|
||||||
|
"""Get Redis key for task message stream."""
|
||||||
|
return f"{config.task_stream_prefix}{task_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_operation_mapping_key(operation_id: str) -> str:
|
||||||
|
"""Get Redis key for operation_id to task_id mapping."""
|
||||||
|
return f"{config.task_op_prefix}{operation_id}"
|
||||||
|
|
||||||
|
|
||||||
|
async def create_task(
|
||||||
|
task_id: str,
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
tool_call_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
operation_id: str,
|
||||||
|
) -> ActiveTask:
|
||||||
|
"""Create a new streaming task in Redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Unique identifier for the task
|
||||||
|
session_id: Chat session ID
|
||||||
|
user_id: User ID (may be None for anonymous)
|
||||||
|
tool_call_id: Tool call ID from the LLM
|
||||||
|
tool_name: Name of the tool being executed
|
||||||
|
operation_id: Operation ID for webhook callbacks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created ActiveTask instance (metadata only)
|
||||||
|
"""
|
||||||
|
task = ActiveTask(
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
operation_id=operation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store metadata in Redis
|
||||||
|
redis = await get_redis_async()
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
op_key = _get_operation_mapping_key(operation_id)
|
||||||
|
|
||||||
|
await redis.hset( # type: ignore[misc]
|
||||||
|
meta_key,
|
||||||
|
mapping={
|
||||||
|
"task_id": task_id,
|
||||||
|
"session_id": session_id,
|
||||||
|
"user_id": user_id or "",
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"operation_id": operation_id,
|
||||||
|
"status": task.status,
|
||||||
|
"created_at": task.created_at.isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await redis.expire(meta_key, config.stream_ttl)
|
||||||
|
|
||||||
|
# Create operation_id -> task_id mapping for webhook lookups
|
||||||
|
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||||
|
|
||||||
|
logger.debug(f"Created task {task_id} for session {session_id}")
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
async def publish_chunk(
|
||||||
|
task_id: str,
|
||||||
|
chunk: StreamBaseResponse,
|
||||||
|
) -> str:
|
||||||
|
"""Publish a chunk to Redis Stream.
|
||||||
|
|
||||||
|
All delivery is via Redis Streams - no in-memory state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to publish to
|
||||||
|
chunk: The stream response chunk to publish
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The Redis Stream message ID
|
||||||
|
"""
|
||||||
|
chunk_json = chunk.model_dump_json()
|
||||||
|
message_id = "0-0"
|
||||||
|
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
|
# Write to Redis Stream for persistence and real-time delivery
|
||||||
|
raw_id = await redis.xadd(
|
||||||
|
stream_key,
|
||||||
|
{"data": chunk_json},
|
||||||
|
maxlen=config.stream_max_length,
|
||||||
|
)
|
||||||
|
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||||
|
|
||||||
|
# Set TTL on stream to match task metadata TTL
|
||||||
|
await redis.expire(stream_key, config.stream_ttl)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to publish chunk for task {task_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return message_id
|
||||||
|
|
||||||
|
|
||||||
|
async def subscribe_to_task(
|
||||||
|
task_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
last_message_id: str = "0-0",
|
||||||
|
) -> asyncio.Queue[StreamBaseResponse] | None:
|
||||||
|
"""Subscribe to a task's stream with replay of missed messages.
|
||||||
|
|
||||||
|
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to subscribe to
|
||||||
|
user_id: User ID for ownership validation
|
||||||
|
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||||
|
or user doesn't have access
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
|
|
||||||
|
if not meta:
|
||||||
|
logger.debug(f"Task {task_id} not found in Redis")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||||
|
task_status = meta.get("status", "")
|
||||||
|
task_user_id = meta.get("user_id", "") or None
|
||||||
|
|
||||||
|
# Validate ownership - if task has an owner, requester must match
|
||||||
|
if task_user_id:
|
||||||
|
if user_id != task_user_id:
|
||||||
|
logger.warning(
|
||||||
|
f"User {user_id} denied access to task {task_id} "
|
||||||
|
f"owned by {task_user_id}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
|
||||||
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
|
# Step 1: Replay messages from Redis Stream
|
||||||
|
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||||
|
|
||||||
|
replayed_count = 0
|
||||||
|
replay_last_id = last_message_id
|
||||||
|
if messages:
|
||||||
|
for _stream_name, stream_messages in messages:
|
||||||
|
for msg_id, msg_data in stream_messages:
|
||||||
|
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||||
|
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||||
|
if "data" in msg_data:
|
||||||
|
try:
|
||||||
|
chunk_data = orjson.loads(msg_data["data"])
|
||||||
|
chunk = _reconstruct_chunk(chunk_data)
|
||||||
|
if chunk:
|
||||||
|
await subscriber_queue.put(chunk)
|
||||||
|
replayed_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to replay message: {e}")
|
||||||
|
|
||||||
|
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
||||||
|
|
||||||
|
# Step 2: If task is still running, start stream listener for live updates
|
||||||
|
if task_status == "running":
|
||||||
|
listener_task = asyncio.create_task(
|
||||||
|
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
||||||
|
)
|
||||||
|
# Track listener task for cleanup on unsubscribe
|
||||||
|
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||||
|
else:
|
||||||
|
# Task is completed/failed - add finish marker
|
||||||
|
await subscriber_queue.put(StreamFinish())
|
||||||
|
|
||||||
|
return subscriber_queue
|
||||||
|
|
||||||
|
|
||||||
|
async def _stream_listener(
|
||||||
|
task_id: str,
|
||||||
|
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||||
|
last_replayed_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||||
|
|
||||||
|
This approach avoids the duplicate message issue that can occur with pub/sub
|
||||||
|
when messages are published during the gap between replay and subscription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to listen for
|
||||||
|
subscriber_queue: Queue to deliver messages to
|
||||||
|
last_replayed_id: Last message ID from replay (continue from here)
|
||||||
|
"""
|
||||||
|
queue_id = id(subscriber_queue)
|
||||||
|
# Track the last successfully delivered message ID for recovery hints
|
||||||
|
last_delivered_id = last_replayed_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
redis = await get_redis_async()
|
||||||
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
current_id = last_replayed_id
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Block for up to 30 seconds waiting for new messages
|
||||||
|
# This allows periodic checking if task is still running
|
||||||
|
messages = await redis.xread(
|
||||||
|
{stream_key: current_id}, block=30000, count=100
|
||||||
|
)
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
# Timeout - check if task is still running
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
status = await redis.hget(meta_key, "status") # type: ignore[misc]
|
||||||
|
if status and status != "running":
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
subscriber_queue.put(StreamFinish()),
|
||||||
|
timeout=QUEUE_PUT_TIMEOUT,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
f"Timeout delivering finish event for task {task_id}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
for _stream_name, stream_messages in messages:
|
||||||
|
for msg_id, msg_data in stream_messages:
|
||||||
|
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||||
|
|
||||||
|
if "data" not in msg_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk_data = orjson.loads(msg_data["data"])
|
||||||
|
chunk = _reconstruct_chunk(chunk_data)
|
||||||
|
if chunk:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
subscriber_queue.put(chunk),
|
||||||
|
timeout=QUEUE_PUT_TIMEOUT,
|
||||||
|
)
|
||||||
|
# Update last delivered ID on successful delivery
|
||||||
|
last_delivered_id = current_id
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
f"Subscriber queue full for task {task_id}, "
|
||||||
|
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
||||||
|
)
|
||||||
|
# Send overflow error with recovery info
|
||||||
|
try:
|
||||||
|
overflow_error = StreamError(
|
||||||
|
errorText="Message delivery timeout - some messages may have been missed",
|
||||||
|
code="QUEUE_OVERFLOW",
|
||||||
|
details={
|
||||||
|
"last_delivered_id": last_delivered_id,
|
||||||
|
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
subscriber_queue.put_nowait(overflow_error)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
# Queue is completely stuck, nothing more we can do
|
||||||
|
logger.error(
|
||||||
|
f"Cannot deliver overflow error for task {task_id}, "
|
||||||
|
"queue completely blocked"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop listening on finish
|
||||||
|
if isinstance(chunk, StreamFinish):
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error processing stream message: {e}")
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug(f"Stream listener cancelled for task {task_id}")
|
||||||
|
raise # Re-raise to propagate cancellation
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Stream listener error for task {task_id}: {e}")
|
||||||
|
# On error, send finish to unblock subscriber
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
subscriber_queue.put(StreamFinish()),
|
||||||
|
timeout=QUEUE_PUT_TIMEOUT,
|
||||||
|
)
|
||||||
|
except (asyncio.TimeoutError, asyncio.QueueFull):
|
||||||
|
logger.warning(
|
||||||
|
f"Could not deliver finish event for task {task_id} after error"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Clean up listener task mapping on exit
|
||||||
|
_listener_tasks.pop(queue_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def mark_task_completed(
|
||||||
|
task_id: str,
|
||||||
|
status: Literal["completed", "failed"] = "completed",
|
||||||
|
) -> bool:
|
||||||
|
"""Mark a task as completed and publish finish event.
|
||||||
|
|
||||||
|
This is idempotent - calling multiple times with the same task_id is safe.
|
||||||
|
Uses atomic compare-and-swap via Lua script to prevent race conditions.
|
||||||
|
Status is updated first (source of truth), then finish event is published (best-effort).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to mark as completed
|
||||||
|
status: Final status ("completed" or "failed")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if task was newly marked completed, False if already completed/failed
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
|
||||||
|
# Atomic compare-and-swap: only update if status is "running"
|
||||||
|
# This prevents race conditions when multiple callers try to complete simultaneously
|
||||||
|
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
|
||||||
|
|
||||||
|
if result == 0:
|
||||||
|
logger.debug(f"Task {task_id} already completed/failed, skipping")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# THEN publish finish event (best-effort - listeners can detect via status polling)
|
||||||
|
try:
|
||||||
|
await publish_chunk(task_id, StreamFinish())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to publish finish event for task {task_id}: {e}. "
|
||||||
|
"Listeners will detect completion via status polling."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up local task reference if exists
|
||||||
|
_local_tasks.pop(task_id, None)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
|
||||||
|
"""Find a task by its operation ID.
|
||||||
|
|
||||||
|
Used by webhook callbacks to locate the task to update.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_id: Operation ID to search for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ActiveTask if found, None otherwise
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
op_key = _get_operation_mapping_key(operation_id)
|
||||||
|
task_id = await redis.get(op_key)
|
||||||
|
|
||||||
|
if not task_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
|
||||||
|
return await get_task(task_id_str)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_task(task_id: str) -> ActiveTask | None:
|
||||||
|
"""Get a task by its ID from Redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to look up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ActiveTask if found, None otherwise
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
|
|
||||||
|
if not meta:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||||
|
return ActiveTask(
|
||||||
|
task_id=meta.get("task_id", ""),
|
||||||
|
session_id=meta.get("session_id", ""),
|
||||||
|
user_id=meta.get("user_id", "") or None,
|
||||||
|
tool_call_id=meta.get("tool_call_id", ""),
|
||||||
|
tool_name=meta.get("tool_name", ""),
|
||||||
|
operation_id=meta.get("operation_id", ""),
|
||||||
|
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_task_with_expiry_info(
|
||||||
|
task_id: str,
|
||||||
|
) -> tuple[ActiveTask | None, str | None]:
|
||||||
|
"""Get a task by its ID with expiration detection.
|
||||||
|
|
||||||
|
Returns (task, error_code) where error_code is:
|
||||||
|
- None if task found
|
||||||
|
- "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired)
|
||||||
|
- "TASK_NOT_FOUND" if neither exists
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID to look up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (ActiveTask or None, error_code or None)
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
meta_key = _get_task_meta_key(task_id)
|
||||||
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
|
|
||||||
|
if not meta:
|
||||||
|
# Check if stream still has data (metadata expired but stream hasn't)
|
||||||
|
stream_len = await redis.xlen(stream_key)
|
||||||
|
if stream_len > 0:
|
||||||
|
return None, "TASK_EXPIRED"
|
||||||
|
return None, "TASK_NOT_FOUND"
|
||||||
|
|
||||||
|
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||||
|
return (
|
||||||
|
ActiveTask(
|
||||||
|
task_id=meta.get("task_id", ""),
|
||||||
|
session_id=meta.get("session_id", ""),
|
||||||
|
user_id=meta.get("user_id", "") or None,
|
||||||
|
tool_call_id=meta.get("tool_call_id", ""),
|
||||||
|
tool_name=meta.get("tool_name", ""),
|
||||||
|
operation_id=meta.get("operation_id", ""),
|
||||||
|
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_active_task_for_session(
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> tuple[ActiveTask | None, str]:
|
||||||
|
"""Get the active (running) task for a session, if any.
|
||||||
|
|
||||||
|
Scans Redis for tasks matching the session_id with status="running".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Session ID to look up
|
||||||
|
user_id: User ID for ownership validation (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
|
||||||
|
"""
|
||||||
|
|
||||||
|
redis = await get_redis_async()
|
||||||
|
|
||||||
|
# Scan Redis for task metadata keys
|
||||||
|
cursor = 0
|
||||||
|
tasks_checked = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
cursor, keys = await redis.scan(
|
||||||
|
cursor, match=f"{config.task_meta_prefix}*", count=100
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
tasks_checked += 1
|
||||||
|
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
|
||||||
|
if not meta:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||||
|
task_session_id = meta.get("session_id", "")
|
||||||
|
task_status = meta.get("status", "")
|
||||||
|
task_user_id = meta.get("user_id", "") or None
|
||||||
|
task_id = meta.get("task_id", "")
|
||||||
|
|
||||||
|
if task_session_id == session_id and task_status == "running":
|
||||||
|
# Validate ownership - if task has an owner, requester must match
|
||||||
|
if task_user_id and user_id != task_user_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the last message ID from Redis Stream
|
||||||
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
last_id = "0-0"
|
||||||
|
try:
|
||||||
|
messages = await redis.xrevrange(stream_key, count=1)
|
||||||
|
if messages:
|
||||||
|
msg_id = messages[0][0]
|
||||||
|
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get last message ID: {e}")
|
||||||
|
|
||||||
|
return (
|
||||||
|
ActiveTask(
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=task_session_id,
|
||||||
|
user_id=task_user_id,
|
||||||
|
tool_call_id=meta.get("tool_call_id", ""),
|
||||||
|
tool_name=meta.get("tool_name", ""),
|
||||||
|
operation_id=meta.get("operation_id", ""),
|
||||||
|
status="running",
|
||||||
|
),
|
||||||
|
last_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
return None, "0-0"
|
||||||
|
|
||||||
|
|
||||||
|
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||||
|
"""Reconstruct a StreamBaseResponse from JSON data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_data: Parsed JSON data from Redis
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reconstructed response object, or None if unknown type
|
||||||
|
"""
|
||||||
|
from .response_model import (
|
||||||
|
ResponseType,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamHeartbeat,
|
||||||
|
StreamStart,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamTextEnd,
|
||||||
|
StreamTextStart,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolInputStart,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
StreamUsage,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map response types to their corresponding classes
|
||||||
|
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
||||||
|
ResponseType.START.value: StreamStart,
|
||||||
|
ResponseType.FINISH.value: StreamFinish,
|
||||||
|
ResponseType.TEXT_START.value: StreamTextStart,
|
||||||
|
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
||||||
|
ResponseType.TEXT_END.value: StreamTextEnd,
|
||||||
|
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
|
||||||
|
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
|
||||||
|
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,
|
||||||
|
ResponseType.ERROR.value: StreamError,
|
||||||
|
ResponseType.USAGE.value: StreamUsage,
|
||||||
|
ResponseType.HEARTBEAT.value: StreamHeartbeat,
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk_type = chunk_data.get("type")
|
||||||
|
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
if chunk_class is None:
|
||||||
|
logger.warning(f"Unknown chunk type: {chunk_type}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return chunk_class(**chunk_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
|
||||||
|
"""Track the asyncio.Task for a task (local reference only).
|
||||||
|
|
||||||
|
This is just for cleanup purposes - the task state is in Redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID
|
||||||
|
asyncio_task: The asyncio Task to track
|
||||||
|
"""
|
||||||
|
_local_tasks[task_id] = asyncio_task
|
||||||
|
|
||||||
|
|
||||||
|
async def unsubscribe_from_task(
|
||||||
|
task_id: str,
|
||||||
|
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||||
|
) -> None:
|
||||||
|
"""Clean up when a subscriber disconnects.
|
||||||
|
|
||||||
|
Cancels the XREAD-based listener task associated with this subscriber queue
|
||||||
|
to prevent resource leaks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task ID
|
||||||
|
subscriber_queue: The subscriber's queue used to look up the listener task
|
||||||
|
"""
|
||||||
|
queue_id = id(subscriber_queue)
|
||||||
|
listener_entry = _listener_tasks.pop(queue_id, None)
|
||||||
|
|
||||||
|
if listener_entry is None:
|
||||||
|
logger.debug(
|
||||||
|
f"No listener task found for task {task_id} queue {queue_id} "
|
||||||
|
"(may have already completed)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
stored_task_id, listener_task = listener_entry
|
||||||
|
|
||||||
|
if stored_task_id != task_id:
|
||||||
|
logger.warning(
|
||||||
|
f"Task ID mismatch in unsubscribe: expected {task_id}, "
|
||||||
|
f"found {stored_task_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if listener_task.done():
|
||||||
|
logger.debug(f"Listener task for task {task_id} already completed")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Cancel the listener task
|
||||||
|
listener_task.cancel()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Wait for the task to be cancelled with a timeout
|
||||||
|
await asyncio.wait_for(listener_task, timeout=5.0)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Expected - the task was successfully cancelled
|
||||||
|
pass
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
f"Timeout waiting for listener task cancellation for task {task_id}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
|
||||||
|
|
||||||
|
logger.debug(f"Successfully unsubscribed from task {task_id}")
|
||||||
@@ -0,0 +1,79 @@
|
|||||||
|
# CoPilot Tools - Future Ideas
|
||||||
|
|
||||||
|
## Multimodal Image Support for CoPilot
|
||||||
|
|
||||||
|
**Problem:** CoPilot uses a vision-capable model but can't "see" workspace images. When a block generates an image and returns `workspace://abc123`, CoPilot can't evaluate it (e.g., checking blog thumbnail quality).
|
||||||
|
|
||||||
|
**Backend Solution:**
|
||||||
|
When preparing messages for the LLM, detect `workspace://` image references and convert them to proper image content blocks:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Before sending to LLM, scan for workspace image references
|
||||||
|
# and inject them as image content parts
|
||||||
|
|
||||||
|
# Example message transformation:
|
||||||
|
# FROM: {"role": "assistant", "content": "Generated image: workspace://abc123"}
|
||||||
|
# TO: {"role": "assistant", "content": [
|
||||||
|
# {"type": "text", "text": "Generated image: workspace://abc123"},
|
||||||
|
# {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
|
||||||
|
# ]}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Where to implement:**
|
||||||
|
- In the chat stream handler before calling the LLM
|
||||||
|
- Or in a message preprocessing step
|
||||||
|
- Need to fetch image from workspace, convert to base64, add as image content
|
||||||
|
|
||||||
|
**Considerations:**
|
||||||
|
- Only do this for image MIME types (image/png, image/jpeg, etc.)
|
||||||
|
- May want a size limit (don't pass 10MB images)
|
||||||
|
- Track which images were "shown" to the AI for frontend indicator
|
||||||
|
- Cost implications - vision API calls are more expensive
|
||||||
|
|
||||||
|
**Frontend Solution:**
|
||||||
|
Show visual indicator on workspace files in chat:
|
||||||
|
- If AI saw the image: normal display
|
||||||
|
- If AI didn't see it: overlay icon saying "AI can't see this image"
|
||||||
|
|
||||||
|
Requires response metadata indicating which `workspace://` refs were passed to the model.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Output Post-Processing Layer for run_block
|
||||||
|
|
||||||
|
**Problem:** Many blocks produce large outputs that:
|
||||||
|
- Consume massive context (100KB base64 image = ~133KB tokens)
|
||||||
|
- Can't fit in conversation
|
||||||
|
- Break things and cause high LLM costs
|
||||||
|
|
||||||
|
**Proposed Solution:** Instead of modifying individual blocks or `store_media_file()`, implement a centralized output processor in `run_block.py` that handles outputs before they're returned to CoPilot.
|
||||||
|
|
||||||
|
**Benefits:**
|
||||||
|
1. **Centralized** - one place to handle all output processing
|
||||||
|
2. **Future-proof** - new blocks automatically get output processing
|
||||||
|
3. **Keeps blocks pure** - they don't need to know about context constraints
|
||||||
|
4. **Handles all large outputs** - not just images
|
||||||
|
|
||||||
|
**Processing Rules:**
|
||||||
|
- Detect base64 data URIs → save to workspace, return `workspace://` reference
|
||||||
|
- Truncate very long strings (>N chars) with truncation note
|
||||||
|
- Summarize large arrays/lists (e.g., "Array with 1000 items, first 5: [...]")
|
||||||
|
- Handle nested large outputs in dicts recursively
|
||||||
|
- Cap total output size
|
||||||
|
|
||||||
|
**Implementation Location:** `run_block.py` after block execution, before returning `BlockOutputResponse`
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```python
|
||||||
|
def _process_outputs_for_context(
|
||||||
|
outputs: dict[str, list[Any]],
|
||||||
|
workspace_manager: WorkspaceManager,
|
||||||
|
max_string_length: int = 10000,
|
||||||
|
max_array_preview: int = 5,
|
||||||
|
) -> dict[str, list[Any]]:
|
||||||
|
"""Process block outputs to prevent context bloat."""
|
||||||
|
processed = {}
|
||||||
|
for name, values in outputs.items():
|
||||||
|
processed[name] = [_process_value(v, workspace_manager) for v in values]
|
||||||
|
return processed
|
||||||
|
```
|
||||||
@@ -10,6 +10,7 @@ from .add_understanding import AddUnderstandingTool
|
|||||||
from .agent_output import AgentOutputTool
|
from .agent_output import AgentOutputTool
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .create_agent import CreateAgentTool
|
from .create_agent import CreateAgentTool
|
||||||
|
from .customize_agent import CustomizeAgentTool
|
||||||
from .edit_agent import EditAgentTool
|
from .edit_agent import EditAgentTool
|
||||||
from .find_agent import FindAgentTool
|
from .find_agent import FindAgentTool
|
||||||
from .find_block import FindBlockTool
|
from .find_block import FindBlockTool
|
||||||
@@ -18,6 +19,12 @@ from .get_doc_page import GetDocPageTool
|
|||||||
from .run_agent import RunAgentTool
|
from .run_agent import RunAgentTool
|
||||||
from .run_block import RunBlockTool
|
from .run_block import RunBlockTool
|
||||||
from .search_docs import SearchDocsTool
|
from .search_docs import SearchDocsTool
|
||||||
|
from .workspace_files import (
|
||||||
|
DeleteWorkspaceFileTool,
|
||||||
|
ListWorkspaceFilesTool,
|
||||||
|
ReadWorkspaceFileTool,
|
||||||
|
WriteWorkspaceFileTool,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||||
@@ -28,6 +35,7 @@ logger = logging.getLogger(__name__)
|
|||||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||||
"add_understanding": AddUnderstandingTool(),
|
"add_understanding": AddUnderstandingTool(),
|
||||||
"create_agent": CreateAgentTool(),
|
"create_agent": CreateAgentTool(),
|
||||||
|
"customize_agent": CustomizeAgentTool(),
|
||||||
"edit_agent": EditAgentTool(),
|
"edit_agent": EditAgentTool(),
|
||||||
"find_agent": FindAgentTool(),
|
"find_agent": FindAgentTool(),
|
||||||
"find_block": FindBlockTool(),
|
"find_block": FindBlockTool(),
|
||||||
@@ -37,6 +45,11 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
|||||||
"view_agent_output": AgentOutputTool(),
|
"view_agent_output": AgentOutputTool(),
|
||||||
"search_docs": SearchDocsTool(),
|
"search_docs": SearchDocsTool(),
|
||||||
"get_doc_page": GetDocPageTool(),
|
"get_doc_page": GetDocPageTool(),
|
||||||
|
# Workspace tools for CoPilot file operations
|
||||||
|
"list_workspace_files": ListWorkspaceFilesTool(),
|
||||||
|
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||||
|
"write_workspace_file": WriteWorkspaceFileTool(),
|
||||||
|
"delete_workspace_file": DeleteWorkspaceFileTool(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Export individual tool instances for backwards compatibility
|
# Export individual tool instances for backwards compatibility
|
||||||
|
|||||||
@@ -2,27 +2,58 @@
|
|||||||
|
|
||||||
from .core import (
|
from .core import (
|
||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
|
AgentJsonValidationError,
|
||||||
|
AgentSummary,
|
||||||
|
DecompositionResult,
|
||||||
|
DecompositionStep,
|
||||||
|
LibraryAgentSummary,
|
||||||
|
MarketplaceAgentSummary,
|
||||||
|
customize_template,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
|
enrich_library_agents_from_steps,
|
||||||
|
extract_search_terms_from_steps,
|
||||||
|
extract_uuids_from_text,
|
||||||
generate_agent,
|
generate_agent,
|
||||||
generate_agent_patch,
|
generate_agent_patch,
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
|
get_all_relevant_agents_for_generation,
|
||||||
|
get_library_agent_by_graph_id,
|
||||||
|
get_library_agent_by_id,
|
||||||
|
get_library_agents_for_generation,
|
||||||
|
graph_to_json,
|
||||||
json_to_graph,
|
json_to_graph,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
|
search_marketplace_agents_for_generation,
|
||||||
)
|
)
|
||||||
|
from .errors import get_user_message_for_error
|
||||||
from .service import health_check as check_external_service_health
|
from .service import health_check as check_external_service_health
|
||||||
from .service import is_external_service_configured
|
from .service import is_external_service_configured
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Core functions
|
"AgentGeneratorNotConfiguredError",
|
||||||
|
"AgentJsonValidationError",
|
||||||
|
"AgentSummary",
|
||||||
|
"DecompositionResult",
|
||||||
|
"DecompositionStep",
|
||||||
|
"LibraryAgentSummary",
|
||||||
|
"MarketplaceAgentSummary",
|
||||||
|
"check_external_service_health",
|
||||||
|
"customize_template",
|
||||||
"decompose_goal",
|
"decompose_goal",
|
||||||
|
"enrich_library_agents_from_steps",
|
||||||
|
"extract_search_terms_from_steps",
|
||||||
|
"extract_uuids_from_text",
|
||||||
"generate_agent",
|
"generate_agent",
|
||||||
"generate_agent_patch",
|
"generate_agent_patch",
|
||||||
"save_agent_to_library",
|
|
||||||
"get_agent_as_json",
|
"get_agent_as_json",
|
||||||
"json_to_graph",
|
"get_all_relevant_agents_for_generation",
|
||||||
# Exceptions
|
"get_library_agent_by_graph_id",
|
||||||
"AgentGeneratorNotConfiguredError",
|
"get_library_agent_by_id",
|
||||||
# Service
|
"get_library_agents_for_generation",
|
||||||
|
"get_user_message_for_error",
|
||||||
|
"graph_to_json",
|
||||||
"is_external_service_configured",
|
"is_external_service_configured",
|
||||||
"check_external_service_health",
|
"json_to_graph",
|
||||||
|
"save_agent_to_library",
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
"""Core agent generation functions."""
|
"""Core agent generation functions."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any, NotRequired, TypedDict
|
||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.data.graph import Graph, Link, Node, create_graph
|
from backend.api.features.store import db as store_db
|
||||||
|
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
||||||
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
|
customize_template_external,
|
||||||
decompose_goal_external,
|
decompose_goal_external,
|
||||||
generate_agent_external,
|
generate_agent_external,
|
||||||
generate_agent_patch_external,
|
generate_agent_patch_external,
|
||||||
@@ -17,6 +21,72 @@ from .service import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionSummary(TypedDict):
|
||||||
|
"""Summary of a single execution for quality assessment."""
|
||||||
|
|
||||||
|
status: str
|
||||||
|
correctness_score: NotRequired[float]
|
||||||
|
activity_summary: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
|
class LibraryAgentSummary(TypedDict):
|
||||||
|
"""Summary of a library agent for sub-agent composition.
|
||||||
|
|
||||||
|
Includes recent executions to help the LLM decide whether to use this agent.
|
||||||
|
Each execution shows status, correctness_score (0-1), and activity_summary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph_id: str
|
||||||
|
graph_version: int
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
output_schema: dict[str, Any]
|
||||||
|
recent_executions: NotRequired[list[ExecutionSummary]]
|
||||||
|
|
||||||
|
|
||||||
|
class MarketplaceAgentSummary(TypedDict):
|
||||||
|
"""Summary of a marketplace agent for sub-agent composition."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
sub_heading: str
|
||||||
|
creator: str
|
||||||
|
is_marketplace_agent: bool
|
||||||
|
|
||||||
|
|
||||||
|
class DecompositionStep(TypedDict, total=False):
|
||||||
|
"""A single step in decomposed instructions."""
|
||||||
|
|
||||||
|
description: str
|
||||||
|
action: str
|
||||||
|
block_name: str
|
||||||
|
tool: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class DecompositionResult(TypedDict, total=False):
|
||||||
|
"""Result from decompose_goal - can be instructions, questions, or error."""
|
||||||
|
|
||||||
|
type: str
|
||||||
|
steps: list[DecompositionStep]
|
||||||
|
questions: list[dict[str, Any]]
|
||||||
|
error: str
|
||||||
|
error_type: str
|
||||||
|
|
||||||
|
|
||||||
|
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def _to_dict_list(
|
||||||
|
agents: list[AgentSummary] | list[dict[str, Any]] | None,
|
||||||
|
) -> list[dict[str, Any]] | None:
|
||||||
|
"""Convert typed agent summaries to plain dicts for external service calls."""
|
||||||
|
if agents is None:
|
||||||
|
return None
|
||||||
|
return [dict(a) for a in agents]
|
||||||
|
|
||||||
|
|
||||||
class AgentGeneratorNotConfiguredError(Exception):
|
class AgentGeneratorNotConfiguredError(Exception):
|
||||||
"""Raised when the external Agent Generator service is not configured."""
|
"""Raised when the external Agent Generator service is not configured."""
|
||||||
|
|
||||||
@@ -36,15 +106,422 @@ def _check_service_configured() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
_UUID_PATTERN = re.compile(
|
||||||
|
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_uuids_from_text(text: str) -> list[str]:
|
||||||
|
"""Extract all UUID v4 strings from text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text that may contain UUIDs (e.g., user's goal description)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unique UUIDs found in the text (lowercase)
|
||||||
|
"""
|
||||||
|
matches = _UUID_PATTERN.findall(text)
|
||||||
|
return list({m.lower() for m in matches})
|
||||||
|
|
||||||
|
|
||||||
|
async def get_library_agent_by_id(
|
||||||
|
user_id: str, agent_id: str
|
||||||
|
) -> LibraryAgentSummary | None:
|
||||||
|
"""Fetch a specific library agent by its ID (library agent ID or graph_id).
|
||||||
|
|
||||||
|
This function tries multiple lookup strategies:
|
||||||
|
1. First tries to find by graph_id (AgentGraph primary key)
|
||||||
|
2. If not found, tries to find by library agent ID (LibraryAgent primary key)
|
||||||
|
|
||||||
|
This handles both cases:
|
||||||
|
- User provides graph_id (e.g., from AgentExecutorBlock)
|
||||||
|
- User provides library agent ID (e.g., from library URL)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LibraryAgentSummary if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||||
|
return LibraryAgentSummary(
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
graph_version=agent.graph_version,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description,
|
||||||
|
input_schema=agent.input_schema,
|
||||||
|
output_schema=agent.output_schema,
|
||||||
|
)
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||||
|
return LibraryAgentSummary(
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
graph_version=agent.graph_version,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description,
|
||||||
|
input_schema=agent.input_schema,
|
||||||
|
output_schema=agent.output_schema,
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
get_library_agent_by_graph_id = get_library_agent_by_id
|
||||||
|
|
||||||
|
|
||||||
|
async def get_library_agents_for_generation(
|
||||||
|
user_id: str,
|
||||||
|
search_query: str | None = None,
|
||||||
|
exclude_graph_id: str | None = None,
|
||||||
|
max_results: int = 15,
|
||||||
|
) -> list[LibraryAgentSummary]:
|
||||||
|
"""Fetch user's library agents formatted for Agent Generator.
|
||||||
|
|
||||||
|
Uses search-based fetching to return relevant agents instead of all agents.
|
||||||
|
This is more scalable for users with large libraries.
|
||||||
|
|
||||||
|
Includes recent_executions list to help the LLM assess agent quality:
|
||||||
|
- Each execution has status, correctness_score (0-1), and activity_summary
|
||||||
|
- This gives the LLM concrete examples of recent performance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
search_query: Optional search term to find relevant agents (user's goal/description)
|
||||||
|
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
|
||||||
|
max_results: Maximum number of agents to return (default 15)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await library_db.list_library_agents(
|
||||||
|
user_id=user_id,
|
||||||
|
search_term=search_query,
|
||||||
|
page=1,
|
||||||
|
page_size=max_results,
|
||||||
|
include_executions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[LibraryAgentSummary] = []
|
||||||
|
for agent in response.agents:
|
||||||
|
if exclude_graph_id is not None and agent.graph_id == exclude_graph_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
summary = LibraryAgentSummary(
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
graph_version=agent.graph_version,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description,
|
||||||
|
input_schema=agent.input_schema,
|
||||||
|
output_schema=agent.output_schema,
|
||||||
|
)
|
||||||
|
if agent.recent_executions:
|
||||||
|
exec_summaries: list[ExecutionSummary] = []
|
||||||
|
for ex in agent.recent_executions:
|
||||||
|
exec_sum = ExecutionSummary(status=ex.status)
|
||||||
|
if ex.correctness_score is not None:
|
||||||
|
exec_sum["correctness_score"] = ex.correctness_score
|
||||||
|
if ex.activity_summary:
|
||||||
|
exec_sum["activity_summary"] = ex.activity_summary
|
||||||
|
exec_summaries.append(exec_sum)
|
||||||
|
summary["recent_executions"] = exec_summaries
|
||||||
|
results.append(summary)
|
||||||
|
return results
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch library agents: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def search_marketplace_agents_for_generation(
|
||||||
|
search_query: str,
|
||||||
|
max_results: int = 10,
|
||||||
|
) -> list[LibraryAgentSummary]:
|
||||||
|
"""Search marketplace agents formatted for Agent Generator.
|
||||||
|
|
||||||
|
Fetches marketplace agents and their full schemas so they can be used
|
||||||
|
as sub-agents in generated workflows.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_query: Search term to find relevant public agents
|
||||||
|
max_results: Maximum number of agents to return (default 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LibraryAgentSummary with full input/output schemas
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await store_db.get_store_agents(
|
||||||
|
search_query=search_query,
|
||||||
|
page=1,
|
||||||
|
page_size=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
agents_with_graphs = [
|
||||||
|
agent for agent in response.agents if agent.agent_graph_id
|
||||||
|
]
|
||||||
|
|
||||||
|
if not agents_with_graphs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
||||||
|
graphs = await get_store_listed_graphs(*graph_ids)
|
||||||
|
|
||||||
|
results: list[LibraryAgentSummary] = []
|
||||||
|
for agent in agents_with_graphs:
|
||||||
|
graph_id = agent.agent_graph_id
|
||||||
|
if graph_id and graph_id in graphs:
|
||||||
|
graph = graphs[graph_id]
|
||||||
|
results.append(
|
||||||
|
LibraryAgentSummary(
|
||||||
|
graph_id=graph.id,
|
||||||
|
graph_version=graph.version,
|
||||||
|
name=agent.agent_name,
|
||||||
|
description=agent.description,
|
||||||
|
input_schema=graph.input_schema,
|
||||||
|
output_schema=graph.output_schema,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to search marketplace agents: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def get_all_relevant_agents_for_generation(
|
||||||
|
user_id: str,
|
||||||
|
search_query: str | None = None,
|
||||||
|
exclude_graph_id: str | None = None,
|
||||||
|
include_library: bool = True,
|
||||||
|
include_marketplace: bool = True,
|
||||||
|
max_library_results: int = 15,
|
||||||
|
max_marketplace_results: int = 10,
|
||||||
|
) -> list[AgentSummary]:
|
||||||
|
"""Fetch relevant agents from library and/or marketplace.
|
||||||
|
|
||||||
|
Searches both user's library and marketplace by default.
|
||||||
|
Explicitly mentioned UUIDs in the search query are always looked up.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
search_query: Search term to find relevant agents (user's goal/description)
|
||||||
|
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
|
||||||
|
include_library: Whether to search user's library (default True)
|
||||||
|
include_marketplace: Whether to also search marketplace (default True)
|
||||||
|
max_library_results: Max library agents to return (default 15)
|
||||||
|
max_marketplace_results: Max marketplace agents to return (default 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of AgentSummary with full schemas (both library and marketplace agents)
|
||||||
|
"""
|
||||||
|
agents: list[AgentSummary] = []
|
||||||
|
seen_graph_ids: set[str] = set()
|
||||||
|
|
||||||
|
if search_query:
|
||||||
|
mentioned_uuids = extract_uuids_from_text(search_query)
|
||||||
|
for graph_id in mentioned_uuids:
|
||||||
|
if graph_id == exclude_graph_id:
|
||||||
|
continue
|
||||||
|
agent = await get_library_agent_by_graph_id(user_id, graph_id)
|
||||||
|
agent_graph_id = agent.get("graph_id") if agent else None
|
||||||
|
if agent and agent_graph_id and agent_graph_id not in seen_graph_ids:
|
||||||
|
agents.append(agent)
|
||||||
|
seen_graph_ids.add(agent_graph_id)
|
||||||
|
logger.debug(
|
||||||
|
f"Found explicitly mentioned agent: {agent.get('name') or 'Unknown'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_library:
|
||||||
|
library_agents = await get_library_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=search_query,
|
||||||
|
exclude_graph_id=exclude_graph_id,
|
||||||
|
max_results=max_library_results,
|
||||||
|
)
|
||||||
|
for agent in library_agents:
|
||||||
|
graph_id = agent.get("graph_id")
|
||||||
|
if graph_id and graph_id not in seen_graph_ids:
|
||||||
|
agents.append(agent)
|
||||||
|
seen_graph_ids.add(graph_id)
|
||||||
|
|
||||||
|
if include_marketplace and search_query:
|
||||||
|
marketplace_agents = await search_marketplace_agents_for_generation(
|
||||||
|
search_query=search_query,
|
||||||
|
max_results=max_marketplace_results,
|
||||||
|
)
|
||||||
|
for agent in marketplace_agents:
|
||||||
|
graph_id = agent.get("graph_id")
|
||||||
|
if graph_id and graph_id not in seen_graph_ids:
|
||||||
|
agents.append(agent)
|
||||||
|
seen_graph_ids.add(graph_id)
|
||||||
|
|
||||||
|
return agents
|
||||||
|
|
||||||
|
|
||||||
|
def extract_search_terms_from_steps(
|
||||||
|
decomposition_result: DecompositionResult | dict[str, Any],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Extract search terms from decomposed instruction steps.
|
||||||
|
|
||||||
|
Analyzes the decomposition result to extract relevant keywords
|
||||||
|
for additional library agent searches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decomposition_result: Result from decompose_goal containing steps
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unique search terms extracted from steps
|
||||||
|
"""
|
||||||
|
search_terms: list[str] = []
|
||||||
|
|
||||||
|
if decomposition_result.get("type") != "instructions":
|
||||||
|
return search_terms
|
||||||
|
|
||||||
|
steps = decomposition_result.get("steps", [])
|
||||||
|
if not steps:
|
||||||
|
return search_terms
|
||||||
|
|
||||||
|
step_keys: list[str] = ["description", "action", "block_name", "tool", "name"]
|
||||||
|
|
||||||
|
for step in steps:
|
||||||
|
for key in step_keys:
|
||||||
|
value = step.get(key) # type: ignore[union-attr]
|
||||||
|
if isinstance(value, str) and len(value) > 3:
|
||||||
|
search_terms.append(value)
|
||||||
|
|
||||||
|
seen: set[str] = set()
|
||||||
|
unique_terms: list[str] = []
|
||||||
|
for term in search_terms:
|
||||||
|
term_lower = term.lower()
|
||||||
|
if term_lower not in seen:
|
||||||
|
seen.add(term_lower)
|
||||||
|
unique_terms.append(term)
|
||||||
|
|
||||||
|
return unique_terms
|
||||||
|
|
||||||
|
|
||||||
|
async def enrich_library_agents_from_steps(
|
||||||
|
user_id: str,
|
||||||
|
decomposition_result: DecompositionResult | dict[str, Any],
|
||||||
|
existing_agents: list[AgentSummary] | list[dict[str, Any]],
|
||||||
|
exclude_graph_id: str | None = None,
|
||||||
|
include_marketplace: bool = True,
|
||||||
|
max_additional_results: int = 10,
|
||||||
|
) -> list[AgentSummary] | list[dict[str, Any]]:
|
||||||
|
"""Enrich library agents list with additional searches based on decomposed steps.
|
||||||
|
|
||||||
|
This implements two-phase search: after decomposition, we search for additional
|
||||||
|
relevant agents based on the specific steps identified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
decomposition_result: Result from decompose_goal containing steps
|
||||||
|
existing_agents: Already fetched library agents from initial search
|
||||||
|
exclude_graph_id: Optional graph ID to exclude
|
||||||
|
include_marketplace: Whether to also search marketplace
|
||||||
|
max_additional_results: Max additional agents per search term (default 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined list of library agents (existing + newly discovered)
|
||||||
|
"""
|
||||||
|
search_terms = extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
if not search_terms:
|
||||||
|
return existing_agents
|
||||||
|
|
||||||
|
existing_ids: set[str] = set()
|
||||||
|
existing_names: set[str] = set()
|
||||||
|
|
||||||
|
for agent in existing_agents:
|
||||||
|
agent_name = agent.get("name")
|
||||||
|
if agent_name and isinstance(agent_name, str):
|
||||||
|
existing_names.add(agent_name.lower())
|
||||||
|
graph_id = agent.get("graph_id") # type: ignore[call-overload]
|
||||||
|
if graph_id and isinstance(graph_id, str):
|
||||||
|
existing_ids.add(graph_id)
|
||||||
|
|
||||||
|
all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents)
|
||||||
|
|
||||||
|
for term in search_terms[:3]:
|
||||||
|
try:
|
||||||
|
additional_agents = await get_all_relevant_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=term,
|
||||||
|
exclude_graph_id=exclude_graph_id,
|
||||||
|
include_marketplace=include_marketplace,
|
||||||
|
max_library_results=max_additional_results,
|
||||||
|
max_marketplace_results=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
for agent in additional_agents:
|
||||||
|
agent_name = agent.get("name")
|
||||||
|
if not agent_name or not isinstance(agent_name, str):
|
||||||
|
continue
|
||||||
|
agent_name_lower = agent_name.lower()
|
||||||
|
|
||||||
|
if agent_name_lower in existing_names:
|
||||||
|
continue
|
||||||
|
|
||||||
|
graph_id = agent.get("graph_id") # type: ignore[call-overload]
|
||||||
|
if graph_id and graph_id in existing_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
all_agents.append(agent)
|
||||||
|
existing_names.add(agent_name_lower)
|
||||||
|
if graph_id and isinstance(graph_id, str):
|
||||||
|
existing_ids.add(graph_id)
|
||||||
|
|
||||||
|
except DatabaseError:
|
||||||
|
logger.error(f"Database error searching for agents with term '{term}'")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to search for additional agents with term '{term}': {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Enriched library agents: {len(existing_agents)} initial + "
|
||||||
|
f"{len(all_agents) - len(existing_agents)} additional = {len(all_agents)} total"
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_agents
|
||||||
|
|
||||||
|
|
||||||
|
async def decompose_goal(
|
||||||
|
description: str,
|
||||||
|
context: str = "",
|
||||||
|
library_agents: list[AgentSummary] | None = None,
|
||||||
|
) -> DecompositionResult | None:
|
||||||
"""Break down a goal into steps or return clarifying questions.
|
"""Break down a goal into steps or return clarifying questions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
description: Natural language goal description
|
description: Natural language goal description
|
||||||
context: Additional context (e.g., answers to previous questions)
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with either:
|
DecompositionResult with either:
|
||||||
- {"type": "clarifying_questions", "questions": [...]}
|
- {"type": "clarifying_questions", "questions": [...]}
|
||||||
- {"type": "instructions", "steps": [...]}
|
- {"type": "instructions", "steps": [...]}
|
||||||
Or None on error
|
Or None on error
|
||||||
@@ -54,26 +531,47 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
|
|||||||
"""
|
"""
|
||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for decompose_goal")
|
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||||
return await decompose_goal_external(description, context)
|
result = await decompose_goal_external(
|
||||||
|
description, context, _to_dict_list(library_agents)
|
||||||
|
)
|
||||||
|
return result # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
async def generate_agent(
|
||||||
|
instructions: DecompositionResult | dict[str, Any],
|
||||||
|
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||||
|
operation_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
"""Generate agent JSON from instructions.
|
"""Generate agent JSON from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
operation_id: Operation ID for async processing (enables Redis Streams
|
||||||
|
completion notification)
|
||||||
|
task_id: Task ID for async processing (enables Redis Streams persistence
|
||||||
|
and SSE delivery)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict or None on error
|
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent")
|
logger.info("Calling external Agent Generator service for generate_agent")
|
||||||
result = await generate_agent_external(instructions)
|
result = await generate_agent_external(
|
||||||
|
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Don't modify async response
|
||||||
|
if result and result.get("status") == "accepted":
|
||||||
|
return result
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
# Ensure required fields
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
|
return result
|
||||||
if "id" not in result:
|
if "id" not in result:
|
||||||
result["id"] = str(uuid.uuid4())
|
result["id"] = str(uuid.uuid4())
|
||||||
if "version" not in result:
|
if "version" not in result:
|
||||||
@@ -83,6 +581,12 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class AgentJsonValidationError(Exception):
|
||||||
|
"""Raised when agent JSON is invalid or missing required fields."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||||
"""Convert agent JSON dict to Graph model.
|
"""Convert agent JSON dict to Graph model.
|
||||||
|
|
||||||
@@ -91,25 +595,55 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Graph ready for saving
|
Graph ready for saving
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentJsonValidationError: If required fields are missing from nodes or links
|
||||||
"""
|
"""
|
||||||
nodes = []
|
nodes = []
|
||||||
for n in agent_json.get("nodes", []):
|
for idx, n in enumerate(agent_json.get("nodes", [])):
|
||||||
|
block_id = n.get("block_id")
|
||||||
|
if not block_id:
|
||||||
|
node_id = n.get("id", f"index_{idx}")
|
||||||
|
raise AgentJsonValidationError(
|
||||||
|
f"Node '{node_id}' is missing required field 'block_id'"
|
||||||
|
)
|
||||||
node = Node(
|
node = Node(
|
||||||
id=n.get("id", str(uuid.uuid4())),
|
id=n.get("id", str(uuid.uuid4())),
|
||||||
block_id=n["block_id"],
|
block_id=block_id,
|
||||||
input_default=n.get("input_default", {}),
|
input_default=n.get("input_default", {}),
|
||||||
metadata=n.get("metadata", {}),
|
metadata=n.get("metadata", {}),
|
||||||
)
|
)
|
||||||
nodes.append(node)
|
nodes.append(node)
|
||||||
|
|
||||||
links = []
|
links = []
|
||||||
for link_data in agent_json.get("links", []):
|
for idx, link_data in enumerate(agent_json.get("links", [])):
|
||||||
|
source_id = link_data.get("source_id")
|
||||||
|
sink_id = link_data.get("sink_id")
|
||||||
|
source_name = link_data.get("source_name")
|
||||||
|
sink_name = link_data.get("sink_name")
|
||||||
|
|
||||||
|
missing_fields = []
|
||||||
|
if not source_id:
|
||||||
|
missing_fields.append("source_id")
|
||||||
|
if not sink_id:
|
||||||
|
missing_fields.append("sink_id")
|
||||||
|
if not source_name:
|
||||||
|
missing_fields.append("source_name")
|
||||||
|
if not sink_name:
|
||||||
|
missing_fields.append("sink_name")
|
||||||
|
|
||||||
|
if missing_fields:
|
||||||
|
link_id = link_data.get("id", f"index_{idx}")
|
||||||
|
raise AgentJsonValidationError(
|
||||||
|
f"Link '{link_id}' is missing required fields: {', '.join(missing_fields)}"
|
||||||
|
)
|
||||||
|
|
||||||
link = Link(
|
link = Link(
|
||||||
id=link_data.get("id", str(uuid.uuid4())),
|
id=link_data.get("id", str(uuid.uuid4())),
|
||||||
source_id=link_data["source_id"],
|
source_id=source_id,
|
||||||
sink_id=link_data["sink_id"],
|
sink_id=sink_id,
|
||||||
source_name=link_data["source_name"],
|
source_name=source_name,
|
||||||
sink_name=link_data["sink_name"],
|
sink_name=sink_name,
|
||||||
is_static=link_data.get("is_static", False),
|
is_static=link_data.get("is_static", False),
|
||||||
)
|
)
|
||||||
links.append(link)
|
links.append(link)
|
||||||
@@ -125,27 +659,6 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _reassign_node_ids(graph: Graph) -> None:
|
|
||||||
"""Reassign all node and link IDs to new UUIDs.
|
|
||||||
|
|
||||||
This is needed when creating a new version to avoid unique constraint violations.
|
|
||||||
"""
|
|
||||||
# Create mapping from old node IDs to new UUIDs
|
|
||||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
|
||||||
|
|
||||||
# Reassign node IDs
|
|
||||||
for node in graph.nodes:
|
|
||||||
node.id = id_map[node.id]
|
|
||||||
|
|
||||||
# Update link references to use new node IDs
|
|
||||||
for link in graph.links:
|
|
||||||
link.id = str(uuid.uuid4()) # Also give links new IDs
|
|
||||||
if link.source_id in id_map:
|
|
||||||
link.source_id = id_map[link.source_id]
|
|
||||||
if link.sink_id in id_map:
|
|
||||||
link.sink_id = id_map[link.sink_id]
|
|
||||||
|
|
||||||
|
|
||||||
async def save_agent_to_library(
|
async def save_agent_to_library(
|
||||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||||
) -> tuple[Graph, Any]:
|
) -> tuple[Graph, Any]:
|
||||||
@@ -159,63 +672,21 @@ async def save_agent_to_library(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (created Graph, LibraryAgent)
|
Tuple of (created Graph, LibraryAgent)
|
||||||
"""
|
"""
|
||||||
from backend.data.graph import get_graph_all_versions
|
|
||||||
|
|
||||||
graph = json_to_graph(agent_json)
|
graph = json_to_graph(agent_json)
|
||||||
|
|
||||||
if is_update:
|
if is_update:
|
||||||
# For updates, keep the same graph ID but increment version
|
return await library_db.update_graph_in_library(graph, user_id)
|
||||||
# and reassign node/link IDs to avoid conflicts
|
return await library_db.create_graph_in_library(graph, user_id)
|
||||||
if graph.id:
|
|
||||||
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
|
||||||
if existing_versions:
|
|
||||||
latest_version = max(v.version for v in existing_versions)
|
|
||||||
graph.version = latest_version + 1
|
|
||||||
# Reassign node IDs (but keep graph ID the same)
|
|
||||||
_reassign_node_ids(graph)
|
|
||||||
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
|
||||||
else:
|
|
||||||
# For new agents, always generate a fresh UUID to avoid collisions
|
|
||||||
graph.id = str(uuid.uuid4())
|
|
||||||
graph.version = 1
|
|
||||||
# Reassign all node IDs as well
|
|
||||||
_reassign_node_ids(graph)
|
|
||||||
logger.info(f"Creating new agent with ID {graph.id}")
|
|
||||||
|
|
||||||
# Save to database
|
|
||||||
created_graph = await create_graph(graph, user_id)
|
|
||||||
|
|
||||||
# Add to user's library (or update existing library agent)
|
|
||||||
library_agents = await library_db.create_library_agent(
|
|
||||||
graph=created_graph,
|
|
||||||
user_id=user_id,
|
|
||||||
sensitive_action_safe_mode=True,
|
|
||||||
create_library_agents_for_sub_graphs=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return created_graph, library_agents[0]
|
|
||||||
|
|
||||||
|
|
||||||
async def get_agent_as_json(
|
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||||
graph_id: str, user_id: str | None
|
"""Convert a Graph object to JSON format for the agent generator.
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Fetch an agent and convert to JSON format for editing.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: Graph ID or library agent ID
|
graph: Graph object to convert
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent as JSON dict or None if not found
|
Agent as JSON dict
|
||||||
"""
|
"""
|
||||||
from backend.data.graph import get_graph
|
|
||||||
|
|
||||||
# Try to get the graph (version=None gets the active version)
|
|
||||||
graph = await get_graph(graph_id, version=None, user_id=user_id)
|
|
||||||
if not graph:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Convert to JSON format
|
|
||||||
nodes = []
|
nodes = []
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
nodes.append(
|
nodes.append(
|
||||||
@@ -252,8 +723,41 @@ async def get_agent_as_json(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def get_agent_as_json(
|
||||||
|
agent_id: str, user_id: str | None
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Fetch an agent and convert to JSON format for editing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Graph ID or library agent ID
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent as JSON dict or None if not found
|
||||||
|
"""
|
||||||
|
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||||
|
|
||||||
|
if not graph and user_id:
|
||||||
|
try:
|
||||||
|
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
graph = await get_graph(
|
||||||
|
library_agent.graph_id, version=None, user_id=user_id
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not graph:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return graph_to_json(graph)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch(
|
async def generate_agent_patch(
|
||||||
update_request: str, current_agent: dict[str, Any]
|
update_request: str,
|
||||||
|
current_agent: dict[str, Any],
|
||||||
|
library_agents: list[AgentSummary] | None = None,
|
||||||
|
operation_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Update an existing agent using natural language.
|
"""Update an existing agent using natural language.
|
||||||
|
|
||||||
@@ -265,13 +769,57 @@ async def generate_agent_patch(
|
|||||||
Args:
|
Args:
|
||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||||
|
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict, or None on error
|
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||||
|
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
"""
|
"""
|
||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||||
return await generate_agent_patch_external(update_request, current_agent)
|
return await generate_agent_patch_external(
|
||||||
|
update_request,
|
||||||
|
current_agent,
|
||||||
|
_to_dict_list(library_agents),
|
||||||
|
operation_id,
|
||||||
|
task_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def customize_template(
|
||||||
|
template_agent: dict[str, Any],
|
||||||
|
modification_request: str,
|
||||||
|
context: str = "",
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Customize a template/marketplace agent using natural language.
|
||||||
|
|
||||||
|
This is used when users want to modify a template or marketplace agent
|
||||||
|
to fit their specific needs before adding it to their library.
|
||||||
|
|
||||||
|
The external Agent Generator service handles:
|
||||||
|
- Understanding the modification request
|
||||||
|
- Applying changes to the template
|
||||||
|
- Fixing and validating the result
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_agent: The template agent JSON to customize
|
||||||
|
modification_request: Natural language description of customizations
|
||||||
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||||
|
error dict {"type": "error", ...}, or None on unexpected error
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
|
"""
|
||||||
|
_check_service_configured()
|
||||||
|
logger.info("Calling external Agent Generator service for customize_template")
|
||||||
|
return await customize_template_external(
|
||||||
|
template_agent, modification_request, context
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,95 @@
|
|||||||
|
"""Error handling utilities for agent generator."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_error_details(details: str) -> str:
|
||||||
|
"""Sanitize error details to remove sensitive information.
|
||||||
|
|
||||||
|
Strips common patterns that could expose internal system info:
|
||||||
|
- File paths (Unix and Windows)
|
||||||
|
- Database connection strings
|
||||||
|
- URLs with credentials
|
||||||
|
- Stack trace internals
|
||||||
|
|
||||||
|
Args:
|
||||||
|
details: Raw error details string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized error details safe for user display
|
||||||
|
"""
|
||||||
|
sanitized = re.sub(
|
||||||
|
r"/[a-zA-Z0-9_./\-]+\.(py|js|ts|json|yaml|yml)", "[path]", details
|
||||||
|
)
|
||||||
|
sanitized = re.sub(r"[A-Z]:\\[a-zA-Z0-9_\\.\\-]+", "[path]", sanitized)
|
||||||
|
sanitized = re.sub(
|
||||||
|
r"(postgres|mysql|mongodb|redis)://[^\s]+", "[database_url]", sanitized
|
||||||
|
)
|
||||||
|
sanitized = re.sub(r"https?://[^:]+:[^@]+@[^\s]+", "[url]", sanitized)
|
||||||
|
sanitized = re.sub(r", line \d+", "", sanitized)
|
||||||
|
sanitized = re.sub(r'File "[^"]+",?', "", sanitized)
|
||||||
|
|
||||||
|
return sanitized.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_message_for_error(
|
||||||
|
error_type: str,
|
||||||
|
operation: str = "process the request",
|
||||||
|
llm_parse_message: str | None = None,
|
||||||
|
validation_message: str | None = None,
|
||||||
|
error_details: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Get a user-friendly error message based on error type.
|
||||||
|
|
||||||
|
This function maps internal error types to user-friendly messages,
|
||||||
|
providing a consistent experience across different agent operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_type: The error type from the external service
|
||||||
|
(e.g., "llm_parse_error", "timeout", "rate_limit")
|
||||||
|
operation: Description of what operation failed, used in the default
|
||||||
|
message (e.g., "analyze the goal", "generate the agent")
|
||||||
|
llm_parse_message: Custom message for llm_parse_error type
|
||||||
|
validation_message: Custom message for validation_error type
|
||||||
|
error_details: Optional additional details about the error
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User-friendly error message suitable for display to the user
|
||||||
|
"""
|
||||||
|
base_message = ""
|
||||||
|
|
||||||
|
if error_type == "llm_parse_error":
|
||||||
|
base_message = (
|
||||||
|
llm_parse_message
|
||||||
|
or "The AI had trouble processing this request. Please try again."
|
||||||
|
)
|
||||||
|
elif error_type == "validation_error":
|
||||||
|
base_message = (
|
||||||
|
validation_message
|
||||||
|
or "The generated agent failed validation. "
|
||||||
|
"This usually happens when the agent structure doesn't match "
|
||||||
|
"what the platform expects. Please try simplifying your goal "
|
||||||
|
"or breaking it into smaller parts."
|
||||||
|
)
|
||||||
|
elif error_type == "patch_error":
|
||||||
|
base_message = (
|
||||||
|
"Failed to apply the changes. The modification couldn't be "
|
||||||
|
"validated. Please try a different approach or simplify the change."
|
||||||
|
)
|
||||||
|
elif error_type in ("timeout", "llm_timeout"):
|
||||||
|
base_message = (
|
||||||
|
"The request took too long to process. This can happen with "
|
||||||
|
"complex agents. Please try again or simplify your goal."
|
||||||
|
)
|
||||||
|
elif error_type in ("rate_limit", "llm_rate_limit"):
|
||||||
|
base_message = "The service is currently busy. Please try again in a moment."
|
||||||
|
else:
|
||||||
|
base_message = f"Failed to {operation}. Please try again."
|
||||||
|
|
||||||
|
if error_details:
|
||||||
|
details = _sanitize_error_details(error_details)
|
||||||
|
if len(details) > 200:
|
||||||
|
details = details[:200] + "..."
|
||||||
|
base_message += f"\n\nTechnical details: {details}"
|
||||||
|
|
||||||
|
return base_message
|
||||||
@@ -14,6 +14,70 @@ from backend.util.settings import Settings
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_error_response(
|
||||||
|
error_message: str,
|
||||||
|
error_type: str = "unknown",
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create a standardized error response dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_message: Human-readable error message
|
||||||
|
error_type: Machine-readable error type
|
||||||
|
details: Optional additional error details
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Error dict with type="error" and error details
|
||||||
|
"""
|
||||||
|
response: dict[str, Any] = {
|
||||||
|
"type": "error",
|
||||||
|
"error": error_message,
|
||||||
|
"error_type": error_type,
|
||||||
|
}
|
||||||
|
if details:
|
||||||
|
response["details"] = details
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
|
||||||
|
"""Classify an HTTP error into error_type and message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e: The HTTP status error
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (error_type, error_message)
|
||||||
|
"""
|
||||||
|
status = e.response.status_code
|
||||||
|
if status == 429:
|
||||||
|
return "rate_limit", f"Agent Generator rate limited: {e}"
|
||||||
|
elif status == 503:
|
||||||
|
return "service_unavailable", f"Agent Generator unavailable: {e}"
|
||||||
|
elif status == 504 or status == 408:
|
||||||
|
return "timeout", f"Agent Generator timed out: {e}"
|
||||||
|
else:
|
||||||
|
return "http_error", f"HTTP error calling Agent Generator: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
|
||||||
|
"""Classify a request error into error_type and message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e: The request error
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (error_type, error_message)
|
||||||
|
"""
|
||||||
|
error_str = str(e).lower()
|
||||||
|
if "timeout" in error_str or "timed out" in error_str:
|
||||||
|
return "timeout", f"Agent Generator request timed out: {e}"
|
||||||
|
elif "connect" in error_str:
|
||||||
|
return "connection_error", f"Could not connect to Agent Generator: {e}"
|
||||||
|
else:
|
||||||
|
return "request_error", f"Request error calling Agent Generator: {e}"
|
||||||
|
|
||||||
|
|
||||||
_client: httpx.AsyncClient | None = None
|
_client: httpx.AsyncClient | None = None
|
||||||
_settings: Settings | None = None
|
_settings: Settings | None = None
|
||||||
|
|
||||||
@@ -53,13 +117,16 @@ def _get_client() -> httpx.AsyncClient:
|
|||||||
|
|
||||||
|
|
||||||
async def decompose_goal_external(
|
async def decompose_goal_external(
|
||||||
description: str, context: str = ""
|
description: str,
|
||||||
|
context: str = "",
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to decompose a goal.
|
"""Call the external service to decompose a goal.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
description: Natural language goal description
|
description: Natural language goal description
|
||||||
context: Additional context (e.g., answers to previous questions)
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with either:
|
Dict with either:
|
||||||
@@ -67,15 +134,17 @@ async def decompose_goal_external(
|
|||||||
- {"type": "instructions", "steps": [...]}
|
- {"type": "instructions", "steps": [...]}
|
||||||
- {"type": "unachievable_goal", ...}
|
- {"type": "unachievable_goal", ...}
|
||||||
- {"type": "vague_goal", ...}
|
- {"type": "vague_goal", ...}
|
||||||
Or None on error
|
- {"type": "error", "error": "...", "error_type": "..."} on error
|
||||||
|
Or None on unexpected error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
# Build the request payload
|
|
||||||
payload: dict[str, Any] = {"description": description}
|
|
||||||
if context:
|
if context:
|
||||||
# The external service uses user_instruction for additional context
|
description = f"{description}\n\nAdditional context from user:\n{context}"
|
||||||
payload["user_instruction"] = context
|
|
||||||
|
payload: dict[str, Any] = {"description": description}
|
||||||
|
if library_agents:
|
||||||
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/decompose-description", json=payload)
|
response = await client.post("/api/decompose-description", json=payload)
|
||||||
@@ -83,8 +152,13 @@ async def decompose_goal_external(
|
|||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
if not data.get("success"):
|
if not data.get("success"):
|
||||||
logger.error(f"External service returned error: {data.get('error')}")
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
return None
|
error_type = data.get("error_type", "unknown")
|
||||||
|
logger.error(
|
||||||
|
f"Agent Generator decomposition failed: {error_msg} "
|
||||||
|
f"(type: {error_type})"
|
||||||
|
)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
# Map the response to the expected format
|
# Map the response to the expected format
|
||||||
response_type = data.get("type")
|
response_type = data.get("type")
|
||||||
@@ -106,88 +180,162 @@ async def decompose_goal_external(
|
|||||||
"type": "vague_goal",
|
"type": "vague_goal",
|
||||||
"suggested_goal": data.get("suggested_goal"),
|
"suggested_goal": data.get("suggested_goal"),
|
||||||
}
|
}
|
||||||
|
elif response_type == "error":
|
||||||
|
# Pass through error from the service
|
||||||
|
return _create_error_response(
|
||||||
|
data.get("error", "Unknown error"),
|
||||||
|
data.get("error_type", "unknown"),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Unknown response type from external service: {response_type}"
|
f"Unknown response type from external service: {response_type}"
|
||||||
)
|
)
|
||||||
return None
|
return _create_error_response(
|
||||||
|
f"Unknown response type from Agent Generator: {response_type}",
|
||||||
|
"invalid_response",
|
||||||
|
)
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
error_type, error_msg = _classify_http_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
logger.error(f"Request error calling external agent generator: {e}")
|
error_type, error_msg = _classify_request_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_external(
|
async def generate_agent_external(
|
||||||
instructions: dict[str, Any]
|
instructions: dict[str, Any],
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
|
operation_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate an agent from instructions.
|
"""Call the external service to generate an agent from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||||
|
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict or None on error
|
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
|
# Build request payload
|
||||||
|
payload: dict[str, Any] = {"instructions": instructions}
|
||||||
|
if library_agents:
|
||||||
|
payload["library_agents"] = library_agents
|
||||||
|
if operation_id and task_id:
|
||||||
|
payload["operation_id"] = operation_id
|
||||||
|
payload["task_id"] = task_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post(
|
response = await client.post("/api/generate-agent", json=payload)
|
||||||
"/api/generate-agent", json={"instructions": instructions}
|
|
||||||
|
# Handle 202 Accepted for async processing
|
||||||
|
if response.status_code == 202:
|
||||||
|
logger.info(
|
||||||
|
f"Agent Generator accepted async request "
|
||||||
|
f"(operation_id={operation_id}, task_id={task_id})"
|
||||||
)
|
)
|
||||||
|
return {
|
||||||
|
"status": "accepted",
|
||||||
|
"operation_id": operation_id,
|
||||||
|
"task_id": task_id,
|
||||||
|
}
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
if not data.get("success"):
|
if not data.get("success"):
|
||||||
logger.error(f"External service returned error: {data.get('error')}")
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
return None
|
error_type = data.get("error_type", "unknown")
|
||||||
|
logger.error(
|
||||||
|
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
|
||||||
|
)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
return data.get("agent_json")
|
return data.get("agent_json")
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
error_type, error_msg = _classify_http_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
logger.error(f"Request error calling external agent generator: {e}")
|
error_type, error_msg = _classify_request_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch_external(
|
async def generate_agent_patch_external(
|
||||||
update_request: str, current_agent: dict[str, Any]
|
update_request: str,
|
||||||
|
current_agent: dict[str, Any],
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
|
operation_id: str | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate a patch for an existing agent.
|
"""Call the external service to generate a patch for an existing agent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||||
|
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict, or None on error
|
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
try:
|
# Build request payload
|
||||||
response = await client.post(
|
payload: dict[str, Any] = {
|
||||||
"/api/update-agent",
|
|
||||||
json={
|
|
||||||
"update_request": update_request,
|
"update_request": update_request,
|
||||||
"current_agent_json": current_agent,
|
"current_agent_json": current_agent,
|
||||||
},
|
}
|
||||||
|
if library_agents:
|
||||||
|
payload["library_agents"] = library_agents
|
||||||
|
if operation_id and task_id:
|
||||||
|
payload["operation_id"] = operation_id
|
||||||
|
payload["task_id"] = task_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post("/api/update-agent", json=payload)
|
||||||
|
|
||||||
|
# Handle 202 Accepted for async processing
|
||||||
|
if response.status_code == 202:
|
||||||
|
logger.info(
|
||||||
|
f"Agent Generator accepted async update request "
|
||||||
|
f"(operation_id={operation_id}, task_id={task_id})"
|
||||||
)
|
)
|
||||||
|
return {
|
||||||
|
"status": "accepted",
|
||||||
|
"operation_id": operation_id,
|
||||||
|
"task_id": task_id,
|
||||||
|
}
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
if not data.get("success"):
|
if not data.get("success"):
|
||||||
logger.error(f"External service returned error: {data.get('error')}")
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
return None
|
error_type = data.get("error_type", "unknown")
|
||||||
|
logger.error(
|
||||||
|
f"Agent Generator patch generation failed: {error_msg} "
|
||||||
|
f"(type: {error_type})"
|
||||||
|
)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
# Check if it's clarifying questions
|
# Check if it's clarifying questions
|
||||||
if data.get("type") == "clarifying_questions":
|
if data.get("type") == "clarifying_questions":
|
||||||
@@ -196,18 +344,99 @@ async def generate_agent_patch_external(
|
|||||||
"questions": data.get("questions", []),
|
"questions": data.get("questions", []),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Check if it's an error passed through
|
||||||
|
if data.get("type") == "error":
|
||||||
|
return _create_error_response(
|
||||||
|
data.get("error", "Unknown error"),
|
||||||
|
data.get("error_type", "unknown"),
|
||||||
|
)
|
||||||
|
|
||||||
# Otherwise return the updated agent JSON
|
# Otherwise return the updated agent JSON
|
||||||
return data.get("agent_json")
|
return data.get("agent_json")
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
error_type, error_msg = _classify_http_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
logger.error(f"Request error calling external agent generator: {e}")
|
error_type, error_msg = _classify_request_error(e)
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||||
return None
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
|
async def customize_template_external(
|
||||||
|
template_agent: dict[str, Any],
|
||||||
|
modification_request: str,
|
||||||
|
context: str = "",
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Call the external service to customize a template/marketplace agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_agent: The template agent JSON to customize
|
||||||
|
modification_request: Natural language description of customizations
|
||||||
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Customized agent JSON, clarifying questions dict, or error dict on error
|
||||||
|
"""
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
request = modification_request
|
||||||
|
if context:
|
||||||
|
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"template_agent_json": template_agent,
|
||||||
|
"modification_request": request,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post("/api/template-modification", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if not data.get("success"):
|
||||||
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
|
error_type = data.get("error_type", "unknown")
|
||||||
|
logger.error(
|
||||||
|
f"Agent Generator template customization failed: {error_msg} "
|
||||||
|
f"(type: {error_type})"
|
||||||
|
)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
|
# Check if it's clarifying questions
|
||||||
|
if data.get("type") == "clarifying_questions":
|
||||||
|
return {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": data.get("questions", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if it's an error passed through
|
||||||
|
if data.get("type") == "error":
|
||||||
|
return _create_error_response(
|
||||||
|
data.get("error", "Unknown error"),
|
||||||
|
data.get("error_type", "unknown"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Otherwise return the customized agent JSON
|
||||||
|
return data.get("agent_json")
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
error_type, error_msg = _classify_http_error(e)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
error_type, error_msg = _classify_request_error(e)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, error_type)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
@@ -19,6 +20,85 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
SearchSource = Literal["marketplace", "library"]
|
SearchSource = Literal["marketplace", "library"]
|
||||||
|
|
||||||
|
_UUID_PATTERN = re.compile(
|
||||||
|
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(text: str) -> bool:
|
||||||
|
"""Check if text is a valid UUID v4."""
|
||||||
|
return bool(_UUID_PATTERN.match(text.strip()))
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
|
||||||
|
"""Fetch a library agent by ID (library agent ID or graph_id).
|
||||||
|
|
||||||
|
Tries multiple lookup strategies:
|
||||||
|
1. First by graph_id (AgentGraph primary key)
|
||||||
|
2. Then by library agent ID (LibraryAgent primary key)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentInfo if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||||
|
return AgentInfo(
|
||||||
|
id=agent.id,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description or "",
|
||||||
|
source="library",
|
||||||
|
in_library=True,
|
||||||
|
creator=agent.creator_name,
|
||||||
|
status=agent.status.value,
|
||||||
|
can_access_graph=agent.can_access_graph,
|
||||||
|
has_external_trigger=agent.has_external_trigger,
|
||||||
|
new_output=agent.new_output,
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
)
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not fetch library agent by graph_id {agent_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||||
|
return AgentInfo(
|
||||||
|
id=agent.id,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description or "",
|
||||||
|
source="library",
|
||||||
|
in_library=True,
|
||||||
|
creator=agent.creator_name,
|
||||||
|
status=agent.status.value,
|
||||||
|
can_access_graph=agent.can_access_graph,
|
||||||
|
has_external_trigger=agent.has_external_trigger,
|
||||||
|
new_output=agent.new_output,
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def search_agents(
|
async def search_agents(
|
||||||
query: str,
|
query: str,
|
||||||
@@ -69,7 +149,15 @@ async def search_agents(
|
|||||||
is_featured=False,
|
is_featured=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else: # library
|
else:
|
||||||
|
if _is_uuid(query):
|
||||||
|
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
|
||||||
|
agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type]
|
||||||
|
if agent:
|
||||||
|
agents.append(agent)
|
||||||
|
logger.info(f"Found agent by direct ID lookup: {agent.name}")
|
||||||
|
|
||||||
|
if not agents:
|
||||||
logger.info(f"Searching user library for: {query}")
|
logger.info(f"Searching user library for: {query}")
|
||||||
results = await library_db.list_library_agents(
|
results = await library_db.list_library_agents(
|
||||||
user_id=user_id, # type: ignore[arg-type]
|
user_id=user_id, # type: ignore[arg-type]
|
||||||
@@ -118,9 +206,9 @@ async def search_agents(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
no_results_msg = (
|
no_results_msg = (
|
||||||
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else f"No agents matching '{query}' found in your library."
|
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
|
||||||
)
|
)
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||||
@@ -136,10 +224,10 @@ async def search_agents(
|
|||||||
message = (
|
message = (
|
||||||
"Now you have found some options for the user to choose from. "
|
"Now you have found some options for the user to choose from. "
|
||||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||||
"Please ask the user if they would like to use any of these agents."
|
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgentsFoundResponse(
|
return AgentsFoundResponse(
|
||||||
|
|||||||
@@ -8,13 +8,17 @@ from backend.api.features.chat.model import ChatSession
|
|||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
|
enrich_library_agents_from_steps,
|
||||||
generate_agent,
|
generate_agent,
|
||||||
|
get_all_relevant_agents_for_generation,
|
||||||
|
get_user_message_for_error,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
)
|
)
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
AgentSavedResponse,
|
AgentSavedResponse,
|
||||||
|
AsyncProcessingResponse,
|
||||||
ClarificationNeededResponse,
|
ClarificationNeededResponse,
|
||||||
ClarifyingQuestion,
|
ClarifyingQuestion,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -95,6 +99,10 @@ class CreateAgentTool(BaseTool):
|
|||||||
save = kwargs.get("save", True)
|
save = kwargs.get("save", True)
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
|
# Extract async processing params (passed by long-running tool handler)
|
||||||
|
operation_id = kwargs.get("_operation_id")
|
||||||
|
task_id = kwargs.get("_task_id")
|
||||||
|
|
||||||
if not description:
|
if not description:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide a description of what the agent should do.",
|
message="Please provide a description of what the agent should do.",
|
||||||
@@ -102,9 +110,24 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 1: Decompose goal into steps
|
library_agents = None
|
||||||
|
if user_id:
|
||||||
try:
|
try:
|
||||||
decomposition_result = await decompose_goal(description, context)
|
library_agents = await get_all_relevant_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=description,
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch library agents: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
decomposition_result = await decompose_goal(
|
||||||
|
description, context, library_agents
|
||||||
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -117,15 +140,31 @@ class CreateAgentTool(BaseTool):
|
|||||||
|
|
||||||
if decomposition_result is None:
|
if decomposition_result is None:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to analyze the goal. The agent generation service may be unavailable or timed out. Please try again.",
|
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
|
||||||
error="decomposition_failed",
|
error="decomposition_failed",
|
||||||
details={
|
details={"description": description[:100]},
|
||||||
"description": description[:100]
|
session_id=session_id,
|
||||||
}, # Include context for debugging
|
)
|
||||||
|
|
||||||
|
if decomposition_result.get("type") == "error":
|
||||||
|
error_msg = decomposition_result.get("error", "Unknown error")
|
||||||
|
error_type = decomposition_result.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="analyze the goal",
|
||||||
|
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"decomposition_failed:{error_type}",
|
||||||
|
details={
|
||||||
|
"description": description[:100],
|
||||||
|
"service_error": error_msg,
|
||||||
|
"error_type": error_type,
|
||||||
|
},
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if LLM returned clarifying questions
|
|
||||||
if decomposition_result.get("type") == "clarifying_questions":
|
if decomposition_result.get("type") == "clarifying_questions":
|
||||||
questions = decomposition_result.get("questions", [])
|
questions = decomposition_result.get("questions", [])
|
||||||
return ClarificationNeededResponse(
|
return ClarificationNeededResponse(
|
||||||
@@ -144,7 +183,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for unachievable/vague goals
|
|
||||||
if decomposition_result.get("type") == "unachievable_goal":
|
if decomposition_result.get("type") == "unachievable_goal":
|
||||||
suggested = decomposition_result.get("suggested_goal", "")
|
suggested = decomposition_result.get("suggested_goal", "")
|
||||||
reason = decomposition_result.get("reason", "")
|
reason = decomposition_result.get("reason", "")
|
||||||
@@ -171,9 +209,27 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
if user_id and library_agents is not None:
|
||||||
try:
|
try:
|
||||||
agent_json = await generate_agent(decomposition_result)
|
library_agents = await enrich_library_agents_from_steps(
|
||||||
|
user_id=user_id,
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=library_agents,
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent_json = await generate_agent(
|
||||||
|
decomposition_result,
|
||||||
|
library_agents,
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -186,11 +242,47 @@ class CreateAgentTool(BaseTool):
|
|||||||
|
|
||||||
if agent_json is None:
|
if agent_json is None:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to generate the agent. The agent generation service may be unavailable or timed out. Please try again.",
|
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
|
||||||
error="generation_failed",
|
error="generation_failed",
|
||||||
|
details={"description": description[:100]},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
|
||||||
|
error_msg = agent_json.get("error", "Unknown error")
|
||||||
|
error_type = agent_json.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="generate the agent",
|
||||||
|
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
|
||||||
|
validation_message=(
|
||||||
|
"I wasn't able to create a valid agent for this request. "
|
||||||
|
"The generated workflow had some structural issues. "
|
||||||
|
"Please try simplifying your goal or breaking it into smaller steps."
|
||||||
|
),
|
||||||
|
error_details=error_msg,
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"generation_failed:{error_type}",
|
||||||
details={
|
details={
|
||||||
"description": description[:100]
|
"description": description[:100],
|
||||||
}, # Include context for debugging
|
"service_error": error_msg,
|
||||||
|
"error_type": error_type,
|
||||||
|
},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if Agent Generator accepted for async processing
|
||||||
|
if agent_json.get("status") == "accepted":
|
||||||
|
logger.info(
|
||||||
|
f"Agent generation delegated to async processing "
|
||||||
|
f"(operation_id={operation_id}, task_id={task_id})"
|
||||||
|
)
|
||||||
|
return AsyncProcessingResponse(
|
||||||
|
message="Agent generation started. You'll be notified when it's complete.",
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -199,7 +291,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
node_count = len(agent_json.get("nodes", []))
|
node_count = len(agent_json.get("nodes", []))
|
||||||
link_count = len(agent_json.get("links", []))
|
link_count = len(agent_json.get("links", []))
|
||||||
|
|
||||||
# Step 3: Preview or save
|
|
||||||
if not save:
|
if not save:
|
||||||
return AgentPreviewResponse(
|
return AgentPreviewResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -214,7 +305,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save to library
|
|
||||||
if not user_id:
|
if not user_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="You must be logged in to save agents.",
|
message="You must be logged in to save agents.",
|
||||||
@@ -232,7 +322,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
agent_id=created_graph.id,
|
agent_id=created_graph.id,
|
||||||
agent_name=created_graph.name,
|
agent_name=created_graph.name,
|
||||||
library_agent_id=library_agent.id,
|
library_agent_id=library_agent.id,
|
||||||
library_agent_link=f"/library/{library_agent.id}",
|
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,337 @@
|
|||||||
|
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.store import db as store_db
|
||||||
|
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||||
|
|
||||||
|
from .agent_generator import (
|
||||||
|
AgentGeneratorNotConfiguredError,
|
||||||
|
customize_template,
|
||||||
|
get_user_message_for_error,
|
||||||
|
graph_to_json,
|
||||||
|
save_agent_to_library,
|
||||||
|
)
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import (
|
||||||
|
AgentPreviewResponse,
|
||||||
|
AgentSavedResponse,
|
||||||
|
ClarificationNeededResponse,
|
||||||
|
ClarifyingQuestion,
|
||||||
|
ErrorResponse,
|
||||||
|
ToolResponseBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomizeAgentTool(BaseTool):
|
||||||
|
"""Tool for customizing marketplace/template agents using natural language."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "customize_agent"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Customize a marketplace or template agent using natural language. "
|
||||||
|
"Takes an existing agent from the marketplace and modifies it based on "
|
||||||
|
"the user's requirements before adding to their library."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_long_running(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"agent_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The marketplace agent ID in format 'creator/slug' "
|
||||||
|
"(e.g., 'autogpt/newsletter-writer'). "
|
||||||
|
"Get this from find_agent results."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"modifications": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Natural language description of how to customize the agent. "
|
||||||
|
"Be specific about what changes you want to make."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"context": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Additional context or answers to previous clarifying questions."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"save": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"Whether to save the customized agent to the user's library. "
|
||||||
|
"Default is true. Set to false for preview only."
|
||||||
|
),
|
||||||
|
"default": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["agent_id", "modifications"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Execute the customize_agent tool.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. Parse the agent ID to get creator/slug
|
||||||
|
2. Fetch the template agent from the marketplace
|
||||||
|
3. Call customize_template with the modification request
|
||||||
|
4. Preview or save based on the save parameter
|
||||||
|
"""
|
||||||
|
agent_id = kwargs.get("agent_id", "").strip()
|
||||||
|
modifications = kwargs.get("modifications", "").strip()
|
||||||
|
context = kwargs.get("context", "")
|
||||||
|
save = kwargs.get("save", True)
|
||||||
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
|
if not agent_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
||||||
|
error="missing_agent_id",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not modifications:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please describe how you want to customize this agent.",
|
||||||
|
error="missing_modifications",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse agent_id in format "creator/slug"
|
||||||
|
parts = [p.strip() for p in agent_id.split("/")]
|
||||||
|
if len(parts) != 2 or not parts[0] or not parts[1]:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Invalid agent ID format: '{agent_id}'. "
|
||||||
|
"Expected format is 'creator/agent-name' "
|
||||||
|
"(e.g., 'autogpt/newsletter-writer')."
|
||||||
|
),
|
||||||
|
error="invalid_agent_id_format",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
creator_username, agent_slug = parts
|
||||||
|
|
||||||
|
# Fetch the marketplace agent details
|
||||||
|
try:
|
||||||
|
agent_details = await store_db.get_store_agent_details(
|
||||||
|
username=creator_username, agent_name=agent_slug
|
||||||
|
)
|
||||||
|
except AgentNotFoundError:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Could not find marketplace agent '{agent_id}'. "
|
||||||
|
"Please check the agent ID and try again."
|
||||||
|
),
|
||||||
|
error="agent_not_found",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to fetch the marketplace agent. Please try again.",
|
||||||
|
error="fetch_error",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not agent_details.store_listing_version_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"The agent '{agent_id}' does not have an available version. "
|
||||||
|
"Please try a different agent."
|
||||||
|
),
|
||||||
|
error="no_version_available",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the full agent graph
|
||||||
|
try:
|
||||||
|
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
||||||
|
template_agent = graph_to_json(graph)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to fetch the agent configuration. Please try again.",
|
||||||
|
error="graph_fetch_error",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call customize_template
|
||||||
|
try:
|
||||||
|
result = await customize_template(
|
||||||
|
template_agent=template_agent,
|
||||||
|
modification_request=modifications,
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
except AgentGeneratorNotConfiguredError:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
"Agent customization is not available. "
|
||||||
|
"The Agent Generator service is not configured."
|
||||||
|
),
|
||||||
|
error="service_not_configured",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
"Failed to customize the agent due to a service error. "
|
||||||
|
"Please try again."
|
||||||
|
),
|
||||||
|
error="customization_service_error",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
"Failed to customize the agent. "
|
||||||
|
"The agent generation service may be unavailable or timed out. "
|
||||||
|
"Please try again."
|
||||||
|
),
|
||||||
|
error="customization_failed",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle error response
|
||||||
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
|
error_msg = result.get("error", "Unknown error")
|
||||||
|
error_type = result.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="customize the agent",
|
||||||
|
llm_parse_message=(
|
||||||
|
"The AI had trouble customizing the agent. "
|
||||||
|
"Please try again or simplify your request."
|
||||||
|
),
|
||||||
|
validation_message=(
|
||||||
|
"The customized agent failed validation. "
|
||||||
|
"Please try rephrasing your request."
|
||||||
|
),
|
||||||
|
error_details=error_msg,
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"customization_failed:{error_type}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle clarifying questions
|
||||||
|
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
||||||
|
questions = result.get("questions") or []
|
||||||
|
if not isinstance(questions, list):
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected clarifying questions format: {type(questions)}"
|
||||||
|
)
|
||||||
|
questions = []
|
||||||
|
return ClarificationNeededResponse(
|
||||||
|
message=(
|
||||||
|
"I need some more information to customize this agent. "
|
||||||
|
"Please answer the following questions:"
|
||||||
|
),
|
||||||
|
questions=[
|
||||||
|
ClarifyingQuestion(
|
||||||
|
question=q.get("question", ""),
|
||||||
|
keyword=q.get("keyword", ""),
|
||||||
|
example=q.get("example"),
|
||||||
|
)
|
||||||
|
for q in questions
|
||||||
|
if isinstance(q, dict)
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Result should be the customized agent JSON
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to customize the agent due to an unexpected response.",
|
||||||
|
error="unexpected_response_type",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
customized_agent = result
|
||||||
|
|
||||||
|
agent_name = customized_agent.get(
|
||||||
|
"name", f"Customized {agent_details.agent_name}"
|
||||||
|
)
|
||||||
|
agent_description = customized_agent.get("description", "")
|
||||||
|
nodes = customized_agent.get("nodes")
|
||||||
|
links = customized_agent.get("links")
|
||||||
|
node_count = len(nodes) if isinstance(nodes, list) else 0
|
||||||
|
link_count = len(links) if isinstance(links, list) else 0
|
||||||
|
|
||||||
|
if not save:
|
||||||
|
return AgentPreviewResponse(
|
||||||
|
message=(
|
||||||
|
f"I've customized the agent '{agent_details.agent_name}'. "
|
||||||
|
f"The customized agent has {node_count} blocks. "
|
||||||
|
f"Review it and call customize_agent with save=true to save it."
|
||||||
|
),
|
||||||
|
agent_json=customized_agent,
|
||||||
|
agent_name=agent_name,
|
||||||
|
description=agent_description,
|
||||||
|
node_count=node_count,
|
||||||
|
link_count=link_count,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="You must be logged in to save agents.",
|
||||||
|
error="auth_required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save to user's library
|
||||||
|
try:
|
||||||
|
created_graph, library_agent = await save_agent_to_library(
|
||||||
|
customized_agent, user_id, is_update=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentSavedResponse(
|
||||||
|
message=(
|
||||||
|
f"Customized agent '{created_graph.name}' "
|
||||||
|
f"(based on '{agent_details.agent_name}') "
|
||||||
|
f"has been saved to your library!"
|
||||||
|
),
|
||||||
|
agent_id=created_graph.id,
|
||||||
|
agent_name=created_graph.name,
|
||||||
|
library_agent_id=library_agent.id,
|
||||||
|
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||||
|
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving customized agent: {e}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to save the customized agent. Please try again.",
|
||||||
|
error="save_failed",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
@@ -9,12 +9,15 @@ from .agent_generator import (
|
|||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
generate_agent_patch,
|
generate_agent_patch,
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
|
get_all_relevant_agents_for_generation,
|
||||||
|
get_user_message_for_error,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
)
|
)
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
AgentSavedResponse,
|
AgentSavedResponse,
|
||||||
|
AsyncProcessingResponse,
|
||||||
ClarificationNeededResponse,
|
ClarificationNeededResponse,
|
||||||
ClarifyingQuestion,
|
ClarifyingQuestion,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -102,6 +105,10 @@ class EditAgentTool(BaseTool):
|
|||||||
save = kwargs.get("save", True)
|
save = kwargs.get("save", True)
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
|
# Extract async processing params (passed by long-running tool handler)
|
||||||
|
operation_id = kwargs.get("_operation_id")
|
||||||
|
task_id = kwargs.get("_task_id")
|
||||||
|
|
||||||
if not agent_id:
|
if not agent_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide the agent ID to edit.",
|
message="Please provide the agent ID to edit.",
|
||||||
@@ -116,7 +123,6 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 1: Fetch current agent
|
|
||||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||||
|
|
||||||
if current_agent is None:
|
if current_agent is None:
|
||||||
@@ -126,14 +132,34 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build the update request with context
|
library_agents = None
|
||||||
|
if user_id:
|
||||||
|
try:
|
||||||
|
graph_id = current_agent.get("id")
|
||||||
|
library_agents = await get_all_relevant_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=changes,
|
||||||
|
exclude_graph_id=graph_id,
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch library agents: {e}")
|
||||||
|
|
||||||
update_request = changes
|
update_request = changes
|
||||||
if context:
|
if context:
|
||||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||||
|
|
||||||
# Step 2: Generate updated agent (external service handles fixing and validation)
|
|
||||||
try:
|
try:
|
||||||
result = await generate_agent_patch(update_request, current_agent)
|
result = await generate_agent_patch(
|
||||||
|
update_request,
|
||||||
|
current_agent,
|
||||||
|
library_agents,
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -152,7 +178,42 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if LLM returned clarifying questions
|
# Check if Agent Generator accepted for async processing
|
||||||
|
if result.get("status") == "accepted":
|
||||||
|
logger.info(
|
||||||
|
f"Agent edit delegated to async processing "
|
||||||
|
f"(operation_id={operation_id}, task_id={task_id})"
|
||||||
|
)
|
||||||
|
return AsyncProcessingResponse(
|
||||||
|
message="Agent edit started. You'll be notified when it's complete.",
|
||||||
|
operation_id=operation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the result is an error from the external service
|
||||||
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
|
error_msg = result.get("error", "Unknown error")
|
||||||
|
error_type = result.get("error_type", "unknown")
|
||||||
|
user_message = get_user_message_for_error(
|
||||||
|
error_type,
|
||||||
|
operation="generate the changes",
|
||||||
|
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
|
||||||
|
validation_message="The generated changes failed validation. Please try rephrasing your request.",
|
||||||
|
error_details=error_msg,
|
||||||
|
)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=user_message,
|
||||||
|
error=f"update_generation_failed:{error_type}",
|
||||||
|
details={
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"changes": changes[:100],
|
||||||
|
"service_error": error_msg,
|
||||||
|
"error_type": error_type,
|
||||||
|
},
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
if result.get("type") == "clarifying_questions":
|
if result.get("type") == "clarifying_questions":
|
||||||
questions = result.get("questions", [])
|
questions = result.get("questions", [])
|
||||||
return ClarificationNeededResponse(
|
return ClarificationNeededResponse(
|
||||||
@@ -171,7 +232,6 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Result is the updated agent JSON
|
|
||||||
updated_agent = result
|
updated_agent = result
|
||||||
|
|
||||||
agent_name = updated_agent.get("name", "Updated Agent")
|
agent_name = updated_agent.get("name", "Updated Agent")
|
||||||
@@ -179,7 +239,6 @@ class EditAgentTool(BaseTool):
|
|||||||
node_count = len(updated_agent.get("nodes", []))
|
node_count = len(updated_agent.get("nodes", []))
|
||||||
link_count = len(updated_agent.get("links", []))
|
link_count = len(updated_agent.get("links", []))
|
||||||
|
|
||||||
# Step 3: Preview or save
|
|
||||||
if not save:
|
if not save:
|
||||||
return AgentPreviewResponse(
|
return AgentPreviewResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -195,7 +254,6 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save to library (creates a new version)
|
|
||||||
if not user_id:
|
if not user_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="You must be logged in to save agents.",
|
message="You must be logged in to save agents.",
|
||||||
@@ -213,7 +271,7 @@ class EditAgentTool(BaseTool):
|
|||||||
agent_id=created_graph.id,
|
agent_id=created_graph.id,
|
||||||
agent_name=created_graph.name,
|
agent_name=created_graph.name,
|
||||||
library_agent_id=library_agent.id,
|
library_agent_id=library_agent.id,
|
||||||
library_agent_link=f"/library/{library_agent.id}",
|
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -28,10 +28,18 @@ class ResponseType(str, Enum):
|
|||||||
BLOCK_OUTPUT = "block_output"
|
BLOCK_OUTPUT = "block_output"
|
||||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||||
DOC_PAGE = "doc_page"
|
DOC_PAGE = "doc_page"
|
||||||
|
# Workspace response types
|
||||||
|
WORKSPACE_FILE_LIST = "workspace_file_list"
|
||||||
|
WORKSPACE_FILE_CONTENT = "workspace_file_content"
|
||||||
|
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
|
||||||
|
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
|
||||||
|
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
|
||||||
# Long-running operation types
|
# Long-running operation types
|
||||||
OPERATION_STARTED = "operation_started"
|
OPERATION_STARTED = "operation_started"
|
||||||
OPERATION_PENDING = "operation_pending"
|
OPERATION_PENDING = "operation_pending"
|
||||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||||
|
# Input validation
|
||||||
|
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||||
|
|
||||||
|
|
||||||
# Base response model
|
# Base response model
|
||||||
@@ -62,6 +70,10 @@ class AgentInfo(BaseModel):
|
|||||||
has_external_trigger: bool | None = None
|
has_external_trigger: bool | None = None
|
||||||
new_output: bool | None = None
|
new_output: bool | None = None
|
||||||
graph_id: str | None = None
|
graph_id: str | None = None
|
||||||
|
inputs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Input schema for the agent, including field names, types, and defaults",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgentsFoundResponse(ToolResponseBase):
|
class AgentsFoundResponse(ToolResponseBase):
|
||||||
@@ -188,6 +200,20 @@ class ErrorResponse(ToolResponseBase):
|
|||||||
details: dict[str, Any] | None = None
|
details: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class InputValidationErrorResponse(ToolResponseBase):
|
||||||
|
"""Response when run_agent receives unknown input fields."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
|
||||||
|
unrecognized_fields: list[str] = Field(
|
||||||
|
description="List of input field names that were not recognized"
|
||||||
|
)
|
||||||
|
inputs: dict[str, Any] = Field(
|
||||||
|
description="The agent's valid input schema for reference"
|
||||||
|
)
|
||||||
|
graph_id: str | None = None
|
||||||
|
graph_version: int | None = None
|
||||||
|
|
||||||
|
|
||||||
# Agent output models
|
# Agent output models
|
||||||
class ExecutionOutputInfo(BaseModel):
|
class ExecutionOutputInfo(BaseModel):
|
||||||
"""Summary of a single execution's outputs."""
|
"""Summary of a single execution's outputs."""
|
||||||
@@ -346,11 +372,15 @@ class OperationStartedResponse(ToolResponseBase):
|
|||||||
|
|
||||||
This is returned immediately to the client while the operation continues
|
This is returned immediately to the client while the operation continues
|
||||||
to execute. The user can close the tab and check back later.
|
to execute. The user can close the tab and check back later.
|
||||||
|
|
||||||
|
The task_id can be used to reconnect to the SSE stream via
|
||||||
|
GET /chat/tasks/{task_id}/stream?last_idx=0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||||
operation_id: str
|
operation_id: str
|
||||||
tool_name: str
|
tool_name: str
|
||||||
|
task_id: str | None = None # For SSE reconnection
|
||||||
|
|
||||||
|
|
||||||
class OperationPendingResponse(ToolResponseBase):
|
class OperationPendingResponse(ToolResponseBase):
|
||||||
@@ -374,3 +404,20 @@ class OperationInProgressResponse(ToolResponseBase):
|
|||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||||
tool_call_id: str
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncProcessingResponse(ToolResponseBase):
|
||||||
|
"""Response when an operation has been delegated to async processing.
|
||||||
|
|
||||||
|
This is returned by tools when the external service accepts the request
|
||||||
|
for async processing (HTTP 202 Accepted). The Redis Streams completion
|
||||||
|
consumer will handle the result when the external service completes.
|
||||||
|
|
||||||
|
The status field is specifically "accepted" to allow the long-running tool
|
||||||
|
handler to detect this response and skip LLM continuation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||||
|
status: str = "accepted" # Must be "accepted" for detection
|
||||||
|
operation_id: str | None = None
|
||||||
|
task_id: str | None = None
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from .models import (
|
|||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
ExecutionOptions,
|
ExecutionOptions,
|
||||||
ExecutionStartedResponse,
|
ExecutionStartedResponse,
|
||||||
|
InputValidationErrorResponse,
|
||||||
SetupInfo,
|
SetupInfo,
|
||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
@@ -273,6 +274,22 @@ class RunAgentTool(BaseTool):
|
|||||||
input_properties = graph.input_schema.get("properties", {})
|
input_properties = graph.input_schema.get("properties", {})
|
||||||
required_fields = set(graph.input_schema.get("required", []))
|
required_fields = set(graph.input_schema.get("required", []))
|
||||||
provided_inputs = set(params.inputs.keys())
|
provided_inputs = set(params.inputs.keys())
|
||||||
|
valid_fields = set(input_properties.keys())
|
||||||
|
|
||||||
|
# Check for unknown input fields
|
||||||
|
unrecognized_fields = provided_inputs - valid_fields
|
||||||
|
if unrecognized_fields:
|
||||||
|
return InputValidationErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
||||||
|
f"Agent was not executed. Please use the correct field names from the schema."
|
||||||
|
),
|
||||||
|
session_id=session_id,
|
||||||
|
unrecognized_fields=sorted(unrecognized_fields),
|
||||||
|
inputs=graph.input_schema,
|
||||||
|
graph_id=graph.id,
|
||||||
|
graph_version=graph.version,
|
||||||
|
)
|
||||||
|
|
||||||
# If agent has inputs but none were provided AND use_defaults is not set,
|
# If agent has inputs but none were provided AND use_defaults is not set,
|
||||||
# always show what's available first so user can decide
|
# always show what's available first so user can decide
|
||||||
|
|||||||
@@ -402,3 +402,42 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
|||||||
# Should return error about missing schedule_name
|
# Should return error about missing schedule_name
|
||||||
assert result_data.get("type") == "error"
|
assert result_data.get("type") == "error"
|
||||||
assert "schedule_name" in result_data["message"].lower()
|
assert "schedule_name" in result_data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
|
||||||
|
"""Test that run_agent returns input_validation_error for unknown input fields."""
|
||||||
|
user = setup_test_data["user"]
|
||||||
|
store_submission = setup_test_data["store_submission"]
|
||||||
|
|
||||||
|
tool = RunAgentTool()
|
||||||
|
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||||
|
session = make_session(user_id=user.id)
|
||||||
|
|
||||||
|
# Execute with unknown input field names
|
||||||
|
response = await tool.execute(
|
||||||
|
user_id=user.id,
|
||||||
|
session_id=str(uuid.uuid4()),
|
||||||
|
tool_call_id=str(uuid.uuid4()),
|
||||||
|
username_agent_slug=agent_marketplace_id,
|
||||||
|
inputs={
|
||||||
|
"unknown_field": "some value",
|
||||||
|
"another_unknown": "another value",
|
||||||
|
},
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert hasattr(response, "output")
|
||||||
|
assert isinstance(response.output, str)
|
||||||
|
result_data = orjson.loads(response.output)
|
||||||
|
|
||||||
|
# Should return input_validation_error type with unrecognized fields
|
||||||
|
assert result_data.get("type") == "input_validation_error"
|
||||||
|
assert "unrecognized_fields" in result_data
|
||||||
|
assert set(result_data["unrecognized_fields"]) == {
|
||||||
|
"another_unknown",
|
||||||
|
"unknown_field",
|
||||||
|
}
|
||||||
|
assert "inputs" in result_data # Contains the valid schema
|
||||||
|
assert "Agent was not executed" in result_data["message"]
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
"""Tool for executing blocks directly."""
|
"""Tool for executing blocks directly."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.block import get_block
|
from backend.data.block import get_block
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
@@ -73,15 +77,22 @@ class RunBlockTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
block: Any,
|
block: Any,
|
||||||
|
input_data: dict[str, Any] | None = None,
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
"""
|
"""
|
||||||
Check if user has required credentials for a block.
|
Check if user has required credentials for a block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
block: Block to check credentials for
|
||||||
|
input_data: Input data for the block (used to determine provider via discriminator)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[matched_credentials, missing_credentials]
|
tuple[matched_credentials, missing_credentials]
|
||||||
"""
|
"""
|
||||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||||
missing_credentials: list[CredentialsMetaInput] = []
|
missing_credentials: list[CredentialsMetaInput] = []
|
||||||
|
input_data = input_data or {}
|
||||||
|
|
||||||
# Get credential field info from block's input schema
|
# Get credential field info from block's input schema
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
@@ -94,14 +105,33 @@ class RunBlockTool(BaseTool):
|
|||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
for field_name, field_info in credentials_fields_info.items():
|
||||||
# field_info.provider is a frozenset of acceptable providers
|
effective_field_info = field_info
|
||||||
# field_info.supported_types is a frozenset of acceptable types
|
if field_info.discriminator and field_info.discriminator_mapping:
|
||||||
|
# Get discriminator from input, falling back to schema default
|
||||||
|
discriminator_value = input_data.get(field_info.discriminator)
|
||||||
|
if discriminator_value is None:
|
||||||
|
field = block.input_schema.model_fields.get(
|
||||||
|
field_info.discriminator
|
||||||
|
)
|
||||||
|
if field and field.default is not PydanticUndefined:
|
||||||
|
discriminator_value = field.default
|
||||||
|
|
||||||
|
if (
|
||||||
|
discriminator_value
|
||||||
|
and discriminator_value in field_info.discriminator_mapping
|
||||||
|
):
|
||||||
|
effective_field_info = field_info.discriminate(discriminator_value)
|
||||||
|
logger.debug(
|
||||||
|
f"Discriminated provider for {field_name}: "
|
||||||
|
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
(
|
(
|
||||||
cred
|
cred
|
||||||
for cred in available_creds
|
for cred in available_creds
|
||||||
if cred.provider in field_info.provider
|
if cred.provider in effective_field_info.provider
|
||||||
and cred.type in field_info.supported_types
|
and cred.type in effective_field_info.supported_types
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -115,8 +145,8 @@ class RunBlockTool(BaseTool):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Create a placeholder for the missing credential
|
# Create a placeholder for the missing credential
|
||||||
provider = next(iter(field_info.provider), "unknown")
|
provider = next(iter(effective_field_info.provider), "unknown")
|
||||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
||||||
missing_credentials.append(
|
missing_credentials.append(
|
||||||
CredentialsMetaInput(
|
CredentialsMetaInput(
|
||||||
id=field_name,
|
id=field_name,
|
||||||
@@ -184,10 +214,9 @@ class RunBlockTool(BaseTool):
|
|||||||
|
|
||||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||||
|
|
||||||
# Check credentials
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
creds_manager = IntegrationCredentialsManager()
|
||||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||||
user_id, block
|
user_id, block, input_data
|
||||||
)
|
)
|
||||||
|
|
||||||
if missing_credentials:
|
if missing_credentials:
|
||||||
@@ -223,11 +252,48 @@ class RunBlockTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Fetch actual credentials and prepare kwargs for block execution
|
# Get or create user's workspace for CoPilot file operations
|
||||||
# Create execution context with defaults (blocks may require it)
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
|
||||||
|
# Generate synthetic IDs for CoPilot context
|
||||||
|
# Each chat session is treated as its own agent with one continuous run
|
||||||
|
# This means:
|
||||||
|
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
|
||||||
|
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
|
||||||
|
# - node_exec_id = unique per block execution
|
||||||
|
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
||||||
|
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
||||||
|
synthetic_node_id = f"copilot-node-{block_id}"
|
||||||
|
synthetic_node_exec_id = (
|
||||||
|
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create unified execution context with all required fields
|
||||||
|
execution_context = ExecutionContext(
|
||||||
|
# Execution identity
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=synthetic_graph_id,
|
||||||
|
graph_exec_id=synthetic_graph_exec_id,
|
||||||
|
graph_version=1, # Versions are 1-indexed
|
||||||
|
node_id=synthetic_node_id,
|
||||||
|
node_exec_id=synthetic_node_exec_id,
|
||||||
|
# Workspace with session scoping
|
||||||
|
workspace_id=workspace.id,
|
||||||
|
session_id=session.session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare kwargs for block execution
|
||||||
|
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||||
exec_kwargs: dict[str, Any] = {
|
exec_kwargs: dict[str, Any] = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"execution_context": ExecutionContext(),
|
"execution_context": execution_context,
|
||||||
|
# Legacy: individual kwargs for blocks not yet using execution_context
|
||||||
|
"workspace_id": workspace.id,
|
||||||
|
"graph_exec_id": synthetic_graph_exec_id,
|
||||||
|
"node_exec_id": synthetic_node_exec_id,
|
||||||
|
"node_id": synthetic_node_id,
|
||||||
|
"graph_version": 1, # Versions are 1-indexed
|
||||||
|
"graph_id": synthetic_graph_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
for field_name, cred_meta in matched_credentials.items():
|
for field_name, cred_meta in matched_credentials.items():
|
||||||
|
|||||||
@@ -8,7 +8,12 @@ from backend.api.features.library import model as library_model
|
|||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
from backend.data.model import (
|
||||||
|
CredentialsFieldInfo,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
HostScopedCredentials,
|
||||||
|
OAuth2Credentials,
|
||||||
|
)
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
@@ -266,13 +271,21 @@ async def match_user_credentials_to_graph(
|
|||||||
credential_requirements,
|
credential_requirements,
|
||||||
_node_fields,
|
_node_fields,
|
||||||
) in aggregated_creds.items():
|
) in aggregated_creds.items():
|
||||||
# Find first matching credential by provider and type
|
# Find first matching credential by provider, type, and scopes
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
(
|
(
|
||||||
cred
|
cred
|
||||||
for cred in available_creds
|
for cred in available_creds
|
||||||
if cred.provider in credential_requirements.provider
|
if cred.provider in credential_requirements.provider
|
||||||
and cred.type in credential_requirements.supported_types
|
and cred.type in credential_requirements.supported_types
|
||||||
|
and (
|
||||||
|
cred.type != "oauth2"
|
||||||
|
or _credential_has_required_scopes(cred, credential_requirements)
|
||||||
|
)
|
||||||
|
and (
|
||||||
|
cred.type != "host_scoped"
|
||||||
|
or _credential_is_for_host(cred, credential_requirements)
|
||||||
|
)
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -296,10 +309,17 @@ async def match_user_credentials_to_graph(
|
|||||||
f"{credential_field_name} (validation failed: {e})"
|
f"{credential_field_name} (validation failed: {e})"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Build a helpful error message including scope requirements
|
||||||
|
error_parts = [
|
||||||
|
f"provider in {list(credential_requirements.provider)}",
|
||||||
|
f"type in {list(credential_requirements.supported_types)}",
|
||||||
|
]
|
||||||
|
if credential_requirements.required_scopes:
|
||||||
|
error_parts.append(
|
||||||
|
f"scopes including {list(credential_requirements.required_scopes)}"
|
||||||
|
)
|
||||||
missing_creds.append(
|
missing_creds.append(
|
||||||
f"{credential_field_name} "
|
f"{credential_field_name} (requires {', '.join(error_parts)})"
|
||||||
f"(requires provider in {list(credential_requirements.provider)}, "
|
|
||||||
f"type in {list(credential_requirements.supported_types)})"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -309,6 +329,35 @@ async def match_user_credentials_to_graph(
|
|||||||
return graph_credentials_inputs, missing_creds
|
return graph_credentials_inputs, missing_creds
|
||||||
|
|
||||||
|
|
||||||
|
def _credential_has_required_scopes(
|
||||||
|
credential: OAuth2Credentials,
|
||||||
|
requirements: CredentialsFieldInfo,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if an OAuth2 credential has all the scopes required by the input."""
|
||||||
|
# If no scopes are required, any credential matches
|
||||||
|
if not requirements.required_scopes:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check that credential scopes are a superset of required scopes
|
||||||
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
|
def _credential_is_for_host(
|
||||||
|
credential: HostScopedCredentials,
|
||||||
|
requirements: CredentialsFieldInfo,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if a host-scoped credential matches the host required by the input."""
|
||||||
|
# We need to know the host to match host-scoped credentials to.
|
||||||
|
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
|
||||||
|
# to discriminator_values. No discriminator_values -> no host to match against.
|
||||||
|
if not requirements.discriminator_values:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check that credential host matches required host.
|
||||||
|
# Host-scoped credential inputs are grouped by host, so any item from the set works.
|
||||||
|
return credential.matches_url(list(requirements.discriminator_values)[0])
|
||||||
|
|
||||||
|
|
||||||
async def check_user_has_required_credentials(
|
async def check_user_has_required_credentials(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
required_credentials: list[CredentialsMetaInput],
|
required_credentials: list[CredentialsMetaInput],
|
||||||
|
|||||||
@@ -0,0 +1,620 @@
|
|||||||
|
"""CoPilot tools for workspace file operations."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.data.workspace import get_or_create_workspace
|
||||||
|
from backend.util.settings import Config
|
||||||
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
from backend.util.workspace import WorkspaceManager
|
||||||
|
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileInfoData(BaseModel):
|
||||||
|
"""Data model for workspace file information (not a response itself)."""
|
||||||
|
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
size_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileListResponse(ToolResponseBase):
|
||||||
|
"""Response containing list of workspace files."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
|
||||||
|
files: list[WorkspaceFileInfoData]
|
||||||
|
total_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileContentResponse(ToolResponseBase):
|
||||||
|
"""Response containing workspace file content (legacy, for small text files)."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
content_base64: str
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
||||||
|
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
mime_type: str
|
||||||
|
size_bytes: int
|
||||||
|
download_url: str
|
||||||
|
preview: str | None = None # First 500 chars for text files
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceWriteResponse(ToolResponseBase):
|
||||||
|
"""Response after writing a file to workspace."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
|
||||||
|
file_id: str
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
size_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceDeleteResponse(ToolResponseBase):
|
||||||
|
"""Response after deleting a file from workspace."""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
|
||||||
|
file_id: str
|
||||||
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
|
class ListWorkspaceFilesTool(BaseTool):
|
||||||
|
"""Tool for listing files in user's workspace."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "list_workspace_files"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"List files in the user's workspace. "
|
||||||
|
"Returns file names, paths, sizes, and metadata. "
|
||||||
|
"Optionally filter by path prefix."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path_prefix": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional path prefix to filter files "
|
||||||
|
"(e.g., '/documents/' to list only files in documents folder). "
|
||||||
|
"By default, only files from the current session are listed."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of files to return (default 50, max 100)",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 100,
|
||||||
|
},
|
||||||
|
"include_all_sessions": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"If true, list files from all sessions. "
|
||||||
|
"Default is false (only current session's files)."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||||
|
limit = min(kwargs.get("limit", 50), 100)
|
||||||
|
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
files = await manager.list_files(
|
||||||
|
path=path_prefix,
|
||||||
|
limit=limit,
|
||||||
|
include_all_sessions=include_all_sessions,
|
||||||
|
)
|
||||||
|
total = await manager.get_file_count(
|
||||||
|
path=path_prefix,
|
||||||
|
include_all_sessions=include_all_sessions,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_infos = [
|
||||||
|
WorkspaceFileInfoData(
|
||||||
|
file_id=f.id,
|
||||||
|
name=f.name,
|
||||||
|
path=f.path,
|
||||||
|
mime_type=f.mimeType,
|
||||||
|
size_bytes=f.sizeBytes,
|
||||||
|
)
|
||||||
|
for f in files
|
||||||
|
]
|
||||||
|
|
||||||
|
scope_msg = "all sessions" if include_all_sessions else "current session"
|
||||||
|
return WorkspaceFileListResponse(
|
||||||
|
files=file_infos,
|
||||||
|
total_count=total,
|
||||||
|
message=f"Found {len(files)} files in workspace ({scope_msg})",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing workspace files: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to list workspace files: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadWorkspaceFileTool(BaseTool):
|
||||||
|
"""Tool for reading file content from workspace."""
|
||||||
|
|
||||||
|
# Size threshold for returning full content vs metadata+URL
|
||||||
|
# Files larger than this return metadata with download URL to prevent context bloat
|
||||||
|
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||||
|
# Preview size for text files
|
||||||
|
PREVIEW_SIZE = 500
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "read_workspace_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Read a file from the user's workspace. "
|
||||||
|
"Specify either file_id or path to identify the file. "
|
||||||
|
"For small text files, returns content directly. "
|
||||||
|
"For large or binary files, returns metadata and a download URL. "
|
||||||
|
"Paths are scoped to the current session by default. "
|
||||||
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file's unique ID (from list_workspace_files)",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||||
|
"Scoped to current session by default."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"force_download_url": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": (
|
||||||
|
"If true, always return metadata+URL instead of inline content. "
|
||||||
|
"Default is false (auto-selects based on file size/type)."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [], # At least one must be provided
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _is_text_mime_type(self, mime_type: str) -> bool:
|
||||||
|
"""Check if the MIME type is a text-based type."""
|
||||||
|
text_types = [
|
||||||
|
"text/",
|
||||||
|
"application/json",
|
||||||
|
"application/xml",
|
||||||
|
"application/javascript",
|
||||||
|
"application/x-python",
|
||||||
|
"application/x-sh",
|
||||||
|
]
|
||||||
|
return any(mime_type.startswith(t) for t in text_types)
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id: Optional[str] = kwargs.get("file_id")
|
||||||
|
path: Optional[str] = kwargs.get("path")
|
||||||
|
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||||
|
|
||||||
|
if not file_id and not path:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide either file_id or path",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
# Get file info
|
||||||
|
if file_id:
|
||||||
|
file_info = await manager.get_file_info(file_id)
|
||||||
|
if file_info is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found: {file_id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
target_file_id = file_id
|
||||||
|
else:
|
||||||
|
# path is guaranteed to be non-None here due to the check above
|
||||||
|
assert path is not None
|
||||||
|
file_info = await manager.get_file_info_by_path(path)
|
||||||
|
if file_info is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found at path: {path}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
target_file_id = file_info.id
|
||||||
|
|
||||||
|
# Decide whether to return inline content or metadata+URL
|
||||||
|
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
|
||||||
|
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
||||||
|
|
||||||
|
# Return inline content for small text files (unless force_download_url)
|
||||||
|
if is_small_file and is_text_file and not force_download_url:
|
||||||
|
content = await manager.read_file_by_id(target_file_id)
|
||||||
|
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||||
|
|
||||||
|
return WorkspaceFileContentResponse(
|
||||||
|
file_id=file_info.id,
|
||||||
|
name=file_info.name,
|
||||||
|
path=file_info.path,
|
||||||
|
mime_type=file_info.mimeType,
|
||||||
|
content_base64=content_b64,
|
||||||
|
message=f"Successfully read file: {file_info.name}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return metadata + workspace:// reference for large or binary files
|
||||||
|
# This prevents context bloat (100KB file = ~133KB as base64)
|
||||||
|
# Use workspace:// format so frontend urlTransform can add proxy prefix
|
||||||
|
download_url = f"workspace://{target_file_id}"
|
||||||
|
|
||||||
|
# Generate preview for text files
|
||||||
|
preview: str | None = None
|
||||||
|
if is_text_file:
|
||||||
|
try:
|
||||||
|
content = await manager.read_file_by_id(target_file_id)
|
||||||
|
preview_text = content[: self.PREVIEW_SIZE].decode(
|
||||||
|
"utf-8", errors="replace"
|
||||||
|
)
|
||||||
|
if len(content) > self.PREVIEW_SIZE:
|
||||||
|
preview_text += "..."
|
||||||
|
preview = preview_text
|
||||||
|
except Exception:
|
||||||
|
pass # Preview is optional
|
||||||
|
|
||||||
|
return WorkspaceFileMetadataResponse(
|
||||||
|
file_id=file_info.id,
|
||||||
|
name=file_info.name,
|
||||||
|
path=file_info.path,
|
||||||
|
mime_type=file_info.mimeType,
|
||||||
|
size_bytes=file_info.sizeBytes,
|
||||||
|
download_url=download_url,
|
||||||
|
preview=preview,
|
||||||
|
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading workspace file: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to read workspace file: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteWorkspaceFileTool(BaseTool):
|
||||||
|
"""Tool for writing files to workspace."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "write_workspace_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Write or create a file in the user's workspace. "
|
||||||
|
"Provide the content as a base64-encoded string. "
|
||||||
|
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||||
|
"Files are saved to the current session's folder by default. "
|
||||||
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"filename": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Name for the file (e.g., 'report.pdf')",
|
||||||
|
},
|
||||||
|
"content_base64": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Base64-encoded file content",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional virtual path where to save the file "
|
||||||
|
"(e.g., '/documents/report.pdf'). "
|
||||||
|
"Defaults to '/{filename}'. Scoped to current session."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"mime_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional MIME type of the file. "
|
||||||
|
"Auto-detected from filename if not provided."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"overwrite": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["filename", "content_base64"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
filename: str = kwargs.get("filename", "")
|
||||||
|
content_b64: str = kwargs.get("content_base64", "")
|
||||||
|
path: Optional[str] = kwargs.get("path")
|
||||||
|
mime_type: Optional[str] = kwargs.get("mime_type")
|
||||||
|
overwrite: bool = kwargs.get("overwrite", False)
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide a filename",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not content_b64:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide content_base64",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode content
|
||||||
|
try:
|
||||||
|
content = base64.b64decode(content_b64)
|
||||||
|
except Exception:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Invalid base64-encoded content",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check size
|
||||||
|
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||||
|
if len(content) > max_file_size:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Virus scan
|
||||||
|
await scan_content_safe(content, filename=filename)
|
||||||
|
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
file_record = await manager.write_file(
|
||||||
|
content=content,
|
||||||
|
filename=filename,
|
||||||
|
path=path,
|
||||||
|
mime_type=mime_type,
|
||||||
|
overwrite=overwrite,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkspaceWriteResponse(
|
||||||
|
file_id=file_record.id,
|
||||||
|
name=file_record.name,
|
||||||
|
path=file_record.path,
|
||||||
|
size_bytes=file_record.sizeBytes,
|
||||||
|
message=f"Successfully wrote file: {file_record.name}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error writing workspace file: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to write workspace file: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteWorkspaceFileTool(BaseTool):
|
||||||
|
"""Tool for deleting files from workspace."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "delete_workspace_file"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Delete a file from the user's workspace. "
|
||||||
|
"Specify either file_id or path to identify the file. "
|
||||||
|
"Paths are scoped to the current session by default. "
|
||||||
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The file's unique ID (from list_workspace_files)",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||||
|
"Scoped to current session by default."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [], # At least one must be provided
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_auth(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _execute(
|
||||||
|
self,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
**kwargs,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
session_id = session.session_id
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Authentication required",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id: Optional[str] = kwargs.get("file_id")
|
||||||
|
path: Optional[str] = kwargs.get("path")
|
||||||
|
|
||||||
|
if not file_id and not path:
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Please provide either file_id or path",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace = await get_or_create_workspace(user_id)
|
||||||
|
# Pass session_id for session-scoped file access
|
||||||
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
|
# Determine the file_id to delete
|
||||||
|
target_file_id: str
|
||||||
|
if file_id:
|
||||||
|
target_file_id = file_id
|
||||||
|
else:
|
||||||
|
# path is guaranteed to be non-None here due to the check above
|
||||||
|
assert path is not None
|
||||||
|
file_info = await manager.get_file_info_by_path(path)
|
||||||
|
if file_info is None:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found at path: {path}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
target_file_id = file_info.id
|
||||||
|
|
||||||
|
success = await manager.delete_file(target_file_id)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"File not found: {target_file_id}",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkspaceDeleteResponse(
|
||||||
|
file_id=target_file_id,
|
||||||
|
success=True,
|
||||||
|
message="File deleted successfully",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
|
||||||
|
return ErrorResponse(
|
||||||
|
message=f"Failed to delete workspace file: {str(e)}",
|
||||||
|
error=str(e),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
@@ -19,9 +19,12 @@ from backend.data.graph import GraphSettings
|
|||||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||||
|
on_graph_activate,
|
||||||
|
on_graph_deactivate,
|
||||||
|
)
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
@@ -39,6 +42,7 @@ async def list_library_agents(
|
|||||||
sort_by: library_model.LibraryAgentSort = library_model.LibraryAgentSort.UPDATED_AT,
|
sort_by: library_model.LibraryAgentSort = library_model.LibraryAgentSort.UPDATED_AT,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 50,
|
page_size: int = 50,
|
||||||
|
include_executions: bool = False,
|
||||||
) -> library_model.LibraryAgentResponse:
|
) -> library_model.LibraryAgentResponse:
|
||||||
"""
|
"""
|
||||||
Retrieves a paginated list of LibraryAgent records for a given user.
|
Retrieves a paginated list of LibraryAgent records for a given user.
|
||||||
@@ -49,6 +53,9 @@ async def list_library_agents(
|
|||||||
sort_by: Sorting field (createdAt, updatedAt, isFavorite, isCreatedByUser).
|
sort_by: Sorting field (createdAt, updatedAt, isFavorite, isCreatedByUser).
|
||||||
page: Current page (1-indexed).
|
page: Current page (1-indexed).
|
||||||
page_size: Number of items per page.
|
page_size: Number of items per page.
|
||||||
|
include_executions: Whether to include execution data for status calculation.
|
||||||
|
Defaults to False for performance (UI fetches status separately).
|
||||||
|
Set to True when accurate status/metrics are needed (e.g., agent generator).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A LibraryAgentResponse containing the list of agents and pagination details.
|
A LibraryAgentResponse containing the list of agents and pagination details.
|
||||||
@@ -64,11 +71,11 @@ async def list_library_agents(
|
|||||||
|
|
||||||
if page < 1 or page_size < 1:
|
if page < 1 or page_size < 1:
|
||||||
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
||||||
raise DatabaseError("Invalid pagination input")
|
raise InvalidInputError("Invalid pagination input")
|
||||||
|
|
||||||
if search_term and len(search_term.strip()) > 100:
|
if search_term and len(search_term.strip()) > 100:
|
||||||
logger.warning(f"Search term too long: {repr(search_term)}")
|
logger.warning(f"Search term too long: {repr(search_term)}")
|
||||||
raise DatabaseError("Search term is too long")
|
raise InvalidInputError("Search term is too long")
|
||||||
|
|
||||||
where_clause: prisma.types.LibraryAgentWhereInput = {
|
where_clause: prisma.types.LibraryAgentWhereInput = {
|
||||||
"userId": user_id,
|
"userId": user_id,
|
||||||
@@ -76,7 +83,6 @@ async def list_library_agents(
|
|||||||
"isArchived": False,
|
"isArchived": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Build search filter if applicable
|
|
||||||
if search_term:
|
if search_term:
|
||||||
where_clause["OR"] = [
|
where_clause["OR"] = [
|
||||||
{
|
{
|
||||||
@@ -93,7 +99,6 @@ async def list_library_agents(
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
# Determine sorting
|
|
||||||
order_by: prisma.types.LibraryAgentOrderByInput | None = None
|
order_by: prisma.types.LibraryAgentOrderByInput | None = None
|
||||||
|
|
||||||
if sort_by == library_model.LibraryAgentSort.CREATED_AT:
|
if sort_by == library_model.LibraryAgentSort.CREATED_AT:
|
||||||
@@ -105,7 +110,7 @@ async def list_library_agents(
|
|||||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||||
where=where_clause,
|
where=where_clause,
|
||||||
include=library_agent_include(
|
include=library_agent_include(
|
||||||
user_id, include_nodes=False, include_executions=False
|
user_id, include_nodes=False, include_executions=include_executions
|
||||||
),
|
),
|
||||||
order=order_by,
|
order=order_by,
|
||||||
skip=(page - 1) * page_size,
|
skip=(page - 1) * page_size,
|
||||||
@@ -175,7 +180,7 @@ async def list_favorite_library_agents(
|
|||||||
|
|
||||||
if page < 1 or page_size < 1:
|
if page < 1 or page_size < 1:
|
||||||
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
||||||
raise DatabaseError("Invalid pagination input")
|
raise InvalidInputError("Invalid pagination input")
|
||||||
|
|
||||||
where_clause: prisma.types.LibraryAgentWhereInput = {
|
where_clause: prisma.types.LibraryAgentWhereInput = {
|
||||||
"userId": user_id,
|
"userId": user_id,
|
||||||
@@ -535,6 +540,92 @@ async def update_agent_version_in_library(
|
|||||||
return library_model.LibraryAgent.from_db(lib)
|
return library_model.LibraryAgent.from_db(lib)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_graph_in_library(
|
||||||
|
graph: graph_db.Graph,
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||||
|
"""Create a new graph and add it to the user's library."""
|
||||||
|
graph.version = 1
|
||||||
|
graph_model = graph_db.make_graph_model(graph, user_id)
|
||||||
|
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||||
|
|
||||||
|
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||||
|
|
||||||
|
library_agents = await create_library_agent(
|
||||||
|
graph=created_graph,
|
||||||
|
user_id=user_id,
|
||||||
|
sensitive_action_safe_mode=True,
|
||||||
|
create_library_agents_for_sub_graphs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if created_graph.is_active:
|
||||||
|
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
||||||
|
|
||||||
|
return created_graph, library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def update_graph_in_library(
|
||||||
|
graph: graph_db.Graph,
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||||
|
"""Create a new version of an existing graph and update the library entry."""
|
||||||
|
existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id)
|
||||||
|
current_active_version = (
|
||||||
|
next((v for v in existing_versions if v.is_active), None)
|
||||||
|
if existing_versions
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
graph.version = (
|
||||||
|
max(v.version for v in existing_versions) + 1 if existing_versions else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_model = graph_db.make_graph_model(graph, user_id)
|
||||||
|
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||||
|
|
||||||
|
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||||
|
|
||||||
|
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
||||||
|
if not library_agent:
|
||||||
|
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
||||||
|
|
||||||
|
library_agent = await update_library_agent_version_and_settings(
|
||||||
|
user_id, created_graph
|
||||||
|
)
|
||||||
|
|
||||||
|
if created_graph.is_active:
|
||||||
|
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
||||||
|
await graph_db.set_graph_active_version(
|
||||||
|
graph_id=created_graph.id,
|
||||||
|
version=created_graph.version,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if current_active_version:
|
||||||
|
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||||
|
|
||||||
|
return created_graph, library_agent
|
||||||
|
|
||||||
|
|
||||||
|
async def update_library_agent_version_and_settings(
|
||||||
|
user_id: str, agent_graph: graph_db.GraphModel
|
||||||
|
) -> library_model.LibraryAgent:
|
||||||
|
"""Update library agent to point to new graph version and sync settings."""
|
||||||
|
library = await update_agent_version_in_library(
|
||||||
|
user_id, agent_graph.id, agent_graph.version
|
||||||
|
)
|
||||||
|
updated_settings = GraphSettings.from_graph(
|
||||||
|
graph=agent_graph,
|
||||||
|
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||||
|
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||||
|
)
|
||||||
|
if updated_settings != library.settings:
|
||||||
|
library = await update_library_agent(
|
||||||
|
library_agent_id=library.id,
|
||||||
|
user_id=user_id,
|
||||||
|
settings=updated_settings,
|
||||||
|
)
|
||||||
|
return library
|
||||||
|
|
||||||
|
|
||||||
async def update_library_agent(
|
async def update_library_agent(
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import pydantic
|
|||||||
from backend.data.block import BlockInput
|
from backend.data.block import BlockInput
|
||||||
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
||||||
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
||||||
|
from backend.util.json import loads as json_loads
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -16,10 +17,10 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class LibraryAgentStatus(str, Enum):
|
class LibraryAgentStatus(str, Enum):
|
||||||
COMPLETED = "COMPLETED" # All runs completed
|
COMPLETED = "COMPLETED"
|
||||||
HEALTHY = "HEALTHY" # Agent is running (not all runs have completed)
|
HEALTHY = "HEALTHY"
|
||||||
WAITING = "WAITING" # Agent is queued or waiting to start
|
WAITING = "WAITING"
|
||||||
ERROR = "ERROR" # Agent is in an error state
|
ERROR = "ERROR"
|
||||||
|
|
||||||
|
|
||||||
class MarketplaceListingCreator(pydantic.BaseModel):
|
class MarketplaceListingCreator(pydantic.BaseModel):
|
||||||
@@ -39,6 +40,30 @@ class MarketplaceListing(pydantic.BaseModel):
|
|||||||
creator: MarketplaceListingCreator
|
creator: MarketplaceListingCreator
|
||||||
|
|
||||||
|
|
||||||
|
class RecentExecution(pydantic.BaseModel):
|
||||||
|
"""Summary of a recent execution for quality assessment.
|
||||||
|
|
||||||
|
Used by the LLM to understand the agent's recent performance with specific examples
|
||||||
|
rather than just aggregate statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: str
|
||||||
|
correctness_score: float | None = None
|
||||||
|
activity_summary: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_settings(settings: dict | str | None) -> GraphSettings:
|
||||||
|
"""Parse settings from database, handling both dict and string formats."""
|
||||||
|
if settings is None:
|
||||||
|
return GraphSettings()
|
||||||
|
try:
|
||||||
|
if isinstance(settings, str):
|
||||||
|
settings = json_loads(settings)
|
||||||
|
return GraphSettings.model_validate(settings)
|
||||||
|
except Exception:
|
||||||
|
return GraphSettings()
|
||||||
|
|
||||||
|
|
||||||
class LibraryAgent(pydantic.BaseModel):
|
class LibraryAgent(pydantic.BaseModel):
|
||||||
"""
|
"""
|
||||||
Represents an agent in the library, including metadata for display and
|
Represents an agent in the library, including metadata for display and
|
||||||
@@ -48,7 +73,7 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
id: str
|
id: str
|
||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
owner_user_id: str # ID of user who owns/created this agent graph
|
owner_user_id: str
|
||||||
|
|
||||||
image_url: str | None
|
image_url: str | None
|
||||||
|
|
||||||
@@ -64,7 +89,7 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
description: str
|
description: str
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
|
|
||||||
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
|
input_schema: dict[str, Any]
|
||||||
output_schema: dict[str, Any]
|
output_schema: dict[str, Any]
|
||||||
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
|
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
|
||||||
description="Input schema for credentials required by the agent",
|
description="Input schema for credentials required by the agent",
|
||||||
@@ -81,25 +106,19 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
)
|
)
|
||||||
trigger_setup_info: Optional[GraphTriggerInfo] = None
|
trigger_setup_info: Optional[GraphTriggerInfo] = None
|
||||||
|
|
||||||
# Indicates whether there's a new output (based on recent runs)
|
|
||||||
new_output: bool
|
new_output: bool
|
||||||
|
execution_count: int = 0
|
||||||
# Whether the user can access the underlying graph
|
success_rate: float | None = None
|
||||||
|
avg_correctness_score: float | None = None
|
||||||
|
recent_executions: list[RecentExecution] = pydantic.Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="List of recent executions with status, score, and summary",
|
||||||
|
)
|
||||||
can_access_graph: bool
|
can_access_graph: bool
|
||||||
|
|
||||||
# Indicates if this agent is the latest version
|
|
||||||
is_latest_version: bool
|
is_latest_version: bool
|
||||||
|
|
||||||
# Whether the agent is marked as favorite by the user
|
|
||||||
is_favorite: bool
|
is_favorite: bool
|
||||||
|
|
||||||
# Recommended schedule cron (from marketplace agents)
|
|
||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
|
|
||||||
# User-specific settings for this library agent
|
|
||||||
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
||||||
|
|
||||||
# Marketplace listing information if the agent has been published
|
|
||||||
marketplace_listing: Optional["MarketplaceListing"] = None
|
marketplace_listing: Optional["MarketplaceListing"] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -123,7 +142,6 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
agent_updated_at = agent.AgentGraph.updatedAt
|
agent_updated_at = agent.AgentGraph.updatedAt
|
||||||
lib_agent_updated_at = agent.updatedAt
|
lib_agent_updated_at = agent.updatedAt
|
||||||
|
|
||||||
# Compute updated_at as the latest between library agent and graph
|
|
||||||
updated_at = (
|
updated_at = (
|
||||||
max(agent_updated_at, lib_agent_updated_at)
|
max(agent_updated_at, lib_agent_updated_at)
|
||||||
if agent_updated_at
|
if agent_updated_at
|
||||||
@@ -136,7 +154,6 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
creator_name = agent.Creator.name or "Unknown"
|
creator_name = agent.Creator.name or "Unknown"
|
||||||
creator_image_url = agent.Creator.avatarUrl or ""
|
creator_image_url = agent.Creator.avatarUrl or ""
|
||||||
|
|
||||||
# Logic to calculate status and new_output
|
|
||||||
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||||
days=7
|
days=7
|
||||||
)
|
)
|
||||||
@@ -145,13 +162,55 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
status = status_result.status
|
status = status_result.status
|
||||||
new_output = status_result.new_output
|
new_output = status_result.new_output
|
||||||
|
|
||||||
# Check if user can access the graph
|
execution_count = len(executions)
|
||||||
can_access_graph = agent.AgentGraph.userId == agent.userId
|
success_rate: float | None = None
|
||||||
|
avg_correctness_score: float | None = None
|
||||||
|
if execution_count > 0:
|
||||||
|
success_count = sum(
|
||||||
|
1
|
||||||
|
for e in executions
|
||||||
|
if e.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED
|
||||||
|
)
|
||||||
|
success_rate = (success_count / execution_count) * 100
|
||||||
|
|
||||||
# Hard-coded to True until a method to check is implemented
|
correctness_scores = []
|
||||||
|
for e in executions:
|
||||||
|
if e.stats and isinstance(e.stats, dict):
|
||||||
|
score = e.stats.get("correctness_score")
|
||||||
|
if score is not None and isinstance(score, (int, float)):
|
||||||
|
correctness_scores.append(float(score))
|
||||||
|
if correctness_scores:
|
||||||
|
avg_correctness_score = sum(correctness_scores) / len(
|
||||||
|
correctness_scores
|
||||||
|
)
|
||||||
|
|
||||||
|
recent_executions: list[RecentExecution] = []
|
||||||
|
for e in executions:
|
||||||
|
exec_score: float | None = None
|
||||||
|
exec_summary: str | None = None
|
||||||
|
if e.stats and isinstance(e.stats, dict):
|
||||||
|
score = e.stats.get("correctness_score")
|
||||||
|
if score is not None and isinstance(score, (int, float)):
|
||||||
|
exec_score = float(score)
|
||||||
|
summary = e.stats.get("activity_status")
|
||||||
|
if summary is not None and isinstance(summary, str):
|
||||||
|
exec_summary = summary
|
||||||
|
exec_status = (
|
||||||
|
e.executionStatus.value
|
||||||
|
if hasattr(e.executionStatus, "value")
|
||||||
|
else str(e.executionStatus)
|
||||||
|
)
|
||||||
|
recent_executions.append(
|
||||||
|
RecentExecution(
|
||||||
|
status=exec_status,
|
||||||
|
correctness_score=exec_score,
|
||||||
|
activity_summary=exec_summary,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
can_access_graph = agent.AgentGraph.userId == agent.userId
|
||||||
is_latest_version = True
|
is_latest_version = True
|
||||||
|
|
||||||
# Build marketplace_listing if available
|
|
||||||
marketplace_listing_data = None
|
marketplace_listing_data = None
|
||||||
if store_listing and store_listing.ActiveVersion and profile:
|
if store_listing and store_listing.ActiveVersion and profile:
|
||||||
creator_data = MarketplaceListingCreator(
|
creator_data = MarketplaceListingCreator(
|
||||||
@@ -190,11 +249,15 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
has_sensitive_action=graph.has_sensitive_action,
|
has_sensitive_action=graph.has_sensitive_action,
|
||||||
trigger_setup_info=graph.trigger_setup_info,
|
trigger_setup_info=graph.trigger_setup_info,
|
||||||
new_output=new_output,
|
new_output=new_output,
|
||||||
|
execution_count=execution_count,
|
||||||
|
success_rate=success_rate,
|
||||||
|
avg_correctness_score=avg_correctness_score,
|
||||||
|
recent_executions=recent_executions,
|
||||||
can_access_graph=can_access_graph,
|
can_access_graph=can_access_graph,
|
||||||
is_latest_version=is_latest_version,
|
is_latest_version=is_latest_version,
|
||||||
is_favorite=agent.isFavorite,
|
is_favorite=agent.isFavorite,
|
||||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||||
settings=GraphSettings.model_validate(agent.settings),
|
settings=_parse_settings(agent.settings),
|
||||||
marketplace_listing=marketplace_listing_data,
|
marketplace_listing=marketplace_listing_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -220,18 +283,15 @@ def _calculate_agent_status(
|
|||||||
if not executions:
|
if not executions:
|
||||||
return AgentStatusResult(status=LibraryAgentStatus.COMPLETED, new_output=False)
|
return AgentStatusResult(status=LibraryAgentStatus.COMPLETED, new_output=False)
|
||||||
|
|
||||||
# Track how many times each execution status appears
|
|
||||||
status_counts = {status: 0 for status in prisma.enums.AgentExecutionStatus}
|
status_counts = {status: 0 for status in prisma.enums.AgentExecutionStatus}
|
||||||
new_output = False
|
new_output = False
|
||||||
|
|
||||||
for execution in executions:
|
for execution in executions:
|
||||||
# Check if there's a completed run more recent than `recent_threshold`
|
|
||||||
if execution.createdAt >= recent_threshold:
|
if execution.createdAt >= recent_threshold:
|
||||||
if execution.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED:
|
if execution.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED:
|
||||||
new_output = True
|
new_output = True
|
||||||
status_counts[execution.executionStatus] += 1
|
status_counts[execution.executionStatus] += 1
|
||||||
|
|
||||||
# Determine the final status based on counts
|
|
||||||
if status_counts[prisma.enums.AgentExecutionStatus.FAILED] > 0:
|
if status_counts[prisma.enums.AgentExecutionStatus.FAILED] > 0:
|
||||||
return AgentStatusResult(status=LibraryAgentStatus.ERROR, new_output=new_output)
|
return AgentStatusResult(status=LibraryAgentStatus.ERROR, new_output=new_output)
|
||||||
elif status_counts[prisma.enums.AgentExecutionStatus.QUEUED] > 0:
|
elif status_counts[prisma.enums.AgentExecutionStatus.QUEUED] > 0:
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import autogpt_libs.auth as autogpt_auth_lib
|
import autogpt_libs.auth as autogpt_auth_lib
|
||||||
@@ -6,15 +5,11 @@ from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
|||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
from prisma.enums import OnboardingStep
|
from prisma.enums import OnboardingStep
|
||||||
|
|
||||||
import backend.api.features.store.exceptions as store_exceptions
|
|
||||||
from backend.data.onboarding import complete_onboarding_step
|
from backend.data.onboarding import complete_onboarding_step
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
|
||||||
|
|
||||||
from .. import db as library_db
|
from .. import db as library_db
|
||||||
from .. import model as library_model
|
from .. import model as library_model
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/agents",
|
prefix="/agents",
|
||||||
tags=["library", "private"],
|
tags=["library", "private"],
|
||||||
@@ -26,10 +21,6 @@ router = APIRouter(
|
|||||||
"",
|
"",
|
||||||
summary="List Library Agents",
|
summary="List Library Agents",
|
||||||
response_model=library_model.LibraryAgentResponse,
|
response_model=library_model.LibraryAgentResponse,
|
||||||
responses={
|
|
||||||
200: {"description": "List of library agents"},
|
|
||||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def list_library_agents(
|
async def list_library_agents(
|
||||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||||
@@ -53,22 +44,7 @@ async def list_library_agents(
|
|||||||
) -> library_model.LibraryAgentResponse:
|
) -> library_model.LibraryAgentResponse:
|
||||||
"""
|
"""
|
||||||
Get all agents in the user's library (both created and saved).
|
Get all agents in the user's library (both created and saved).
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
search_term: Optional search term to filter agents by name/description.
|
|
||||||
filter_by: List of filters to apply (favorites, created by user).
|
|
||||||
sort_by: List of sorting criteria (created date, updated date).
|
|
||||||
page: Page number to retrieve.
|
|
||||||
page_size: Number of agents per page.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A LibraryAgentResponse containing agents and pagination metadata.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If a server/database error occurs.
|
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
return await library_db.list_library_agents(
|
return await library_db.list_library_agents(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_term=search_term,
|
search_term=search_term,
|
||||||
@@ -76,20 +52,11 @@ async def list_library_agents(
|
|||||||
page=page,
|
page=page,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Could not list library agents for user #{user_id}: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=str(e),
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/favorites",
|
"/favorites",
|
||||||
summary="List Favorite Library Agents",
|
summary="List Favorite Library Agents",
|
||||||
responses={
|
|
||||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def list_favorite_library_agents(
|
async def list_favorite_library_agents(
|
||||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||||
@@ -106,30 +73,12 @@ async def list_favorite_library_agents(
|
|||||||
) -> library_model.LibraryAgentResponse:
|
) -> library_model.LibraryAgentResponse:
|
||||||
"""
|
"""
|
||||||
Get all favorite agents in the user's library.
|
Get all favorite agents in the user's library.
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
page: Page number to retrieve.
|
|
||||||
page_size: Number of agents per page.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A LibraryAgentResponse containing favorite agents and pagination metadata.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If a server/database error occurs.
|
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
return await library_db.list_favorite_library_agents(
|
return await library_db.list_favorite_library_agents(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
page=page,
|
page=page,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Could not list favorite library agents for user #{user_id}: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=str(e),
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{library_agent_id}", summary="Get Library Agent")
|
@router.get("/{library_agent_id}", summary="Get Library Agent")
|
||||||
@@ -162,10 +111,6 @@ async def get_library_agent_by_graph_id(
|
|||||||
summary="Get Agent By Store ID",
|
summary="Get Agent By Store ID",
|
||||||
tags=["store", "library"],
|
tags=["store", "library"],
|
||||||
response_model=library_model.LibraryAgent | None,
|
response_model=library_model.LibraryAgent | None,
|
||||||
responses={
|
|
||||||
200: {"description": "Library agent found"},
|
|
||||||
404: {"description": "Agent not found"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def get_library_agent_by_store_listing_version_id(
|
async def get_library_agent_by_store_listing_version_id(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
@@ -174,32 +119,15 @@ async def get_library_agent_by_store_listing_version_id(
|
|||||||
"""
|
"""
|
||||||
Get Library Agent from Store Listing Version ID.
|
Get Library Agent from Store Listing Version ID.
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
return await library_db.get_library_agent_by_store_version_id(
|
return await library_db.get_library_agent_by_store_version_id(
|
||||||
store_listing_version_id, user_id
|
store_listing_version_id, user_id
|
||||||
)
|
)
|
||||||
except NotFoundError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=str(e),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Could not fetch library agent from store version ID: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=str(e),
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"",
|
"",
|
||||||
summary="Add Marketplace Agent",
|
summary="Add Marketplace Agent",
|
||||||
status_code=status.HTTP_201_CREATED,
|
status_code=status.HTTP_201_CREATED,
|
||||||
responses={
|
|
||||||
201: {"description": "Agent added successfully"},
|
|
||||||
404: {"description": "Store listing version not found"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def add_marketplace_agent_to_library(
|
async def add_marketplace_agent_to_library(
|
||||||
store_listing_version_id: str = Body(embed=True),
|
store_listing_version_id: str = Body(embed=True),
|
||||||
@@ -210,59 +138,19 @@ async def add_marketplace_agent_to_library(
|
|||||||
) -> library_model.LibraryAgent:
|
) -> library_model.LibraryAgent:
|
||||||
"""
|
"""
|
||||||
Add an agent from the marketplace to the user's library.
|
Add an agent from the marketplace to the user's library.
|
||||||
|
|
||||||
Args:
|
|
||||||
store_listing_version_id: ID of the store listing version to add.
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
library_model.LibraryAgent: Agent added to the library
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException(404): If the listing version is not found.
|
|
||||||
HTTPException(500): If a server/database error occurs.
|
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
agent = await library_db.add_store_agent_to_library(
|
agent = await library_db.add_store_agent_to_library(
|
||||||
store_listing_version_id=store_listing_version_id,
|
store_listing_version_id=store_listing_version_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
if source != "onboarding":
|
if source != "onboarding":
|
||||||
await complete_onboarding_step(
|
await complete_onboarding_step(user_id, OnboardingStep.MARKETPLACE_ADD_AGENT)
|
||||||
user_id, OnboardingStep.MARKETPLACE_ADD_AGENT
|
|
||||||
)
|
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
except store_exceptions.AgentNotFoundError as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Could not find store listing version {store_listing_version_id} "
|
|
||||||
"to add to library"
|
|
||||||
)
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
|
||||||
except DatabaseError as e:
|
|
||||||
logger.error(f"Database error while adding agent to library: {e}", e)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail={"message": str(e), "hint": "Inspect DB logs for details."},
|
|
||||||
) from e
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error while adding agent to library: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail={
|
|
||||||
"message": str(e),
|
|
||||||
"hint": "Check server logs for more information.",
|
|
||||||
},
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
@router.patch(
|
||||||
"/{library_agent_id}",
|
"/{library_agent_id}",
|
||||||
summary="Update Library Agent",
|
summary="Update Library Agent",
|
||||||
responses={
|
|
||||||
200: {"description": "Agent updated successfully"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def update_library_agent(
|
async def update_library_agent(
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
@@ -271,16 +159,7 @@ async def update_library_agent(
|
|||||||
) -> library_model.LibraryAgent:
|
) -> library_model.LibraryAgent:
|
||||||
"""
|
"""
|
||||||
Update the library agent with the given fields.
|
Update the library agent with the given fields.
|
||||||
|
|
||||||
Args:
|
|
||||||
library_agent_id: ID of the library agent to update.
|
|
||||||
payload: Fields to update (auto_update_version, is_favorite, etc.).
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException(500): If a server/database error occurs.
|
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
return await library_db.update_library_agent(
|
return await library_db.update_library_agent(
|
||||||
library_agent_id=library_agent_id,
|
library_agent_id=library_agent_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -290,33 +169,11 @@ async def update_library_agent(
|
|||||||
is_archived=payload.is_archived,
|
is_archived=payload.is_archived,
|
||||||
settings=payload.settings,
|
settings=payload.settings,
|
||||||
)
|
)
|
||||||
except NotFoundError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=str(e),
|
|
||||||
) from e
|
|
||||||
except DatabaseError as e:
|
|
||||||
logger.error(f"Database error while updating library agent: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail={"message": str(e), "hint": "Verify DB connection."},
|
|
||||||
) from e
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error while updating library agent: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail={"message": str(e), "hint": "Check server logs."},
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"/{library_agent_id}",
|
"/{library_agent_id}",
|
||||||
summary="Delete Library Agent",
|
summary="Delete Library Agent",
|
||||||
responses={
|
|
||||||
204: {"description": "Agent deleted successfully"},
|
|
||||||
404: {"description": "Agent not found"},
|
|
||||||
500: {"description": "Server error"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def delete_library_agent(
|
async def delete_library_agent(
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
@@ -324,28 +181,11 @@ async def delete_library_agent(
|
|||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Soft-delete the specified library agent.
|
Soft-delete the specified library agent.
|
||||||
|
|
||||||
Args:
|
|
||||||
library_agent_id: ID of the library agent to delete.
|
|
||||||
user_id: ID of the authenticated user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
204 No Content if successful.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException(404): If the agent does not exist.
|
|
||||||
HTTPException(500): If a server/database error occurs.
|
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
await library_db.delete_library_agent(
|
await library_db.delete_library_agent(
|
||||||
library_agent_id=library_agent_id, user_id=user_id
|
library_agent_id=library_agent_id, user_id=user_id
|
||||||
)
|
)
|
||||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
except NotFoundError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=str(e),
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")
|
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")
|
||||||
|
|||||||
@@ -118,21 +118,6 @@ async def test_get_library_agents_success(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
|
||||||
mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents")
|
|
||||||
mock_db_call.side_effect = Exception("Test error")
|
|
||||||
|
|
||||||
response = client.get("/agents?search_term=test")
|
|
||||||
assert response.status_code == 500
|
|
||||||
mock_db_call.assert_called_once_with(
|
|
||||||
user_id=test_user_id,
|
|
||||||
search_term="test",
|
|
||||||
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
|
|
||||||
page=1,
|
|
||||||
page_size=15,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_favorite_library_agents_success(
|
async def test_get_favorite_library_agents_success(
|
||||||
mocker: pytest_mock.MockFixture,
|
mocker: pytest_mock.MockFixture,
|
||||||
@@ -190,23 +175,6 @@ async def test_get_favorite_library_agents_success(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_favorite_library_agents_error(
|
|
||||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
|
||||||
):
|
|
||||||
mock_db_call = mocker.patch(
|
|
||||||
"backend.api.features.library.db.list_favorite_library_agents"
|
|
||||||
)
|
|
||||||
mock_db_call.side_effect = Exception("Test error")
|
|
||||||
|
|
||||||
response = client.get("/agents/favorites")
|
|
||||||
assert response.status_code == 500
|
|
||||||
mock_db_call.assert_called_once_with(
|
|
||||||
user_id=test_user_id,
|
|
||||||
page=1,
|
|
||||||
page_size=15,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_agent_to_library_success(
|
def test_add_agent_to_library_success(
|
||||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||||
):
|
):
|
||||||
@@ -258,19 +226,3 @@ def test_add_agent_to_library_success(
|
|||||||
store_listing_version_id="test-version-id", user_id=test_user_id
|
store_listing_version_id="test-version-id", user_id=test_user_id
|
||||||
)
|
)
|
||||||
mock_complete_onboarding.assert_awaited_once()
|
mock_complete_onboarding.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture, test_user_id: str):
|
|
||||||
mock_db_call = mocker.patch(
|
|
||||||
"backend.api.features.library.db.add_store_agent_to_library"
|
|
||||||
)
|
|
||||||
mock_db_call.side_effect = Exception("Test error")
|
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
"/agents", json={"store_listing_version_id": "test-version-id"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 500
|
|
||||||
assert "detail" in response.json() # Verify error response structure
|
|
||||||
mock_db_call.assert_called_once_with(
|
|
||||||
store_listing_version_id="test-version-id", user_id=test_user_id
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ async def get_store_agents(
|
|||||||
description=agent["description"],
|
description=agent["description"],
|
||||||
runs=agent["runs"],
|
runs=agent["runs"],
|
||||||
rating=agent["rating"],
|
rating=agent["rating"],
|
||||||
|
agent_graph_id=agent.get("agentGraphId", ""),
|
||||||
)
|
)
|
||||||
store_agents.append(store_agent)
|
store_agents.append(store_agent)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -170,6 +171,7 @@ async def get_store_agents(
|
|||||||
description=agent.description,
|
description=agent.description,
|
||||||
runs=agent.runs,
|
runs=agent.runs,
|
||||||
rating=agent.rating,
|
rating=agent.rating,
|
||||||
|
agent_graph_id=agent.agentGraphId,
|
||||||
)
|
)
|
||||||
# Add to the list only if creation was successful
|
# Add to the list only if creation was successful
|
||||||
store_agents.append(store_agent)
|
store_agents.append(store_agent)
|
||||||
|
|||||||
@@ -454,6 +454,7 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
|||||||
total_processed = 0
|
total_processed = 0
|
||||||
total_success = 0
|
total_success = 0
|
||||||
total_failed = 0
|
total_failed = 0
|
||||||
|
all_errors: dict[str, int] = {} # Aggregate errors across all content types
|
||||||
|
|
||||||
# Process content types in explicit order
|
# Process content types in explicit order
|
||||||
processing_order = [
|
processing_order = [
|
||||||
@@ -499,23 +500,12 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
|||||||
success = sum(1 for result in results if result is True)
|
success = sum(1 for result in results if result is True)
|
||||||
failed = len(results) - success
|
failed = len(results) - success
|
||||||
|
|
||||||
# Aggregate unique errors to avoid Sentry spam
|
# Aggregate errors across all content types
|
||||||
if failed > 0:
|
if failed > 0:
|
||||||
# Group errors by type and message
|
|
||||||
error_summary: dict[str, int] = {}
|
|
||||||
for result in results:
|
for result in results:
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
error_key = f"{type(result).__name__}: {str(result)}"
|
error_key = f"{type(result).__name__}: {str(result)}"
|
||||||
error_summary[error_key] = error_summary.get(error_key, 0) + 1
|
all_errors[error_key] = all_errors.get(error_key, 0) + 1
|
||||||
|
|
||||||
# Log aggregated error summary
|
|
||||||
error_details = ", ".join(
|
|
||||||
f"{error} ({count}x)" for error, count in error_summary.items()
|
|
||||||
)
|
|
||||||
logger.error(
|
|
||||||
f"{content_type.value}: {failed}/{len(results)} embeddings failed. "
|
|
||||||
f"Errors: {error_details}"
|
|
||||||
)
|
|
||||||
|
|
||||||
results_by_type[content_type.value] = {
|
results_by_type[content_type.value] = {
|
||||||
"processed": len(missing_items),
|
"processed": len(missing_items),
|
||||||
@@ -542,6 +532,13 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
|||||||
"error": str(e),
|
"error": str(e),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Log aggregated errors once at the end
|
||||||
|
if all_errors:
|
||||||
|
error_details = ", ".join(
|
||||||
|
f"{error} ({count}x)" for error, count in all_errors.items()
|
||||||
|
)
|
||||||
|
logger.error(f"Embedding backfill errors: {error_details}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"by_type": results_by_type,
|
"by_type": results_by_type,
|
||||||
"totals": {
|
"totals": {
|
||||||
|
|||||||
@@ -454,6 +454,9 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
cleanup_embeddings: list,
|
cleanup_embeddings: list,
|
||||||
):
|
):
|
||||||
"""Test unified search pagination works correctly."""
|
"""Test unified search pagination works correctly."""
|
||||||
|
# Use a unique search term to avoid matching other test data
|
||||||
|
unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
# Create multiple items
|
# Create multiple items
|
||||||
content_ids = []
|
content_ids = []
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
@@ -465,14 +468,14 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
content_type=ContentType.BLOCK,
|
content_type=ContentType.BLOCK,
|
||||||
content_id=content_id,
|
content_id=content_id,
|
||||||
embedding=mock_embedding,
|
embedding=mock_embedding,
|
||||||
searchable_text=f"pagination test item number {i}",
|
searchable_text=f"{unique_term} item number {i}",
|
||||||
metadata={"index": i},
|
metadata={"index": i},
|
||||||
user_id=None,
|
user_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get first page
|
# Get first page
|
||||||
page1_results, total1 = await unified_hybrid_search(
|
page1_results, total1 = await unified_hybrid_search(
|
||||||
query="pagination test",
|
query=unique_term,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=2,
|
page_size=2,
|
||||||
@@ -480,7 +483,7 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
|
|
||||||
# Get second page
|
# Get second page
|
||||||
page2_results, total2 = await unified_hybrid_search(
|
page2_results, total2 = await unified_hybrid_search(
|
||||||
query="pagination test",
|
query=unique_term,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=2,
|
page=2,
|
||||||
page_size=2,
|
page_size=2,
|
||||||
|
|||||||
@@ -600,6 +600,7 @@ async def hybrid_search(
|
|||||||
sa.featured,
|
sa.featured,
|
||||||
sa.is_available,
|
sa.is_available,
|
||||||
sa.updated_at,
|
sa.updated_at,
|
||||||
|
sa."agentGraphId",
|
||||||
-- Searchable text for BM25 reranking
|
-- Searchable text for BM25 reranking
|
||||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||||
-- Semantic score
|
-- Semantic score
|
||||||
@@ -659,6 +660,7 @@ async def hybrid_search(
|
|||||||
featured,
|
featured,
|
||||||
is_available,
|
is_available,
|
||||||
updated_at,
|
updated_at,
|
||||||
|
"agentGraphId",
|
||||||
searchable_text,
|
searchable_text,
|
||||||
semantic_score,
|
semantic_score,
|
||||||
lexical_score,
|
lexical_score,
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ class StoreAgent(pydantic.BaseModel):
|
|||||||
description: str
|
description: str
|
||||||
runs: int
|
runs: int
|
||||||
rating: float
|
rating: float
|
||||||
|
agent_graph_id: str
|
||||||
|
|
||||||
|
|
||||||
class StoreAgentsResponse(pydantic.BaseModel):
|
class StoreAgentsResponse(pydantic.BaseModel):
|
||||||
|
|||||||
@@ -26,11 +26,13 @@ def test_store_agent():
|
|||||||
description="Test description",
|
description="Test description",
|
||||||
runs=50,
|
runs=50,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
|
agent_graph_id="test-graph-id",
|
||||||
)
|
)
|
||||||
assert agent.slug == "test-agent"
|
assert agent.slug == "test-agent"
|
||||||
assert agent.agent_name == "Test Agent"
|
assert agent.agent_name == "Test Agent"
|
||||||
assert agent.runs == 50
|
assert agent.runs == 50
|
||||||
assert agent.rating == 4.5
|
assert agent.rating == 4.5
|
||||||
|
assert agent.agent_graph_id == "test-graph-id"
|
||||||
|
|
||||||
|
|
||||||
def test_store_agents_response():
|
def test_store_agents_response():
|
||||||
@@ -46,6 +48,7 @@ def test_store_agents_response():
|
|||||||
description="Test description",
|
description="Test description",
|
||||||
runs=50,
|
runs=50,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
|
agent_graph_id="test-graph-id",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ def test_get_agents_featured(
|
|||||||
description="Featured agent description",
|
description="Featured agent description",
|
||||||
runs=100,
|
runs=100,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
|
agent_graph_id="test-graph-1",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
@@ -127,6 +128,7 @@ def test_get_agents_by_creator(
|
|||||||
description="Creator agent description",
|
description="Creator agent description",
|
||||||
runs=50,
|
runs=50,
|
||||||
rating=4.0,
|
rating=4.0,
|
||||||
|
agent_graph_id="test-graph-2",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
@@ -172,6 +174,7 @@ def test_get_agents_sorted(
|
|||||||
description="Top agent description",
|
description="Top agent description",
|
||||||
runs=1000,
|
runs=1000,
|
||||||
rating=5.0,
|
rating=5.0,
|
||||||
|
agent_graph_id="test-graph-3",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
@@ -217,6 +220,7 @@ def test_get_agents_search(
|
|||||||
description="Specific search term description",
|
description="Specific search term description",
|
||||||
runs=75,
|
runs=75,
|
||||||
rating=4.2,
|
rating=4.2,
|
||||||
|
agent_graph_id="test-graph-search",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
@@ -262,6 +266,7 @@ def test_get_agents_category(
|
|||||||
description="Category agent description",
|
description="Category agent description",
|
||||||
runs=60,
|
runs=60,
|
||||||
rating=4.1,
|
rating=4.1,
|
||||||
|
agent_graph_id="test-graph-category",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
@@ -306,6 +311,7 @@ def test_get_agents_pagination(
|
|||||||
description=f"Agent {i} description",
|
description=f"Agent {i} description",
|
||||||
runs=i * 10,
|
runs=i * 10,
|
||||||
rating=4.0,
|
rating=4.0,
|
||||||
|
agent_graph_id="test-graph-2",
|
||||||
)
|
)
|
||||||
for i in range(5)
|
for i in range(5)
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ class TestCacheDeletion:
|
|||||||
description="Test description",
|
description="Test description",
|
||||||
runs=100,
|
runs=100,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
|
agent_graph_id="test-graph-id",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=Pagination(
|
pagination=Pagination(
|
||||||
|
|||||||
@@ -101,7 +101,6 @@ from backend.util.timezone_utils import (
|
|||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
|
||||||
from .library import db as library_db
|
from .library import db as library_db
|
||||||
from .library import model as library_model
|
|
||||||
from .store.model import StoreAgentDetails
|
from .store.model import StoreAgentDetails
|
||||||
|
|
||||||
|
|
||||||
@@ -261,18 +260,36 @@ async def get_onboarding_agents(
|
|||||||
return await get_recommended_agents(user_id)
|
return await get_recommended_agents(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingStatusResponse(pydantic.BaseModel):
|
||||||
|
"""Response for onboarding status check."""
|
||||||
|
|
||||||
|
is_onboarding_enabled: bool
|
||||||
|
is_chat_enabled: bool
|
||||||
|
|
||||||
|
|
||||||
@v1_router.get(
|
@v1_router.get(
|
||||||
"/onboarding/enabled",
|
"/onboarding/enabled",
|
||||||
summary="Is onboarding enabled",
|
summary="Is onboarding enabled",
|
||||||
tags=["onboarding", "public"],
|
tags=["onboarding", "public"],
|
||||||
|
response_model=OnboardingStatusResponse,
|
||||||
)
|
)
|
||||||
async def is_onboarding_enabled(
|
async def is_onboarding_enabled(
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> bool:
|
) -> OnboardingStatusResponse:
|
||||||
# If chat is enabled for user, skip legacy onboarding
|
# Check if chat is enabled for user
|
||||||
if await is_feature_enabled(Flag.CHAT, user_id, False):
|
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
|
||||||
return False
|
|
||||||
return await onboarding_enabled()
|
# If chat is enabled, skip legacy onboarding
|
||||||
|
if is_chat_enabled:
|
||||||
|
return OnboardingStatusResponse(
|
||||||
|
is_onboarding_enabled=False,
|
||||||
|
is_chat_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OnboardingStatusResponse(
|
||||||
|
is_onboarding_enabled=await onboarding_enabled(),
|
||||||
|
is_chat_enabled=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@v1_router.post(
|
@v1_router.post(
|
||||||
@@ -805,18 +822,16 @@ async def update_graph(
|
|||||||
graph: graph_db.Graph,
|
graph: graph_db.Graph,
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> graph_db.GraphModel:
|
) -> graph_db.GraphModel:
|
||||||
# Sanity check
|
|
||||||
if graph.id and graph.id != graph_id:
|
if graph.id and graph.id != graph_id:
|
||||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||||
|
|
||||||
# Determine new version
|
|
||||||
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||||
if not existing_versions:
|
if not existing_versions:
|
||||||
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||||
latest_version_number = max(g.version for g in existing_versions)
|
|
||||||
graph.version = latest_version_number + 1
|
|
||||||
|
|
||||||
|
graph.version = max(g.version for g in existing_versions) + 1
|
||||||
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
||||||
|
|
||||||
graph = graph_db.make_graph_model(graph, user_id)
|
graph = graph_db.make_graph_model(graph, user_id)
|
||||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||||
graph.validate_graph(for_run=False)
|
graph.validate_graph(for_run=False)
|
||||||
@@ -824,27 +839,23 @@ async def update_graph(
|
|||||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||||
|
|
||||||
if new_graph_version.is_active:
|
if new_graph_version.is_active:
|
||||||
# Keep the library agent up to date with the new active version
|
await library_db.update_library_agent_version_and_settings(
|
||||||
await _update_library_agent_version_and_settings(user_id, new_graph_version)
|
user_id, new_graph_version
|
||||||
|
)
|
||||||
# Handle activation of the new graph first to ensure continuity
|
|
||||||
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
||||||
# Ensure new version is the only active version
|
|
||||||
await graph_db.set_graph_active_version(
|
await graph_db.set_graph_active_version(
|
||||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||||
)
|
)
|
||||||
if current_active_version:
|
if current_active_version:
|
||||||
# Handle deactivation of the previously active version
|
|
||||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||||
|
|
||||||
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
|
|
||||||
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
||||||
graph_id,
|
graph_id,
|
||||||
new_graph_version.version,
|
new_graph_version.version,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
include_subgraphs=True,
|
include_subgraphs=True,
|
||||||
)
|
)
|
||||||
assert new_graph_version_with_subgraphs # make type checker happy
|
assert new_graph_version_with_subgraphs
|
||||||
return new_graph_version_with_subgraphs
|
return new_graph_version_with_subgraphs
|
||||||
|
|
||||||
|
|
||||||
@@ -882,33 +893,15 @@ async def set_graph_active_version(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Keep the library agent up to date with the new active version
|
# Keep the library agent up to date with the new active version
|
||||||
await _update_library_agent_version_and_settings(user_id, new_active_graph)
|
await library_db.update_library_agent_version_and_settings(
|
||||||
|
user_id, new_active_graph
|
||||||
|
)
|
||||||
|
|
||||||
if current_active_graph and current_active_graph.version != new_active_version:
|
if current_active_graph and current_active_graph.version != new_active_version:
|
||||||
# Handle deactivation of the previously active version
|
# Handle deactivation of the previously active version
|
||||||
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
async def _update_library_agent_version_and_settings(
|
|
||||||
user_id: str, agent_graph: graph_db.GraphModel
|
|
||||||
) -> library_model.LibraryAgent:
|
|
||||||
library = await library_db.update_agent_version_in_library(
|
|
||||||
user_id, agent_graph.id, agent_graph.version
|
|
||||||
)
|
|
||||||
updated_settings = GraphSettings.from_graph(
|
|
||||||
graph=agent_graph,
|
|
||||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
|
||||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
|
||||||
)
|
|
||||||
if updated_settings != library.settings:
|
|
||||||
library = await library_db.update_library_agent(
|
|
||||||
library_agent_id=library.id,
|
|
||||||
user_id=user_id,
|
|
||||||
settings=updated_settings,
|
|
||||||
)
|
|
||||||
return library
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.patch(
|
@v1_router.patch(
|
||||||
path="/graphs/{graph_id}/settings",
|
path="/graphs/{graph_id}/settings",
|
||||||
summary="Update graph settings",
|
summary="Update graph settings",
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
# Workspace API feature module
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
"""
|
||||||
|
Workspace API routes for managing user file storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Annotated
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from backend.data.workspace import get_workspace, get_workspace_file
|
||||||
|
from backend.util.workspace_storage import get_workspace_storage
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_filename_for_header(filename: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize filename for Content-Disposition header to prevent header injection.
|
||||||
|
|
||||||
|
Removes/replaces characters that could break the header or inject new headers.
|
||||||
|
Uses RFC5987 encoding for non-ASCII characters.
|
||||||
|
"""
|
||||||
|
# Remove CR, LF, and null bytes (header injection prevention)
|
||||||
|
sanitized = re.sub(r"[\r\n\x00]", "", filename)
|
||||||
|
# Escape quotes
|
||||||
|
sanitized = sanitized.replace('"', '\\"')
|
||||||
|
# For non-ASCII, use RFC5987 filename* parameter
|
||||||
|
# Check if filename has non-ASCII characters
|
||||||
|
try:
|
||||||
|
sanitized.encode("ascii")
|
||||||
|
return f'attachment; filename="{sanitized}"'
|
||||||
|
except UnicodeEncodeError:
|
||||||
|
# Use RFC5987 encoding for UTF-8 filenames
|
||||||
|
encoded = quote(sanitized, safe="")
|
||||||
|
return f"attachment; filename*=UTF-8''{encoded}"
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = fastapi.APIRouter(
|
||||||
|
dependencies=[fastapi.Security(requires_user)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_streaming_response(content: bytes, file) -> Response:
|
||||||
|
"""Create a streaming response for file content."""
|
||||||
|
return Response(
|
||||||
|
content=content,
|
||||||
|
media_type=file.mimeType,
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||||
|
"Content-Length": str(len(content)),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_file_download_response(file) -> Response:
|
||||||
|
"""
|
||||||
|
Create a download response for a workspace file.
|
||||||
|
|
||||||
|
Handles both local storage (direct streaming) and GCS (signed URL redirect
|
||||||
|
with fallback to streaming).
|
||||||
|
"""
|
||||||
|
storage = await get_workspace_storage()
|
||||||
|
|
||||||
|
# For local storage, stream the file directly
|
||||||
|
if file.storagePath.startswith("local://"):
|
||||||
|
content = await storage.retrieve(file.storagePath)
|
||||||
|
return _create_streaming_response(content, file)
|
||||||
|
|
||||||
|
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||||
|
try:
|
||||||
|
url = await storage.get_download_url(file.storagePath, expires_in=300)
|
||||||
|
# If we got back an API path (fallback), stream directly instead
|
||||||
|
if url.startswith("/api/"):
|
||||||
|
content = await storage.retrieve(file.storagePath)
|
||||||
|
return _create_streaming_response(content, file)
|
||||||
|
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||||
|
except Exception as e:
|
||||||
|
# Log the signed URL failure with context
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get signed URL for file {file.id} "
|
||||||
|
f"(storagePath={file.storagePath}): {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# Fall back to streaming directly from GCS
|
||||||
|
try:
|
||||||
|
content = await storage.retrieve(file.storagePath)
|
||||||
|
return _create_streaming_response(content, file)
|
||||||
|
except Exception as fallback_error:
|
||||||
|
logger.error(
|
||||||
|
f"Fallback streaming also failed for file {file.id} "
|
||||||
|
f"(storagePath={file.storagePath}): {fallback_error}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/files/{file_id}/download",
|
||||||
|
summary="Download file by ID",
|
||||||
|
)
|
||||||
|
async def download_file(
|
||||||
|
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||||
|
file_id: str,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Download a file by its ID.
|
||||||
|
|
||||||
|
Returns the file content directly or redirects to a signed URL for GCS.
|
||||||
|
"""
|
||||||
|
workspace = await get_workspace(user_id)
|
||||||
|
if workspace is None:
|
||||||
|
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
|
||||||
|
|
||||||
|
file = await get_workspace_file(file_id, workspace.id)
|
||||||
|
if file is None:
|
||||||
|
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
|
return await _create_file_download_response(file)
|
||||||
@@ -32,6 +32,7 @@ import backend.api.features.postmark.postmark
|
|||||||
import backend.api.features.store.model
|
import backend.api.features.store.model
|
||||||
import backend.api.features.store.routes
|
import backend.api.features.store.routes
|
||||||
import backend.api.features.v1
|
import backend.api.features.v1
|
||||||
|
import backend.api.features.workspace.routes as workspace_routes
|
||||||
import backend.data.block
|
import backend.data.block
|
||||||
import backend.data.db
|
import backend.data.db
|
||||||
import backend.data.graph
|
import backend.data.graph
|
||||||
@@ -39,6 +40,10 @@ import backend.data.user
|
|||||||
import backend.integrations.webhooks.utils
|
import backend.integrations.webhooks.utils
|
||||||
import backend.util.service
|
import backend.util.service
|
||||||
import backend.util.settings
|
import backend.util.settings
|
||||||
|
from backend.api.features.chat.completion_consumer import (
|
||||||
|
start_completion_consumer,
|
||||||
|
stop_completion_consumer,
|
||||||
|
)
|
||||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||||
from backend.data.model import Credentials
|
from backend.data.model import Credentials
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
@@ -52,6 +57,7 @@ from backend.util.exceptions import (
|
|||||||
)
|
)
|
||||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||||
from backend.util.service import UnhealthyServiceError
|
from backend.util.service import UnhealthyServiceError
|
||||||
|
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||||
|
|
||||||
from .external.fastapi_app import external_api
|
from .external.fastapi_app import external_api
|
||||||
from .features.analytics import router as analytics_router
|
from .features.analytics import router as analytics_router
|
||||||
@@ -116,14 +122,31 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||||
|
|
||||||
|
# Start chat completion consumer for Redis Streams notifications
|
||||||
|
try:
|
||||||
|
await start_completion_consumer()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not start chat completion consumer: {e}")
|
||||||
|
|
||||||
with launch_darkly_context():
|
with launch_darkly_context():
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
# Stop chat completion consumer
|
||||||
|
try:
|
||||||
|
await stop_completion_consumer()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error stopping chat completion consumer: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await shutdown_cloud_storage_handler()
|
await shutdown_cloud_storage_handler()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error shutting down cloud storage handler: {e}")
|
logger.warning(f"Error shutting down cloud storage handler: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await shutdown_workspace_storage()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error shutting down workspace storage: {e}")
|
||||||
|
|
||||||
await backend.data.db.disconnect()
|
await backend.data.db.disconnect()
|
||||||
|
|
||||||
|
|
||||||
@@ -315,6 +338,11 @@ app.include_router(
|
|||||||
tags=["v2", "chat"],
|
tags=["v2", "chat"],
|
||||||
prefix="/api/chat",
|
prefix="/api/chat",
|
||||||
)
|
)
|
||||||
|
app.include_router(
|
||||||
|
workspace_routes.router,
|
||||||
|
tags=["workspace"],
|
||||||
|
prefix="/api/workspace",
|
||||||
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.oauth.router,
|
backend.api.features.oauth.router,
|
||||||
tags=["oauth"],
|
tags=["oauth"],
|
||||||
|
|||||||
@@ -66,6 +66,8 @@ async def event_broadcaster(manager: ConnectionManager):
|
|||||||
execution_bus = AsyncRedisExecutionEventBus()
|
execution_bus = AsyncRedisExecutionEventBus()
|
||||||
notification_bus = AsyncRedisNotificationEventBus()
|
notification_bus = AsyncRedisNotificationEventBus()
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
async def execution_worker():
|
async def execution_worker():
|
||||||
async for event in execution_bus.listen("*"):
|
async for event in execution_bus.listen("*"):
|
||||||
await manager.send_execution_update(event)
|
await manager.send_execution_update(event)
|
||||||
@@ -78,6 +80,10 @@ async def event_broadcaster(manager: ConnectionManager):
|
|||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.gather(execution_worker(), notification_worker())
|
await asyncio.gather(execution_worker(), notification_worker())
|
||||||
|
finally:
|
||||||
|
# Ensure PubSub connections are closed on any exit to prevent leaks
|
||||||
|
await execution_bus.close()
|
||||||
|
await notification_bus.close()
|
||||||
|
|
||||||
|
|
||||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -117,11 +118,13 @@ class AIImageCustomizerBlock(Block):
|
|||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_output=[
|
test_output=[
|
||||||
("image_url", "https://replicate.delivery/generated-image.jpg"),
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
|
# Use data URI to avoid HTTP requests during tests
|
||||||
"run_model": lambda *args, **kwargs: MediaFileType(
|
"run_model": lambda *args, **kwargs: MediaFileType(
|
||||||
"https://replicate.delivery/generated-image.jpg"
|
""
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -132,8 +135,7 @@ class AIImageCustomizerBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
@@ -141,10 +143,9 @@ class AIImageCustomizerBlock(Block):
|
|||||||
processed_images = await asyncio.gather(
|
processed_images = await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
store_media_file(
|
store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=img,
|
file=img,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True,
|
return_format="for_external_api", # Get content for Replicate API
|
||||||
)
|
)
|
||||||
for img in input_data.images
|
for img in input_data.images
|
||||||
)
|
)
|
||||||
@@ -158,7 +159,14 @@ class AIImageCustomizerBlock(Block):
|
|||||||
aspect_ratio=input_data.aspect_ratio.value,
|
aspect_ratio=input_data.aspect_ratio.value,
|
||||||
output_format=input_data.output_format.value,
|
output_format=input_data.output_format.value,
|
||||||
)
|
)
|
||||||
yield "image_url", result
|
|
||||||
|
# Store the generated image to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=result,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", stored_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "error", str(e)
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from replicate.client import Client as ReplicateClient
|
|||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -13,6 +14,8 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.file import store_media_file
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
|
|
||||||
class ImageSize(str, Enum):
|
class ImageSize(str, Enum):
|
||||||
@@ -165,11 +168,13 @@ class AIImageGeneratorBlock(Block):
|
|||||||
test_output=[
|
test_output=[
|
||||||
(
|
(
|
||||||
"image_url",
|
"image_url",
|
||||||
"https://replicate.delivery/generated-image.webp",
|
# Test output is a data URI since we now store images
|
||||||
|
lambda x: x.startswith(""
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -318,11 +323,24 @@ class AIImageGeneratorBlock(Block):
|
|||||||
style_text = style_map.get(style, "")
|
style_text = style_map.get(style, "")
|
||||||
return f"{style_text} of" if style_text else ""
|
return f"{style_text} of" if style_text else ""
|
||||||
|
|
||||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
url = await self.generate_image(input_data, credentials)
|
url = await self.generate_image(input_data, credentials)
|
||||||
if url:
|
if url:
|
||||||
yield "image_url", url
|
# Store the generated image to the user's workspace/execution folder
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", stored_url
|
||||||
else:
|
else:
|
||||||
yield "error", "Image generation returned an empty result."
|
yield "error", "Image generation returned an empty result."
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -21,7 +22,9 @@ from backend.data.model import (
|
|||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.exceptions import BlockExecutionError
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
@@ -271,7 +274,10 @@ class AIShortformVideoCreatorBlock(Block):
|
|||||||
"voice": Voice.LILY,
|
"voice": Voice.LILY,
|
||||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||||
},
|
},
|
||||||
test_output=("video_url", "https://example.com/video.mp4"),
|
test_output=(
|
||||||
|
"video_url",
|
||||||
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
|
),
|
||||||
test_mock={
|
test_mock={
|
||||||
"create_webhook": lambda *args, **kwargs: (
|
"create_webhook": lambda *args, **kwargs: (
|
||||||
"test_uuid",
|
"test_uuid",
|
||||||
@@ -280,15 +286,21 @@ class AIShortformVideoCreatorBlock(Block):
|
|||||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||||
"check_video_status": lambda *args, **kwargs: {
|
"check_video_status": lambda *args, **kwargs: {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"videoUrl": "https://example.com/video.mp4",
|
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
|
# Use data URI to avoid HTTP requests during tests
|
||||||
|
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Create a new Webhook.site URL
|
# Create a new Webhook.site URL
|
||||||
webhook_token, webhook_url = await self.create_webhook()
|
webhook_token, webhook_url = await self.create_webhook()
|
||||||
@@ -340,7 +352,13 @@ class AIShortformVideoCreatorBlock(Block):
|
|||||||
)
|
)
|
||||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||||
logger.debug(f"Video ready: {video_url}")
|
logger.debug(f"Video ready: {video_url}")
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
|
|
||||||
|
|
||||||
class AIAdMakerVideoCreatorBlock(Block):
|
class AIAdMakerVideoCreatorBlock(Block):
|
||||||
@@ -447,7 +465,10 @@ class AIAdMakerVideoCreatorBlock(Block):
|
|||||||
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
test_output=("video_url", "https://example.com/ad.mp4"),
|
test_output=(
|
||||||
|
"video_url",
|
||||||
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
|
),
|
||||||
test_mock={
|
test_mock={
|
||||||
"create_webhook": lambda *args, **kwargs: (
|
"create_webhook": lambda *args, **kwargs: (
|
||||||
"test_uuid",
|
"test_uuid",
|
||||||
@@ -456,14 +477,21 @@ class AIAdMakerVideoCreatorBlock(Block):
|
|||||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||||
"check_video_status": lambda *args, **kwargs: {
|
"check_video_status": lambda *args, **kwargs: {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"videoUrl": "https://example.com/ad.mp4",
|
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
|
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
webhook_token, webhook_url = await self.create_webhook()
|
webhook_token, webhook_url = await self.create_webhook()
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -531,7 +559,13 @@ class AIAdMakerVideoCreatorBlock(Block):
|
|||||||
raise RuntimeError("Failed to create video: No project ID returned")
|
raise RuntimeError("Failed to create video: No project ID returned")
|
||||||
|
|
||||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
|
|
||||||
|
|
||||||
class AIScreenshotToVideoAdBlock(Block):
|
class AIScreenshotToVideoAdBlock(Block):
|
||||||
@@ -626,7 +660,10 @@ class AIScreenshotToVideoAdBlock(Block):
|
|||||||
"script": "Amazing numbers!",
|
"script": "Amazing numbers!",
|
||||||
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
||||||
},
|
},
|
||||||
test_output=("video_url", "https://example.com/screenshot.mp4"),
|
test_output=(
|
||||||
|
"video_url",
|
||||||
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
|
),
|
||||||
test_mock={
|
test_mock={
|
||||||
"create_webhook": lambda *args, **kwargs: (
|
"create_webhook": lambda *args, **kwargs: (
|
||||||
"test_uuid",
|
"test_uuid",
|
||||||
@@ -635,14 +672,21 @@ class AIScreenshotToVideoAdBlock(Block):
|
|||||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||||
"check_video_status": lambda *args, **kwargs: {
|
"check_video_status": lambda *args, **kwargs: {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"videoUrl": "https://example.com/screenshot.mp4",
|
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
|
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
webhook_token, webhook_url = await self.create_webhook()
|
webhook_token, webhook_url = await self.create_webhook()
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -710,4 +754,10 @@ class AIScreenshotToVideoAdBlock(Block):
|
|||||||
raise RuntimeError("Failed to create video: No project ID returned")
|
raise RuntimeError("Failed to create video: No project ID returned")
|
||||||
|
|
||||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
@@ -17,6 +18,8 @@ from backend.sdk import (
|
|||||||
Requests,
|
Requests,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
|
from backend.util.file import store_media_file
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
from ._config import bannerbear
|
from ._config import bannerbear
|
||||||
|
|
||||||
@@ -135,15 +138,17 @@ class BannerbearTextOverlayBlock(Block):
|
|||||||
},
|
},
|
||||||
test_output=[
|
test_output=[
|
||||||
("success", True),
|
("success", True),
|
||||||
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
("uid", "test-uid-123"),
|
("uid", "test-uid-123"),
|
||||||
("status", "completed"),
|
("status", "completed"),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
|
# Use data URI to avoid HTTP requests during tests
|
||||||
"_make_api_request": lambda *args, **kwargs: {
|
"_make_api_request": lambda *args, **kwargs: {
|
||||||
"uid": "test-uid-123",
|
"uid": "test-uid-123",
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
|
"image_url": "",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -177,7 +182,12 @@ class BannerbearTextOverlayBlock(Block):
|
|||||||
raise Exception(error_msg)
|
raise Exception(error_msg)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Build the modifications array
|
# Build the modifications array
|
||||||
modifications = []
|
modifications = []
|
||||||
@@ -234,6 +244,18 @@ class BannerbearTextOverlayBlock(Block):
|
|||||||
|
|
||||||
# Synchronous request - image should be ready
|
# Synchronous request - image should be ready
|
||||||
yield "success", True
|
yield "success", True
|
||||||
yield "image_url", data.get("image_url", "")
|
|
||||||
|
# Store the generated image to workspace for persistence
|
||||||
|
image_url = data.get("image_url", "")
|
||||||
|
if image_url:
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(image_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "image_url", stored_url
|
||||||
|
else:
|
||||||
|
yield "image_url", ""
|
||||||
|
|
||||||
yield "uid", data.get("uid", "")
|
yield "uid", data.get("uid", "")
|
||||||
yield "status", data.get("status", "completed")
|
yield "status", data.get("status", "completed")
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.type import MediaFileType, convert
|
from backend.util.type import MediaFileType, convert
|
||||||
@@ -17,10 +18,10 @@ from backend.util.type import MediaFileType, convert
|
|||||||
class FileStoreBlock(Block):
|
class FileStoreBlock(Block):
|
||||||
class Input(BlockSchemaInput):
|
class Input(BlockSchemaInput):
|
||||||
file_in: MediaFileType = SchemaField(
|
file_in: MediaFileType = SchemaField(
|
||||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
description="The file to download and store. Can be a URL (https://...), data URI, or local path."
|
||||||
)
|
)
|
||||||
base_64: bool = SchemaField(
|
base_64: bool = SchemaField(
|
||||||
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
|
description="Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks).",
|
||||||
default=False,
|
default=False,
|
||||||
advanced=True,
|
advanced=True,
|
||||||
title="Produce Base64 Output",
|
title="Produce Base64 Output",
|
||||||
@@ -28,13 +29,18 @@ class FileStoreBlock(Block):
|
|||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
file_out: MediaFileType = SchemaField(
|
file_out: MediaFileType = SchemaField(
|
||||||
description="The relative path to the stored file in the temporary directory."
|
description="Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks."
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
|
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
|
||||||
description="Stores the input file in the temporary directory.",
|
description=(
|
||||||
|
"Downloads and stores a file from a URL, data URI, or local path. "
|
||||||
|
"Use this to fetch images, documents, or other files for processing. "
|
||||||
|
"In CoPilot: saves to workspace (use list_workspace_files to see it). "
|
||||||
|
"In graphs: outputs a data URI to pass to other blocks."
|
||||||
|
),
|
||||||
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
|
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
|
||||||
input_schema=FileStoreBlock.Input,
|
input_schema=FileStoreBlock.Input,
|
||||||
output_schema=FileStoreBlock.Output,
|
output_schema=FileStoreBlock.Output,
|
||||||
@@ -45,15 +51,18 @@ class FileStoreBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
# Determine return format based on user preference
|
||||||
|
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
|
||||||
|
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
|
||||||
|
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
|
||||||
|
|
||||||
yield "file_out", await store_media_file(
|
yield "file_out", await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.file_in,
|
file=input_data.file_in,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=input_data.base_64,
|
return_format=return_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import APIKeyCredentials, SchemaField
|
from backend.data.model import APIKeyCredentials, SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
@@ -666,8 +667,7 @@ class SendDiscordFileBlock(Block):
|
|||||||
file: MediaFileType,
|
file: MediaFileType,
|
||||||
filename: str,
|
filename: str,
|
||||||
message_content: str,
|
message_content: str,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
intents.guilds = True
|
intents.guilds = True
|
||||||
@@ -731,10 +731,9 @@ class SendDiscordFileBlock(Block):
|
|||||||
# Local file path - read from stored media file
|
# Local file path - read from stored media file
|
||||||
# This would be a path from a previous block's output
|
# This would be a path from a previous block's output
|
||||||
stored_file = await store_media_file(
|
stored_file = await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=file,
|
file=file,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True, # Get as data URI
|
return_format="for_external_api", # Get content to send to Discord
|
||||||
)
|
)
|
||||||
# Now process as data URI
|
# Now process as data URI
|
||||||
header, encoded = stored_file.split(",", 1)
|
header, encoded = stored_file.split(",", 1)
|
||||||
@@ -781,8 +780,7 @@ class SendDiscordFileBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
@@ -793,8 +791,7 @@ class SendDiscordFileBlock(Block):
|
|||||||
file=input_data.file,
|
file=input_data.file,
|
||||||
filename=input_data.filename,
|
filename=input_data.filename,
|
||||||
message_content=input_data.message_content,
|
message_content=input_data.message_content,
|
||||||
graph_exec_id=graph_exec_id,
|
execution_context=execution_context,
|
||||||
user_id=user_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "status", result.get("status", "Unknown error")
|
yield "status", result.get("status", "Unknown error")
|
||||||
|
|||||||
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""ElevenLabs integration blocks - test credentials and shared utilities."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="elevenlabs",
|
||||||
|
api_key=SecretStr("mock-elevenlabs-api-key"),
|
||||||
|
title="Mock ElevenLabs API key",
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_CREDENTIALS.provider,
|
||||||
|
"id": TEST_CREDENTIALS.id,
|
||||||
|
"type": TEST_CREDENTIALS.type,
|
||||||
|
"title": TEST_CREDENTIALS.title,
|
||||||
|
}
|
||||||
|
|
||||||
|
ElevenLabsCredentials = APIKeyCredentials
|
||||||
|
ElevenLabsCredentialsInput = CredentialsMetaInput[
|
||||||
|
Literal[ProviderName.ELEVENLABS], Literal["api_key"]
|
||||||
|
]
|
||||||
77
autogpt_platform/backend/backend/blocks/encoder_block.py
Normal file
77
autogpt_platform/backend/backend/blocks/encoder_block.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""Text encoding block for converting special characters to escape sequences."""
|
||||||
|
|
||||||
|
import codecs
|
||||||
|
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncoderBlock(Block):
|
||||||
|
"""
|
||||||
|
Encodes a string by converting special characters into escape sequences.
|
||||||
|
|
||||||
|
This block is the inverse of TextDecoderBlock. It takes text containing
|
||||||
|
special characters (like newlines, tabs, etc.) and converts them into
|
||||||
|
their escape sequence representations (e.g., newline becomes \\n).
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
"""Input schema for TextEncoderBlock."""
|
||||||
|
|
||||||
|
text: str = SchemaField(
|
||||||
|
description="A string containing special characters to be encoded",
|
||||||
|
placeholder="Your text with newlines and quotes to encode",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
"""Output schema for TextEncoderBlock."""
|
||||||
|
|
||||||
|
encoded_text: str = SchemaField(
|
||||||
|
description="The encoded text with special characters converted to escape sequences"
|
||||||
|
)
|
||||||
|
error: str = SchemaField(description="Error message if encoding fails")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="5185f32e-4b65-4ecf-8fbb-873f003f09d6",
|
||||||
|
description="Encodes a string by converting special characters into escape sequences",
|
||||||
|
categories={BlockCategory.TEXT},
|
||||||
|
input_schema=TextEncoderBlock.Input,
|
||||||
|
output_schema=TextEncoderBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"text": """Hello
|
||||||
|
World!
|
||||||
|
This is a "quoted" string."""
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"encoded_text",
|
||||||
|
"""Hello\\nWorld!\\nThis is a "quoted" string.""",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||||
|
"""
|
||||||
|
Encode the input text by converting special characters to escape sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data: The input containing the text to encode.
|
||||||
|
**kwargs: Additional keyword arguments (unused).
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The encoded text with escape sequences, or an error message if encoding fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
encoded_text = codecs.encode(input_data.text, "unicode_escape").decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
yield "encoded_text", encoded_text
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", f"Encoding error: {str(e)}"
|
||||||
@@ -17,8 +17,11 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import ClientResponseError, Requests
|
from backend.util.request import ClientResponseError, Requests
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -64,9 +67,13 @@ class AIVideoGeneratorBlock(Block):
|
|||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
test_output=[("video_url", "https://fal.media/files/example/video.mp4")],
|
test_output=[
|
||||||
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("video_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
"generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4"
|
# Use data URI to avoid HTTP requests during tests
|
||||||
|
"generate_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -208,11 +215,22 @@ class AIVideoGeneratorBlock(Block):
|
|||||||
raise RuntimeError(f"API request failed: {str(e)}")
|
raise RuntimeError(f"API request failed: {str(e)}")
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: FalCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: FalCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
video_url = await self.generate_video(input_data, credentials)
|
video_url = await self.generate_video(input_data, credentials)
|
||||||
yield "video_url", video_url
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
yield "error", error_message
|
yield "error", error_message
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -121,10 +122,12 @@ class AIImageEditorBlock(Block):
|
|||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_output=[
|
test_output=[
|
||||||
("output_image", "https://replicate.com/output/edited-image.png"),
|
# Output will be a workspace ref or data URI depending on context
|
||||||
|
("output_image", lambda x: x.startswith(("workspace://", "data:"))),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
"run_model": lambda *args, **kwargs: "https://replicate.com/output/edited-image.png",
|
# Use data URI to avoid HTTP requests during tests
|
||||||
|
"run_model": lambda *args, **kwargs: "",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
@@ -134,8 +137,7 @@ class AIImageEditorBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
result = await self.run_model(
|
result = await self.run_model(
|
||||||
@@ -144,20 +146,25 @@ class AIImageEditorBlock(Block):
|
|||||||
prompt=input_data.prompt,
|
prompt=input_data.prompt,
|
||||||
input_image_b64=(
|
input_image_b64=(
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.input_image,
|
file=input_data.input_image,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True,
|
return_format="for_external_api", # Get content for Replicate API
|
||||||
)
|
)
|
||||||
if input_data.input_image
|
if input_data.input_image
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
aspect_ratio=input_data.aspect_ratio.value,
|
aspect_ratio=input_data.aspect_ratio.value,
|
||||||
seed=input_data.seed,
|
seed=input_data.seed,
|
||||||
user_id=user_id,
|
user_id=execution_context.user_id or "",
|
||||||
graph_exec_id=graph_exec_id,
|
graph_exec_id=execution_context.graph_exec_id or "",
|
||||||
)
|
)
|
||||||
yield "output_image", result
|
# Store the generated image to the user's workspace for persistence
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=result,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "output_image", stored_url
|
||||||
|
|
||||||
async def run_model(
|
async def run_model(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
@@ -95,8 +96,7 @@ def _make_mime_text(
|
|||||||
|
|
||||||
async def create_mime_message(
|
async def create_mime_message(
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
||||||
|
|
||||||
@@ -117,12 +117,12 @@ async def create_mime_message(
|
|||||||
if input_data.attachments:
|
if input_data.attachments:
|
||||||
for attach in input_data.attachments:
|
for attach in input_data.attachments:
|
||||||
local_path = await store_media_file(
|
local_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=attach,
|
file=attach,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
part = MIMEBase("application", "octet-stream")
|
part = MIMEBase("application", "octet-stream")
|
||||||
with open(abs_path, "rb") as f:
|
with open(abs_path, "rb") as f:
|
||||||
part.set_payload(f.read())
|
part.set_payload(f.read())
|
||||||
@@ -582,27 +582,25 @@ class GmailSendBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
result = await self._send_email(
|
result = await self._send_email(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "result", result
|
yield "result", result
|
||||||
|
|
||||||
async def _send_email(
|
async def _send_email(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not input_data.to or not input_data.subject or not input_data.body:
|
if not input_data.to or not input_data.subject or not input_data.body:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"At least one recipient, subject, and body are required for sending an email"
|
"At least one recipient, subject, and body are required for sending an email"
|
||||||
)
|
)
|
||||||
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
|
raw_message = await create_mime_message(input_data, execution_context)
|
||||||
sent_message = await asyncio.to_thread(
|
sent_message = await asyncio.to_thread(
|
||||||
lambda: service.users()
|
lambda: service.users()
|
||||||
.messages()
|
.messages()
|
||||||
@@ -692,30 +690,28 @@ class GmailCreateDraftBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
result = await self._create_draft(
|
result = await self._create_draft(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "result", GmailDraftResult(
|
yield "result", GmailDraftResult(
|
||||||
id=result["id"], message_id=result["message"]["id"], status="draft_created"
|
id=result["id"], message_id=result["message"]["id"], status="draft_created"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _create_draft(
|
async def _create_draft(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not input_data.to or not input_data.subject:
|
if not input_data.to or not input_data.subject:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"At least one recipient and subject are required for creating a draft"
|
"At least one recipient and subject are required for creating a draft"
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
|
raw_message = await create_mime_message(input_data, execution_context)
|
||||||
draft = await asyncio.to_thread(
|
draft = await asyncio.to_thread(
|
||||||
lambda: service.users()
|
lambda: service.users()
|
||||||
.drafts()
|
.drafts()
|
||||||
@@ -1100,7 +1096,7 @@ class GmailGetThreadBlock(GmailBase):
|
|||||||
|
|
||||||
|
|
||||||
async def _build_reply_message(
|
async def _build_reply_message(
|
||||||
service, input_data, graph_exec_id: str, user_id: str
|
service, input_data, execution_context: ExecutionContext
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Builds a reply MIME message for Gmail threads.
|
Builds a reply MIME message for Gmail threads.
|
||||||
@@ -1190,12 +1186,12 @@ async def _build_reply_message(
|
|||||||
# Handle attachments
|
# Handle attachments
|
||||||
for attach in input_data.attachments:
|
for attach in input_data.attachments:
|
||||||
local_path = await store_media_file(
|
local_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=attach,
|
file=attach,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
part = MIMEBase("application", "octet-stream")
|
part = MIMEBase("application", "octet-stream")
|
||||||
with open(abs_path, "rb") as f:
|
with open(abs_path, "rb") as f:
|
||||||
part.set_payload(f.read())
|
part.set_payload(f.read())
|
||||||
@@ -1311,16 +1307,14 @@ class GmailReplyBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
message = await self._reply(
|
message = await self._reply(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "messageId", message["id"]
|
yield "messageId", message["id"]
|
||||||
yield "threadId", message.get("threadId", input_data.threadId)
|
yield "threadId", message.get("threadId", input_data.threadId)
|
||||||
@@ -1343,11 +1337,11 @@ class GmailReplyBlock(GmailBase):
|
|||||||
yield "email", email
|
yield "email", email
|
||||||
|
|
||||||
async def _reply(
|
async def _reply(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# Build the reply message using the shared helper
|
# Build the reply message using the shared helper
|
||||||
raw, thread_id = await _build_reply_message(
|
raw, thread_id = await _build_reply_message(
|
||||||
service, input_data, graph_exec_id, user_id
|
service, input_data, execution_context
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send the message
|
# Send the message
|
||||||
@@ -1441,16 +1435,14 @@ class GmailDraftReplyBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
draft = await self._create_draft_reply(
|
draft = await self._create_draft_reply(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "draftId", draft["id"]
|
yield "draftId", draft["id"]
|
||||||
yield "messageId", draft["message"]["id"]
|
yield "messageId", draft["message"]["id"]
|
||||||
@@ -1458,11 +1450,11 @@ class GmailDraftReplyBlock(GmailBase):
|
|||||||
yield "status", "draft_created"
|
yield "status", "draft_created"
|
||||||
|
|
||||||
async def _create_draft_reply(
|
async def _create_draft_reply(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# Build the reply message using the shared helper
|
# Build the reply message using the shared helper
|
||||||
raw, thread_id = await _build_reply_message(
|
raw, thread_id = await _build_reply_message(
|
||||||
service, input_data, graph_exec_id, user_id
|
service, input_data, execution_context
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create draft with proper thread association
|
# Create draft with proper thread association
|
||||||
@@ -1629,23 +1621,21 @@ class GmailForwardBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
result = await self._forward_message(
|
result = await self._forward_message(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
graph_exec_id,
|
execution_context,
|
||||||
user_id,
|
|
||||||
)
|
)
|
||||||
yield "messageId", result["id"]
|
yield "messageId", result["id"]
|
||||||
yield "threadId", result.get("threadId", "")
|
yield "threadId", result.get("threadId", "")
|
||||||
yield "status", "forwarded"
|
yield "status", "forwarded"
|
||||||
|
|
||||||
async def _forward_message(
|
async def _forward_message(
|
||||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
self, service, input_data: Input, execution_context: ExecutionContext
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not input_data.to:
|
if not input_data.to:
|
||||||
raise ValueError("At least one recipient is required for forwarding")
|
raise ValueError("At least one recipient is required for forwarding")
|
||||||
@@ -1727,12 +1717,12 @@ To: {original_to}
|
|||||||
# Add any additional attachments
|
# Add any additional attachments
|
||||||
for attach in input_data.additionalAttachments:
|
for attach in input_data.additionalAttachments:
|
||||||
local_path = await store_media_file(
|
local_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=attach,
|
file=attach,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
part = MIMEBase("application", "octet-stream")
|
part = MIMEBase("application", "octet-stream")
|
||||||
with open(abs_path, "rb") as f:
|
with open(abs_path, "rb") as f:
|
||||||
part.set_payload(f.read())
|
part.set_payload(f.read())
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
@@ -116,10 +117,9 @@ class SendWebRequestBlock(Block):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _prepare_files(
|
async def _prepare_files(
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
files_name: str,
|
files_name: str,
|
||||||
files: list[MediaFileType],
|
files: list[MediaFileType],
|
||||||
user_id: str,
|
|
||||||
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
|
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
|
||||||
"""
|
"""
|
||||||
Prepare files for the request by storing them and reading their content.
|
Prepare files for the request by storing them and reading their content.
|
||||||
@@ -127,11 +127,16 @@ class SendWebRequestBlock(Block):
|
|||||||
(files_name, (filename, BytesIO, mime_type))
|
(files_name, (filename, BytesIO, mime_type))
|
||||||
"""
|
"""
|
||||||
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
if graph_exec_id is None:
|
||||||
|
raise ValueError("graph_exec_id is required for file operations")
|
||||||
|
|
||||||
for media in files:
|
for media in files:
|
||||||
# Normalise to a list so we can repeat the same key
|
# Normalise to a list so we can repeat the same key
|
||||||
rel_path = await store_media_file(
|
rel_path = await store_media_file(
|
||||||
graph_exec_id, media, user_id, return_content=False
|
file=media,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, rel_path)
|
abs_path = get_exec_file_path(graph_exec_id, rel_path)
|
||||||
async with aiofiles.open(abs_path, "rb") as f:
|
async with aiofiles.open(abs_path, "rb") as f:
|
||||||
@@ -143,7 +148,7 @@ class SendWebRequestBlock(Block):
|
|||||||
return files_payload
|
return files_payload
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs
|
self, input_data: Input, *, execution_context: ExecutionContext, **kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# ─── Parse/normalise body ────────────────────────────────────
|
# ─── Parse/normalise body ────────────────────────────────────
|
||||||
body = input_data.body
|
body = input_data.body
|
||||||
@@ -174,7 +179,7 @@ class SendWebRequestBlock(Block):
|
|||||||
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
||||||
if use_files:
|
if use_files:
|
||||||
files_payload = await self._prepare_files(
|
files_payload = await self._prepare_files(
|
||||||
graph_exec_id, input_data.files_name, input_data.files, user_id
|
execution_context, input_data.files_name, input_data.files
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enforce body format rules
|
# Enforce body format rules
|
||||||
@@ -238,9 +243,8 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
credentials: HostScopedCredentials,
|
credentials: HostScopedCredentials,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
||||||
@@ -271,6 +275,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
|||||||
|
|
||||||
# Use parent class run method
|
# Use parent class run method
|
||||||
async for output_name, output_data in super().run(
|
async for output_name, output_data in super().run(
|
||||||
base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs
|
base_input, execution_context=execution_context, **kwargs
|
||||||
):
|
):
|
||||||
yield output_name, output_data
|
yield output_name, output_data
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.mock import MockObject
|
from backend.util.mock import MockObject
|
||||||
@@ -462,18 +463,21 @@ class AgentFileInputBlock(AgentInputBlock):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
if not input_data.value:
|
if not input_data.value:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Determine return format based on user preference
|
||||||
|
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
|
||||||
|
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
|
||||||
|
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
|
||||||
|
|
||||||
yield "result", await store_media_file(
|
yield "result", await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.value,
|
file=input_data.value,
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=input_data.base_64,
|
return_format=return_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -162,8 +162,16 @@ class LinearClient:
|
|||||||
"searchTerm": team_name,
|
"searchTerm": team_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
team_id = await self.query(query, variables)
|
result = await self.query(query, variables)
|
||||||
return team_id["teams"]["nodes"][0]["id"]
|
nodes = result["teams"]["nodes"]
|
||||||
|
|
||||||
|
if not nodes:
|
||||||
|
raise LinearAPIException(
|
||||||
|
f"Team '{team_name}' not found. Check the team name or key and try again.",
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
|
return nodes[0]["id"]
|
||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@@ -240,17 +248,44 @@ class LinearClient:
|
|||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def try_search_issues(self, term: str) -> list[Issue]:
|
async def try_search_issues(
|
||||||
|
self,
|
||||||
|
term: str,
|
||||||
|
max_results: int = 10,
|
||||||
|
team_id: str | None = None,
|
||||||
|
) -> list[Issue]:
|
||||||
try:
|
try:
|
||||||
query = """
|
query = """
|
||||||
query SearchIssues($term: String!, $includeComments: Boolean!) {
|
query SearchIssues(
|
||||||
searchIssues(term: $term, includeComments: $includeComments) {
|
$term: String!,
|
||||||
|
$first: Int,
|
||||||
|
$teamId: String
|
||||||
|
) {
|
||||||
|
searchIssues(
|
||||||
|
term: $term,
|
||||||
|
first: $first,
|
||||||
|
teamId: $teamId
|
||||||
|
) {
|
||||||
nodes {
|
nodes {
|
||||||
id
|
id
|
||||||
identifier
|
identifier
|
||||||
title
|
title
|
||||||
description
|
description
|
||||||
priority
|
priority
|
||||||
|
createdAt
|
||||||
|
state {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
type
|
||||||
|
}
|
||||||
|
project {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
}
|
||||||
|
assignee {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -258,7 +293,8 @@ class LinearClient:
|
|||||||
|
|
||||||
variables: dict[str, Any] = {
|
variables: dict[str, Any] = {
|
||||||
"term": term,
|
"term": term,
|
||||||
"includeComments": True,
|
"first": max_results,
|
||||||
|
"teamId": team_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
issues = await self.query(query, variables)
|
issues = await self.query(query, variables)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from ._config import (
|
|||||||
LinearScope,
|
LinearScope,
|
||||||
linear,
|
linear,
|
||||||
)
|
)
|
||||||
from .models import CreateIssueResponse, Issue
|
from .models import CreateIssueResponse, Issue, State
|
||||||
|
|
||||||
|
|
||||||
class LinearCreateIssueBlock(Block):
|
class LinearCreateIssueBlock(Block):
|
||||||
@@ -135,9 +135,20 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
description="Linear credentials with read permissions",
|
description="Linear credentials with read permissions",
|
||||||
required_scopes={LinearScope.READ},
|
required_scopes={LinearScope.READ},
|
||||||
)
|
)
|
||||||
|
max_results: int = SchemaField(
|
||||||
|
description="Maximum number of results to return",
|
||||||
|
default=10,
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
)
|
||||||
|
team_name: str | None = SchemaField(
|
||||||
|
description="Optional team name to filter results (e.g., 'Internal', 'Open Source')",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
issues: list[Issue] = SchemaField(description="List of issues")
|
issues: list[Issue] = SchemaField(description="List of issues")
|
||||||
|
error: str = SchemaField(description="Error message if the search failed")
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -145,8 +156,11 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
description="Searches for issues on Linear",
|
description="Searches for issues on Linear",
|
||||||
input_schema=self.Input,
|
input_schema=self.Input,
|
||||||
output_schema=self.Output,
|
output_schema=self.Output,
|
||||||
|
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
|
||||||
test_input={
|
test_input={
|
||||||
"term": "Test issue",
|
"term": "Test issue",
|
||||||
|
"max_results": 10,
|
||||||
|
"team_name": None,
|
||||||
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS_OAUTH,
|
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||||
@@ -156,10 +170,14 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
[
|
[
|
||||||
Issue(
|
Issue(
|
||||||
id="abc123",
|
id="abc123",
|
||||||
identifier="abc123",
|
identifier="TST-123",
|
||||||
title="Test issue",
|
title="Test issue",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
priority=1,
|
priority=1,
|
||||||
|
state=State(
|
||||||
|
id="state1", name="In Progress", type="started"
|
||||||
|
),
|
||||||
|
createdAt="2026-01-15T10:00:00.000Z",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -168,10 +186,12 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
"search_issues": lambda *args, **kwargs: [
|
"search_issues": lambda *args, **kwargs: [
|
||||||
Issue(
|
Issue(
|
||||||
id="abc123",
|
id="abc123",
|
||||||
identifier="abc123",
|
identifier="TST-123",
|
||||||
title="Test issue",
|
title="Test issue",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
priority=1,
|
priority=1,
|
||||||
|
state=State(id="state1", name="In Progress", type="started"),
|
||||||
|
createdAt="2026-01-15T10:00:00.000Z",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -181,10 +201,22 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
async def search_issues(
|
async def search_issues(
|
||||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||||
term: str,
|
term: str,
|
||||||
|
max_results: int = 10,
|
||||||
|
team_name: str | None = None,
|
||||||
) -> list[Issue]:
|
) -> list[Issue]:
|
||||||
client = LinearClient(credentials=credentials)
|
client = LinearClient(credentials=credentials)
|
||||||
response: list[Issue] = await client.try_search_issues(term=term)
|
|
||||||
return response
|
# Resolve team name to ID if provided
|
||||||
|
# Raises LinearAPIException with descriptive message if team not found
|
||||||
|
team_id: str | None = None
|
||||||
|
if team_name:
|
||||||
|
team_id = await client.try_get_team_by_name(team_name=team_name)
|
||||||
|
|
||||||
|
return await client.try_search_issues(
|
||||||
|
term=term,
|
||||||
|
max_results=max_results,
|
||||||
|
team_id=team_id,
|
||||||
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
@@ -196,7 +228,10 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
"""Execute the issue search"""
|
"""Execute the issue search"""
|
||||||
try:
|
try:
|
||||||
issues = await self.search_issues(
|
issues = await self.search_issues(
|
||||||
credentials=credentials, term=input_data.term
|
credentials=credentials,
|
||||||
|
term=input_data.term,
|
||||||
|
max_results=input_data.max_results,
|
||||||
|
team_name=input_data.team_name,
|
||||||
)
|
)
|
||||||
yield "issues", issues
|
yield "issues", issues
|
||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
|
|||||||
@@ -36,12 +36,21 @@ class Project(BaseModel):
|
|||||||
content: str | None = None
|
content: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class State(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
type: str | None = (
|
||||||
|
None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Issue(BaseModel):
|
class Issue(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
identifier: str
|
identifier: str
|
||||||
title: str
|
title: str
|
||||||
description: str | None
|
description: str | None
|
||||||
priority: int
|
priority: int
|
||||||
|
state: State | None = None
|
||||||
project: Project | None = None
|
project: Project | None = None
|
||||||
createdAt: str | None = None
|
createdAt: str | None = None
|
||||||
comments: list[Comment] | None = None
|
comments: list[Comment] | None = None
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from backend.data.model import (
|
|||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.logging import TruncatedLogger
|
from backend.util.logging import TruncatedLogger
|
||||||
from backend.util.prompt import compress_prompt, estimate_token_count
|
from backend.util.prompt import compress_context, estimate_token_count
|
||||||
from backend.util.text import TextFormatter
|
from backend.util.text import TextFormatter
|
||||||
|
|
||||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||||
@@ -115,7 +115,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
||||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||||
# AI/ML API models
|
# AI/ML API models
|
||||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||||
@@ -271,6 +271,9 @@ MODEL_METADATA = {
|
|||||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
||||||
), # claude-4-sonnet-20250514
|
), # claude-4-sonnet-20250514
|
||||||
|
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
||||||
|
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
||||||
|
), # claude-opus-4-6
|
||||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||||
), # claude-opus-4-5-20251101
|
), # claude-opus-4-5-20251101
|
||||||
@@ -280,9 +283,6 @@ MODEL_METADATA = {
|
|||||||
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
||||||
), # claude-haiku-4-5-20251001
|
), # claude-haiku-4-5-20251001
|
||||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
|
||||||
"anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2
|
|
||||||
), # claude-3-7-sonnet-20250219
|
|
||||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||||
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
||||||
), # claude-3-haiku-20240307
|
), # claude-3-haiku-20240307
|
||||||
@@ -638,11 +638,18 @@ async def llm_call(
|
|||||||
context_window = llm_model.context_window
|
context_window = llm_model.context_window
|
||||||
|
|
||||||
if compress_prompt_to_fit:
|
if compress_prompt_to_fit:
|
||||||
prompt = compress_prompt(
|
result = await compress_context(
|
||||||
messages=prompt,
|
messages=prompt,
|
||||||
target_tokens=llm_model.context_window // 2,
|
target_tokens=llm_model.context_window // 2,
|
||||||
lossy_ok=True,
|
client=None, # Truncation-only, no LLM summarization
|
||||||
|
reserve=0, # Caller handles response token budget separately
|
||||||
)
|
)
|
||||||
|
if result.error:
|
||||||
|
logger.warning(
|
||||||
|
f"Prompt compression did not meet target: {result.error}. "
|
||||||
|
f"Proceeding with {result.token_count} tokens."
|
||||||
|
)
|
||||||
|
prompt = result.messages
|
||||||
|
|
||||||
# Calculate available tokens based on context window and input length
|
# Calculate available tokens based on context window and input length
|
||||||
estimated_input_tokens = estimate_token_count(prompt)
|
estimated_input_tokens = estimate_token_count(prompt)
|
||||||
|
|||||||
@@ -1,251 +0,0 @@
|
|||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from typing import Literal, Optional
|
|
||||||
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
|
||||||
from moviepy.video.fx.Loop import Loop
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class MediaDurationBlock(Block):
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
media_in: MediaFileType = SchemaField(
|
|
||||||
description="Media input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
is_video: bool = SchemaField(
|
|
||||||
description="Whether the media is a video (True) or audio (False).",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
duration: float = SchemaField(
|
|
||||||
description="Duration of the media file (in seconds)."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
|
||||||
description="Block to get the duration of a media file.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=MediaDurationBlock.Input,
|
|
||||||
output_schema=MediaDurationBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
graph_exec_id: str,
|
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# 1) Store the input media locally
|
|
||||||
local_media_path = await store_media_file(
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.media_in,
|
|
||||||
user_id=user_id,
|
|
||||||
return_content=False,
|
|
||||||
)
|
|
||||||
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
|
|
||||||
|
|
||||||
# 2) Load the clip
|
|
||||||
if input_data.is_video:
|
|
||||||
clip = VideoFileClip(media_abspath)
|
|
||||||
else:
|
|
||||||
clip = AudioFileClip(media_abspath)
|
|
||||||
|
|
||||||
yield "duration", clip.duration
|
|
||||||
|
|
||||||
|
|
||||||
class LoopVideoBlock(Block):
|
|
||||||
"""
|
|
||||||
Block for looping (repeating) a video clip until a given duration or number of loops.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="The input video (can be a URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
|
|
||||||
duration: Optional[float] = SchemaField(
|
|
||||||
description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.",
|
|
||||||
default=None,
|
|
||||||
ge=0.0,
|
|
||||||
)
|
|
||||||
n_loops: Optional[int] = SchemaField(
|
|
||||||
description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).",
|
|
||||||
default=None,
|
|
||||||
ge=1,
|
|
||||||
)
|
|
||||||
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
|
||||||
description="How to return the output video. Either a relative path or base64 data URI.",
|
|
||||||
default="file_path",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: str = SchemaField(
|
|
||||||
description="Looped video returned either as a relative path or a data URI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
|
||||||
description="Block to loop a video to a given duration or number of repeats.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=LoopVideoBlock.Input,
|
|
||||||
output_schema=LoopVideoBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
node_exec_id: str,
|
|
||||||
graph_exec_id: str,
|
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# 1) Store the input video locally
|
|
||||||
local_video_path = await store_media_file(
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.video_in,
|
|
||||||
user_id=user_id,
|
|
||||||
return_content=False,
|
|
||||||
)
|
|
||||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
|
||||||
|
|
||||||
# 2) Load the clip
|
|
||||||
clip = VideoFileClip(input_abspath)
|
|
||||||
|
|
||||||
# 3) Apply the loop effect
|
|
||||||
looped_clip = clip
|
|
||||||
if input_data.duration:
|
|
||||||
# Loop until we reach the specified duration
|
|
||||||
looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)])
|
|
||||||
elif input_data.n_loops:
|
|
||||||
looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)])
|
|
||||||
else:
|
|
||||||
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
|
||||||
|
|
||||||
assert isinstance(looped_clip, VideoFileClip)
|
|
||||||
|
|
||||||
# 4) Save the looped output
|
|
||||||
output_filename = MediaFileType(
|
|
||||||
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
|
|
||||||
)
|
|
||||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
|
||||||
|
|
||||||
looped_clip = looped_clip.with_audio(clip.audio)
|
|
||||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
|
||||||
|
|
||||||
# Return as data URI
|
|
||||||
video_out = await store_media_file(
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=output_filename,
|
|
||||||
user_id=user_id,
|
|
||||||
return_content=input_data.output_return_type == "data_uri",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
|
|
||||||
|
|
||||||
class AddAudioToVideoBlock(Block):
|
|
||||||
"""
|
|
||||||
Block that adds (attaches) an audio track to an existing video.
|
|
||||||
Optionally scale the volume of the new track.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Video input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
audio_in: MediaFileType = SchemaField(
|
|
||||||
description="Audio input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
volume: float = SchemaField(
|
|
||||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
|
||||||
default=1.0,
|
|
||||||
)
|
|
||||||
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
|
||||||
description="Return the final output as a relative path or base64 data URI.",
|
|
||||||
default="file_path",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Final video (with attached audio), as a path or data URI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="3503748d-62b6-4425-91d6-725b064af509",
|
|
||||||
description="Block to attach an audio file to a video file using moviepy.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=AddAudioToVideoBlock.Input,
|
|
||||||
output_schema=AddAudioToVideoBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
node_exec_id: str,
|
|
||||||
graph_exec_id: str,
|
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# 1) Store the inputs locally
|
|
||||||
local_video_path = await store_media_file(
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.video_in,
|
|
||||||
user_id=user_id,
|
|
||||||
return_content=False,
|
|
||||||
)
|
|
||||||
local_audio_path = await store_media_file(
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.audio_in,
|
|
||||||
user_id=user_id,
|
|
||||||
return_content=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
|
||||||
video_abspath = os.path.join(abs_temp_dir, local_video_path)
|
|
||||||
audio_abspath = os.path.join(abs_temp_dir, local_audio_path)
|
|
||||||
|
|
||||||
# 2) Load video + audio with moviepy
|
|
||||||
video_clip = VideoFileClip(video_abspath)
|
|
||||||
audio_clip = AudioFileClip(audio_abspath)
|
|
||||||
# Optionally scale volume
|
|
||||||
if input_data.volume != 1.0:
|
|
||||||
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
|
||||||
|
|
||||||
# 3) Attach the new audio track
|
|
||||||
final_clip = video_clip.with_audio(audio_clip)
|
|
||||||
|
|
||||||
# 4) Write to output file
|
|
||||||
output_filename = MediaFileType(
|
|
||||||
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
|
|
||||||
)
|
|
||||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
|
||||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
|
||||||
|
|
||||||
# 5) Return either path or data URI
|
|
||||||
video_out = await store_media_file(
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=output_filename,
|
|
||||||
user_id=user_id,
|
|
||||||
return_content=input_data.output_return_type == "data_uri",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
@@ -11,6 +11,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -112,8 +113,7 @@ class ScreenshotWebPageBlock(Block):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def take_screenshot(
|
async def take_screenshot(
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
url: str,
|
url: str,
|
||||||
viewport_width: int,
|
viewport_width: int,
|
||||||
viewport_height: int,
|
viewport_height: int,
|
||||||
@@ -155,12 +155,11 @@ class ScreenshotWebPageBlock(Block):
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"image": await store_media_file(
|
"image": await store_media_file(
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=MediaFileType(
|
file=MediaFileType(
|
||||||
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
|
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
|
||||||
),
|
),
|
||||||
user_id=user_id,
|
execution_context=execution_context,
|
||||||
return_content=True,
|
return_format="for_block_output",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,15 +168,13 @@ class ScreenshotWebPageBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
graph_exec_id: str,
|
execution_context: ExecutionContext,
|
||||||
user_id: str,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
screenshot_data = await self.take_screenshot(
|
screenshot_data = await self.take_screenshot(
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
graph_exec_id=graph_exec_id,
|
execution_context=execution_context,
|
||||||
user_id=user_id,
|
|
||||||
url=input_data.url,
|
url=input_data.url,
|
||||||
viewport_width=input_data.viewport_width,
|
viewport_width=input_data.viewport_width,
|
||||||
viewport_height=input_data.viewport_height,
|
viewport_height=input_data.viewport_height,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import ContributorDetails, SchemaField
|
from backend.data.model import ContributorDetails, SchemaField
|
||||||
from backend.util.file import get_exec_file_path, store_media_file
|
from backend.util.file import get_exec_file_path, store_media_file
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
@@ -98,7 +99,7 @@ class ReadSpreadsheetBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
import csv
|
import csv
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
@@ -106,14 +107,16 @@ class ReadSpreadsheetBlock(Block):
|
|||||||
# Determine data source - prefer file_input if provided, otherwise use contents
|
# Determine data source - prefer file_input if provided, otherwise use contents
|
||||||
if input_data.file_input:
|
if input_data.file_input:
|
||||||
stored_file_path = await store_media_file(
|
stored_file_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.file_input,
|
file=input_data.file_input,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get full file path
|
# Get full file path
|
||||||
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||||
|
file_path = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, stored_file_path
|
||||||
|
)
|
||||||
if not Path(file_path).exists():
|
if not Path(file_path).exists():
|
||||||
raise ValueError(f"File does not exist: {file_path}")
|
raise ValueError(f"File does not exist: {file_path}")
|
||||||
|
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum):
|
|||||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||||
|
|
||||||
# Anthropic
|
# Anthropic
|
||||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_name(self) -> str:
|
def provider_name(self) -> str:
|
||||||
@@ -137,7 +137,7 @@ class StagehandObserveBlock(Block):
|
|||||||
model: StagehandRecommendedLlmModel = SchemaField(
|
model: StagehandRecommendedLlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
description="LLM to use for Stagehand (provider is inferred)",
|
description="LLM to use for Stagehand (provider is inferred)",
|
||||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
model_credentials: AICredentials = AICredentialsField()
|
model_credentials: AICredentials = AICredentialsField()
|
||||||
@@ -182,10 +182,7 @@ class StagehandObserveBlock(Block):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}")
|
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
|
||||||
logger.info(
|
|
||||||
f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with disable_signal_handling():
|
with disable_signal_handling():
|
||||||
stagehand = Stagehand(
|
stagehand = Stagehand(
|
||||||
@@ -230,7 +227,7 @@ class StagehandActBlock(Block):
|
|||||||
model: StagehandRecommendedLlmModel = SchemaField(
|
model: StagehandRecommendedLlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
description="LLM to use for Stagehand (provider is inferred)",
|
description="LLM to use for Stagehand (provider is inferred)",
|
||||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
model_credentials: AICredentials = AICredentialsField()
|
model_credentials: AICredentials = AICredentialsField()
|
||||||
@@ -282,10 +279,7 @@ class StagehandActBlock(Block):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}")
|
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
|
||||||
logger.info(
|
|
||||||
f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with disable_signal_handling():
|
with disable_signal_handling():
|
||||||
stagehand = Stagehand(
|
stagehand = Stagehand(
|
||||||
@@ -330,7 +324,7 @@ class StagehandExtractBlock(Block):
|
|||||||
model: StagehandRecommendedLlmModel = SchemaField(
|
model: StagehandRecommendedLlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
description="LLM to use for Stagehand (provider is inferred)",
|
description="LLM to use for Stagehand (provider is inferred)",
|
||||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
model_credentials: AICredentials = AICredentialsField()
|
model_credentials: AICredentials = AICredentialsField()
|
||||||
@@ -370,10 +364,7 @@ class StagehandExtractBlock(Block):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}")
|
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
|
||||||
logger.info(
|
|
||||||
f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with disable_signal_handling():
|
with disable_signal_handling():
|
||||||
stagehand = Stagehand(
|
stagehand = Stagehand(
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -17,7 +18,9 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
@@ -102,7 +105,7 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
test_output=[
|
test_output=[
|
||||||
(
|
(
|
||||||
"video_url",
|
"video_url",
|
||||||
"https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
|
lambda x: x.startswith(("workspace://", "data:")),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
@@ -110,9 +113,10 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
|
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
|
||||||
"status": "created",
|
"status": "created",
|
||||||
},
|
},
|
||||||
|
# Use data URI to avoid HTTP requests during tests
|
||||||
"get_clip_status": lambda *args, **kwargs: {
|
"get_clip_status": lambda *args, **kwargs: {
|
||||||
"status": "done",
|
"status": "done",
|
||||||
"result_url": "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
|
"result_url": "data:video/mp4;base64,AAAA",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -138,7 +142,12 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Create the clip
|
# Create the clip
|
||||||
payload = {
|
payload = {
|
||||||
@@ -165,7 +174,14 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
for _ in range(input_data.max_polling_attempts):
|
for _ in range(input_data.max_polling_attempts):
|
||||||
status_response = await self.get_clip_status(credentials.api_key, clip_id)
|
status_response = await self.get_clip_status(credentials.api_key, clip_id)
|
||||||
if status_response["status"] == "done":
|
if status_response["status"] == "done":
|
||||||
yield "video_url", status_response["result_url"]
|
# Store the generated video to the user's workspace for persistence
|
||||||
|
video_url = status_response["result_url"]
|
||||||
|
stored_url = await store_media_file(
|
||||||
|
file=MediaFileType(video_url),
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
yield "video_url", stored_url
|
||||||
return
|
return
|
||||||
elif status_response["status"] == "error":
|
elif status_response["status"] == "error":
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from backend.blocks.iteration import StepThroughItemsBlock
|
|||||||
from backend.blocks.llm import AITextSummarizerBlock
|
from backend.blocks.llm import AITextSummarizerBlock
|
||||||
from backend.blocks.text import ExtractTextInformationBlock
|
from backend.blocks.text import ExtractTextInformationBlock
|
||||||
from backend.blocks.xml_parser import XMLParserBlock
|
from backend.blocks.xml_parser import XMLParserBlock
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
@@ -233,9 +234,12 @@ class TestStoreMediaFileSecurity:
|
|||||||
|
|
||||||
with pytest.raises(ValueError, match="File too large"):
|
with pytest.raises(ValueError, match="File too large"):
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id="test",
|
|
||||||
file=MediaFileType(large_data_uri),
|
file=MediaFileType(large_data_uri),
|
||||||
|
execution_context=ExecutionContext(
|
||||||
user_id="test_user",
|
user_id="test_user",
|
||||||
|
graph_exec_id="test",
|
||||||
|
),
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("backend.util.file.Path")
|
@patch("backend.util.file.Path")
|
||||||
@@ -270,9 +274,12 @@ class TestStoreMediaFileSecurity:
|
|||||||
# Should raise an error when directory size exceeds limit
|
# Should raise an error when directory size exceeds limit
|
||||||
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
|
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
graph_exec_id="test",
|
|
||||||
file=MediaFileType(
|
file=MediaFileType(
|
||||||
"data:text/plain;base64,dGVzdA=="
|
"data:text/plain;base64,dGVzdA=="
|
||||||
), # Small test file
|
), # Small test file
|
||||||
|
execution_context=ExecutionContext(
|
||||||
user_id="test_user",
|
user_id="test_user",
|
||||||
|
graph_exec_id="test",
|
||||||
|
),
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,10 +11,22 @@ from backend.blocks.http import (
|
|||||||
HttpMethod,
|
HttpMethod,
|
||||||
SendAuthenticatedWebRequestBlock,
|
SendAuthenticatedWebRequestBlock,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import HostScopedCredentials
|
from backend.data.model import HostScopedCredentials
|
||||||
from backend.util.request import Response
|
from backend.util.request import Response
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_context(
|
||||||
|
graph_exec_id: str = "test-exec-id",
|
||||||
|
user_id: str = "test-user-id",
|
||||||
|
) -> ExecutionContext:
|
||||||
|
"""Helper to create test ExecutionContext."""
|
||||||
|
return ExecutionContext(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestHttpBlockWithHostScopedCredentials:
|
class TestHttpBlockWithHostScopedCredentials:
|
||||||
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
||||||
|
|
||||||
@@ -105,8 +117,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=exact_match_credentials,
|
credentials=exact_match_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -161,8 +172,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=wildcard_credentials,
|
credentials=wildcard_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -208,8 +218,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=non_matching_credentials,
|
credentials=non_matching_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -258,8 +267,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=exact_match_credentials,
|
credentials=exact_match_credentials,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -318,8 +326,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=auto_discovered_creds, # Execution manager found these
|
credentials=auto_discovered_creds, # Execution manager found these
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -382,8 +389,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=multi_header_creds,
|
credentials=multi_header_creds,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -471,8 +477,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=test_creds,
|
credentials=test_creds,
|
||||||
graph_exec_id="test-exec-id",
|
execution_context=make_test_context(),
|
||||||
user_id="test-user-id",
|
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,77 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.blocks.encoder_block import TextEncoderBlock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_basic():
|
||||||
|
"""Test basic encoding of newlines and special characters."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="Hello\nWorld")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
assert result[0][1] == "Hello\\nWorld"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_multiple_escapes():
|
||||||
|
"""Test encoding of multiple escape sequences."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(
|
||||||
|
TextEncoderBlock.Input(text="Line1\nLine2\tTabbed\rCarriage")
|
||||||
|
):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
assert "\\n" in result[0][1]
|
||||||
|
assert "\\t" in result[0][1]
|
||||||
|
assert "\\r" in result[0][1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_unicode():
|
||||||
|
"""Test that unicode characters are handled correctly."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="Hello 世界\n")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
# Unicode characters should be escaped as \uXXXX sequences
|
||||||
|
assert "\\n" in result[0][1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_empty_string():
|
||||||
|
"""Test encoding of an empty string."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
assert result[0][1] == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_error_handling():
|
||||||
|
"""Test that encoding errors are handled gracefully."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
|
||||||
|
with patch("codecs.encode", side_effect=Exception("Mocked encoding error")):
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="test")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "error"
|
||||||
|
assert "Mocked encoding error" in result[0][1]
|
||||||
@@ -11,6 +11,7 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util import json, text
|
from backend.util import json, text
|
||||||
from backend.util.file import get_exec_file_path, store_media_file
|
from backend.util.file import get_exec_file_path, store_media_file
|
||||||
@@ -444,18 +445,21 @@ class FileReadBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Store the media file properly (handles URLs, data URIs, etc.)
|
# Store the media file properly (handles URLs, data URIs, etc.)
|
||||||
stored_file_path = await store_media_file(
|
stored_file_path = await store_media_file(
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
file=input_data.file_input,
|
file=input_data.file_input,
|
||||||
return_content=False,
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get full file path
|
# Get full file path (graph_exec_id validated by store_media_file above)
|
||||||
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
if not execution_context.graph_exec_id:
|
||||||
|
raise ValueError("execution_context.graph_exec_id is required")
|
||||||
|
file_path = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, stored_file_path
|
||||||
|
)
|
||||||
|
|
||||||
if not Path(file_path).exists():
|
if not Path(file_path).exists():
|
||||||
raise ValueError(f"File does not exist: {file_path}")
|
raise ValueError(f"File does not exist: {file_path}")
|
||||||
|
|||||||
37
autogpt_platform/backend/backend/blocks/video/__init__.py
Normal file
37
autogpt_platform/backend/backend/blocks/video/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Video editing blocks for AutoGPT Platform.
|
||||||
|
|
||||||
|
This module provides blocks for:
|
||||||
|
- Downloading videos from URLs (YouTube, Vimeo, news sites, direct links)
|
||||||
|
- Clipping/trimming video segments
|
||||||
|
- Concatenating multiple videos
|
||||||
|
- Adding text overlays
|
||||||
|
- Adding AI-generated narration
|
||||||
|
- Getting media duration
|
||||||
|
- Looping videos
|
||||||
|
- Adding audio to videos
|
||||||
|
|
||||||
|
Dependencies:
|
||||||
|
- yt-dlp: For video downloading
|
||||||
|
- moviepy: For video editing operations
|
||||||
|
- elevenlabs: For AI narration (optional)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from backend.blocks.video.add_audio import AddAudioToVideoBlock
|
||||||
|
from backend.blocks.video.clip import VideoClipBlock
|
||||||
|
from backend.blocks.video.concat import VideoConcatBlock
|
||||||
|
from backend.blocks.video.download import VideoDownloadBlock
|
||||||
|
from backend.blocks.video.duration import MediaDurationBlock
|
||||||
|
from backend.blocks.video.loop import LoopVideoBlock
|
||||||
|
from backend.blocks.video.narration import VideoNarrationBlock
|
||||||
|
from backend.blocks.video.text_overlay import VideoTextOverlayBlock
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AddAudioToVideoBlock",
|
||||||
|
"LoopVideoBlock",
|
||||||
|
"MediaDurationBlock",
|
||||||
|
"VideoClipBlock",
|
||||||
|
"VideoConcatBlock",
|
||||||
|
"VideoDownloadBlock",
|
||||||
|
"VideoNarrationBlock",
|
||||||
|
"VideoTextOverlayBlock",
|
||||||
|
]
|
||||||
131
autogpt_platform/backend/backend/blocks/video/_utils.py
Normal file
131
autogpt_platform/backend/backend/blocks/video/_utils.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Shared utilities for video blocks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Known operation tags added by video blocks
|
||||||
|
_VIDEO_OPS = (
|
||||||
|
r"(?:clip|overlay|narrated|looped|concat|audio_attached|with_audio|narration)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Matches: {node_exec_id}_{operation}_ where node_exec_id contains a UUID
|
||||||
|
_BLOCK_PREFIX_RE = re.compile(
|
||||||
|
r"^[a-zA-Z0-9_-]*"
|
||||||
|
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||||
|
r"[a-zA-Z0-9_-]*"
|
||||||
|
r"_" + _VIDEO_OPS + r"_"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Matches: a lone {node_exec_id}_ prefix (no operation keyword, e.g. download output)
|
||||||
|
_UUID_PREFIX_RE = re.compile(
|
||||||
|
r"^[a-zA-Z0-9_-]*"
|
||||||
|
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||||
|
r"[a-zA-Z0-9_-]*_"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_source_name(input_path: str, max_length: int = 50) -> str:
|
||||||
|
"""Extract the original source filename by stripping block-generated prefixes.
|
||||||
|
|
||||||
|
Iteratively removes {node_exec_id}_{operation}_ prefixes that accumulate
|
||||||
|
when chaining video blocks, recovering the original human-readable name.
|
||||||
|
|
||||||
|
Safe for plain filenames (no UUID -> no stripping).
|
||||||
|
Falls back to "video" if everything is stripped.
|
||||||
|
"""
|
||||||
|
stem = Path(input_path).stem
|
||||||
|
|
||||||
|
# Pass 1: strip {node_exec_id}_{operation}_ prefixes iteratively
|
||||||
|
while _BLOCK_PREFIX_RE.match(stem):
|
||||||
|
stem = _BLOCK_PREFIX_RE.sub("", stem, count=1)
|
||||||
|
|
||||||
|
# Pass 2: strip a lone {node_exec_id}_ prefix (e.g. from download block)
|
||||||
|
if _UUID_PREFIX_RE.match(stem):
|
||||||
|
stem = _UUID_PREFIX_RE.sub("", stem, count=1)
|
||||||
|
|
||||||
|
if not stem:
|
||||||
|
return "video"
|
||||||
|
|
||||||
|
return stem[:max_length]
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_codecs(output_path: str) -> tuple[str, str]:
|
||||||
|
"""Get appropriate video and audio codecs based on output file extension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path: Path to the output file (used to determine extension)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (video_codec, audio_codec)
|
||||||
|
|
||||||
|
Codec mappings:
|
||||||
|
- .mp4: H.264 + AAC (universal compatibility)
|
||||||
|
- .webm: VP8 + Vorbis (web streaming)
|
||||||
|
- .mkv: H.264 + AAC (container supports many codecs)
|
||||||
|
- .mov: H.264 + AAC (Apple QuickTime, widely compatible)
|
||||||
|
- .m4v: H.264 + AAC (Apple iTunes/devices)
|
||||||
|
- .avi: MPEG-4 + MP3 (legacy Windows)
|
||||||
|
"""
|
||||||
|
ext = os.path.splitext(output_path)[1].lower()
|
||||||
|
|
||||||
|
codec_map: dict[str, tuple[str, str]] = {
|
||||||
|
".mp4": ("libx264", "aac"),
|
||||||
|
".webm": ("libvpx", "libvorbis"),
|
||||||
|
".mkv": ("libx264", "aac"),
|
||||||
|
".mov": ("libx264", "aac"),
|
||||||
|
".m4v": ("libx264", "aac"),
|
||||||
|
".avi": ("mpeg4", "libmp3lame"),
|
||||||
|
}
|
||||||
|
|
||||||
|
return codec_map.get(ext, ("libx264", "aac"))
|
||||||
|
|
||||||
|
|
||||||
|
def strip_chapters_inplace(video_path: str) -> None:
|
||||||
|
"""Strip chapter metadata from a media file in-place using ffmpeg.
|
||||||
|
|
||||||
|
MoviePy 2.x crashes with IndexError when parsing files with embedded
|
||||||
|
chapter metadata (https://github.com/Zulko/moviepy/issues/2419).
|
||||||
|
This strips chapters without re-encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Absolute path to the media file to strip chapters from.
|
||||||
|
"""
|
||||||
|
base, ext = os.path.splitext(video_path)
|
||||||
|
tmp_path = base + ".tmp" + ext
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-i",
|
||||||
|
video_path,
|
||||||
|
"-map_chapters",
|
||||||
|
"-1",
|
||||||
|
"-codec",
|
||||||
|
"copy",
|
||||||
|
tmp_path,
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
logger.warning(
|
||||||
|
"ffmpeg chapter strip failed (rc=%d): %s",
|
||||||
|
result.returncode,
|
||||||
|
result.stderr,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
os.replace(tmp_path, video_path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.warning("ffmpeg not found; skipping chapter strip")
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp_path):
|
||||||
|
os.unlink(tmp_path)
|
||||||
113
autogpt_platform/backend/backend/blocks/video/add_audio.py
Normal file
113
autogpt_platform/backend/backend/blocks/video/add_audio.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""AddAudioToVideoBlock - Attach an audio track to a video file."""
|
||||||
|
|
||||||
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class AddAudioToVideoBlock(Block):
|
||||||
|
"""Add (attach) an audio track to an existing video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Video input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
audio_in: MediaFileType = SchemaField(
|
||||||
|
description="Audio input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
volume: float = SchemaField(
|
||||||
|
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||||
|
default=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Final video (with attached audio), as a path or data URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="3503748d-62b6-4425-91d6-725b064af509",
|
||||||
|
description="Block to attach an audio file to a video file using moviepy.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=AddAudioToVideoBlock.Input,
|
||||||
|
output_schema=AddAudioToVideoBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
|
# 1) Store the inputs locally
|
||||||
|
local_video_path = await store_media_file(
|
||||||
|
file=input_data.video_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
local_audio_path = await store_media_file(
|
||||||
|
file=input_data.audio_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
video_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||||
|
audio_abspath = get_exec_file_path(graph_exec_id, local_audio_path)
|
||||||
|
|
||||||
|
# 2) Load video + audio with moviepy
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
strip_chapters_inplace(audio_abspath)
|
||||||
|
video_clip = None
|
||||||
|
audio_clip = None
|
||||||
|
final_clip = None
|
||||||
|
try:
|
||||||
|
video_clip = VideoFileClip(video_abspath)
|
||||||
|
audio_clip = AudioFileClip(audio_abspath)
|
||||||
|
# Optionally scale volume
|
||||||
|
if input_data.volume != 1.0:
|
||||||
|
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
||||||
|
|
||||||
|
# 3) Attach the new audio track
|
||||||
|
final_clip = video_clip.with_audio(audio_clip)
|
||||||
|
|
||||||
|
# 4) Write to output file
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_with_audio_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||||
|
final_clip.write_videofile(
|
||||||
|
output_abspath, codec="libx264", audio_codec="aac"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if final_clip:
|
||||||
|
final_clip.close()
|
||||||
|
if audio_clip:
|
||||||
|
audio_clip.close()
|
||||||
|
if video_clip:
|
||||||
|
video_clip.close()
|
||||||
|
|
||||||
|
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
|
video_out = await store_media_file(
|
||||||
|
file=output_filename,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
167
autogpt_platform/backend/backend/blocks/video/clip.py
Normal file
167
autogpt_platform/backend/backend/blocks/video/clip.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""VideoClipBlock - Extract a segment from a video file."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoClipBlock(Block):
|
||||||
|
"""Extract a time segment from a video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Input video (URL, data URI, or local path)"
|
||||||
|
)
|
||||||
|
start_time: float = SchemaField(description="Start time in seconds", ge=0.0)
|
||||||
|
end_time: float = SchemaField(description="End time in seconds", ge=0.0)
|
||||||
|
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
||||||
|
description="Output format", default="mp4", advanced=True
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Clipped video file (path or data URI)"
|
||||||
|
)
|
||||||
|
duration: float = SchemaField(description="Clip duration in seconds")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8f539119-e580-4d86-ad41-86fbcb22abb1",
|
||||||
|
description="Extract a time segment from a video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
test_input={
|
||||||
|
"video_in": "/tmp/test.mp4",
|
||||||
|
"start_time": 0.0,
|
||||||
|
"end_time": 10.0,
|
||||||
|
},
|
||||||
|
test_output=[("video_out", str), ("duration", float)],
|
||||||
|
test_mock={
|
||||||
|
"_clip_video": lambda *args: 10.0,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "clip_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clip_video(
|
||||||
|
self,
|
||||||
|
video_abspath: str,
|
||||||
|
output_abspath: str,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
) -> float:
|
||||||
|
"""Extract a clip from a video. Extracted for testability."""
|
||||||
|
clip = None
|
||||||
|
subclip = None
|
||||||
|
try:
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
clip = VideoFileClip(video_abspath)
|
||||||
|
subclip = clip.subclipped(start_time, end_time)
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
subclip.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
return subclip.duration
|
||||||
|
finally:
|
||||||
|
if subclip:
|
||||||
|
subclip.close()
|
||||||
|
if clip:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# Validate time range
|
||||||
|
if input_data.end_time <= input_data.start_time:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store the input video locally
|
||||||
|
local_video_path = await self._store_input_video(
|
||||||
|
execution_context, input_data.video_in
|
||||||
|
)
|
||||||
|
video_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_video_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build output path
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(
|
||||||
|
f"{node_exec_id}_clip_{source}.{input_data.output_format}"
|
||||||
|
)
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
duration = self._clip_video(
|
||||||
|
video_abspath,
|
||||||
|
output_abspath,
|
||||||
|
input_data.start_time,
|
||||||
|
input_data.end_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
yield "duration", duration
|
||||||
|
|
||||||
|
except BlockExecutionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to clip video: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
227
autogpt_platform/backend/backend/blocks/video/concat.py
Normal file
227
autogpt_platform/backend/backend/blocks/video/concat.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
"""VideoConcatBlock - Concatenate multiple video clips into one."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from moviepy import concatenate_videoclips
|
||||||
|
from moviepy.video.fx import CrossFadeIn, CrossFadeOut, FadeIn, FadeOut
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoConcatBlock(Block):
|
||||||
|
"""Merge multiple video clips into one continuous video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
videos: list[MediaFileType] = SchemaField(
|
||||||
|
description="List of video files to concatenate (in order)"
|
||||||
|
)
|
||||||
|
transition: Literal["none", "crossfade", "fade_black"] = SchemaField(
|
||||||
|
description="Transition between clips", default="none"
|
||||||
|
)
|
||||||
|
transition_duration: int = SchemaField(
|
||||||
|
description="Transition duration in seconds",
|
||||||
|
default=1,
|
||||||
|
ge=0,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
||||||
|
description="Output format", default="mp4", advanced=True
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Concatenated video file (path or data URI)"
|
||||||
|
)
|
||||||
|
total_duration: float = SchemaField(description="Total duration in seconds")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="9b0f531a-1118-487f-aeec-3fa63ea8900a",
|
||||||
|
description="Merge multiple video clips into one continuous video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
test_input={
|
||||||
|
"videos": ["/tmp/a.mp4", "/tmp/b.mp4"],
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
("video_out", str),
|
||||||
|
("total_duration", float),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"_concat_videos": lambda *args: 20.0,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "concat_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _concat_videos(
|
||||||
|
self,
|
||||||
|
video_abspaths: list[str],
|
||||||
|
output_abspath: str,
|
||||||
|
transition: str,
|
||||||
|
transition_duration: int,
|
||||||
|
) -> float:
|
||||||
|
"""Concatenate videos. Extracted for testability.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total duration of the concatenated video.
|
||||||
|
"""
|
||||||
|
clips = []
|
||||||
|
faded_clips = []
|
||||||
|
final = None
|
||||||
|
try:
|
||||||
|
# Load clips
|
||||||
|
for v in video_abspaths:
|
||||||
|
strip_chapters_inplace(v)
|
||||||
|
clips.append(VideoFileClip(v))
|
||||||
|
|
||||||
|
# Validate transition_duration against shortest clip
|
||||||
|
if transition in {"crossfade", "fade_black"} and transition_duration > 0:
|
||||||
|
min_duration = min(c.duration for c in clips)
|
||||||
|
if transition_duration >= min_duration:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=(
|
||||||
|
f"transition_duration ({transition_duration}s) must be "
|
||||||
|
f"shorter than the shortest clip ({min_duration:.2f}s)"
|
||||||
|
),
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
if transition == "crossfade":
|
||||||
|
for i, clip in enumerate(clips):
|
||||||
|
effects = []
|
||||||
|
if i > 0:
|
||||||
|
effects.append(CrossFadeIn(transition_duration))
|
||||||
|
if i < len(clips) - 1:
|
||||||
|
effects.append(CrossFadeOut(transition_duration))
|
||||||
|
if effects:
|
||||||
|
clip = clip.with_effects(effects)
|
||||||
|
faded_clips.append(clip)
|
||||||
|
final = concatenate_videoclips(
|
||||||
|
faded_clips,
|
||||||
|
method="compose",
|
||||||
|
padding=-transition_duration,
|
||||||
|
)
|
||||||
|
elif transition == "fade_black":
|
||||||
|
for clip in clips:
|
||||||
|
faded = clip.with_effects(
|
||||||
|
[FadeIn(transition_duration), FadeOut(transition_duration)]
|
||||||
|
)
|
||||||
|
faded_clips.append(faded)
|
||||||
|
final = concatenate_videoclips(faded_clips)
|
||||||
|
else:
|
||||||
|
final = concatenate_videoclips(clips)
|
||||||
|
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
final.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
|
||||||
|
return final.duration
|
||||||
|
finally:
|
||||||
|
if final:
|
||||||
|
final.close()
|
||||||
|
for clip in faded_clips:
|
||||||
|
clip.close()
|
||||||
|
for clip in clips:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# Validate minimum clips
|
||||||
|
if len(input_data.videos) < 2:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message="At least 2 videos are required for concatenation",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store all input videos locally
|
||||||
|
video_abspaths = []
|
||||||
|
for video in input_data.videos:
|
||||||
|
local_path = await self._store_input_video(execution_context, video)
|
||||||
|
video_abspaths.append(
|
||||||
|
get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build output path
|
||||||
|
source = (
|
||||||
|
extract_source_name(video_abspaths[0]) if video_abspaths else "video"
|
||||||
|
)
|
||||||
|
output_filename = MediaFileType(
|
||||||
|
f"{node_exec_id}_concat_{source}.{input_data.output_format}"
|
||||||
|
)
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
total_duration = self._concat_videos(
|
||||||
|
video_abspaths,
|
||||||
|
output_abspath,
|
||||||
|
input_data.transition,
|
||||||
|
input_data.transition_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
yield "total_duration", total_duration
|
||||||
|
|
||||||
|
except BlockExecutionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to concatenate videos: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
172
autogpt_platform/backend/backend/blocks/video/download.py
Normal file
172
autogpt_platform/backend/backend/blocks/video/download.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
"""VideoDownloadBlock - Download video from URL (YouTube, Vimeo, news sites, direct links)."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import typing
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import yt_dlp
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from yt_dlp import _Params
|
||||||
|
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoDownloadBlock(Block):
|
||||||
|
"""Download video from URL using yt-dlp."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
url: str = SchemaField(
|
||||||
|
description="URL of the video to download (YouTube, Vimeo, direct link, etc.)",
|
||||||
|
placeholder="https://www.youtube.com/watch?v=...",
|
||||||
|
)
|
||||||
|
quality: Literal["best", "1080p", "720p", "480p", "audio_only"] = SchemaField(
|
||||||
|
description="Video quality preference", default="720p"
|
||||||
|
)
|
||||||
|
output_format: Literal["mp4", "webm", "mkv"] = SchemaField(
|
||||||
|
description="Output video format", default="mp4", advanced=True
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_file: MediaFileType = SchemaField(
|
||||||
|
description="Downloaded video (path or data URI)"
|
||||||
|
)
|
||||||
|
duration: float = SchemaField(description="Video duration in seconds")
|
||||||
|
title: str = SchemaField(description="Video title from source")
|
||||||
|
source_url: str = SchemaField(description="Original source URL")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="c35daabb-cd60-493b-b9ad-51f1fe4b50c4",
|
||||||
|
description="Download video from URL (YouTube, Vimeo, news sites, direct links)",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
disabled=True, # Disable until we can sandbox yt-dlp and handle security implications
|
||||||
|
test_input={
|
||||||
|
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
||||||
|
"quality": "480p",
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
("video_file", str),
|
||||||
|
("duration", float),
|
||||||
|
("title", str),
|
||||||
|
("source_url", str),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"_download_video": lambda *args: (
|
||||||
|
"video.mp4",
|
||||||
|
212.0,
|
||||||
|
"Test Video",
|
||||||
|
),
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "video.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_format_string(self, quality: str) -> str:
|
||||||
|
formats = {
|
||||||
|
"best": "bestvideo+bestaudio/best",
|
||||||
|
"1080p": "bestvideo[height<=1080]+bestaudio/best[height<=1080]",
|
||||||
|
"720p": "bestvideo[height<=720]+bestaudio/best[height<=720]",
|
||||||
|
"480p": "bestvideo[height<=480]+bestaudio/best[height<=480]",
|
||||||
|
"audio_only": "bestaudio/best",
|
||||||
|
}
|
||||||
|
return formats.get(quality, formats["720p"])
|
||||||
|
|
||||||
|
def _download_video(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
quality: str,
|
||||||
|
output_format: str,
|
||||||
|
output_dir: str,
|
||||||
|
node_exec_id: str,
|
||||||
|
) -> tuple[str, float, str]:
|
||||||
|
"""Download video. Extracted for testability."""
|
||||||
|
output_template = os.path.join(
|
||||||
|
output_dir, f"{node_exec_id}_%(title).50s.%(ext)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
ydl_opts: "_Params" = {
|
||||||
|
"format": f"{self._get_format_string(quality)}/best",
|
||||||
|
"outtmpl": output_template,
|
||||||
|
"merge_output_format": output_format,
|
||||||
|
"quiet": True,
|
||||||
|
"no_warnings": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||||
|
info = ydl.extract_info(url, download=True)
|
||||||
|
video_path = ydl.prepare_filename(info)
|
||||||
|
|
||||||
|
# Handle format conversion in filename
|
||||||
|
if not video_path.endswith(f".{output_format}"):
|
||||||
|
video_path = video_path.rsplit(".", 1)[0] + f".{output_format}"
|
||||||
|
|
||||||
|
# Return just the filename, not the full path
|
||||||
|
filename = os.path.basename(video_path)
|
||||||
|
|
||||||
|
return (
|
||||||
|
filename,
|
||||||
|
info.get("duration") or 0.0,
|
||||||
|
info.get("title") or "Unknown",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Get the exec file directory
|
||||||
|
output_dir = get_exec_file_path(execution_context.graph_exec_id, "")
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
filename, duration, title = self._download_video(
|
||||||
|
input_data.url,
|
||||||
|
input_data.quality,
|
||||||
|
input_data.output_format,
|
||||||
|
output_dir,
|
||||||
|
node_exec_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, MediaFileType(filename)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_file", video_out
|
||||||
|
yield "duration", duration
|
||||||
|
yield "title", title
|
||||||
|
yield "source_url", input_data.url
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to download video: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
77
autogpt_platform/backend/backend/blocks/video/duration.py
Normal file
77
autogpt_platform/backend/backend/blocks/video/duration.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""MediaDurationBlock - Get the duration of a media file."""
|
||||||
|
|
||||||
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import strip_chapters_inplace
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class MediaDurationBlock(Block):
|
||||||
|
"""Get the duration of a media file (video or audio)."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
media_in: MediaFileType = SchemaField(
|
||||||
|
description="Media input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
is_video: bool = SchemaField(
|
||||||
|
description="Whether the media is a video (True) or audio (False).",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
duration: float = SchemaField(
|
||||||
|
description="Duration of the media file (in seconds)."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
||||||
|
description="Block to get the duration of a media file.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=MediaDurationBlock.Input,
|
||||||
|
output_schema=MediaDurationBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# 1) Store the input media locally
|
||||||
|
local_media_path = await store_media_file(
|
||||||
|
file=input_data.media_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
media_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_media_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Strip chapters to avoid MoviePy crash, then load the clip
|
||||||
|
strip_chapters_inplace(media_abspath)
|
||||||
|
clip = None
|
||||||
|
try:
|
||||||
|
if input_data.is_video:
|
||||||
|
clip = VideoFileClip(media_abspath)
|
||||||
|
else:
|
||||||
|
clip = AudioFileClip(media_abspath)
|
||||||
|
|
||||||
|
duration = clip.duration
|
||||||
|
finally:
|
||||||
|
if clip:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
yield "duration", duration
|
||||||
115
autogpt_platform/backend/backend/blocks/video/loop.py
Normal file
115
autogpt_platform/backend/backend/blocks/video/loop.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
"""LoopVideoBlock - Loop a video to a given duration or number of repeats."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from moviepy.video.fx.Loop import Loop
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class LoopVideoBlock(Block):
|
||||||
|
"""Loop (repeat) a video clip until a given duration or number of loops."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="The input video (can be a URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
duration: Optional[float] = SchemaField(
|
||||||
|
description="Target duration (in seconds) to loop the video to. Either duration or n_loops must be provided.",
|
||||||
|
default=None,
|
||||||
|
ge=0.0,
|
||||||
|
le=3600.0, # Max 1 hour to prevent disk exhaustion
|
||||||
|
)
|
||||||
|
n_loops: Optional[int] = SchemaField(
|
||||||
|
description="Number of times to repeat the video. Either n_loops or duration must be provided.",
|
||||||
|
default=None,
|
||||||
|
ge=1,
|
||||||
|
le=10, # Max 10 loops to prevent disk exhaustion
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Looped video returned either as a relative path or a data URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
||||||
|
description="Block to loop a video to a given duration or number of repeats.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=LoopVideoBlock.Input,
|
||||||
|
output_schema=LoopVideoBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
|
# 1) Store the input video locally
|
||||||
|
local_video_path = await store_media_file(
|
||||||
|
file=input_data.video_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||||
|
|
||||||
|
# 2) Load the clip
|
||||||
|
strip_chapters_inplace(input_abspath)
|
||||||
|
clip = None
|
||||||
|
looped_clip = None
|
||||||
|
try:
|
||||||
|
clip = VideoFileClip(input_abspath)
|
||||||
|
|
||||||
|
# 3) Apply the loop effect
|
||||||
|
if input_data.duration:
|
||||||
|
# Loop until we reach the specified duration
|
||||||
|
looped_clip = clip.with_effects([Loop(duration=input_data.duration)])
|
||||||
|
elif input_data.n_loops:
|
||||||
|
looped_clip = clip.with_effects([Loop(n=input_data.n_loops)])
|
||||||
|
else:
|
||||||
|
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
||||||
|
|
||||||
|
assert isinstance(looped_clip, VideoFileClip)
|
||||||
|
|
||||||
|
# 4) Save the looped output
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_looped_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||||
|
|
||||||
|
looped_clip = looped_clip.with_audio(clip.audio)
|
||||||
|
looped_clip.write_videofile(
|
||||||
|
output_abspath, codec="libx264", audio_codec="aac"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if looped_clip:
|
||||||
|
looped_clip.close()
|
||||||
|
if clip:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
|
video_out = await store_media_file(
|
||||||
|
file=output_filename,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
267
autogpt_platform/backend/backend/blocks/video/narration.py
Normal file
267
autogpt_platform/backend/backend/blocks/video/narration.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""VideoNarrationBlock - Generate AI voice narration and add to video."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from elevenlabs import ElevenLabs
|
||||||
|
from moviepy import CompositeAudioClip
|
||||||
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.elevenlabs._auth import (
|
||||||
|
TEST_CREDENTIALS,
|
||||||
|
TEST_CREDENTIALS_INPUT,
|
||||||
|
ElevenLabsCredentials,
|
||||||
|
ElevenLabsCredentialsInput,
|
||||||
|
)
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoNarrationBlock(Block):
|
||||||
|
"""Generate AI narration and add to video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: ElevenLabsCredentialsInput = CredentialsField(
|
||||||
|
description="ElevenLabs API key for voice synthesis"
|
||||||
|
)
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Input video (URL, data URI, or local path)"
|
||||||
|
)
|
||||||
|
script: str = SchemaField(description="Narration script text")
|
||||||
|
voice_id: str = SchemaField(
|
||||||
|
description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||||
|
)
|
||||||
|
model_id: Literal[
|
||||||
|
"eleven_multilingual_v2",
|
||||||
|
"eleven_flash_v2_5",
|
||||||
|
"eleven_turbo_v2_5",
|
||||||
|
"eleven_turbo_v2",
|
||||||
|
] = SchemaField(
|
||||||
|
description="ElevenLabs TTS model",
|
||||||
|
default="eleven_multilingual_v2",
|
||||||
|
)
|
||||||
|
mix_mode: Literal["replace", "mix", "ducking"] = SchemaField(
|
||||||
|
description="How to combine with original audio. 'ducking' applies stronger attenuation than 'mix'.",
|
||||||
|
default="ducking",
|
||||||
|
)
|
||||||
|
narration_volume: float = SchemaField(
|
||||||
|
description="Narration volume (0.0 to 2.0)",
|
||||||
|
default=1.0,
|
||||||
|
ge=0.0,
|
||||||
|
le=2.0,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
original_volume: float = SchemaField(
|
||||||
|
description="Original audio volume when mixing (0.0 to 1.0)",
|
||||||
|
default=0.3,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Video with narration (path or data URI)"
|
||||||
|
)
|
||||||
|
audio_file: MediaFileType = SchemaField(
|
||||||
|
description="Generated audio file (path or data URI)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="3d036b53-859c-4b17-9826-ca340f736e0e",
|
||||||
|
description="Generate AI narration and add to video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA, BlockCategory.AI},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
test_input={
|
||||||
|
"video_in": "/tmp/test.mp4",
|
||||||
|
"script": "Hello world",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[("video_out", str), ("audio_file", str)],
|
||||||
|
test_mock={
|
||||||
|
"_generate_narration_audio": lambda *args: b"mock audio content",
|
||||||
|
"_add_narration_to_video": lambda *args: None,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "narrated_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_narration_audio(
|
||||||
|
self, api_key: str, script: str, voice_id: str, model_id: str
|
||||||
|
) -> bytes:
|
||||||
|
"""Generate narration audio via ElevenLabs API."""
|
||||||
|
client = ElevenLabs(api_key=api_key)
|
||||||
|
audio_generator = client.text_to_speech.convert(
|
||||||
|
voice_id=voice_id,
|
||||||
|
text=script,
|
||||||
|
model_id=model_id,
|
||||||
|
)
|
||||||
|
# The SDK returns a generator, collect all chunks
|
||||||
|
return b"".join(audio_generator)
|
||||||
|
|
||||||
|
def _add_narration_to_video(
|
||||||
|
self,
|
||||||
|
video_abspath: str,
|
||||||
|
audio_abspath: str,
|
||||||
|
output_abspath: str,
|
||||||
|
mix_mode: str,
|
||||||
|
narration_volume: float,
|
||||||
|
original_volume: float,
|
||||||
|
) -> None:
|
||||||
|
"""Add narration audio to video. Extracted for testability."""
|
||||||
|
video = None
|
||||||
|
final = None
|
||||||
|
narration_original = None
|
||||||
|
narration_scaled = None
|
||||||
|
original = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
video = VideoFileClip(video_abspath)
|
||||||
|
narration_original = AudioFileClip(audio_abspath)
|
||||||
|
narration_scaled = narration_original.with_volume_scaled(narration_volume)
|
||||||
|
narration = narration_scaled
|
||||||
|
|
||||||
|
if mix_mode == "replace":
|
||||||
|
final_audio = narration
|
||||||
|
elif mix_mode == "mix":
|
||||||
|
if video.audio:
|
||||||
|
original = video.audio.with_volume_scaled(original_volume)
|
||||||
|
final_audio = CompositeAudioClip([original, narration])
|
||||||
|
else:
|
||||||
|
final_audio = narration
|
||||||
|
else: # ducking - apply stronger attenuation
|
||||||
|
if video.audio:
|
||||||
|
# Ducking uses a much lower volume for original audio
|
||||||
|
ducking_volume = original_volume * 0.3
|
||||||
|
original = video.audio.with_volume_scaled(ducking_volume)
|
||||||
|
final_audio = CompositeAudioClip([original, narration])
|
||||||
|
else:
|
||||||
|
final_audio = narration
|
||||||
|
|
||||||
|
final = video.with_audio(final_audio)
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
final.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if original:
|
||||||
|
original.close()
|
||||||
|
if narration_scaled:
|
||||||
|
narration_scaled.close()
|
||||||
|
if narration_original:
|
||||||
|
narration_original.close()
|
||||||
|
if final:
|
||||||
|
final.close()
|
||||||
|
if video:
|
||||||
|
video.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: ElevenLabsCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store the input video locally
|
||||||
|
local_video_path = await self._store_input_video(
|
||||||
|
execution_context, input_data.video_in
|
||||||
|
)
|
||||||
|
video_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_video_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate narration audio via ElevenLabs
|
||||||
|
audio_content = self._generate_narration_audio(
|
||||||
|
credentials.api_key.get_secret_value(),
|
||||||
|
input_data.script,
|
||||||
|
input_data.voice_id,
|
||||||
|
input_data.model_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save audio to exec file path
|
||||||
|
audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3")
|
||||||
|
audio_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, audio_filename
|
||||||
|
)
|
||||||
|
os.makedirs(os.path.dirname(audio_abspath), exist_ok=True)
|
||||||
|
with open(audio_abspath, "wb") as f:
|
||||||
|
f.write(audio_content)
|
||||||
|
|
||||||
|
# Add narration to video
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_narrated_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
self._add_narration_to_video(
|
||||||
|
video_abspath,
|
||||||
|
audio_abspath,
|
||||||
|
output_abspath,
|
||||||
|
input_data.mix_mode,
|
||||||
|
input_data.narration_volume,
|
||||||
|
input_data.original_volume,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
audio_out = await self._store_output_video(
|
||||||
|
execution_context, audio_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
yield "audio_file", audio_out
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to add narration: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
231
autogpt_platform/backend/backend/blocks/video/text_overlay.py
Normal file
231
autogpt_platform/backend/backend/blocks/video/text_overlay.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
"""VideoTextOverlayBlock - Add text overlay to video."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from moviepy import CompositeVideoClip, TextClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoTextOverlayBlock(Block):
|
||||||
|
"""Add text overlay/caption to video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Input video (URL, data URI, or local path)"
|
||||||
|
)
|
||||||
|
text: str = SchemaField(description="Text to overlay on video")
|
||||||
|
position: Literal[
|
||||||
|
"top",
|
||||||
|
"center",
|
||||||
|
"bottom",
|
||||||
|
"top-left",
|
||||||
|
"top-right",
|
||||||
|
"bottom-left",
|
||||||
|
"bottom-right",
|
||||||
|
] = SchemaField(description="Position of text on screen", default="bottom")
|
||||||
|
start_time: float | None = SchemaField(
|
||||||
|
description="When to show text (seconds). None = entire video",
|
||||||
|
default=None,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
end_time: float | None = SchemaField(
|
||||||
|
description="When to hide text (seconds). None = until end",
|
||||||
|
default=None,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
font_size: int = SchemaField(
|
||||||
|
description="Font size", default=48, ge=12, le=200, advanced=True
|
||||||
|
)
|
||||||
|
font_color: str = SchemaField(
|
||||||
|
description="Font color (hex or name)", default="white", advanced=True
|
||||||
|
)
|
||||||
|
bg_color: str | None = SchemaField(
|
||||||
|
description="Background color behind text (None for transparent)",
|
||||||
|
default=None,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Video with text overlay (path or data URI)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8ef14de6-cc90-430a-8cfa-3a003be92454",
|
||||||
|
description="Add text overlay/caption to video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
disabled=True, # Disable until we can lockdown imagemagick security policy
|
||||||
|
test_input={"video_in": "/tmp/test.mp4", "text": "Hello World"},
|
||||||
|
test_output=[("video_out", str)],
|
||||||
|
test_mock={
|
||||||
|
"_add_text_overlay": lambda *args: None,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "overlay_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_text_overlay(
|
||||||
|
self,
|
||||||
|
video_abspath: str,
|
||||||
|
output_abspath: str,
|
||||||
|
text: str,
|
||||||
|
position: str,
|
||||||
|
start_time: float | None,
|
||||||
|
end_time: float | None,
|
||||||
|
font_size: int,
|
||||||
|
font_color: str,
|
||||||
|
bg_color: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Add text overlay to video. Extracted for testability."""
|
||||||
|
video = None
|
||||||
|
final = None
|
||||||
|
txt_clip = None
|
||||||
|
try:
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
video = VideoFileClip(video_abspath)
|
||||||
|
|
||||||
|
txt_clip = TextClip(
|
||||||
|
text=text,
|
||||||
|
font_size=font_size,
|
||||||
|
color=font_color,
|
||||||
|
bg_color=bg_color,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Position mapping
|
||||||
|
pos_map = {
|
||||||
|
"top": ("center", "top"),
|
||||||
|
"center": ("center", "center"),
|
||||||
|
"bottom": ("center", "bottom"),
|
||||||
|
"top-left": ("left", "top"),
|
||||||
|
"top-right": ("right", "top"),
|
||||||
|
"bottom-left": ("left", "bottom"),
|
||||||
|
"bottom-right": ("right", "bottom"),
|
||||||
|
}
|
||||||
|
|
||||||
|
txt_clip = txt_clip.with_position(pos_map[position])
|
||||||
|
|
||||||
|
# Set timing
|
||||||
|
start = start_time or 0
|
||||||
|
end = end_time or video.duration
|
||||||
|
duration = max(0, end - start)
|
||||||
|
txt_clip = txt_clip.with_start(start).with_end(end).with_duration(duration)
|
||||||
|
|
||||||
|
final = CompositeVideoClip([video, txt_clip])
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
final.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if txt_clip:
|
||||||
|
txt_clip.close()
|
||||||
|
if final:
|
||||||
|
final.close()
|
||||||
|
if video:
|
||||||
|
video.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# Validate time range if both are provided
|
||||||
|
if (
|
||||||
|
input_data.start_time is not None
|
||||||
|
and input_data.end_time is not None
|
||||||
|
and input_data.end_time <= input_data.start_time
|
||||||
|
):
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store the input video locally
|
||||||
|
local_video_path = await self._store_input_video(
|
||||||
|
execution_context, input_data.video_in
|
||||||
|
)
|
||||||
|
video_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_video_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build output path
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_overlay_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
self._add_text_overlay(
|
||||||
|
video_abspath,
|
||||||
|
output_abspath,
|
||||||
|
input_data.text,
|
||||||
|
input_data.position,
|
||||||
|
input_data.start_time,
|
||||||
|
input_data.end_time,
|
||||||
|
input_data.font_size,
|
||||||
|
input_data.font_color,
|
||||||
|
input_data.bg_color,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
|
||||||
|
except BlockExecutionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to add text overlay: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
@@ -165,10 +165,13 @@ class TranscribeYoutubeVideoBlock(Block):
|
|||||||
credentials: WebshareProxyCredentials,
|
credentials: WebshareProxyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
video_id = self.extract_video_id(input_data.youtube_url)
|
video_id = self.extract_video_id(input_data.youtube_url)
|
||||||
yield "video_id", video_id
|
|
||||||
|
|
||||||
transcript = self.get_transcript(video_id, credentials)
|
transcript = self.get_transcript(video_id, credentials)
|
||||||
transcript_text = self.format_transcript(transcript=transcript)
|
transcript_text = self.format_transcript(transcript=transcript)
|
||||||
|
|
||||||
|
# Only yield after all operations succeed
|
||||||
|
yield "video_id", video_id
|
||||||
yield "transcript", transcript_text
|
yield "transcript", transcript_text
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|||||||
@@ -873,14 +873,13 @@ def is_block_auth_configured(
|
|||||||
|
|
||||||
|
|
||||||
async def initialize_blocks() -> None:
|
async def initialize_blocks() -> None:
|
||||||
# First, sync all provider costs to blocks
|
|
||||||
# Imported here to avoid circular import
|
|
||||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||||
|
from backend.util.retry import func_retry
|
||||||
|
|
||||||
sync_all_provider_costs()
|
sync_all_provider_costs()
|
||||||
|
|
||||||
for cls in get_blocks().values():
|
@func_retry
|
||||||
block = cls()
|
async def sync_block_to_db(block: Block) -> None:
|
||||||
existing_block = await AgentBlock.prisma().find_first(
|
existing_block = await AgentBlock.prisma().find_first(
|
||||||
where={"OR": [{"id": block.id}, {"name": block.name}]}
|
where={"OR": [{"id": block.id}, {"name": block.name}]}
|
||||||
)
|
)
|
||||||
@@ -893,7 +892,7 @@ async def initialize_blocks() -> None:
|
|||||||
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
continue
|
return
|
||||||
|
|
||||||
input_schema = json.dumps(block.input_schema.jsonschema())
|
input_schema = json.dumps(block.input_schema.jsonschema())
|
||||||
output_schema = json.dumps(block.output_schema.jsonschema())
|
output_schema = json.dumps(block.output_schema.jsonschema())
|
||||||
@@ -913,6 +912,25 @@ async def initialize_blocks() -> None:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
failed_blocks: list[str] = []
|
||||||
|
for cls in get_blocks().values():
|
||||||
|
block = cls()
|
||||||
|
try:
|
||||||
|
await sync_block_to_db(block)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to sync block {block.name} to database: {e}. "
|
||||||
|
"Block is still available in memory.",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
failed_blocks.append(block.name)
|
||||||
|
|
||||||
|
if failed_blocks:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to sync {len(failed_blocks)} block(s) to database: "
|
||||||
|
f"{', '.join(failed_blocks)}. These blocks are still available in memory."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||||
def get_block(block_id: str) -> AnyBlockSchema | None:
|
def get_block(block_id: str) -> AnyBlockSchema | None:
|
||||||
|
|||||||
@@ -36,12 +36,14 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
|||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||||
|
from backend.blocks.video.narration import VideoNarrationBlock
|
||||||
from backend.data.block import Block, BlockCost, BlockCostType
|
from backend.data.block import Block, BlockCost, BlockCostType
|
||||||
from backend.integrations.credentials_store import (
|
from backend.integrations.credentials_store import (
|
||||||
aiml_api_credentials,
|
aiml_api_credentials,
|
||||||
anthropic_credentials,
|
anthropic_credentials,
|
||||||
apollo_credentials,
|
apollo_credentials,
|
||||||
did_credentials,
|
did_credentials,
|
||||||
|
elevenlabs_credentials,
|
||||||
enrichlayer_credentials,
|
enrichlayer_credentials,
|
||||||
groq_credentials,
|
groq_credentials,
|
||||||
ideogram_credentials,
|
ideogram_credentials,
|
||||||
@@ -78,10 +80,10 @@ MODEL_COST: dict[LlmModel, int] = {
|
|||||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||||
LlmModel.CLAUDE_4_OPUS: 21,
|
LlmModel.CLAUDE_4_OPUS: 21,
|
||||||
LlmModel.CLAUDE_4_SONNET: 5,
|
LlmModel.CLAUDE_4_SONNET: 5,
|
||||||
|
LlmModel.CLAUDE_4_6_OPUS: 14,
|
||||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
|
||||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||||
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
||||||
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
||||||
@@ -640,4 +642,16 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
VideoNarrationBlock: [
|
||||||
|
BlockCost(
|
||||||
|
cost_amount=5, # ElevenLabs TTS cost
|
||||||
|
cost_filter={
|
||||||
|
"credentials": {
|
||||||
|
"id": elevenlabs_credentials.id,
|
||||||
|
"provider": elevenlabs_credentials.provider,
|
||||||
|
"type": elevenlabs_credentials.type,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,6 +134,16 @@ async def test_block_credit_reset(server: SpinTestServer):
|
|||||||
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
||||||
user_credit.time_now = lambda: month1
|
user_credit.time_now = lambda: month1
|
||||||
|
|
||||||
|
# IMPORTANT: Set updatedAt to December of previous year to ensure it's
|
||||||
|
# in a different month than month1 (January). This fixes a timing bug
|
||||||
|
# where if the test runs in early February, 35 days ago would be January,
|
||||||
|
# matching the mocked month1 and preventing the refill from triggering.
|
||||||
|
dec_previous_year = month1.replace(year=month1.year - 1, month=12, day=15)
|
||||||
|
await UserBalance.prisma().update(
|
||||||
|
where={"userId": DEFAULT_USER_ID},
|
||||||
|
data={"updatedAt": dec_previous_year},
|
||||||
|
)
|
||||||
|
|
||||||
# First call in month 1 should trigger refill
|
# First call in month 1 should trigger refill
|
||||||
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||||
assert balance == REFILL_VALUE # Should get 1000 credits
|
assert balance == REFILL_VALUE # Should get 1000 credits
|
||||||
|
|||||||
@@ -133,10 +133,23 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
|
|||||||
|
|
||||||
|
|
||||||
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||||
|
def __init__(self):
|
||||||
|
self._pubsub: AsyncPubSub | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
async def connection(self) -> redis.AsyncRedis:
|
async def connection(self) -> redis.AsyncRedis:
|
||||||
return await redis.get_redis_async()
|
return await redis.get_redis_async()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the PubSub connection if it exists."""
|
||||||
|
if self._pubsub is not None:
|
||||||
|
try:
|
||||||
|
await self._pubsub.close()
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to close PubSub connection", exc_info=True)
|
||||||
|
finally:
|
||||||
|
self._pubsub = None
|
||||||
|
|
||||||
async def publish_event(self, event: M, channel_key: str):
|
async def publish_event(self, event: M, channel_key: str):
|
||||||
"""
|
"""
|
||||||
Publish an event to Redis. Gracefully handles connection failures
|
Publish an event to Redis. Gracefully handles connection failures
|
||||||
@@ -157,6 +170,7 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
|||||||
await self.connection, channel_key
|
await self.connection, channel_key
|
||||||
)
|
)
|
||||||
assert isinstance(pubsub, AsyncPubSub)
|
assert isinstance(pubsub, AsyncPubSub)
|
||||||
|
self._pubsub = pubsub
|
||||||
|
|
||||||
if "*" in channel_key:
|
if "*" in channel_key:
|
||||||
await pubsub.psubscribe(full_channel_name)
|
await pubsub.psubscribe(full_channel_name)
|
||||||
|
|||||||
@@ -83,12 +83,29 @@ class ExecutionContext(BaseModel):
|
|||||||
|
|
||||||
model_config = {"extra": "ignore"}
|
model_config = {"extra": "ignore"}
|
||||||
|
|
||||||
|
# Execution identity
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
graph_id: Optional[str] = None
|
||||||
|
graph_exec_id: Optional[str] = None
|
||||||
|
graph_version: Optional[int] = None
|
||||||
|
node_id: Optional[str] = None
|
||||||
|
node_exec_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Safety settings
|
||||||
human_in_the_loop_safe_mode: bool = True
|
human_in_the_loop_safe_mode: bool = True
|
||||||
sensitive_action_safe_mode: bool = False
|
sensitive_action_safe_mode: bool = False
|
||||||
|
|
||||||
|
# User settings
|
||||||
user_timezone: str = "UTC"
|
user_timezone: str = "UTC"
|
||||||
|
|
||||||
|
# Execution hierarchy
|
||||||
root_execution_id: Optional[str] = None
|
root_execution_id: Optional[str] = None
|
||||||
parent_execution_id: Optional[str] = None
|
parent_execution_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Workspace
|
||||||
|
workspace_id: Optional[str] = None
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# -------------------------- Models -------------------------- #
|
# -------------------------- Models -------------------------- #
|
||||||
|
|
||||||
|
|||||||
@@ -1028,6 +1028,39 @@ async def get_graph(
|
|||||||
return GraphModel.from_db(graph, for_export)
|
return GraphModel.from_db(graph, for_export)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_store_listed_graphs(*graph_ids: str) -> dict[str, GraphModel]:
|
||||||
|
"""Batch-fetch multiple store-listed graphs by their IDs.
|
||||||
|
|
||||||
|
Only returns graphs that have approved store listings (publicly available).
|
||||||
|
Does not require permission checks since store-listed graphs are public.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*graph_ids: Variable number of graph IDs to fetch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping graph_id to GraphModel for graphs with approved store listings
|
||||||
|
"""
|
||||||
|
if not graph_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
store_listings = await StoreListingVersion.prisma().find_many(
|
||||||
|
where={
|
||||||
|
"agentGraphId": {"in": list(graph_ids)},
|
||||||
|
"submissionStatus": SubmissionStatus.APPROVED,
|
||||||
|
"isDeleted": False,
|
||||||
|
},
|
||||||
|
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
||||||
|
distinct=["agentGraphId"],
|
||||||
|
order={"agentGraphVersion": "desc"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
listing.agentGraphId: GraphModel.from_db(listing.AgentGraph)
|
||||||
|
for listing in store_listings
|
||||||
|
if listing.AgentGraph
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_graph_as_admin(
|
async def get_graph_as_admin(
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
version: int | None = None,
|
version: int | None = None,
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
get_args,
|
get_args,
|
||||||
)
|
)
|
||||||
from urllib.parse import urlparse
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from prisma.enums import CreditTransactionType, OnboardingStep
|
from prisma.enums import CreditTransactionType, OnboardingStep
|
||||||
@@ -42,6 +41,7 @@ from typing_extensions import TypedDict
|
|||||||
|
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.json import loads as json_loads
|
from backend.util.json import loads as json_loads
|
||||||
|
from backend.util.request import parse_url
|
||||||
from backend.util.settings import Secrets
|
from backend.util.settings import Secrets
|
||||||
|
|
||||||
# Type alias for any provider name (including custom ones)
|
# Type alias for any provider name (including custom ones)
|
||||||
@@ -397,19 +397,25 @@ class HostScopedCredentials(_BaseCredentials):
|
|||||||
def matches_url(self, url: str) -> bool:
|
def matches_url(self, url: str) -> bool:
|
||||||
"""Check if this credential should be applied to the given URL."""
|
"""Check if this credential should be applied to the given URL."""
|
||||||
|
|
||||||
parsed_url = urlparse(url)
|
request_host, request_port = _extract_host_from_url(url)
|
||||||
# Extract hostname without port
|
cred_scope_host, cred_scope_port = _extract_host_from_url(self.host)
|
||||||
request_host = parsed_url.hostname
|
|
||||||
if not request_host:
|
if not request_host:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Simple host matching - exact match or wildcard subdomain match
|
# If a port is specified in credential host, the request host port must match
|
||||||
if self.host == request_host:
|
if cred_scope_port is not None and request_port != cred_scope_port:
|
||||||
|
return False
|
||||||
|
# Non-standard ports are only allowed if explicitly specified in credential host
|
||||||
|
elif cred_scope_port is None and request_port not in (80, 443, None):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Simple host matching
|
||||||
|
if cred_scope_host == request_host:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
||||||
if self.host.startswith("*."):
|
if cred_scope_host.startswith("*."):
|
||||||
domain = self.host[2:] # Remove "*."
|
domain = cred_scope_host[2:] # Remove "*."
|
||||||
return request_host.endswith(f".{domain}") or request_host == domain
|
return request_host.endswith(f".{domain}") or request_host == domain
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@@ -551,13 +557,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _extract_host_from_url(url: str) -> str:
|
def _extract_host_from_url(url: str) -> tuple[str, int | None]:
|
||||||
"""Extract host from URL for grouping host-scoped credentials."""
|
"""Extract host and port from URL for grouping host-scoped credentials."""
|
||||||
try:
|
try:
|
||||||
parsed = urlparse(url)
|
parsed = parse_url(url)
|
||||||
return parsed.hostname or url
|
return parsed.hostname or url, parsed.port
|
||||||
except Exception:
|
except Exception:
|
||||||
return ""
|
return "", None
|
||||||
|
|
||||||
|
|
||||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||||
@@ -606,7 +612,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
providers = frozenset(
|
providers = frozenset(
|
||||||
[cast(CP, "http")]
|
[cast(CP, "http")]
|
||||||
+ [
|
+ [
|
||||||
cast(CP, _extract_host_from_url(str(value)))
|
cast(CP, parse_url(str(value)).netloc)
|
||||||
for value in field.discriminator_values
|
for value in field.discriminator_values
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -666,10 +672,16 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
if not (self.discriminator and self.discriminator_mapping):
|
if not (self.discriminator and self.discriminator_mapping):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider = self.discriminator_mapping[discriminator_value]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{discriminator_value}' is not supported. "
|
||||||
|
"It may have been deprecated. Please update your agent configuration."
|
||||||
|
)
|
||||||
|
|
||||||
return CredentialsFieldInfo(
|
return CredentialsFieldInfo(
|
||||||
credentials_provider=frozenset(
|
credentials_provider=frozenset([provider]),
|
||||||
[self.discriminator_mapping[discriminator_value]]
|
|
||||||
),
|
|
||||||
credentials_types=self.supported_types,
|
credentials_types=self.supported_types,
|
||||||
credentials_scopes=self.required_scopes,
|
credentials_scopes=self.required_scopes,
|
||||||
discriminator=self.discriminator,
|
discriminator=self.discriminator,
|
||||||
|
|||||||
@@ -79,10 +79,23 @@ class TestHostScopedCredentials:
|
|||||||
headers={"Authorization": SecretStr("Bearer token")},
|
headers={"Authorization": SecretStr("Bearer token")},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
# Non-standard ports require explicit port in credential host
|
||||||
|
assert not creds.matches_url("http://localhost:8080/api/v1")
|
||||||
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
||||||
assert creds.matches_url("http://localhost/simple")
|
assert creds.matches_url("http://localhost/simple")
|
||||||
|
|
||||||
|
def test_matches_url_with_explicit_port(self):
|
||||||
|
"""Test URL matching with explicit port in credential host."""
|
||||||
|
creds = HostScopedCredentials(
|
||||||
|
provider="custom",
|
||||||
|
host="localhost:8080",
|
||||||
|
headers={"Authorization": SecretStr("Bearer token")},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||||
|
assert not creds.matches_url("http://localhost:3000/api/v1")
|
||||||
|
assert not creds.matches_url("http://localhost/simple")
|
||||||
|
|
||||||
def test_empty_headers_dict(self):
|
def test_empty_headers_dict(self):
|
||||||
"""Test HostScopedCredentials with empty headers."""
|
"""Test HostScopedCredentials with empty headers."""
|
||||||
creds = HostScopedCredentials(
|
creds = HostScopedCredentials(
|
||||||
@@ -128,8 +141,20 @@ class TestHostScopedCredentials:
|
|||||||
("*.example.com", "https://sub.api.example.com/test", True),
|
("*.example.com", "https://sub.api.example.com/test", True),
|
||||||
("*.example.com", "https://example.com/test", True),
|
("*.example.com", "https://example.com/test", True),
|
||||||
("*.example.com", "https://example.org/test", False),
|
("*.example.com", "https://example.org/test", False),
|
||||||
("localhost", "http://localhost:3000/test", True),
|
# Non-standard ports require explicit port in credential host
|
||||||
|
("localhost", "http://localhost:3000/test", False),
|
||||||
|
("localhost:3000", "http://localhost:3000/test", True),
|
||||||
("localhost", "http://127.0.0.1:3000/test", False),
|
("localhost", "http://127.0.0.1:3000/test", False),
|
||||||
|
# IPv6 addresses (frontend stores with brackets via URL.hostname)
|
||||||
|
("[::1]", "http://[::1]/test", True),
|
||||||
|
("[::1]", "http://[::1]:80/test", True),
|
||||||
|
("[::1]", "https://[::1]:443/test", True),
|
||||||
|
("[::1]", "http://[::1]:8080/test", False), # Non-standard port
|
||||||
|
("[::1]:8080", "http://[::1]:8080/test", True),
|
||||||
|
("[::1]:8080", "http://[::1]:9090/test", False),
|
||||||
|
("[2001:db8::1]", "http://[2001:db8::1]/path", True),
|
||||||
|
("[2001:db8::1]", "https://[2001:db8::1]:443/path", True),
|
||||||
|
("[2001:db8::1]", "http://[2001:db8::ff]/path", False),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
||||||
|
|||||||
276
autogpt_platform/backend/backend/data/workspace.py
Normal file
276
autogpt_platform/backend/backend/data/workspace.py
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
"""
|
||||||
|
Database CRUD operations for User Workspace.
|
||||||
|
|
||||||
|
This module provides functions for managing user workspaces and workspace files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from prisma.models import UserWorkspace, UserWorkspaceFile
|
||||||
|
from prisma.types import UserWorkspaceFileWhereInput
|
||||||
|
|
||||||
|
from backend.util.json import SafeJson
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
|
||||||
|
"""
|
||||||
|
Get user's workspace, creating one if it doesn't exist.
|
||||||
|
|
||||||
|
Uses upsert to handle race conditions when multiple concurrent requests
|
||||||
|
attempt to create a workspace for the same user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspace instance
|
||||||
|
"""
|
||||||
|
workspace = await UserWorkspace.prisma().upsert(
|
||||||
|
where={"userId": user_id},
|
||||||
|
data={
|
||||||
|
"create": {"userId": user_id},
|
||||||
|
"update": {}, # No updates needed if exists
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return workspace
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
|
||||||
|
"""
|
||||||
|
Get user's workspace if it exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspace instance or None
|
||||||
|
"""
|
||||||
|
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
||||||
|
|
||||||
|
|
||||||
|
async def create_workspace_file(
|
||||||
|
workspace_id: str,
|
||||||
|
file_id: str,
|
||||||
|
name: str,
|
||||||
|
path: str,
|
||||||
|
storage_path: str,
|
||||||
|
mime_type: str,
|
||||||
|
size_bytes: int,
|
||||||
|
checksum: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
) -> UserWorkspaceFile:
|
||||||
|
"""
|
||||||
|
Create a new workspace file record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
file_id: The file ID (same as used in storage path for consistency)
|
||||||
|
name: User-visible filename
|
||||||
|
path: Virtual path (e.g., "/documents/report.pdf")
|
||||||
|
storage_path: Actual storage path (GCS or local)
|
||||||
|
mime_type: MIME type of the file
|
||||||
|
size_bytes: File size in bytes
|
||||||
|
checksum: Optional SHA256 checksum
|
||||||
|
metadata: Optional additional metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created UserWorkspaceFile instance
|
||||||
|
"""
|
||||||
|
# Normalize path to start with /
|
||||||
|
if not path.startswith("/"):
|
||||||
|
path = f"/{path}"
|
||||||
|
|
||||||
|
file = await UserWorkspaceFile.prisma().create(
|
||||||
|
data={
|
||||||
|
"id": file_id,
|
||||||
|
"workspaceId": workspace_id,
|
||||||
|
"name": name,
|
||||||
|
"path": path,
|
||||||
|
"storagePath": storage_path,
|
||||||
|
"mimeType": mime_type,
|
||||||
|
"sizeBytes": size_bytes,
|
||||||
|
"checksum": checksum,
|
||||||
|
"metadata": SafeJson(metadata or {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created workspace file {file.id} at path {path} "
|
||||||
|
f"in workspace {workspace_id}"
|
||||||
|
)
|
||||||
|
return file
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_file(
|
||||||
|
file_id: str,
|
||||||
|
workspace_id: Optional[str] = None,
|
||||||
|
) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Get a workspace file by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file ID
|
||||||
|
workspace_id: Optional workspace ID for validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspaceFile instance or None
|
||||||
|
"""
|
||||||
|
where_clause: dict = {"id": file_id, "isDeleted": False}
|
||||||
|
if workspace_id:
|
||||||
|
where_clause["workspaceId"] = workspace_id
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_file_by_path(
|
||||||
|
workspace_id: str,
|
||||||
|
path: str,
|
||||||
|
) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Get a workspace file by its virtual path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
path: Virtual path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserWorkspaceFile instance or None
|
||||||
|
"""
|
||||||
|
# Normalize path
|
||||||
|
if not path.startswith("/"):
|
||||||
|
path = f"/{path}"
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().find_first(
|
||||||
|
where={
|
||||||
|
"workspaceId": workspace_id,
|
||||||
|
"path": path,
|
||||||
|
"isDeleted": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_workspace_files(
|
||||||
|
workspace_id: str,
|
||||||
|
path_prefix: Optional[str] = None,
|
||||||
|
include_deleted: bool = False,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
List files in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
path_prefix: Optional path prefix to filter (e.g., "/documents/")
|
||||||
|
include_deleted: Whether to include soft-deleted files
|
||||||
|
limit: Maximum number of files to return
|
||||||
|
offset: Number of files to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of UserWorkspaceFile instances
|
||||||
|
"""
|
||||||
|
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
|
||||||
|
|
||||||
|
if not include_deleted:
|
||||||
|
where_clause["isDeleted"] = False
|
||||||
|
|
||||||
|
if path_prefix:
|
||||||
|
# Normalize prefix
|
||||||
|
if not path_prefix.startswith("/"):
|
||||||
|
path_prefix = f"/{path_prefix}"
|
||||||
|
where_clause["path"] = {"startswith": path_prefix}
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().find_many(
|
||||||
|
where=where_clause,
|
||||||
|
order={"createdAt": "desc"},
|
||||||
|
take=limit,
|
||||||
|
skip=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def count_workspace_files(
|
||||||
|
workspace_id: str,
|
||||||
|
path_prefix: Optional[str] = None,
|
||||||
|
include_deleted: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Count files in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
path_prefix: Optional path prefix to filter (e.g., "/sessions/abc123/")
|
||||||
|
include_deleted: Whether to include soft-deleted files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of files
|
||||||
|
"""
|
||||||
|
where_clause: dict = {"workspaceId": workspace_id}
|
||||||
|
if not include_deleted:
|
||||||
|
where_clause["isDeleted"] = False
|
||||||
|
|
||||||
|
if path_prefix:
|
||||||
|
# Normalize prefix
|
||||||
|
if not path_prefix.startswith("/"):
|
||||||
|
path_prefix = f"/{path_prefix}"
|
||||||
|
where_clause["path"] = {"startswith": path_prefix}
|
||||||
|
|
||||||
|
return await UserWorkspaceFile.prisma().count(where=where_clause)
|
||||||
|
|
||||||
|
|
||||||
|
async def soft_delete_workspace_file(
|
||||||
|
file_id: str,
|
||||||
|
workspace_id: Optional[str] = None,
|
||||||
|
) -> Optional[UserWorkspaceFile]:
|
||||||
|
"""
|
||||||
|
Soft-delete a workspace file.
|
||||||
|
|
||||||
|
The path is modified to include a deletion timestamp to free up the original
|
||||||
|
path for new files while preserving the record for potential recovery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file ID
|
||||||
|
workspace_id: Optional workspace ID for validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated UserWorkspaceFile instance or None if not found
|
||||||
|
"""
|
||||||
|
# First verify the file exists and belongs to workspace
|
||||||
|
file = await get_workspace_file(file_id, workspace_id)
|
||||||
|
if file is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
deleted_at = datetime.now(timezone.utc)
|
||||||
|
# Modify path to free up the unique constraint for new files at original path
|
||||||
|
# Format: {original_path}__deleted__{timestamp}
|
||||||
|
deleted_path = f"{file.path}__deleted__{int(deleted_at.timestamp())}"
|
||||||
|
|
||||||
|
updated = await UserWorkspaceFile.prisma().update(
|
||||||
|
where={"id": file_id},
|
||||||
|
data={
|
||||||
|
"isDeleted": True,
|
||||||
|
"deletedAt": deleted_at,
|
||||||
|
"path": deleted_path,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Soft-deleted workspace file {file_id}")
|
||||||
|
return updated
|
||||||
|
|
||||||
|
|
||||||
|
async def get_workspace_total_size(workspace_id: str) -> int:
|
||||||
|
"""
|
||||||
|
Get the total size of all files in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total size in bytes
|
||||||
|
"""
|
||||||
|
files = await list_workspace_files(workspace_id)
|
||||||
|
return sum(file.sizeBytes for file in files)
|
||||||
@@ -17,6 +17,7 @@ from backend.data.analytics import (
|
|||||||
get_accuracy_trends_and_alerts,
|
get_accuracy_trends_and_alerts,
|
||||||
get_marketplace_graphs_for_monitoring,
|
get_marketplace_graphs_for_monitoring,
|
||||||
)
|
)
|
||||||
|
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
create_graph_execution,
|
create_graph_execution,
|
||||||
@@ -219,6 +220,9 @@ class DatabaseManager(AppService):
|
|||||||
# Onboarding
|
# Onboarding
|
||||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||||
|
|
||||||
|
# OAuth
|
||||||
|
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
|
||||||
|
|
||||||
# Store
|
# Store
|
||||||
get_store_agents = _(get_store_agents)
|
get_store_agents = _(get_store_agents)
|
||||||
get_store_agent_details = _(get_store_agent_details)
|
get_store_agent_details = _(get_store_agent_details)
|
||||||
@@ -349,6 +353,9 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
# Onboarding
|
# Onboarding
|
||||||
increment_onboarding_runs = d.increment_onboarding_runs
|
increment_onboarding_runs = d.increment_onboarding_runs
|
||||||
|
|
||||||
|
# OAuth
|
||||||
|
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
|
||||||
|
|
||||||
# Store
|
# Store
|
||||||
get_store_agents = d.get_store_agents
|
get_store_agents = d.get_store_agents
|
||||||
get_store_agent_details = d.get_store_agent_details
|
get_store_agent_details = d.get_store_agent_details
|
||||||
|
|||||||
@@ -236,7 +236,14 @@ async def execute_node(
|
|||||||
input_size = len(input_data_str)
|
input_size = len(input_data_str)
|
||||||
log_metadata.debug("Executed node with input", input=input_data_str)
|
log_metadata.debug("Executed node with input", input=input_data_str)
|
||||||
|
|
||||||
|
# Create node-specific execution context to avoid race conditions
|
||||||
|
# (multiple nodes can execute concurrently and would otherwise mutate shared state)
|
||||||
|
execution_context = execution_context.model_copy(
|
||||||
|
update={"node_id": node_id, "node_exec_id": node_exec_id}
|
||||||
|
)
|
||||||
|
|
||||||
# Inject extra execution arguments for the blocks via kwargs
|
# Inject extra execution arguments for the blocks via kwargs
|
||||||
|
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||||
extra_exec_kwargs: dict = {
|
extra_exec_kwargs: dict = {
|
||||||
"graph_id": graph_id,
|
"graph_id": graph_id,
|
||||||
"graph_version": graph_version,
|
"graph_version": graph_version,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user