mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-29 17:08:01 -05:00
Compare commits
4 Commits
update-bra
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9538992eaf | ||
|
|
27b72062f2 | ||
|
|
9a79a8d257 | ||
|
|
a9bf08748b |
@@ -29,7 +29,8 @@
|
|||||||
"postCreateCmd": [
|
"postCreateCmd": [
|
||||||
"cd autogpt_platform/autogpt_libs && poetry install",
|
"cd autogpt_platform/autogpt_libs && poetry install",
|
||||||
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
||||||
"cd autogpt_platform/frontend && pnpm install"
|
"cd autogpt_platform/frontend && pnpm install",
|
||||||
|
"cd docs && pip install -r requirements.txt"
|
||||||
],
|
],
|
||||||
"terminalCommand": "code .",
|
"terminalCommand": "code .",
|
||||||
"deleteBranchWithWorktree": false
|
"deleteBranchWithWorktree": false
|
||||||
|
|||||||
6
.github/copilot-instructions.md
vendored
6
.github/copilot-instructions.md
vendored
@@ -160,7 +160,7 @@ pnpm storybook # Start component development server
|
|||||||
|
|
||||||
**Backend Entry Points:**
|
**Backend Entry Points:**
|
||||||
|
|
||||||
- `backend/backend/api/rest_api.py` - FastAPI application setup
|
- `backend/backend/server/server.py` - FastAPI application setup
|
||||||
- `backend/backend/data/` - Database models and user management
|
- `backend/backend/data/` - Database models and user management
|
||||||
- `backend/blocks/` - Agent execution blocks and logic
|
- `backend/blocks/` - Agent execution blocks and logic
|
||||||
|
|
||||||
@@ -219,7 +219,7 @@ Agents are built using a visual block-based system where each block performs a s
|
|||||||
|
|
||||||
### API Development
|
### API Development
|
||||||
|
|
||||||
1. Update routes in `/backend/backend/api/features/`
|
1. Update routes in `/backend/backend/server/routers/`
|
||||||
2. Add/update Pydantic models in same directory
|
2. Add/update Pydantic models in same directory
|
||||||
3. Write tests alongside route files
|
3. Write tests alongside route files
|
||||||
4. For `data/*.py` changes, validate user ID checks
|
4. For `data/*.py` changes, validate user ID checks
|
||||||
@@ -285,7 +285,7 @@ Agents are built using a visual block-based system where each block performs a s
|
|||||||
|
|
||||||
### Security Guidelines
|
### Security Guidelines
|
||||||
|
|
||||||
**Cache Protection Middleware** (`/backend/backend/api/middleware/security.py`):
|
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
||||||
|
|
||||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -178,5 +178,5 @@ autogpt_platform/backend/settings.py
|
|||||||
*.ign.*
|
*.ign.*
|
||||||
.test-contents
|
.test-contents
|
||||||
.claude/settings.local.json
|
.claude/settings.local.json
|
||||||
CLAUDE.local.md
|
|
||||||
/autogpt_platform/backend/logs
|
/autogpt_platform/backend/logs
|
||||||
|
.next
|
||||||
24
AGENTS.md
24
AGENTS.md
@@ -16,6 +16,7 @@ See `docs/content/platform/getting-started.md` for setup instructions.
|
|||||||
- Format Python code with `poetry run format`.
|
- Format Python code with `poetry run format`.
|
||||||
- Format frontend code using `pnpm format`.
|
- Format frontend code using `pnpm format`.
|
||||||
|
|
||||||
|
|
||||||
## Frontend guidelines:
|
## Frontend guidelines:
|
||||||
|
|
||||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||||
@@ -32,17 +33,14 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
|||||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||||
|
|
||||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
||||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
||||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
||||||
- Use function declarations for components, arrow functions only for callbacks
|
- Use function declarations for components, arrow functions only for callbacks
|
||||||
- No barrel files or `index.ts` re-exports
|
- No barrel files or `index.ts` re-exports
|
||||||
|
- Do not use `useCallback` or `useMemo` unless strictly needed
|
||||||
- Avoid comments at all times unless the code is very complex
|
- Avoid comments at all times unless the code is very complex
|
||||||
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
|
|
||||||
- Do not type hook returns, let Typescript infer as much as possible
|
|
||||||
- Never type with `any`, if not types available use `unknown`
|
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
@@ -51,8 +49,22 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
|||||||
|
|
||||||
Always run the relevant linters and tests before committing.
|
Always run the relevant linters and tests before committing.
|
||||||
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
||||||
Types: - feat - fix - refactor - ci - dx (developer experience)
|
Types:
|
||||||
Scopes: - platform - platform/library - platform/marketplace - backend - backend/executor - frontend - frontend/library - frontend/marketplace - blocks
|
- feat
|
||||||
|
- fix
|
||||||
|
- refactor
|
||||||
|
- ci
|
||||||
|
- dx (developer experience)
|
||||||
|
Scopes:
|
||||||
|
- platform
|
||||||
|
- platform/library
|
||||||
|
- platform/marketplace
|
||||||
|
- backend
|
||||||
|
- backend/executor
|
||||||
|
- frontend
|
||||||
|
- frontend/library
|
||||||
|
- frontend/marketplace
|
||||||
|
- blocks
|
||||||
|
|
||||||
## Pull requests
|
## Pull requests
|
||||||
|
|
||||||
|
|||||||
@@ -6,30 +6,152 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
|||||||
|
|
||||||
AutoGPT Platform is a monorepo containing:
|
AutoGPT Platform is a monorepo containing:
|
||||||
|
|
||||||
- **Backend** (`backend`): Python FastAPI server with async support
|
- **Backend** (`/backend`): Python FastAPI server with async support
|
||||||
- **Frontend** (`frontend`): Next.js React application
|
- **Frontend** (`/frontend`): Next.js React application
|
||||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
|
||||||
|
|
||||||
## Component Documentation
|
## Essential Commands
|
||||||
|
|
||||||
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
|
### Backend Development
|
||||||
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
|
|
||||||
|
|
||||||
## Key Concepts
|
```bash
|
||||||
|
# Install dependencies
|
||||||
|
cd backend && poetry install
|
||||||
|
|
||||||
|
# Run database migrations
|
||||||
|
poetry run prisma migrate dev
|
||||||
|
|
||||||
|
# Start all services (database, redis, rabbitmq, clamav)
|
||||||
|
docker compose up -d
|
||||||
|
|
||||||
|
# Run the backend server
|
||||||
|
poetry run serve
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
poetry run test
|
||||||
|
|
||||||
|
# Run specific test
|
||||||
|
poetry run pytest path/to/test_file.py::test_function_name
|
||||||
|
|
||||||
|
# Run block tests (tests that validate all blocks work correctly)
|
||||||
|
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||||
|
|
||||||
|
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||||
|
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||||
|
|
||||||
|
# Lint and format
|
||||||
|
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||||
|
poetry run format # Black + isort
|
||||||
|
poetry run lint # ruff
|
||||||
|
```
|
||||||
|
|
||||||
|
More details can be found in TESTING.md
|
||||||
|
|
||||||
|
#### Creating/Updating Snapshots
|
||||||
|
|
||||||
|
When you first write a test or when the expected output changes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
poetry run pytest path/to/test.py --snapshot-update
|
||||||
|
```
|
||||||
|
|
||||||
|
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||||
|
|
||||||
|
### Frontend Development
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dependencies
|
||||||
|
cd frontend && pnpm i
|
||||||
|
|
||||||
|
# Generate API client from OpenAPI spec
|
||||||
|
pnpm generate:api
|
||||||
|
|
||||||
|
# Start development server
|
||||||
|
pnpm dev
|
||||||
|
|
||||||
|
# Run E2E tests
|
||||||
|
pnpm test
|
||||||
|
|
||||||
|
# Run Storybook for component development
|
||||||
|
pnpm storybook
|
||||||
|
|
||||||
|
# Build production
|
||||||
|
pnpm build
|
||||||
|
|
||||||
|
# Format and lint
|
||||||
|
pnpm format
|
||||||
|
|
||||||
|
# Type checking
|
||||||
|
pnpm types
|
||||||
|
```
|
||||||
|
|
||||||
|
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
|
||||||
|
|
||||||
|
**Key Frontend Conventions:**
|
||||||
|
|
||||||
|
- Separate render logic from data/behavior in components
|
||||||
|
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||||
|
- Use function declarations (not arrow functions) for components/handlers
|
||||||
|
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||||
|
- Only use Phosphor Icons
|
||||||
|
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
|
||||||
|
|
||||||
|
## Architecture Overview
|
||||||
|
|
||||||
|
### Backend Architecture
|
||||||
|
|
||||||
|
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||||
|
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||||
|
- **Queue System**: RabbitMQ for async task processing
|
||||||
|
- **Execution Engine**: Separate executor service processes agent workflows
|
||||||
|
- **Authentication**: JWT-based with Supabase integration
|
||||||
|
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||||
|
|
||||||
|
### Frontend Architecture
|
||||||
|
|
||||||
|
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||||
|
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||||
|
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||||
|
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||||
|
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||||
|
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||||
|
- **Icons**: Phosphor Icons only
|
||||||
|
- **Feature Flags**: LaunchDarkly integration
|
||||||
|
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||||
|
- **Testing**: Playwright for E2E, Storybook for component development
|
||||||
|
|
||||||
|
### Key Concepts
|
||||||
|
|
||||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
|
||||||
3. **Integrations**: OAuth and API connections stored per user
|
3. **Integrations**: OAuth and API connections stored per user
|
||||||
4. **Store**: Marketplace for sharing agent templates
|
4. **Store**: Marketplace for sharing agent templates
|
||||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||||
|
|
||||||
|
### Testing Approach
|
||||||
|
|
||||||
|
- Backend uses pytest with snapshot testing for API responses
|
||||||
|
- Test files are colocated with source files (`*_test.py`)
|
||||||
|
- Frontend uses Playwright for E2E tests
|
||||||
|
- Component testing via Storybook
|
||||||
|
|
||||||
|
### Database Schema
|
||||||
|
|
||||||
|
Key models (defined in `/backend/schema.prisma`):
|
||||||
|
|
||||||
|
- `User`: Authentication and profile data
|
||||||
|
- `AgentGraph`: Workflow definitions with version control
|
||||||
|
- `AgentGraphExecution`: Execution history and results
|
||||||
|
- `AgentNode`: Individual nodes in a workflow
|
||||||
|
- `StoreListing`: Marketplace listings for sharing agents
|
||||||
|
|
||||||
### Environment Configuration
|
### Environment Configuration
|
||||||
|
|
||||||
#### Configuration Files
|
#### Configuration Files
|
||||||
|
|
||||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
|
||||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
|
||||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
|
||||||
|
|
||||||
#### Docker Environment Loading Order
|
#### Docker Environment Loading Order
|
||||||
|
|
||||||
@@ -45,12 +167,83 @@ AutoGPT Platform is a monorepo containing:
|
|||||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||||
|
|
||||||
|
### Common Development Tasks
|
||||||
|
|
||||||
|
**Adding a new block:**
|
||||||
|
|
||||||
|
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||||
|
|
||||||
|
- Provider configuration with `ProviderBuilder`
|
||||||
|
- Block schema definition
|
||||||
|
- Authentication (API keys, OAuth, webhooks)
|
||||||
|
- Testing and validation
|
||||||
|
- File organization
|
||||||
|
|
||||||
|
Quick steps:
|
||||||
|
|
||||||
|
1. Create new file in `/backend/backend/blocks/`
|
||||||
|
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||||
|
3. Inherit from `Block` base class
|
||||||
|
4. Define input/output schemas using `BlockSchema`
|
||||||
|
5. Implement async `run` method
|
||||||
|
6. Generate unique block ID using `uuid.uuid4()`
|
||||||
|
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||||
|
|
||||||
|
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||||
|
ex: do the inputs and outputs tie well together?
|
||||||
|
|
||||||
|
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||||
|
|
||||||
|
**Modifying the API:**
|
||||||
|
|
||||||
|
1. Update route in `/backend/backend/server/routers/`
|
||||||
|
2. Add/update Pydantic models in same directory
|
||||||
|
3. Write tests alongside the route file
|
||||||
|
4. Run `poetry run test` to verify
|
||||||
|
|
||||||
|
### Frontend guidelines:
|
||||||
|
|
||||||
|
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||||
|
|
||||||
|
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||||
|
- Add `usePageName.ts` hook for logic
|
||||||
|
- Put sub-components in local `components/` folder
|
||||||
|
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||||
|
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||||
|
- Never use `src/components/__legacy__/*`
|
||||||
|
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||||
|
- Regenerate with `pnpm generate:api`
|
||||||
|
- Pattern: `use{Method}{Version}{OperationName}`
|
||||||
|
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||||
|
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||||
|
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||||
|
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||||
|
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
||||||
|
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
||||||
|
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
||||||
|
- Use function declarations for components, arrow functions only for callbacks
|
||||||
|
- No barrel files or `index.ts` re-exports
|
||||||
|
- Do not use `useCallback` or `useMemo` unless strictly needed
|
||||||
|
- Avoid comments at all times unless the code is very complex
|
||||||
|
|
||||||
|
### Security Implementation
|
||||||
|
|
||||||
|
**Cache Protection Middleware:**
|
||||||
|
|
||||||
|
- Located in `/backend/backend/server/middleware/security.py`
|
||||||
|
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||||
|
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||||
|
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
|
||||||
|
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||||
|
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||||
|
- Applied to both main API server and external API applications
|
||||||
|
|
||||||
### Creating Pull Requests
|
### Creating Pull Requests
|
||||||
|
|
||||||
- Create the PR against the `dev` branch of the repository.
|
- Create the PR aginst the `dev` branch of the repository.
|
||||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
|
||||||
- Use conventional commit messages (see below)
|
- Use conventional commit messages (see below)/
|
||||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
|
||||||
- Run the github pre-commit hooks to ensure code quality.
|
- Run the github pre-commit hooks to ensure code quality.
|
||||||
|
|
||||||
### Reviewing/Revising Pull Requests
|
### Reviewing/Revising Pull Requests
|
||||||
|
|||||||
@@ -1,170 +0,0 @@
|
|||||||
# CLAUDE.md - Backend
|
|
||||||
|
|
||||||
This file provides guidance to Claude Code when working with the backend.
|
|
||||||
|
|
||||||
## Essential Commands
|
|
||||||
|
|
||||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Install dependencies
|
|
||||||
poetry install
|
|
||||||
|
|
||||||
# Run database migrations
|
|
||||||
poetry run prisma migrate dev
|
|
||||||
|
|
||||||
# Start all services (database, redis, rabbitmq, clamav)
|
|
||||||
docker compose up -d
|
|
||||||
|
|
||||||
# Run the backend as a whole
|
|
||||||
poetry run app
|
|
||||||
|
|
||||||
# Run tests
|
|
||||||
poetry run test
|
|
||||||
|
|
||||||
# Run specific test
|
|
||||||
poetry run pytest path/to/test_file.py::test_function_name
|
|
||||||
|
|
||||||
# Run block tests (tests that validate all blocks work correctly)
|
|
||||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
|
||||||
|
|
||||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
|
||||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
|
||||||
|
|
||||||
# Lint and format
|
|
||||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
|
||||||
poetry run format # Black + isort
|
|
||||||
poetry run lint # ruff
|
|
||||||
```
|
|
||||||
|
|
||||||
More details can be found in @TESTING.md
|
|
||||||
|
|
||||||
### Creating/Updating Snapshots
|
|
||||||
|
|
||||||
When you first write a test or when the expected output changes:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
poetry run pytest path/to/test.py --snapshot-update
|
|
||||||
```
|
|
||||||
|
|
||||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
|
||||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
|
||||||
- **Queue System**: RabbitMQ for async task processing
|
|
||||||
- **Execution Engine**: Separate executor service processes agent workflows
|
|
||||||
- **Authentication**: JWT-based with Supabase integration
|
|
||||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
|
||||||
|
|
||||||
## Testing Approach
|
|
||||||
|
|
||||||
- Uses pytest with snapshot testing for API responses
|
|
||||||
- Test files are colocated with source files (`*_test.py`)
|
|
||||||
|
|
||||||
## Database Schema
|
|
||||||
|
|
||||||
Key models (defined in `schema.prisma`):
|
|
||||||
|
|
||||||
- `User`: Authentication and profile data
|
|
||||||
- `AgentGraph`: Workflow definitions with version control
|
|
||||||
- `AgentGraphExecution`: Execution history and results
|
|
||||||
- `AgentNode`: Individual nodes in a workflow
|
|
||||||
- `StoreListing`: Marketplace listings for sharing agents
|
|
||||||
|
|
||||||
## Environment Configuration
|
|
||||||
|
|
||||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
|
||||||
|
|
||||||
## Common Development Tasks
|
|
||||||
|
|
||||||
### Adding a new block
|
|
||||||
|
|
||||||
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
|
|
||||||
|
|
||||||
- Provider configuration with `ProviderBuilder`
|
|
||||||
- Block schema definition
|
|
||||||
- Authentication (API keys, OAuth, webhooks)
|
|
||||||
- Testing and validation
|
|
||||||
- File organization
|
|
||||||
|
|
||||||
Quick steps:
|
|
||||||
|
|
||||||
1. Create new file in `backend/blocks/`
|
|
||||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
|
||||||
3. Inherit from `Block` base class
|
|
||||||
4. Define input/output schemas using `BlockSchema`
|
|
||||||
5. Implement async `run` method
|
|
||||||
6. Generate unique block ID using `uuid.uuid4()`
|
|
||||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
|
||||||
|
|
||||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
|
||||||
ex: do the inputs and outputs tie well together?
|
|
||||||
|
|
||||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
|
||||||
|
|
||||||
#### Handling files in blocks with `store_media_file()`
|
|
||||||
|
|
||||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
|
||||||
|
|
||||||
| Format | Use When | Returns |
|
|
||||||
|--------|----------|---------|
|
|
||||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
|
||||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
|
||||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
|
||||||
|
|
||||||
**Examples:**
|
|
||||||
|
|
||||||
```python
|
|
||||||
# INPUT: Need to process file locally with ffmpeg
|
|
||||||
local_path = await store_media_file(
|
|
||||||
file=input_data.video,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
|
||||||
|
|
||||||
# INPUT: Need to send to external API like Replicate
|
|
||||||
image_b64 = await store_media_file(
|
|
||||||
file=input_data.image,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_external_api",
|
|
||||||
)
|
|
||||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
|
||||||
|
|
||||||
# OUTPUT: Returning result from block
|
|
||||||
result_url = await store_media_file(
|
|
||||||
file=generated_image_url,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
yield "image_url", result_url
|
|
||||||
# In CoPilot: result_url = "workspace://abc123"
|
|
||||||
# In graphs: result_url = "data:image/png;base64,..."
|
|
||||||
```
|
|
||||||
|
|
||||||
**Key points:**
|
|
||||||
|
|
||||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
|
||||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
|
||||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
|
||||||
|
|
||||||
### Modifying the API
|
|
||||||
|
|
||||||
1. Update route in `backend/api/features/`
|
|
||||||
2. Add/update Pydantic models in same directory
|
|
||||||
3. Write tests alongside the route file
|
|
||||||
4. Run `poetry run test` to verify
|
|
||||||
|
|
||||||
## Security Implementation
|
|
||||||
|
|
||||||
### Cache Protection Middleware
|
|
||||||
|
|
||||||
- Located in `backend/api/middleware/security.py`
|
|
||||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
|
||||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
|
||||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
|
||||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
|
||||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
|
||||||
- Applied to both main API server and external API applications
|
|
||||||
@@ -138,7 +138,7 @@ If the test doesn't need the `user_id` specifically, mocking is not necessary as
|
|||||||
|
|
||||||
#### Using Global Auth Fixtures
|
#### Using Global Auth Fixtures
|
||||||
|
|
||||||
Two global auth fixtures are provided by `backend/api/conftest.py`:
|
Two global auth fixtures are provided by `backend/server/conftest.py`:
|
||||||
|
|
||||||
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
|
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
|
||||||
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
|
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ router = fastapi.APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Taken from backend/api/features/store/db.py
|
# Taken from backend/server/v2/store/db.py
|
||||||
def sanitize_query(query: str | None) -> str | None:
|
def sanitize_query(query: str | None) -> str | None:
|
||||||
if query is None:
|
if query is None:
|
||||||
return query
|
return query
|
||||||
|
|||||||
@@ -1,79 +0,0 @@
|
|||||||
# CoPilot Tools - Future Ideas
|
|
||||||
|
|
||||||
## Multimodal Image Support for CoPilot
|
|
||||||
|
|
||||||
**Problem:** CoPilot uses a vision-capable model but can't "see" workspace images. When a block generates an image and returns `workspace://abc123`, CoPilot can't evaluate it (e.g., checking blog thumbnail quality).
|
|
||||||
|
|
||||||
**Backend Solution:**
|
|
||||||
When preparing messages for the LLM, detect `workspace://` image references and convert them to proper image content blocks:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Before sending to LLM, scan for workspace image references
|
|
||||||
# and inject them as image content parts
|
|
||||||
|
|
||||||
# Example message transformation:
|
|
||||||
# FROM: {"role": "assistant", "content": "Generated image: workspace://abc123"}
|
|
||||||
# TO: {"role": "assistant", "content": [
|
|
||||||
# {"type": "text", "text": "Generated image: workspace://abc123"},
|
|
||||||
# {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
|
|
||||||
# ]}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Where to implement:**
|
|
||||||
- In the chat stream handler before calling the LLM
|
|
||||||
- Or in a message preprocessing step
|
|
||||||
- Need to fetch image from workspace, convert to base64, add as image content
|
|
||||||
|
|
||||||
**Considerations:**
|
|
||||||
- Only do this for image MIME types (image/png, image/jpeg, etc.)
|
|
||||||
- May want a size limit (don't pass 10MB images)
|
|
||||||
- Track which images were "shown" to the AI for frontend indicator
|
|
||||||
- Cost implications - vision API calls are more expensive
|
|
||||||
|
|
||||||
**Frontend Solution:**
|
|
||||||
Show visual indicator on workspace files in chat:
|
|
||||||
- If AI saw the image: normal display
|
|
||||||
- If AI didn't see it: overlay icon saying "AI can't see this image"
|
|
||||||
|
|
||||||
Requires response metadata indicating which `workspace://` refs were passed to the model.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Output Post-Processing Layer for run_block
|
|
||||||
|
|
||||||
**Problem:** Many blocks produce large outputs that:
|
|
||||||
- Consume massive context (100KB base64 image = ~133KB tokens)
|
|
||||||
- Can't fit in conversation
|
|
||||||
- Break things and cause high LLM costs
|
|
||||||
|
|
||||||
**Proposed Solution:** Instead of modifying individual blocks or `store_media_file()`, implement a centralized output processor in `run_block.py` that handles outputs before they're returned to CoPilot.
|
|
||||||
|
|
||||||
**Benefits:**
|
|
||||||
1. **Centralized** - one place to handle all output processing
|
|
||||||
2. **Future-proof** - new blocks automatically get output processing
|
|
||||||
3. **Keeps blocks pure** - they don't need to know about context constraints
|
|
||||||
4. **Handles all large outputs** - not just images
|
|
||||||
|
|
||||||
**Processing Rules:**
|
|
||||||
- Detect base64 data URIs → save to workspace, return `workspace://` reference
|
|
||||||
- Truncate very long strings (>N chars) with truncation note
|
|
||||||
- Summarize large arrays/lists (e.g., "Array with 1000 items, first 5: [...]")
|
|
||||||
- Handle nested large outputs in dicts recursively
|
|
||||||
- Cap total output size
|
|
||||||
|
|
||||||
**Implementation Location:** `run_block.py` after block execution, before returning `BlockOutputResponse`
|
|
||||||
|
|
||||||
**Example:**
|
|
||||||
```python
|
|
||||||
def _process_outputs_for_context(
|
|
||||||
outputs: dict[str, list[Any]],
|
|
||||||
workspace_manager: WorkspaceManager,
|
|
||||||
max_string_length: int = 10000,
|
|
||||||
max_array_preview: int = 5,
|
|
||||||
) -> dict[str, list[Any]]:
|
|
||||||
"""Process block outputs to prevent context bloat."""
|
|
||||||
processed = {}
|
|
||||||
for name, values in outputs.items():
|
|
||||||
processed[name] = [_process_value(v, workspace_manager) for v in values]
|
|
||||||
return processed
|
|
||||||
```
|
|
||||||
@@ -18,12 +18,6 @@ from .get_doc_page import GetDocPageTool
|
|||||||
from .run_agent import RunAgentTool
|
from .run_agent import RunAgentTool
|
||||||
from .run_block import RunBlockTool
|
from .run_block import RunBlockTool
|
||||||
from .search_docs import SearchDocsTool
|
from .search_docs import SearchDocsTool
|
||||||
from .workspace_files import (
|
|
||||||
DeleteWorkspaceFileTool,
|
|
||||||
ListWorkspaceFilesTool,
|
|
||||||
ReadWorkspaceFileTool,
|
|
||||||
WriteWorkspaceFileTool,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||||
@@ -43,11 +37,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
|||||||
"view_agent_output": AgentOutputTool(),
|
"view_agent_output": AgentOutputTool(),
|
||||||
"search_docs": SearchDocsTool(),
|
"search_docs": SearchDocsTool(),
|
||||||
"get_doc_page": GetDocPageTool(),
|
"get_doc_page": GetDocPageTool(),
|
||||||
# Workspace tools for CoPilot file operations
|
|
||||||
"list_workspace_files": ListWorkspaceFilesTool(),
|
|
||||||
"read_workspace_file": ReadWorkspaceFileTool(),
|
|
||||||
"write_workspace_file": WriteWorkspaceFileTool(),
|
|
||||||
"delete_workspace_file": DeleteWorkspaceFileTool(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Export individual tool instances for backwards compatibility
|
# Export individual tool instances for backwards compatibility
|
||||||
|
|||||||
@@ -28,12 +28,6 @@ class ResponseType(str, Enum):
|
|||||||
BLOCK_OUTPUT = "block_output"
|
BLOCK_OUTPUT = "block_output"
|
||||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||||
DOC_PAGE = "doc_page"
|
DOC_PAGE = "doc_page"
|
||||||
# Workspace response types
|
|
||||||
WORKSPACE_FILE_LIST = "workspace_file_list"
|
|
||||||
WORKSPACE_FILE_CONTENT = "workspace_file_content"
|
|
||||||
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
|
|
||||||
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
|
|
||||||
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
|
|
||||||
# Long-running operation types
|
# Long-running operation types
|
||||||
OPERATION_STARTED = "operation_started"
|
OPERATION_STARTED = "operation_started"
|
||||||
OPERATION_PENDING = "operation_pending"
|
OPERATION_PENDING = "operation_pending"
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Tool for executing blocks directly."""
|
"""Tool for executing blocks directly."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -9,7 +8,6 @@ from backend.api.features.chat.model import ChatSession
|
|||||||
from backend.data.block import get_block
|
from backend.data.block import get_block
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
@@ -225,48 +223,11 @@ class RunBlockTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get or create user's workspace for CoPilot file operations
|
# Fetch actual credentials and prepare kwargs for block execution
|
||||||
workspace = await get_or_create_workspace(user_id)
|
# Create execution context with defaults (blocks may require it)
|
||||||
|
|
||||||
# Generate synthetic IDs for CoPilot context
|
|
||||||
# Each chat session is treated as its own agent with one continuous run
|
|
||||||
# This means:
|
|
||||||
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
|
|
||||||
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
|
|
||||||
# - node_exec_id = unique per block execution
|
|
||||||
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
|
||||||
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
|
||||||
synthetic_node_id = f"copilot-node-{block_id}"
|
|
||||||
synthetic_node_exec_id = (
|
|
||||||
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create unified execution context with all required fields
|
|
||||||
execution_context = ExecutionContext(
|
|
||||||
# Execution identity
|
|
||||||
user_id=user_id,
|
|
||||||
graph_id=synthetic_graph_id,
|
|
||||||
graph_exec_id=synthetic_graph_exec_id,
|
|
||||||
graph_version=1, # Versions are 1-indexed
|
|
||||||
node_id=synthetic_node_id,
|
|
||||||
node_exec_id=synthetic_node_exec_id,
|
|
||||||
# Workspace with session scoping
|
|
||||||
workspace_id=workspace.id,
|
|
||||||
session_id=session.session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare kwargs for block execution
|
|
||||||
# Keep individual kwargs for backwards compatibility with existing blocks
|
|
||||||
exec_kwargs: dict[str, Any] = {
|
exec_kwargs: dict[str, Any] = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"execution_context": execution_context,
|
"execution_context": ExecutionContext(),
|
||||||
# Legacy: individual kwargs for blocks not yet using execution_context
|
|
||||||
"workspace_id": workspace.id,
|
|
||||||
"graph_exec_id": synthetic_graph_exec_id,
|
|
||||||
"node_exec_id": synthetic_node_exec_id,
|
|
||||||
"node_id": synthetic_node_id,
|
|
||||||
"graph_version": 1, # Versions are 1-indexed
|
|
||||||
"graph_id": synthetic_graph_id,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for field_name, cred_meta in matched_credentials.items():
|
for field_name, cred_meta in matched_credentials.items():
|
||||||
|
|||||||
@@ -1,620 +0,0 @@
|
|||||||
"""CoPilot tools for workspace file operations."""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import logging
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.data.workspace import get_or_create_workspace
|
|
||||||
from backend.util.settings import Config
|
|
||||||
from backend.util.virus_scanner import scan_content_safe
|
|
||||||
from backend.util.workspace import WorkspaceManager
|
|
||||||
|
|
||||||
from .base import BaseTool
|
|
||||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceFileInfoData(BaseModel):
|
|
||||||
"""Data model for workspace file information (not a response itself)."""
|
|
||||||
|
|
||||||
file_id: str
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
mime_type: str
|
|
||||||
size_bytes: int
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceFileListResponse(ToolResponseBase):
|
|
||||||
"""Response containing list of workspace files."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
|
|
||||||
files: list[WorkspaceFileInfoData]
|
|
||||||
total_count: int
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceFileContentResponse(ToolResponseBase):
|
|
||||||
"""Response containing workspace file content (legacy, for small text files)."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
|
|
||||||
file_id: str
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
mime_type: str
|
|
||||||
content_base64: str
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
|
||||||
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
|
|
||||||
file_id: str
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
mime_type: str
|
|
||||||
size_bytes: int
|
|
||||||
download_url: str
|
|
||||||
preview: str | None = None # First 500 chars for text files
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceWriteResponse(ToolResponseBase):
|
|
||||||
"""Response after writing a file to workspace."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
|
|
||||||
file_id: str
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
size_bytes: int
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceDeleteResponse(ToolResponseBase):
|
|
||||||
"""Response after deleting a file from workspace."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
|
|
||||||
file_id: str
|
|
||||||
success: bool
|
|
||||||
|
|
||||||
|
|
||||||
class ListWorkspaceFilesTool(BaseTool):
|
|
||||||
"""Tool for listing files in user's workspace."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "list_workspace_files"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"List files in the user's workspace. "
|
|
||||||
"Returns file names, paths, sizes, and metadata. "
|
|
||||||
"Optionally filter by path prefix."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"path_prefix": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Optional path prefix to filter files "
|
|
||||||
"(e.g., '/documents/' to list only files in documents folder). "
|
|
||||||
"By default, only files from the current session are listed."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"limit": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Maximum number of files to return (default 50, max 100)",
|
|
||||||
"minimum": 1,
|
|
||||||
"maximum": 100,
|
|
||||||
},
|
|
||||||
"include_all_sessions": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": (
|
|
||||||
"If true, list files from all sessions. "
|
|
||||||
"Default is false (only current session's files)."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
session_id = session.session_id
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Authentication required",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
|
||||||
limit = min(kwargs.get("limit", 50), 100)
|
|
||||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
|
||||||
|
|
||||||
try:
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
# Pass session_id for session-scoped file access
|
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
|
||||||
|
|
||||||
files = await manager.list_files(
|
|
||||||
path=path_prefix,
|
|
||||||
limit=limit,
|
|
||||||
include_all_sessions=include_all_sessions,
|
|
||||||
)
|
|
||||||
total = await manager.get_file_count(
|
|
||||||
path=path_prefix,
|
|
||||||
include_all_sessions=include_all_sessions,
|
|
||||||
)
|
|
||||||
|
|
||||||
file_infos = [
|
|
||||||
WorkspaceFileInfoData(
|
|
||||||
file_id=f.id,
|
|
||||||
name=f.name,
|
|
||||||
path=f.path,
|
|
||||||
mime_type=f.mimeType,
|
|
||||||
size_bytes=f.sizeBytes,
|
|
||||||
)
|
|
||||||
for f in files
|
|
||||||
]
|
|
||||||
|
|
||||||
scope_msg = "all sessions" if include_all_sessions else "current session"
|
|
||||||
return WorkspaceFileListResponse(
|
|
||||||
files=file_infos,
|
|
||||||
total_count=total,
|
|
||||||
message=f"Found {len(files)} files in workspace ({scope_msg})",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error listing workspace files: {e}", exc_info=True)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to list workspace files: {str(e)}",
|
|
||||||
error=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReadWorkspaceFileTool(BaseTool):
|
|
||||||
"""Tool for reading file content from workspace."""
|
|
||||||
|
|
||||||
# Size threshold for returning full content vs metadata+URL
|
|
||||||
# Files larger than this return metadata with download URL to prevent context bloat
|
|
||||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
|
||||||
# Preview size for text files
|
|
||||||
PREVIEW_SIZE = 500
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "read_workspace_file"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Read a file from the user's workspace. "
|
|
||||||
"Specify either file_id or path to identify the file. "
|
|
||||||
"For small text files, returns content directly. "
|
|
||||||
"For large or binary files, returns metadata and a download URL. "
|
|
||||||
"Paths are scoped to the current session by default. "
|
|
||||||
"Use /sessions/<session_id>/... for cross-session access."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"file_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The file's unique ID (from list_workspace_files)",
|
|
||||||
},
|
|
||||||
"path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
|
||||||
"Scoped to current session by default."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"force_download_url": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": (
|
|
||||||
"If true, always return metadata+URL instead of inline content. "
|
|
||||||
"Default is false (auto-selects based on file size/type)."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": [], # At least one must be provided
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _is_text_mime_type(self, mime_type: str) -> bool:
|
|
||||||
"""Check if the MIME type is a text-based type."""
|
|
||||||
text_types = [
|
|
||||||
"text/",
|
|
||||||
"application/json",
|
|
||||||
"application/xml",
|
|
||||||
"application/javascript",
|
|
||||||
"application/x-python",
|
|
||||||
"application/x-sh",
|
|
||||||
]
|
|
||||||
return any(mime_type.startswith(t) for t in text_types)
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
session_id = session.session_id
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Authentication required",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
file_id: Optional[str] = kwargs.get("file_id")
|
|
||||||
path: Optional[str] = kwargs.get("path")
|
|
||||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
|
||||||
|
|
||||||
if not file_id and not path:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide either file_id or path",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
# Pass session_id for session-scoped file access
|
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
|
||||||
|
|
||||||
# Get file info
|
|
||||||
if file_id:
|
|
||||||
file_info = await manager.get_file_info(file_id)
|
|
||||||
if file_info is None:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"File not found: {file_id}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
target_file_id = file_id
|
|
||||||
else:
|
|
||||||
# path is guaranteed to be non-None here due to the check above
|
|
||||||
assert path is not None
|
|
||||||
file_info = await manager.get_file_info_by_path(path)
|
|
||||||
if file_info is None:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"File not found at path: {path}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
target_file_id = file_info.id
|
|
||||||
|
|
||||||
# Decide whether to return inline content or metadata+URL
|
|
||||||
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
|
|
||||||
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
|
||||||
|
|
||||||
# Return inline content for small text files (unless force_download_url)
|
|
||||||
if is_small_file and is_text_file and not force_download_url:
|
|
||||||
content = await manager.read_file_by_id(target_file_id)
|
|
||||||
content_b64 = base64.b64encode(content).decode("utf-8")
|
|
||||||
|
|
||||||
return WorkspaceFileContentResponse(
|
|
||||||
file_id=file_info.id,
|
|
||||||
name=file_info.name,
|
|
||||||
path=file_info.path,
|
|
||||||
mime_type=file_info.mimeType,
|
|
||||||
content_base64=content_b64,
|
|
||||||
message=f"Successfully read file: {file_info.name}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return metadata + workspace:// reference for large or binary files
|
|
||||||
# This prevents context bloat (100KB file = ~133KB as base64)
|
|
||||||
# Use workspace:// format so frontend urlTransform can add proxy prefix
|
|
||||||
download_url = f"workspace://{target_file_id}"
|
|
||||||
|
|
||||||
# Generate preview for text files
|
|
||||||
preview: str | None = None
|
|
||||||
if is_text_file:
|
|
||||||
try:
|
|
||||||
content = await manager.read_file_by_id(target_file_id)
|
|
||||||
preview_text = content[: self.PREVIEW_SIZE].decode(
|
|
||||||
"utf-8", errors="replace"
|
|
||||||
)
|
|
||||||
if len(content) > self.PREVIEW_SIZE:
|
|
||||||
preview_text += "..."
|
|
||||||
preview = preview_text
|
|
||||||
except Exception:
|
|
||||||
pass # Preview is optional
|
|
||||||
|
|
||||||
return WorkspaceFileMetadataResponse(
|
|
||||||
file_id=file_info.id,
|
|
||||||
name=file_info.name,
|
|
||||||
path=file_info.path,
|
|
||||||
mime_type=file_info.mimeType,
|
|
||||||
size_bytes=file_info.sizeBytes,
|
|
||||||
download_url=download_url,
|
|
||||||
preview=preview,
|
|
||||||
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error reading workspace file: {e}", exc_info=True)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to read workspace file: {str(e)}",
|
|
||||||
error=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WriteWorkspaceFileTool(BaseTool):
|
|
||||||
"""Tool for writing files to workspace."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "write_workspace_file"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Write or create a file in the user's workspace. "
|
|
||||||
"Provide the content as a base64-encoded string. "
|
|
||||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
|
||||||
"Files are saved to the current session's folder by default. "
|
|
||||||
"Use /sessions/<session_id>/... for cross-session access."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"filename": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Name for the file (e.g., 'report.pdf')",
|
|
||||||
},
|
|
||||||
"content_base64": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Base64-encoded file content",
|
|
||||||
},
|
|
||||||
"path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Optional virtual path where to save the file "
|
|
||||||
"(e.g., '/documents/report.pdf'). "
|
|
||||||
"Defaults to '/{filename}'. Scoped to current session."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"mime_type": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Optional MIME type of the file. "
|
|
||||||
"Auto-detected from filename if not provided."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"overwrite": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": "Whether to overwrite if file exists at path (default: false)",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["filename", "content_base64"],
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
session_id = session.session_id
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Authentication required",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
filename: str = kwargs.get("filename", "")
|
|
||||||
content_b64: str = kwargs.get("content_base64", "")
|
|
||||||
path: Optional[str] = kwargs.get("path")
|
|
||||||
mime_type: Optional[str] = kwargs.get("mime_type")
|
|
||||||
overwrite: bool = kwargs.get("overwrite", False)
|
|
||||||
|
|
||||||
if not filename:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide a filename",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not content_b64:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide content_base64",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decode content
|
|
||||||
try:
|
|
||||||
content = base64.b64decode(content_b64)
|
|
||||||
except Exception:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Invalid base64-encoded content",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check size
|
|
||||||
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
|
||||||
if len(content) > max_file_size:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Virus scan
|
|
||||||
await scan_content_safe(content, filename=filename)
|
|
||||||
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
# Pass session_id for session-scoped file access
|
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
|
||||||
|
|
||||||
file_record = await manager.write_file(
|
|
||||||
content=content,
|
|
||||||
filename=filename,
|
|
||||||
path=path,
|
|
||||||
mime_type=mime_type,
|
|
||||||
overwrite=overwrite,
|
|
||||||
)
|
|
||||||
|
|
||||||
return WorkspaceWriteResponse(
|
|
||||||
file_id=file_record.id,
|
|
||||||
name=file_record.name,
|
|
||||||
path=file_record.path,
|
|
||||||
size_bytes=file_record.sizeBytes,
|
|
||||||
message=f"Successfully wrote file: {file_record.name}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error writing workspace file: {e}", exc_info=True)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to write workspace file: {str(e)}",
|
|
||||||
error=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DeleteWorkspaceFileTool(BaseTool):
|
|
||||||
"""Tool for deleting files from workspace."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "delete_workspace_file"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Delete a file from the user's workspace. "
|
|
||||||
"Specify either file_id or path to identify the file. "
|
|
||||||
"Paths are scoped to the current session by default. "
|
|
||||||
"Use /sessions/<session_id>/... for cross-session access."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"file_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The file's unique ID (from list_workspace_files)",
|
|
||||||
},
|
|
||||||
"path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
|
||||||
"Scoped to current session by default."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": [], # At least one must be provided
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
session_id = session.session_id
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Authentication required",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
file_id: Optional[str] = kwargs.get("file_id")
|
|
||||||
path: Optional[str] = kwargs.get("path")
|
|
||||||
|
|
||||||
if not file_id and not path:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide either file_id or path",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
|
||||||
# Pass session_id for session-scoped file access
|
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
|
||||||
|
|
||||||
# Determine the file_id to delete
|
|
||||||
target_file_id: str
|
|
||||||
if file_id:
|
|
||||||
target_file_id = file_id
|
|
||||||
else:
|
|
||||||
# path is guaranteed to be non-None here due to the check above
|
|
||||||
assert path is not None
|
|
||||||
file_info = await manager.get_file_info_by_path(path)
|
|
||||||
if file_info is None:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"File not found at path: {path}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
target_file_id = file_info.id
|
|
||||||
|
|
||||||
success = await manager.delete_file(target_file_id)
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"File not found: {target_file_id}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return WorkspaceDeleteResponse(
|
|
||||||
file_id=target_file_id,
|
|
||||||
success=True,
|
|
||||||
message="File deleted successfully",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to delete workspace file: {str(e)}",
|
|
||||||
error=str(e),
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
# Workspace API feature module
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
"""
|
|
||||||
Workspace API routes for managing user file storage.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from typing import Annotated
|
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
|
||||||
from fastapi.responses import Response
|
|
||||||
|
|
||||||
from backend.data.workspace import get_workspace, get_workspace_file
|
|
||||||
from backend.util.workspace_storage import get_workspace_storage
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_filename_for_header(filename: str) -> str:
|
|
||||||
"""
|
|
||||||
Sanitize filename for Content-Disposition header to prevent header injection.
|
|
||||||
|
|
||||||
Removes/replaces characters that could break the header or inject new headers.
|
|
||||||
Uses RFC5987 encoding for non-ASCII characters.
|
|
||||||
"""
|
|
||||||
# Remove CR, LF, and null bytes (header injection prevention)
|
|
||||||
sanitized = re.sub(r"[\r\n\x00]", "", filename)
|
|
||||||
# Escape quotes
|
|
||||||
sanitized = sanitized.replace('"', '\\"')
|
|
||||||
# For non-ASCII, use RFC5987 filename* parameter
|
|
||||||
# Check if filename has non-ASCII characters
|
|
||||||
try:
|
|
||||||
sanitized.encode("ascii")
|
|
||||||
return f'attachment; filename="{sanitized}"'
|
|
||||||
except UnicodeEncodeError:
|
|
||||||
# Use RFC5987 encoding for UTF-8 filenames
|
|
||||||
encoded = quote(sanitized, safe="")
|
|
||||||
return f"attachment; filename*=UTF-8''{encoded}"
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = fastapi.APIRouter(
|
|
||||||
dependencies=[fastapi.Security(requires_user)],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_streaming_response(content: bytes, file) -> Response:
|
|
||||||
"""Create a streaming response for file content."""
|
|
||||||
return Response(
|
|
||||||
content=content,
|
|
||||||
media_type=file.mimeType,
|
|
||||||
headers={
|
|
||||||
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
|
||||||
"Content-Length": str(len(content)),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _create_file_download_response(file) -> Response:
|
|
||||||
"""
|
|
||||||
Create a download response for a workspace file.
|
|
||||||
|
|
||||||
Handles both local storage (direct streaming) and GCS (signed URL redirect
|
|
||||||
with fallback to streaming).
|
|
||||||
"""
|
|
||||||
storage = await get_workspace_storage()
|
|
||||||
|
|
||||||
# For local storage, stream the file directly
|
|
||||||
if file.storagePath.startswith("local://"):
|
|
||||||
content = await storage.retrieve(file.storagePath)
|
|
||||||
return _create_streaming_response(content, file)
|
|
||||||
|
|
||||||
# For GCS, try to redirect to signed URL, fall back to streaming
|
|
||||||
try:
|
|
||||||
url = await storage.get_download_url(file.storagePath, expires_in=300)
|
|
||||||
# If we got back an API path (fallback), stream directly instead
|
|
||||||
if url.startswith("/api/"):
|
|
||||||
content = await storage.retrieve(file.storagePath)
|
|
||||||
return _create_streaming_response(content, file)
|
|
||||||
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
|
||||||
except Exception as e:
|
|
||||||
# Log the signed URL failure with context
|
|
||||||
logger.error(
|
|
||||||
f"Failed to get signed URL for file {file.id} "
|
|
||||||
f"(storagePath={file.storagePath}): {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# Fall back to streaming directly from GCS
|
|
||||||
try:
|
|
||||||
content = await storage.retrieve(file.storagePath)
|
|
||||||
return _create_streaming_response(content, file)
|
|
||||||
except Exception as fallback_error:
|
|
||||||
logger.error(
|
|
||||||
f"Fallback streaming also failed for file {file.id} "
|
|
||||||
f"(storagePath={file.storagePath}): {fallback_error}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/files/{file_id}/download",
|
|
||||||
summary="Download file by ID",
|
|
||||||
)
|
|
||||||
async def download_file(
|
|
||||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
|
||||||
file_id: str,
|
|
||||||
) -> Response:
|
|
||||||
"""
|
|
||||||
Download a file by its ID.
|
|
||||||
|
|
||||||
Returns the file content directly or redirects to a signed URL for GCS.
|
|
||||||
"""
|
|
||||||
workspace = await get_workspace(user_id)
|
|
||||||
if workspace is None:
|
|
||||||
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
|
|
||||||
|
|
||||||
file = await get_workspace_file(file_id, workspace.id)
|
|
||||||
if file is None:
|
|
||||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
|
||||||
|
|
||||||
return await _create_file_download_response(file)
|
|
||||||
@@ -32,7 +32,6 @@ import backend.api.features.postmark.postmark
|
|||||||
import backend.api.features.store.model
|
import backend.api.features.store.model
|
||||||
import backend.api.features.store.routes
|
import backend.api.features.store.routes
|
||||||
import backend.api.features.v1
|
import backend.api.features.v1
|
||||||
import backend.api.features.workspace.routes as workspace_routes
|
|
||||||
import backend.data.block
|
import backend.data.block
|
||||||
import backend.data.db
|
import backend.data.db
|
||||||
import backend.data.graph
|
import backend.data.graph
|
||||||
@@ -53,7 +52,6 @@ from backend.util.exceptions import (
|
|||||||
)
|
)
|
||||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||||
from backend.util.service import UnhealthyServiceError
|
from backend.util.service import UnhealthyServiceError
|
||||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
|
||||||
|
|
||||||
from .external.fastapi_app import external_api
|
from .external.fastapi_app import external_api
|
||||||
from .features.analytics import router as analytics_router
|
from .features.analytics import router as analytics_router
|
||||||
@@ -126,11 +124,6 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error shutting down cloud storage handler: {e}")
|
logger.warning(f"Error shutting down cloud storage handler: {e}")
|
||||||
|
|
||||||
try:
|
|
||||||
await shutdown_workspace_storage()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error shutting down workspace storage: {e}")
|
|
||||||
|
|
||||||
await backend.data.db.disconnect()
|
await backend.data.db.disconnect()
|
||||||
|
|
||||||
|
|
||||||
@@ -322,11 +315,6 @@ app.include_router(
|
|||||||
tags=["v2", "chat"],
|
tags=["v2", "chat"],
|
||||||
prefix="/api/chat",
|
prefix="/api/chat",
|
||||||
)
|
)
|
||||||
app.include_router(
|
|
||||||
workspace_routes.router,
|
|
||||||
tags=["workspace"],
|
|
||||||
prefix="/api/workspace",
|
|
||||||
)
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.oauth.router,
|
backend.api.features.oauth.router,
|
||||||
tags=["oauth"],
|
tags=["oauth"],
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -118,13 +117,11 @@ class AIImageCustomizerBlock(Block):
|
|||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_output=[
|
test_output=[
|
||||||
# Output will be a workspace ref or data URI depending on context
|
("image_url", "https://replicate.delivery/generated-image.jpg"),
|
||||||
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
|
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
# Use data URI to avoid HTTP requests during tests
|
|
||||||
"run_model": lambda *args, **kwargs: MediaFileType(
|
"run_model": lambda *args, **kwargs: MediaFileType(
|
||||||
"data:image/jpeg;base64,/9j/4AAQSkZJRgABAgAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q=="
|
"https://replicate.delivery/generated-image.jpg"
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -135,7 +132,8 @@ class AIImageCustomizerBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
@@ -143,9 +141,10 @@ class AIImageCustomizerBlock(Block):
|
|||||||
processed_images = await asyncio.gather(
|
processed_images = await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
store_media_file(
|
store_media_file(
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=img,
|
file=img,
|
||||||
execution_context=execution_context,
|
user_id=user_id,
|
||||||
return_format="for_external_api", # Get content for Replicate API
|
return_content=True,
|
||||||
)
|
)
|
||||||
for img in input_data.images
|
for img in input_data.images
|
||||||
)
|
)
|
||||||
@@ -159,14 +158,7 @@ class AIImageCustomizerBlock(Block):
|
|||||||
aspect_ratio=input_data.aspect_ratio.value,
|
aspect_ratio=input_data.aspect_ratio.value,
|
||||||
output_format=input_data.output_format.value,
|
output_format=input_data.output_format.value,
|
||||||
)
|
)
|
||||||
|
yield "image_url", result
|
||||||
# Store the generated image to the user's workspace for persistence
|
|
||||||
stored_url = await store_media_file(
|
|
||||||
file=result,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
yield "image_url", stored_url
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "error", str(e)
|
yield "error", str(e)
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from replicate.client import Client as ReplicateClient
|
|||||||
from replicate.helpers import FileOutput
|
from replicate.helpers import FileOutput
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -14,8 +13,6 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.file import store_media_file
|
|
||||||
from backend.util.type import MediaFileType
|
|
||||||
|
|
||||||
|
|
||||||
class ImageSize(str, Enum):
|
class ImageSize(str, Enum):
|
||||||
@@ -168,13 +165,11 @@ class AIImageGeneratorBlock(Block):
|
|||||||
test_output=[
|
test_output=[
|
||||||
(
|
(
|
||||||
"image_url",
|
"image_url",
|
||||||
# Test output is a data URI since we now store images
|
"https://replicate.delivery/generated-image.webp",
|
||||||
lambda x: x.startswith("data:image/"),
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
# Return a data URI directly so store_media_file doesn't need to download
|
"_run_client": lambda *args, **kwargs: "https://replicate.delivery/generated-image.webp"
|
||||||
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -323,24 +318,11 @@ class AIImageGeneratorBlock(Block):
|
|||||||
style_text = style_map.get(style, "")
|
style_text = style_map.get(style, "")
|
||||||
return f"{style_text} of" if style_text else ""
|
return f"{style_text} of" if style_text else ""
|
||||||
|
|
||||||
async def run(
|
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: APIKeyCredentials,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
url = await self.generate_image(input_data, credentials)
|
url = await self.generate_image(input_data, credentials)
|
||||||
if url:
|
if url:
|
||||||
# Store the generated image to the user's workspace/execution folder
|
yield "image_url", url
|
||||||
stored_url = await store_media_file(
|
|
||||||
file=MediaFileType(url),
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
yield "image_url", stored_url
|
|
||||||
else:
|
else:
|
||||||
yield "error", "Image generation returned an empty result."
|
yield "error", "Image generation returned an empty result."
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -22,9 +21,7 @@ from backend.data.model import (
|
|||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.exceptions import BlockExecutionError
|
from backend.util.exceptions import BlockExecutionError
|
||||||
from backend.util.file import store_media_file
|
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
from backend.util.type import MediaFileType
|
|
||||||
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
@@ -274,10 +271,7 @@ class AIShortformVideoCreatorBlock(Block):
|
|||||||
"voice": Voice.LILY,
|
"voice": Voice.LILY,
|
||||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||||
},
|
},
|
||||||
test_output=(
|
test_output=("video_url", "https://example.com/video.mp4"),
|
||||||
"video_url",
|
|
||||||
lambda x: x.startswith(("workspace://", "data:")),
|
|
||||||
),
|
|
||||||
test_mock={
|
test_mock={
|
||||||
"create_webhook": lambda *args, **kwargs: (
|
"create_webhook": lambda *args, **kwargs: (
|
||||||
"test_uuid",
|
"test_uuid",
|
||||||
@@ -286,21 +280,15 @@ class AIShortformVideoCreatorBlock(Block):
|
|||||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||||
"check_video_status": lambda *args, **kwargs: {
|
"check_video_status": lambda *args, **kwargs: {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"videoUrl": "data:video/mp4;base64,AAAA",
|
"videoUrl": "https://example.com/video.mp4",
|
||||||
},
|
},
|
||||||
# Use data URI to avoid HTTP requests during tests
|
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
|
||||||
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: APIKeyCredentials,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Create a new Webhook.site URL
|
# Create a new Webhook.site URL
|
||||||
webhook_token, webhook_url = await self.create_webhook()
|
webhook_token, webhook_url = await self.create_webhook()
|
||||||
@@ -352,13 +340,7 @@ class AIShortformVideoCreatorBlock(Block):
|
|||||||
)
|
)
|
||||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||||
logger.debug(f"Video ready: {video_url}")
|
logger.debug(f"Video ready: {video_url}")
|
||||||
# Store the generated video to the user's workspace for persistence
|
yield "video_url", video_url
|
||||||
stored_url = await store_media_file(
|
|
||||||
file=MediaFileType(video_url),
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
yield "video_url", stored_url
|
|
||||||
|
|
||||||
|
|
||||||
class AIAdMakerVideoCreatorBlock(Block):
|
class AIAdMakerVideoCreatorBlock(Block):
|
||||||
@@ -465,10 +447,7 @@ class AIAdMakerVideoCreatorBlock(Block):
|
|||||||
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
test_output=(
|
test_output=("video_url", "https://example.com/ad.mp4"),
|
||||||
"video_url",
|
|
||||||
lambda x: x.startswith(("workspace://", "data:")),
|
|
||||||
),
|
|
||||||
test_mock={
|
test_mock={
|
||||||
"create_webhook": lambda *args, **kwargs: (
|
"create_webhook": lambda *args, **kwargs: (
|
||||||
"test_uuid",
|
"test_uuid",
|
||||||
@@ -477,21 +456,14 @@ class AIAdMakerVideoCreatorBlock(Block):
|
|||||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||||
"check_video_status": lambda *args, **kwargs: {
|
"check_video_status": lambda *args, **kwargs: {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"videoUrl": "data:video/mp4;base64,AAAA",
|
"videoUrl": "https://example.com/ad.mp4",
|
||||||
},
|
},
|
||||||
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: APIKeyCredentials,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
webhook_token, webhook_url = await self.create_webhook()
|
webhook_token, webhook_url = await self.create_webhook()
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -559,13 +531,7 @@ class AIAdMakerVideoCreatorBlock(Block):
|
|||||||
raise RuntimeError("Failed to create video: No project ID returned")
|
raise RuntimeError("Failed to create video: No project ID returned")
|
||||||
|
|
||||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||||
# Store the generated video to the user's workspace for persistence
|
yield "video_url", video_url
|
||||||
stored_url = await store_media_file(
|
|
||||||
file=MediaFileType(video_url),
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
yield "video_url", stored_url
|
|
||||||
|
|
||||||
|
|
||||||
class AIScreenshotToVideoAdBlock(Block):
|
class AIScreenshotToVideoAdBlock(Block):
|
||||||
@@ -660,10 +626,7 @@ class AIScreenshotToVideoAdBlock(Block):
|
|||||||
"script": "Amazing numbers!",
|
"script": "Amazing numbers!",
|
||||||
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
||||||
},
|
},
|
||||||
test_output=(
|
test_output=("video_url", "https://example.com/screenshot.mp4"),
|
||||||
"video_url",
|
|
||||||
lambda x: x.startswith(("workspace://", "data:")),
|
|
||||||
),
|
|
||||||
test_mock={
|
test_mock={
|
||||||
"create_webhook": lambda *args, **kwargs: (
|
"create_webhook": lambda *args, **kwargs: (
|
||||||
"test_uuid",
|
"test_uuid",
|
||||||
@@ -672,21 +635,14 @@ class AIScreenshotToVideoAdBlock(Block):
|
|||||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||||
"check_video_status": lambda *args, **kwargs: {
|
"check_video_status": lambda *args, **kwargs: {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"videoUrl": "data:video/mp4;base64,AAAA",
|
"videoUrl": "https://example.com/screenshot.mp4",
|
||||||
},
|
},
|
||||||
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: APIKeyCredentials,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
webhook_token, webhook_url = await self.create_webhook()
|
webhook_token, webhook_url = await self.create_webhook()
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -754,10 +710,4 @@ class AIScreenshotToVideoAdBlock(Block):
|
|||||||
raise RuntimeError("Failed to create video: No project ID returned")
|
raise RuntimeError("Failed to create video: No project ID returned")
|
||||||
|
|
||||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||||
# Store the generated video to the user's workspace for persistence
|
yield "video_url", video_url
|
||||||
stored_url = await store_media_file(
|
|
||||||
file=MediaFileType(video_url),
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
yield "video_url", stored_url
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.sdk import (
|
from backend.sdk import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
Block,
|
Block,
|
||||||
@@ -18,8 +17,6 @@ from backend.sdk import (
|
|||||||
Requests,
|
Requests,
|
||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.util.file import store_media_file
|
|
||||||
from backend.util.type import MediaFileType
|
|
||||||
|
|
||||||
from ._config import bannerbear
|
from ._config import bannerbear
|
||||||
|
|
||||||
@@ -138,17 +135,15 @@ class BannerbearTextOverlayBlock(Block):
|
|||||||
},
|
},
|
||||||
test_output=[
|
test_output=[
|
||||||
("success", True),
|
("success", True),
|
||||||
# Output will be a workspace ref or data URI depending on context
|
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
|
||||||
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
|
|
||||||
("uid", "test-uid-123"),
|
("uid", "test-uid-123"),
|
||||||
("status", "completed"),
|
("status", "completed"),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
# Use data URI to avoid HTTP requests during tests
|
|
||||||
"_make_api_request": lambda *args, **kwargs: {
|
"_make_api_request": lambda *args, **kwargs: {
|
||||||
"uid": "test-uid-123",
|
"uid": "test-uid-123",
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"image_url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAABAAEBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+v//Z",
|
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -182,12 +177,7 @@ class BannerbearTextOverlayBlock(Block):
|
|||||||
raise Exception(error_msg)
|
raise Exception(error_msg)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: APIKeyCredentials,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Build the modifications array
|
# Build the modifications array
|
||||||
modifications = []
|
modifications = []
|
||||||
@@ -244,18 +234,6 @@ class BannerbearTextOverlayBlock(Block):
|
|||||||
|
|
||||||
# Synchronous request - image should be ready
|
# Synchronous request - image should be ready
|
||||||
yield "success", True
|
yield "success", True
|
||||||
|
yield "image_url", data.get("image_url", "")
|
||||||
# Store the generated image to workspace for persistence
|
|
||||||
image_url = data.get("image_url", "")
|
|
||||||
if image_url:
|
|
||||||
stored_url = await store_media_file(
|
|
||||||
file=MediaFileType(image_url),
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
yield "image_url", stored_url
|
|
||||||
else:
|
|
||||||
yield "image_url", ""
|
|
||||||
|
|
||||||
yield "uid", data.get("uid", "")
|
yield "uid", data.get("uid", "")
|
||||||
yield "status", data.get("status", "completed")
|
yield "status", data.get("status", "completed")
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from backend.data.block import (
|
|||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.type import MediaFileType, convert
|
from backend.util.type import MediaFileType, convert
|
||||||
@@ -18,10 +17,10 @@ from backend.util.type import MediaFileType, convert
|
|||||||
class FileStoreBlock(Block):
|
class FileStoreBlock(Block):
|
||||||
class Input(BlockSchemaInput):
|
class Input(BlockSchemaInput):
|
||||||
file_in: MediaFileType = SchemaField(
|
file_in: MediaFileType = SchemaField(
|
||||||
description="The file to download and store. Can be a URL (https://...), data URI, or local path."
|
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
||||||
)
|
)
|
||||||
base_64: bool = SchemaField(
|
base_64: bool = SchemaField(
|
||||||
description="Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks).",
|
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
|
||||||
default=False,
|
default=False,
|
||||||
advanced=True,
|
advanced=True,
|
||||||
title="Produce Base64 Output",
|
title="Produce Base64 Output",
|
||||||
@@ -29,18 +28,13 @@ class FileStoreBlock(Block):
|
|||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
file_out: MediaFileType = SchemaField(
|
file_out: MediaFileType = SchemaField(
|
||||||
description="Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks."
|
description="The relative path to the stored file in the temporary directory."
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
|
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
|
||||||
description=(
|
description="Stores the input file in the temporary directory.",
|
||||||
"Downloads and stores a file from a URL, data URI, or local path. "
|
|
||||||
"Use this to fetch images, documents, or other files for processing. "
|
|
||||||
"In CoPilot: saves to workspace (use list_workspace_files to see it). "
|
|
||||||
"In graphs: outputs a data URI to pass to other blocks."
|
|
||||||
),
|
|
||||||
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
|
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
|
||||||
input_schema=FileStoreBlock.Input,
|
input_schema=FileStoreBlock.Input,
|
||||||
output_schema=FileStoreBlock.Output,
|
output_schema=FileStoreBlock.Output,
|
||||||
@@ -51,18 +45,15 @@ class FileStoreBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Determine return format based on user preference
|
|
||||||
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
|
|
||||||
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
|
|
||||||
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
|
|
||||||
|
|
||||||
yield "file_out", await store_media_file(
|
yield "file_out", await store_media_file(
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=input_data.file_in,
|
file=input_data.file_in,
|
||||||
execution_context=execution_context,
|
user_id=user_id,
|
||||||
return_format=return_format,
|
return_content=input_data.base_64,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import APIKeyCredentials, SchemaField
|
from backend.data.model import APIKeyCredentials, SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
@@ -667,7 +666,8 @@ class SendDiscordFileBlock(Block):
|
|||||||
file: MediaFileType,
|
file: MediaFileType,
|
||||||
filename: str,
|
filename: str,
|
||||||
message_content: str,
|
message_content: str,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
intents.guilds = True
|
intents.guilds = True
|
||||||
@@ -731,9 +731,10 @@ class SendDiscordFileBlock(Block):
|
|||||||
# Local file path - read from stored media file
|
# Local file path - read from stored media file
|
||||||
# This would be a path from a previous block's output
|
# This would be a path from a previous block's output
|
||||||
stored_file = await store_media_file(
|
stored_file = await store_media_file(
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=file,
|
file=file,
|
||||||
execution_context=execution_context,
|
user_id=user_id,
|
||||||
return_format="for_external_api", # Get content to send to Discord
|
return_content=True, # Get as data URI
|
||||||
)
|
)
|
||||||
# Now process as data URI
|
# Now process as data URI
|
||||||
header, encoded = stored_file.split(",", 1)
|
header, encoded = stored_file.split(",", 1)
|
||||||
@@ -780,7 +781,8 @@ class SendDiscordFileBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
@@ -791,7 +793,8 @@ class SendDiscordFileBlock(Block):
|
|||||||
file=input_data.file,
|
file=input_data.file,
|
||||||
filename=input_data.filename,
|
filename=input_data.filename,
|
||||||
message_content=input_data.message_content,
|
message_content=input_data.message_content,
|
||||||
execution_context=execution_context,
|
graph_exec_id=graph_exec_id,
|
||||||
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "status", result.get("status", "Unknown error")
|
yield "status", result.get("status", "Unknown error")
|
||||||
|
|||||||
@@ -17,11 +17,8 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
|
||||||
from backend.util.request import ClientResponseError, Requests
|
from backend.util.request import ClientResponseError, Requests
|
||||||
from backend.util.type import MediaFileType
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -67,13 +64,9 @@ class AIVideoGeneratorBlock(Block):
|
|||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
test_output=[
|
test_output=[("video_url", "https://fal.media/files/example/video.mp4")],
|
||||||
# Output will be a workspace ref or data URI depending on context
|
|
||||||
("video_url", lambda x: x.startswith(("workspace://", "data:"))),
|
|
||||||
],
|
|
||||||
test_mock={
|
test_mock={
|
||||||
# Use data URI to avoid HTTP requests during tests
|
"generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4"
|
||||||
"generate_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA"
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -215,22 +208,11 @@ class AIVideoGeneratorBlock(Block):
|
|||||||
raise RuntimeError(f"API request failed: {str(e)}")
|
raise RuntimeError(f"API request failed: {str(e)}")
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self, input_data: Input, *, credentials: FalCredentials, **kwargs
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: FalCredentials,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
video_url = await self.generate_video(input_data, credentials)
|
video_url = await self.generate_video(input_data, credentials)
|
||||||
# Store the generated video to the user's workspace for persistence
|
yield "video_url", video_url
|
||||||
stored_url = await store_media_file(
|
|
||||||
file=MediaFileType(video_url),
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
yield "video_url", stored_url
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
yield "error", error_message
|
yield "error", error_message
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -122,12 +121,10 @@ class AIImageEditorBlock(Block):
|
|||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
test_output=[
|
test_output=[
|
||||||
# Output will be a workspace ref or data URI depending on context
|
("output_image", "https://replicate.com/output/edited-image.png"),
|
||||||
("output_image", lambda x: x.startswith(("workspace://", "data:"))),
|
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
# Use data URI to avoid HTTP requests during tests
|
"run_model": lambda *args, **kwargs: "https://replicate.com/output/edited-image.png",
|
||||||
"run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
|
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
)
|
)
|
||||||
@@ -137,7 +134,8 @@ class AIImageEditorBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
result = await self.run_model(
|
result = await self.run_model(
|
||||||
@@ -146,25 +144,20 @@ class AIImageEditorBlock(Block):
|
|||||||
prompt=input_data.prompt,
|
prompt=input_data.prompt,
|
||||||
input_image_b64=(
|
input_image_b64=(
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=input_data.input_image,
|
file=input_data.input_image,
|
||||||
execution_context=execution_context,
|
user_id=user_id,
|
||||||
return_format="for_external_api", # Get content for Replicate API
|
return_content=True,
|
||||||
)
|
)
|
||||||
if input_data.input_image
|
if input_data.input_image
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
aspect_ratio=input_data.aspect_ratio.value,
|
aspect_ratio=input_data.aspect_ratio.value,
|
||||||
seed=input_data.seed,
|
seed=input_data.seed,
|
||||||
user_id=execution_context.user_id or "",
|
user_id=user_id,
|
||||||
graph_exec_id=execution_context.graph_exec_id or "",
|
graph_exec_id=graph_exec_id,
|
||||||
)
|
)
|
||||||
# Store the generated image to the user's workspace for persistence
|
yield "output_image", result
|
||||||
stored_url = await store_media_file(
|
|
||||||
file=result,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
yield "output_image", stored_url
|
|
||||||
|
|
||||||
async def run_model(
|
async def run_model(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
@@ -96,7 +95,8 @@ def _make_mime_text(
|
|||||||
|
|
||||||
async def create_mime_message(
|
async def create_mime_message(
|
||||||
input_data,
|
input_data,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
||||||
|
|
||||||
@@ -117,12 +117,12 @@ async def create_mime_message(
|
|||||||
if input_data.attachments:
|
if input_data.attachments:
|
||||||
for attach in input_data.attachments:
|
for attach in input_data.attachments:
|
||||||
local_path = await store_media_file(
|
local_path = await store_media_file(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=attach,
|
file=attach,
|
||||||
execution_context=execution_context,
|
return_content=False,
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
)
|
||||||
assert execution_context.graph_exec_id # Validated by store_media_file
|
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||||
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
|
||||||
part = MIMEBase("application", "octet-stream")
|
part = MIMEBase("application", "octet-stream")
|
||||||
with open(abs_path, "rb") as f:
|
with open(abs_path, "rb") as f:
|
||||||
part.set_payload(f.read())
|
part.set_payload(f.read())
|
||||||
@@ -582,25 +582,27 @@ class GmailSendBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
result = await self._send_email(
|
result = await self._send_email(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
execution_context,
|
graph_exec_id,
|
||||||
|
user_id,
|
||||||
)
|
)
|
||||||
yield "result", result
|
yield "result", result
|
||||||
|
|
||||||
async def _send_email(
|
async def _send_email(
|
||||||
self, service, input_data: Input, execution_context: ExecutionContext
|
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not input_data.to or not input_data.subject or not input_data.body:
|
if not input_data.to or not input_data.subject or not input_data.body:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"At least one recipient, subject, and body are required for sending an email"
|
"At least one recipient, subject, and body are required for sending an email"
|
||||||
)
|
)
|
||||||
raw_message = await create_mime_message(input_data, execution_context)
|
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
|
||||||
sent_message = await asyncio.to_thread(
|
sent_message = await asyncio.to_thread(
|
||||||
lambda: service.users()
|
lambda: service.users()
|
||||||
.messages()
|
.messages()
|
||||||
@@ -690,28 +692,30 @@ class GmailCreateDraftBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
result = await self._create_draft(
|
result = await self._create_draft(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
execution_context,
|
graph_exec_id,
|
||||||
|
user_id,
|
||||||
)
|
)
|
||||||
yield "result", GmailDraftResult(
|
yield "result", GmailDraftResult(
|
||||||
id=result["id"], message_id=result["message"]["id"], status="draft_created"
|
id=result["id"], message_id=result["message"]["id"], status="draft_created"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _create_draft(
|
async def _create_draft(
|
||||||
self, service, input_data: Input, execution_context: ExecutionContext
|
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not input_data.to or not input_data.subject:
|
if not input_data.to or not input_data.subject:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"At least one recipient and subject are required for creating a draft"
|
"At least one recipient and subject are required for creating a draft"
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_message = await create_mime_message(input_data, execution_context)
|
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
|
||||||
draft = await asyncio.to_thread(
|
draft = await asyncio.to_thread(
|
||||||
lambda: service.users()
|
lambda: service.users()
|
||||||
.drafts()
|
.drafts()
|
||||||
@@ -1096,7 +1100,7 @@ class GmailGetThreadBlock(GmailBase):
|
|||||||
|
|
||||||
|
|
||||||
async def _build_reply_message(
|
async def _build_reply_message(
|
||||||
service, input_data, execution_context: ExecutionContext
|
service, input_data, graph_exec_id: str, user_id: str
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Builds a reply MIME message for Gmail threads.
|
Builds a reply MIME message for Gmail threads.
|
||||||
@@ -1186,12 +1190,12 @@ async def _build_reply_message(
|
|||||||
# Handle attachments
|
# Handle attachments
|
||||||
for attach in input_data.attachments:
|
for attach in input_data.attachments:
|
||||||
local_path = await store_media_file(
|
local_path = await store_media_file(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=attach,
|
file=attach,
|
||||||
execution_context=execution_context,
|
return_content=False,
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
)
|
||||||
assert execution_context.graph_exec_id # Validated by store_media_file
|
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||||
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
|
||||||
part = MIMEBase("application", "octet-stream")
|
part = MIMEBase("application", "octet-stream")
|
||||||
with open(abs_path, "rb") as f:
|
with open(abs_path, "rb") as f:
|
||||||
part.set_payload(f.read())
|
part.set_payload(f.read())
|
||||||
@@ -1307,14 +1311,16 @@ class GmailReplyBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
message = await self._reply(
|
message = await self._reply(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
execution_context,
|
graph_exec_id,
|
||||||
|
user_id,
|
||||||
)
|
)
|
||||||
yield "messageId", message["id"]
|
yield "messageId", message["id"]
|
||||||
yield "threadId", message.get("threadId", input_data.threadId)
|
yield "threadId", message.get("threadId", input_data.threadId)
|
||||||
@@ -1337,11 +1343,11 @@ class GmailReplyBlock(GmailBase):
|
|||||||
yield "email", email
|
yield "email", email
|
||||||
|
|
||||||
async def _reply(
|
async def _reply(
|
||||||
self, service, input_data: Input, execution_context: ExecutionContext
|
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# Build the reply message using the shared helper
|
# Build the reply message using the shared helper
|
||||||
raw, thread_id = await _build_reply_message(
|
raw, thread_id = await _build_reply_message(
|
||||||
service, input_data, execution_context
|
service, input_data, graph_exec_id, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send the message
|
# Send the message
|
||||||
@@ -1435,14 +1441,16 @@ class GmailDraftReplyBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
draft = await self._create_draft_reply(
|
draft = await self._create_draft_reply(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
execution_context,
|
graph_exec_id,
|
||||||
|
user_id,
|
||||||
)
|
)
|
||||||
yield "draftId", draft["id"]
|
yield "draftId", draft["id"]
|
||||||
yield "messageId", draft["message"]["id"]
|
yield "messageId", draft["message"]["id"]
|
||||||
@@ -1450,11 +1458,11 @@ class GmailDraftReplyBlock(GmailBase):
|
|||||||
yield "status", "draft_created"
|
yield "status", "draft_created"
|
||||||
|
|
||||||
async def _create_draft_reply(
|
async def _create_draft_reply(
|
||||||
self, service, input_data: Input, execution_context: ExecutionContext
|
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# Build the reply message using the shared helper
|
# Build the reply message using the shared helper
|
||||||
raw, thread_id = await _build_reply_message(
|
raw, thread_id = await _build_reply_message(
|
||||||
service, input_data, execution_context
|
service, input_data, graph_exec_id, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create draft with proper thread association
|
# Create draft with proper thread association
|
||||||
@@ -1621,21 +1629,23 @@ class GmailForwardBlock(GmailBase):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: GoogleCredentials,
|
credentials: GoogleCredentials,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
service = self._build_service(credentials, **kwargs)
|
service = self._build_service(credentials, **kwargs)
|
||||||
result = await self._forward_message(
|
result = await self._forward_message(
|
||||||
service,
|
service,
|
||||||
input_data,
|
input_data,
|
||||||
execution_context,
|
graph_exec_id,
|
||||||
|
user_id,
|
||||||
)
|
)
|
||||||
yield "messageId", result["id"]
|
yield "messageId", result["id"]
|
||||||
yield "threadId", result.get("threadId", "")
|
yield "threadId", result.get("threadId", "")
|
||||||
yield "status", "forwarded"
|
yield "status", "forwarded"
|
||||||
|
|
||||||
async def _forward_message(
|
async def _forward_message(
|
||||||
self, service, input_data: Input, execution_context: ExecutionContext
|
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if not input_data.to:
|
if not input_data.to:
|
||||||
raise ValueError("At least one recipient is required for forwarding")
|
raise ValueError("At least one recipient is required for forwarding")
|
||||||
@@ -1717,12 +1727,12 @@ To: {original_to}
|
|||||||
# Add any additional attachments
|
# Add any additional attachments
|
||||||
for attach in input_data.additionalAttachments:
|
for attach in input_data.additionalAttachments:
|
||||||
local_path = await store_media_file(
|
local_path = await store_media_file(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=attach,
|
file=attach,
|
||||||
execution_context=execution_context,
|
return_content=False,
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
)
|
||||||
assert execution_context.graph_exec_id # Validated by store_media_file
|
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||||
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
|
||||||
part = MIMEBase("application", "octet-stream")
|
part = MIMEBase("application", "octet-stream")
|
||||||
with open(abs_path, "rb") as f:
|
with open(abs_path, "rb") as f:
|
||||||
part.set_payload(f.read())
|
part.set_payload(f.read())
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
@@ -117,9 +116,10 @@ class SendWebRequestBlock(Block):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _prepare_files(
|
async def _prepare_files(
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
files_name: str,
|
files_name: str,
|
||||||
files: list[MediaFileType],
|
files: list[MediaFileType],
|
||||||
|
user_id: str,
|
||||||
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
|
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
|
||||||
"""
|
"""
|
||||||
Prepare files for the request by storing them and reading their content.
|
Prepare files for the request by storing them and reading their content.
|
||||||
@@ -127,16 +127,11 @@ class SendWebRequestBlock(Block):
|
|||||||
(files_name, (filename, BytesIO, mime_type))
|
(files_name, (filename, BytesIO, mime_type))
|
||||||
"""
|
"""
|
||||||
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
if graph_exec_id is None:
|
|
||||||
raise ValueError("graph_exec_id is required for file operations")
|
|
||||||
|
|
||||||
for media in files:
|
for media in files:
|
||||||
# Normalise to a list so we can repeat the same key
|
# Normalise to a list so we can repeat the same key
|
||||||
rel_path = await store_media_file(
|
rel_path = await store_media_file(
|
||||||
file=media,
|
graph_exec_id, media, user_id, return_content=False
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
)
|
||||||
abs_path = get_exec_file_path(graph_exec_id, rel_path)
|
abs_path = get_exec_file_path(graph_exec_id, rel_path)
|
||||||
async with aiofiles.open(abs_path, "rb") as f:
|
async with aiofiles.open(abs_path, "rb") as f:
|
||||||
@@ -148,7 +143,7 @@ class SendWebRequestBlock(Block):
|
|||||||
return files_payload
|
return files_payload
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, execution_context: ExecutionContext, **kwargs
|
self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# ─── Parse/normalise body ────────────────────────────────────
|
# ─── Parse/normalise body ────────────────────────────────────
|
||||||
body = input_data.body
|
body = input_data.body
|
||||||
@@ -179,7 +174,7 @@ class SendWebRequestBlock(Block):
|
|||||||
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
||||||
if use_files:
|
if use_files:
|
||||||
files_payload = await self._prepare_files(
|
files_payload = await self._prepare_files(
|
||||||
execution_context, input_data.files_name, input_data.files
|
graph_exec_id, input_data.files_name, input_data.files, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enforce body format rules
|
# Enforce body format rules
|
||||||
@@ -243,8 +238,9 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
credentials: HostScopedCredentials,
|
credentials: HostScopedCredentials,
|
||||||
|
user_id: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
||||||
@@ -275,6 +271,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
|||||||
|
|
||||||
# Use parent class run method
|
# Use parent class run method
|
||||||
async for output_name, output_data in super().run(
|
async for output_name, output_data in super().run(
|
||||||
base_input, execution_context=execution_context, **kwargs
|
base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs
|
||||||
):
|
):
|
||||||
yield output_name, output_data
|
yield output_name, output_data
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockType,
|
BlockType,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.mock import MockObject
|
from backend.util.mock import MockObject
|
||||||
@@ -463,21 +462,18 @@ class AgentFileInputBlock(AgentInputBlock):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
if not input_data.value:
|
if not input_data.value:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Determine return format based on user preference
|
|
||||||
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
|
|
||||||
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
|
|
||||||
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
|
|
||||||
|
|
||||||
yield "result", await store_media_file(
|
yield "result", await store_media_file(
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=input_data.value,
|
file=input_data.value,
|
||||||
execution_context=execution_context,
|
user_id=user_id,
|
||||||
return_format=return_format,
|
return_content=input_data.base_64,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
from moviepy.video.fx.Loop import Loop
|
from moviepy.video.fx.Loop import Loop
|
||||||
@@ -13,7 +13,6 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
@@ -47,19 +46,18 @@ class MediaDurationBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# 1) Store the input media locally
|
# 1) Store the input media locally
|
||||||
local_media_path = await store_media_file(
|
local_media_path = await store_media_file(
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=input_data.media_in,
|
file=input_data.media_in,
|
||||||
execution_context=execution_context,
|
user_id=user_id,
|
||||||
return_format="for_local_processing",
|
return_content=False,
|
||||||
)
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
media_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_media_path
|
|
||||||
)
|
)
|
||||||
|
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
|
||||||
|
|
||||||
# 2) Load the clip
|
# 2) Load the clip
|
||||||
if input_data.is_video:
|
if input_data.is_video:
|
||||||
@@ -90,6 +88,10 @@ class LoopVideoBlock(Block):
|
|||||||
default=None,
|
default=None,
|
||||||
ge=1,
|
ge=1,
|
||||||
)
|
)
|
||||||
|
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
||||||
|
description="How to return the output video. Either a relative path or base64 data URI.",
|
||||||
|
default="file_path",
|
||||||
|
)
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
video_out: str = SchemaField(
|
video_out: str = SchemaField(
|
||||||
@@ -109,19 +111,17 @@ class LoopVideoBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
execution_context: ExecutionContext,
|
node_exec_id: str,
|
||||||
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
assert execution_context.node_exec_id is not None
|
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
node_exec_id = execution_context.node_exec_id
|
|
||||||
|
|
||||||
# 1) Store the input video locally
|
# 1) Store the input video locally
|
||||||
local_video_path = await store_media_file(
|
local_video_path = await store_media_file(
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=input_data.video_in,
|
file=input_data.video_in,
|
||||||
execution_context=execution_context,
|
user_id=user_id,
|
||||||
return_format="for_local_processing",
|
return_content=False,
|
||||||
)
|
)
|
||||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||||
|
|
||||||
@@ -149,11 +149,12 @@ class LoopVideoBlock(Block):
|
|||||||
looped_clip = looped_clip.with_audio(clip.audio)
|
looped_clip = looped_clip.with_audio(clip.audio)
|
||||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||||
|
|
||||||
# Return output - for_block_output returns workspace:// if available, else data URI
|
# Return as data URI
|
||||||
video_out = await store_media_file(
|
video_out = await store_media_file(
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=output_filename,
|
file=output_filename,
|
||||||
execution_context=execution_context,
|
user_id=user_id,
|
||||||
return_format="for_block_output",
|
return_content=input_data.output_return_type == "data_uri",
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "video_out", video_out
|
yield "video_out", video_out
|
||||||
@@ -176,6 +177,10 @@ class AddAudioToVideoBlock(Block):
|
|||||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||||
default=1.0,
|
default=1.0,
|
||||||
)
|
)
|
||||||
|
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
||||||
|
description="Return the final output as a relative path or base64 data URI.",
|
||||||
|
default="file_path",
|
||||||
|
)
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
video_out: MediaFileType = SchemaField(
|
video_out: MediaFileType = SchemaField(
|
||||||
@@ -195,24 +200,23 @@ class AddAudioToVideoBlock(Block):
|
|||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
execution_context: ExecutionContext,
|
node_exec_id: str,
|
||||||
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
assert execution_context.node_exec_id is not None
|
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
node_exec_id = execution_context.node_exec_id
|
|
||||||
|
|
||||||
# 1) Store the inputs locally
|
# 1) Store the inputs locally
|
||||||
local_video_path = await store_media_file(
|
local_video_path = await store_media_file(
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=input_data.video_in,
|
file=input_data.video_in,
|
||||||
execution_context=execution_context,
|
user_id=user_id,
|
||||||
return_format="for_local_processing",
|
return_content=False,
|
||||||
)
|
)
|
||||||
local_audio_path = await store_media_file(
|
local_audio_path = await store_media_file(
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=input_data.audio_in,
|
file=input_data.audio_in,
|
||||||
execution_context=execution_context,
|
user_id=user_id,
|
||||||
return_format="for_local_processing",
|
return_content=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
||||||
@@ -236,11 +240,12 @@ class AddAudioToVideoBlock(Block):
|
|||||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||||
|
|
||||||
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
# 5) Return either path or data URI
|
||||||
video_out = await store_media_file(
|
video_out = await store_media_file(
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=output_filename,
|
file=output_filename,
|
||||||
execution_context=execution_context,
|
user_id=user_id,
|
||||||
return_format="for_block_output",
|
return_content=input_data.output_return_type == "data_uri",
|
||||||
)
|
)
|
||||||
|
|
||||||
yield "video_out", video_out
|
yield "video_out", video_out
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -113,7 +112,8 @@ class ScreenshotWebPageBlock(Block):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def take_screenshot(
|
async def take_screenshot(
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
url: str,
|
url: str,
|
||||||
viewport_width: int,
|
viewport_width: int,
|
||||||
viewport_height: int,
|
viewport_height: int,
|
||||||
@@ -155,11 +155,12 @@ class ScreenshotWebPageBlock(Block):
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"image": await store_media_file(
|
"image": await store_media_file(
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=MediaFileType(
|
file=MediaFileType(
|
||||||
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
|
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
|
||||||
),
|
),
|
||||||
execution_context=execution_context,
|
user_id=user_id,
|
||||||
return_format="for_block_output",
|
return_content=True,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,13 +169,15 @@ class ScreenshotWebPageBlock(Block):
|
|||||||
input_data: Input,
|
input_data: Input,
|
||||||
*,
|
*,
|
||||||
credentials: APIKeyCredentials,
|
credentials: APIKeyCredentials,
|
||||||
execution_context: ExecutionContext,
|
graph_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
try:
|
||||||
screenshot_data = await self.take_screenshot(
|
screenshot_data = await self.take_screenshot(
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
execution_context=execution_context,
|
graph_exec_id=graph_exec_id,
|
||||||
|
user_id=user_id,
|
||||||
url=input_data.url,
|
url=input_data.url,
|
||||||
viewport_width=input_data.viewport_width,
|
viewport_width=input_data.viewport_width,
|
||||||
viewport_height=input_data.viewport_height,
|
viewport_height=input_data.viewport_height,
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import ContributorDetails, SchemaField
|
from backend.data.model import ContributorDetails, SchemaField
|
||||||
from backend.util.file import get_exec_file_path, store_media_file
|
from backend.util.file import get_exec_file_path, store_media_file
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
@@ -99,7 +98,7 @@ class ReadSpreadsheetBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
|
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
import csv
|
import csv
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
@@ -107,16 +106,14 @@ class ReadSpreadsheetBlock(Block):
|
|||||||
# Determine data source - prefer file_input if provided, otherwise use contents
|
# Determine data source - prefer file_input if provided, otherwise use contents
|
||||||
if input_data.file_input:
|
if input_data.file_input:
|
||||||
stored_file_path = await store_media_file(
|
stored_file_path = await store_media_file(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=input_data.file_input,
|
file=input_data.file_input,
|
||||||
execution_context=execution_context,
|
return_content=False,
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get full file path
|
# Get full file path
|
||||||
assert execution_context.graph_exec_id # Validated by store_media_file
|
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
||||||
file_path = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, stored_file_path
|
|
||||||
)
|
|
||||||
if not Path(file_path).exists():
|
if not Path(file_path).exists():
|
||||||
raise ValueError(f"File does not exist: {file_path}")
|
raise ValueError(f"File does not exist: {file_path}")
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
APIKeyCredentials,
|
APIKeyCredentials,
|
||||||
CredentialsField,
|
CredentialsField,
|
||||||
@@ -18,9 +17,7 @@ from backend.data.model import (
|
|||||||
SchemaField,
|
SchemaField,
|
||||||
)
|
)
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.file import store_media_file
|
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
from backend.util.type import MediaFileType
|
|
||||||
|
|
||||||
TEST_CREDENTIALS = APIKeyCredentials(
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
@@ -105,7 +102,7 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
test_output=[
|
test_output=[
|
||||||
(
|
(
|
||||||
"video_url",
|
"video_url",
|
||||||
lambda x: x.startswith(("workspace://", "data:")),
|
"https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
@@ -113,10 +110,9 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
|
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
|
||||||
"status": "created",
|
"status": "created",
|
||||||
},
|
},
|
||||||
# Use data URI to avoid HTTP requests during tests
|
|
||||||
"get_clip_status": lambda *args, **kwargs: {
|
"get_clip_status": lambda *args, **kwargs: {
|
||||||
"status": "done",
|
"status": "done",
|
||||||
"result_url": "data:video/mp4;base64,AAAA",
|
"result_url": "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS,
|
test_credentials=TEST_CREDENTIALS,
|
||||||
@@ -142,12 +138,7 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
credentials: APIKeyCredentials,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Create the clip
|
# Create the clip
|
||||||
payload = {
|
payload = {
|
||||||
@@ -174,14 +165,7 @@ class CreateTalkingAvatarVideoBlock(Block):
|
|||||||
for _ in range(input_data.max_polling_attempts):
|
for _ in range(input_data.max_polling_attempts):
|
||||||
status_response = await self.get_clip_status(credentials.api_key, clip_id)
|
status_response = await self.get_clip_status(credentials.api_key, clip_id)
|
||||||
if status_response["status"] == "done":
|
if status_response["status"] == "done":
|
||||||
# Store the generated video to the user's workspace for persistence
|
yield "video_url", status_response["result_url"]
|
||||||
video_url = status_response["result_url"]
|
|
||||||
stored_url = await store_media_file(
|
|
||||||
file=MediaFileType(video_url),
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
yield "video_url", stored_url
|
|
||||||
return
|
return
|
||||||
elif status_response["status"] == "error":
|
elif status_response["status"] == "error":
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from backend.blocks.iteration import StepThroughItemsBlock
|
|||||||
from backend.blocks.llm import AITextSummarizerBlock
|
from backend.blocks.llm import AITextSummarizerBlock
|
||||||
from backend.blocks.text import ExtractTextInformationBlock
|
from backend.blocks.text import ExtractTextInformationBlock
|
||||||
from backend.blocks.xml_parser import XMLParserBlock
|
from backend.blocks.xml_parser import XMLParserBlock
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
@@ -234,12 +233,9 @@ class TestStoreMediaFileSecurity:
|
|||||||
|
|
||||||
with pytest.raises(ValueError, match="File too large"):
|
with pytest.raises(ValueError, match="File too large"):
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
|
graph_exec_id="test",
|
||||||
file=MediaFileType(large_data_uri),
|
file=MediaFileType(large_data_uri),
|
||||||
execution_context=ExecutionContext(
|
user_id="test_user",
|
||||||
user_id="test_user",
|
|
||||||
graph_exec_id="test",
|
|
||||||
),
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("backend.util.file.Path")
|
@patch("backend.util.file.Path")
|
||||||
@@ -274,12 +270,9 @@ class TestStoreMediaFileSecurity:
|
|||||||
# Should raise an error when directory size exceeds limit
|
# Should raise an error when directory size exceeds limit
|
||||||
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
|
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
|
graph_exec_id="test",
|
||||||
file=MediaFileType(
|
file=MediaFileType(
|
||||||
"data:text/plain;base64,dGVzdA=="
|
"data:text/plain;base64,dGVzdA=="
|
||||||
), # Small test file
|
), # Small test file
|
||||||
execution_context=ExecutionContext(
|
user_id="test_user",
|
||||||
user_id="test_user",
|
|
||||||
graph_exec_id="test",
|
|
||||||
),
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,22 +11,10 @@ from backend.blocks.http import (
|
|||||||
HttpMethod,
|
HttpMethod,
|
||||||
SendAuthenticatedWebRequestBlock,
|
SendAuthenticatedWebRequestBlock,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import HostScopedCredentials
|
from backend.data.model import HostScopedCredentials
|
||||||
from backend.util.request import Response
|
from backend.util.request import Response
|
||||||
|
|
||||||
|
|
||||||
def make_test_context(
|
|
||||||
graph_exec_id: str = "test-exec-id",
|
|
||||||
user_id: str = "test-user-id",
|
|
||||||
) -> ExecutionContext:
|
|
||||||
"""Helper to create test ExecutionContext."""
|
|
||||||
return ExecutionContext(
|
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestHttpBlockWithHostScopedCredentials:
|
class TestHttpBlockWithHostScopedCredentials:
|
||||||
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
||||||
|
|
||||||
@@ -117,7 +105,8 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=exact_match_credentials,
|
credentials=exact_match_credentials,
|
||||||
execution_context=make_test_context(),
|
graph_exec_id="test-exec-id",
|
||||||
|
user_id="test-user-id",
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -172,7 +161,8 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=wildcard_credentials,
|
credentials=wildcard_credentials,
|
||||||
execution_context=make_test_context(),
|
graph_exec_id="test-exec-id",
|
||||||
|
user_id="test-user-id",
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -218,7 +208,8 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=non_matching_credentials,
|
credentials=non_matching_credentials,
|
||||||
execution_context=make_test_context(),
|
graph_exec_id="test-exec-id",
|
||||||
|
user_id="test-user-id",
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -267,7 +258,8 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=exact_match_credentials,
|
credentials=exact_match_credentials,
|
||||||
execution_context=make_test_context(),
|
graph_exec_id="test-exec-id",
|
||||||
|
user_id="test-user-id",
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -326,7 +318,8 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=auto_discovered_creds, # Execution manager found these
|
credentials=auto_discovered_creds, # Execution manager found these
|
||||||
execution_context=make_test_context(),
|
graph_exec_id="test-exec-id",
|
||||||
|
user_id="test-user-id",
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -389,7 +382,8 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=multi_header_creds,
|
credentials=multi_header_creds,
|
||||||
execution_context=make_test_context(),
|
graph_exec_id="test-exec-id",
|
||||||
|
user_id="test-user-id",
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
@@ -477,7 +471,8 @@ class TestHttpBlockWithHostScopedCredentials:
|
|||||||
async for output_name, output_data in http_block.run(
|
async for output_name, output_data in http_block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=test_creds,
|
credentials=test_creds,
|
||||||
execution_context=make_test_context(),
|
graph_exec_id="test-exec-id",
|
||||||
|
user_id="test-user-id",
|
||||||
):
|
):
|
||||||
result.append((output_name, output_data))
|
result.append((output_name, output_data))
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from backend.data.block import (
|
|||||||
BlockSchemaInput,
|
BlockSchemaInput,
|
||||||
BlockSchemaOutput,
|
BlockSchemaOutput,
|
||||||
)
|
)
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util import json, text
|
from backend.util import json, text
|
||||||
from backend.util.file import get_exec_file_path, store_media_file
|
from backend.util.file import get_exec_file_path, store_media_file
|
||||||
@@ -445,21 +444,18 @@ class FileReadBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
|
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
# Store the media file properly (handles URLs, data URIs, etc.)
|
# Store the media file properly (handles URLs, data URIs, etc.)
|
||||||
stored_file_path = await store_media_file(
|
stored_file_path = await store_media_file(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
file=input_data.file_input,
|
file=input_data.file_input,
|
||||||
execution_context=execution_context,
|
return_content=False,
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get full file path (graph_exec_id validated by store_media_file above)
|
# Get full file path
|
||||||
if not execution_context.graph_exec_id:
|
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
||||||
raise ValueError("execution_context.graph_exec_id is required")
|
|
||||||
file_path = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, stored_file_path
|
|
||||||
)
|
|
||||||
|
|
||||||
if not Path(file_path).exists():
|
if not Path(file_path).exists():
|
||||||
raise ValueError(f"File does not exist: {file_path}")
|
raise ValueError(f"File does not exist: {file_path}")
|
||||||
|
|||||||
@@ -83,29 +83,12 @@ class ExecutionContext(BaseModel):
|
|||||||
|
|
||||||
model_config = {"extra": "ignore"}
|
model_config = {"extra": "ignore"}
|
||||||
|
|
||||||
# Execution identity
|
|
||||||
user_id: Optional[str] = None
|
|
||||||
graph_id: Optional[str] = None
|
|
||||||
graph_exec_id: Optional[str] = None
|
|
||||||
graph_version: Optional[int] = None
|
|
||||||
node_id: Optional[str] = None
|
|
||||||
node_exec_id: Optional[str] = None
|
|
||||||
|
|
||||||
# Safety settings
|
|
||||||
human_in_the_loop_safe_mode: bool = True
|
human_in_the_loop_safe_mode: bool = True
|
||||||
sensitive_action_safe_mode: bool = False
|
sensitive_action_safe_mode: bool = False
|
||||||
|
|
||||||
# User settings
|
|
||||||
user_timezone: str = "UTC"
|
user_timezone: str = "UTC"
|
||||||
|
|
||||||
# Execution hierarchy
|
|
||||||
root_execution_id: Optional[str] = None
|
root_execution_id: Optional[str] = None
|
||||||
parent_execution_id: Optional[str] = None
|
parent_execution_id: Optional[str] = None
|
||||||
|
|
||||||
# Workspace
|
|
||||||
workspace_id: Optional[str] = None
|
|
||||||
session_id: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------- Models -------------------------- #
|
# -------------------------- Models -------------------------- #
|
||||||
|
|
||||||
|
|||||||
@@ -1,276 +0,0 @@
|
|||||||
"""
|
|
||||||
Database CRUD operations for User Workspace.
|
|
||||||
|
|
||||||
This module provides functions for managing user workspaces and workspace files.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from prisma.models import UserWorkspace, UserWorkspaceFile
|
|
||||||
from prisma.types import UserWorkspaceFileWhereInput
|
|
||||||
|
|
||||||
from backend.util.json import SafeJson
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
|
|
||||||
"""
|
|
||||||
Get user's workspace, creating one if it doesn't exist.
|
|
||||||
|
|
||||||
Uses upsert to handle race conditions when multiple concurrent requests
|
|
||||||
attempt to create a workspace for the same user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The user's ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UserWorkspace instance
|
|
||||||
"""
|
|
||||||
workspace = await UserWorkspace.prisma().upsert(
|
|
||||||
where={"userId": user_id},
|
|
||||||
data={
|
|
||||||
"create": {"userId": user_id},
|
|
||||||
"update": {}, # No updates needed if exists
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return workspace
|
|
||||||
|
|
||||||
|
|
||||||
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
|
|
||||||
"""
|
|
||||||
Get user's workspace if it exists.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The user's ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UserWorkspace instance or None
|
|
||||||
"""
|
|
||||||
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
|
||||||
|
|
||||||
|
|
||||||
async def create_workspace_file(
|
|
||||||
workspace_id: str,
|
|
||||||
file_id: str,
|
|
||||||
name: str,
|
|
||||||
path: str,
|
|
||||||
storage_path: str,
|
|
||||||
mime_type: str,
|
|
||||||
size_bytes: int,
|
|
||||||
checksum: Optional[str] = None,
|
|
||||||
metadata: Optional[dict] = None,
|
|
||||||
) -> UserWorkspaceFile:
|
|
||||||
"""
|
|
||||||
Create a new workspace file record.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace_id: The workspace ID
|
|
||||||
file_id: The file ID (same as used in storage path for consistency)
|
|
||||||
name: User-visible filename
|
|
||||||
path: Virtual path (e.g., "/documents/report.pdf")
|
|
||||||
storage_path: Actual storage path (GCS or local)
|
|
||||||
mime_type: MIME type of the file
|
|
||||||
size_bytes: File size in bytes
|
|
||||||
checksum: Optional SHA256 checksum
|
|
||||||
metadata: Optional additional metadata
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created UserWorkspaceFile instance
|
|
||||||
"""
|
|
||||||
# Normalize path to start with /
|
|
||||||
if not path.startswith("/"):
|
|
||||||
path = f"/{path}"
|
|
||||||
|
|
||||||
file = await UserWorkspaceFile.prisma().create(
|
|
||||||
data={
|
|
||||||
"id": file_id,
|
|
||||||
"workspaceId": workspace_id,
|
|
||||||
"name": name,
|
|
||||||
"path": path,
|
|
||||||
"storagePath": storage_path,
|
|
||||||
"mimeType": mime_type,
|
|
||||||
"sizeBytes": size_bytes,
|
|
||||||
"checksum": checksum,
|
|
||||||
"metadata": SafeJson(metadata or {}),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Created workspace file {file.id} at path {path} "
|
|
||||||
f"in workspace {workspace_id}"
|
|
||||||
)
|
|
||||||
return file
|
|
||||||
|
|
||||||
|
|
||||||
async def get_workspace_file(
|
|
||||||
file_id: str,
|
|
||||||
workspace_id: Optional[str] = None,
|
|
||||||
) -> Optional[UserWorkspaceFile]:
|
|
||||||
"""
|
|
||||||
Get a workspace file by ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_id: The file ID
|
|
||||||
workspace_id: Optional workspace ID for validation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UserWorkspaceFile instance or None
|
|
||||||
"""
|
|
||||||
where_clause: dict = {"id": file_id, "isDeleted": False}
|
|
||||||
if workspace_id:
|
|
||||||
where_clause["workspaceId"] = workspace_id
|
|
||||||
|
|
||||||
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_workspace_file_by_path(
|
|
||||||
workspace_id: str,
|
|
||||||
path: str,
|
|
||||||
) -> Optional[UserWorkspaceFile]:
|
|
||||||
"""
|
|
||||||
Get a workspace file by its virtual path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace_id: The workspace ID
|
|
||||||
path: Virtual path
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UserWorkspaceFile instance or None
|
|
||||||
"""
|
|
||||||
# Normalize path
|
|
||||||
if not path.startswith("/"):
|
|
||||||
path = f"/{path}"
|
|
||||||
|
|
||||||
return await UserWorkspaceFile.prisma().find_first(
|
|
||||||
where={
|
|
||||||
"workspaceId": workspace_id,
|
|
||||||
"path": path,
|
|
||||||
"isDeleted": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def list_workspace_files(
|
|
||||||
workspace_id: str,
|
|
||||||
path_prefix: Optional[str] = None,
|
|
||||||
include_deleted: bool = False,
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
offset: int = 0,
|
|
||||||
) -> list[UserWorkspaceFile]:
|
|
||||||
"""
|
|
||||||
List files in a workspace.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace_id: The workspace ID
|
|
||||||
path_prefix: Optional path prefix to filter (e.g., "/documents/")
|
|
||||||
include_deleted: Whether to include soft-deleted files
|
|
||||||
limit: Maximum number of files to return
|
|
||||||
offset: Number of files to skip
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of UserWorkspaceFile instances
|
|
||||||
"""
|
|
||||||
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
|
|
||||||
|
|
||||||
if not include_deleted:
|
|
||||||
where_clause["isDeleted"] = False
|
|
||||||
|
|
||||||
if path_prefix:
|
|
||||||
# Normalize prefix
|
|
||||||
if not path_prefix.startswith("/"):
|
|
||||||
path_prefix = f"/{path_prefix}"
|
|
||||||
where_clause["path"] = {"startswith": path_prefix}
|
|
||||||
|
|
||||||
return await UserWorkspaceFile.prisma().find_many(
|
|
||||||
where=where_clause,
|
|
||||||
order={"createdAt": "desc"},
|
|
||||||
take=limit,
|
|
||||||
skip=offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def count_workspace_files(
|
|
||||||
workspace_id: str,
|
|
||||||
path_prefix: Optional[str] = None,
|
|
||||||
include_deleted: bool = False,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Count files in a workspace.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace_id: The workspace ID
|
|
||||||
path_prefix: Optional path prefix to filter (e.g., "/sessions/abc123/")
|
|
||||||
include_deleted: Whether to include soft-deleted files
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of files
|
|
||||||
"""
|
|
||||||
where_clause: dict = {"workspaceId": workspace_id}
|
|
||||||
if not include_deleted:
|
|
||||||
where_clause["isDeleted"] = False
|
|
||||||
|
|
||||||
if path_prefix:
|
|
||||||
# Normalize prefix
|
|
||||||
if not path_prefix.startswith("/"):
|
|
||||||
path_prefix = f"/{path_prefix}"
|
|
||||||
where_clause["path"] = {"startswith": path_prefix}
|
|
||||||
|
|
||||||
return await UserWorkspaceFile.prisma().count(where=where_clause)
|
|
||||||
|
|
||||||
|
|
||||||
async def soft_delete_workspace_file(
|
|
||||||
file_id: str,
|
|
||||||
workspace_id: Optional[str] = None,
|
|
||||||
) -> Optional[UserWorkspaceFile]:
|
|
||||||
"""
|
|
||||||
Soft-delete a workspace file.
|
|
||||||
|
|
||||||
The path is modified to include a deletion timestamp to free up the original
|
|
||||||
path for new files while preserving the record for potential recovery.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_id: The file ID
|
|
||||||
workspace_id: Optional workspace ID for validation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated UserWorkspaceFile instance or None if not found
|
|
||||||
"""
|
|
||||||
# First verify the file exists and belongs to workspace
|
|
||||||
file = await get_workspace_file(file_id, workspace_id)
|
|
||||||
if file is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
deleted_at = datetime.now(timezone.utc)
|
|
||||||
# Modify path to free up the unique constraint for new files at original path
|
|
||||||
# Format: {original_path}__deleted__{timestamp}
|
|
||||||
deleted_path = f"{file.path}__deleted__{int(deleted_at.timestamp())}"
|
|
||||||
|
|
||||||
updated = await UserWorkspaceFile.prisma().update(
|
|
||||||
where={"id": file_id},
|
|
||||||
data={
|
|
||||||
"isDeleted": True,
|
|
||||||
"deletedAt": deleted_at,
|
|
||||||
"path": deleted_path,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Soft-deleted workspace file {file_id}")
|
|
||||||
return updated
|
|
||||||
|
|
||||||
|
|
||||||
async def get_workspace_total_size(workspace_id: str) -> int:
|
|
||||||
"""
|
|
||||||
Get the total size of all files in a workspace.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace_id: The workspace ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Total size in bytes
|
|
||||||
"""
|
|
||||||
files = await list_workspace_files(workspace_id)
|
|
||||||
return sum(file.sizeBytes for file in files)
|
|
||||||
@@ -236,14 +236,7 @@ async def execute_node(
|
|||||||
input_size = len(input_data_str)
|
input_size = len(input_data_str)
|
||||||
log_metadata.debug("Executed node with input", input=input_data_str)
|
log_metadata.debug("Executed node with input", input=input_data_str)
|
||||||
|
|
||||||
# Create node-specific execution context to avoid race conditions
|
|
||||||
# (multiple nodes can execute concurrently and would otherwise mutate shared state)
|
|
||||||
execution_context = execution_context.model_copy(
|
|
||||||
update={"node_id": node_id, "node_exec_id": node_exec_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Inject extra execution arguments for the blocks via kwargs
|
# Inject extra execution arguments for the blocks via kwargs
|
||||||
# Keep individual kwargs for backwards compatibility with existing blocks
|
|
||||||
extra_exec_kwargs: dict = {
|
extra_exec_kwargs: dict = {
|
||||||
"graph_id": graph_id,
|
"graph_id": graph_id,
|
||||||
"graph_version": graph_version,
|
"graph_version": graph_version,
|
||||||
|
|||||||
@@ -892,19 +892,11 @@ async def add_graph_execution(
|
|||||||
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
||||||
|
|
||||||
execution_context = ExecutionContext(
|
execution_context = ExecutionContext(
|
||||||
# Execution identity
|
|
||||||
user_id=user_id,
|
|
||||||
graph_id=graph_id,
|
|
||||||
graph_exec_id=graph_exec.id,
|
|
||||||
graph_version=graph_exec.graph_version,
|
|
||||||
# Safety settings
|
|
||||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||||
# User settings
|
|
||||||
user_timezone=(
|
user_timezone=(
|
||||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||||
),
|
),
|
||||||
# Execution hierarchy
|
|
||||||
root_execution_id=graph_exec.id,
|
root_execution_id=graph_exec.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -348,7 +348,6 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
|||||||
mock_graph_exec.id = "execution-id-123"
|
mock_graph_exec.id = "execution-id-123"
|
||||||
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
||||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||||
mock_graph_exec.graph_version = graph_version
|
|
||||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||||
|
|
||||||
# Mock the queue and event bus
|
# Mock the queue and event bus
|
||||||
@@ -435,9 +434,6 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
|||||||
# Create a second mock execution for the sanity check
|
# Create a second mock execution for the sanity check
|
||||||
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||||
mock_graph_exec_2.id = "execution-id-456"
|
mock_graph_exec_2.id = "execution-id-456"
|
||||||
mock_graph_exec_2.node_executions = []
|
|
||||||
mock_graph_exec_2.status = ExecutionStatus.QUEUED
|
|
||||||
mock_graph_exec_2.graph_version = graph_version
|
|
||||||
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
|
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||||
|
|
||||||
# Reset mocks and set up for second call
|
# Reset mocks and set up for second call
|
||||||
@@ -618,7 +614,6 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
|||||||
mock_graph_exec.id = "execution-id-123"
|
mock_graph_exec.id = "execution-id-123"
|
||||||
mock_graph_exec.node_executions = []
|
mock_graph_exec.node_executions = []
|
||||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||||
mock_graph_exec.graph_version = graph_version
|
|
||||||
|
|
||||||
# Track what's passed to to_graph_execution_entry
|
# Track what's passed to to_graph_execution_entry
|
||||||
captured_kwargs = {}
|
captured_kwargs = {}
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import aiohttp
|
|||||||
from gcloud.aio import storage as async_gcs_storage
|
from gcloud.aio import storage as async_gcs_storage
|
||||||
from google.cloud import storage as gcs_storage
|
from google.cloud import storage as gcs_storage
|
||||||
|
|
||||||
from backend.util.gcs_utils import download_with_fresh_session, generate_signed_url
|
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -252,7 +251,7 @@ class CloudStorageHandler:
|
|||||||
f"in_task: {current_task is not None}"
|
f"in_task: {current_task is not None}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
|
# Parse bucket and blob name from path
|
||||||
parts = path.split("/", 1)
|
parts = path.split("/", 1)
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
raise ValueError(f"Invalid GCS path: {path}")
|
raise ValueError(f"Invalid GCS path: {path}")
|
||||||
@@ -262,19 +261,50 @@ class CloudStorageHandler:
|
|||||||
# Authorization check
|
# Authorization check
|
||||||
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
||||||
|
|
||||||
logger.info(
|
# Use a fresh client for each download to avoid session issues
|
||||||
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
|
# This is less efficient but more reliable with the executor's event loop
|
||||||
|
logger.info("[CloudStorage] Creating fresh GCS client for download")
|
||||||
|
|
||||||
|
# Create a new session specifically for this download
|
||||||
|
session = aiohttp.ClientSession(
|
||||||
|
connector=aiohttp.TCPConnector(limit=10, force_close=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async_client = None
|
||||||
try:
|
try:
|
||||||
content = await download_with_fresh_session(bucket_name, blob_name)
|
# Create a new GCS client with the fresh session
|
||||||
|
async_client = async_gcs_storage.Storage(session=session)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Download content using the fresh client
|
||||||
|
content = await async_client.download(bucket_name, blob_name)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CloudStorage] GCS download successful - size: {len(content)} bytes"
|
f"[CloudStorage] GCS download successful - size: {len(content)} bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
await async_client.close()
|
||||||
|
await session.close()
|
||||||
|
|
||||||
return content
|
return content
|
||||||
except FileNotFoundError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Always try to clean up
|
||||||
|
if async_client is not None:
|
||||||
|
try:
|
||||||
|
await async_client.close()
|
||||||
|
except Exception as cleanup_error:
|
||||||
|
logger.warning(
|
||||||
|
f"[CloudStorage] Error closing GCS client: {cleanup_error}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await session.close()
|
||||||
|
except Exception as cleanup_error:
|
||||||
|
logger.warning(f"[CloudStorage] Error closing session: {cleanup_error}")
|
||||||
|
|
||||||
# Log the specific error for debugging
|
# Log the specific error for debugging
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[CloudStorage] GCS download failed - error: {str(e)}, "
|
f"[CloudStorage] GCS download failed - error: {str(e)}, "
|
||||||
@@ -289,6 +319,10 @@ class CloudStorageHandler:
|
|||||||
f"current_task: {current_task}, "
|
f"current_task: {current_task}, "
|
||||||
f"bucket: {bucket_name}, blob: redacted for privacy"
|
f"bucket: {bucket_name}, blob: redacted for privacy"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Convert gcloud-aio exceptions to standard ones
|
||||||
|
if "404" in str(e) or "Not Found" in str(e):
|
||||||
|
raise FileNotFoundError(f"File not found: gcs://{path}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _validate_file_access(
|
def _validate_file_access(
|
||||||
@@ -411,7 +445,8 @@ class CloudStorageHandler:
|
|||||||
graph_exec_id: str | None = None,
|
graph_exec_id: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate signed URL for GCS with authorization."""
|
"""Generate signed URL for GCS with authorization."""
|
||||||
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
|
|
||||||
|
# Parse bucket and blob name from path
|
||||||
parts = path.split("/", 1)
|
parts = path.split("/", 1)
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
raise ValueError(f"Invalid GCS path: {path}")
|
raise ValueError(f"Invalid GCS path: {path}")
|
||||||
@@ -421,11 +456,21 @@ class CloudStorageHandler:
|
|||||||
# Authorization check
|
# Authorization check
|
||||||
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
||||||
|
|
||||||
|
# Use sync client for signed URLs since gcloud-aio doesn't support them
|
||||||
sync_client = self._get_sync_gcs_client()
|
sync_client = self._get_sync_gcs_client()
|
||||||
return await generate_signed_url(
|
bucket = sync_client.bucket(bucket_name)
|
||||||
sync_client, bucket_name, blob_name, expiration_hours * 3600
|
blob = bucket.blob(blob_name)
|
||||||
|
|
||||||
|
# Generate signed URL asynchronously using sync client
|
||||||
|
url = await asyncio.to_thread(
|
||||||
|
blob.generate_signed_url,
|
||||||
|
version="v4",
|
||||||
|
expiration=datetime.now(timezone.utc) + timedelta(hours=expiration_hours),
|
||||||
|
method="GET",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return url
|
||||||
|
|
||||||
async def delete_expired_files(self, provider: str = "gcs") -> int:
|
async def delete_expired_files(self, provider: str = "gcs") -> int:
|
||||||
"""
|
"""
|
||||||
Delete files that have passed their expiration time.
|
Delete files that have passed their expiration time.
|
||||||
|
|||||||
@@ -5,26 +5,13 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Literal
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
from backend.util.settings import Config
|
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
|
|
||||||
# Return format options for store_media_file
|
|
||||||
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
|
||||||
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
|
||||||
# - "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
|
|
||||||
MediaReturnFormat = Literal[
|
|
||||||
"for_local_processing", "for_external_api", "for_block_output"
|
|
||||||
]
|
|
||||||
|
|
||||||
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
|
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
|
||||||
|
|
||||||
# Maximum filename length (conservative limit for most filesystems)
|
# Maximum filename length (conservative limit for most filesystems)
|
||||||
@@ -80,56 +67,42 @@ def clean_exec_files(graph_exec_id: str, file: str = "") -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def store_media_file(
|
async def store_media_file(
|
||||||
|
graph_exec_id: str,
|
||||||
file: MediaFileType,
|
file: MediaFileType,
|
||||||
execution_context: "ExecutionContext",
|
user_id: str,
|
||||||
*,
|
return_content: bool = False,
|
||||||
return_format: MediaReturnFormat,
|
|
||||||
) -> MediaFileType:
|
) -> MediaFileType:
|
||||||
"""
|
"""
|
||||||
Safely handle 'file' (a data URI, a URL, a workspace:// reference, or a local path
|
Safely handle 'file' (a data URI, a URL, or a local path relative to {temp}/exec_file/{exec_id}),
|
||||||
relative to {temp}/exec_file/{exec_id}), placing or verifying it under:
|
placing or verifying it under:
|
||||||
{tempdir}/exec_file/{exec_id}/...
|
{tempdir}/exec_file/{exec_id}/...
|
||||||
|
|
||||||
For each MediaFileType input:
|
If 'return_content=True', return a data URI (data:<mime>;base64,<content>).
|
||||||
- Data URI: decode and store locally
|
Otherwise, returns the file media path relative to the exec_id folder.
|
||||||
- URL: download and store locally
|
|
||||||
- workspace:// reference: read from workspace, store locally
|
|
||||||
- Local path: verify it exists in exec_file directory
|
|
||||||
|
|
||||||
Return format options:
|
For each MediaFileType type:
|
||||||
- "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
- Data URI:
|
||||||
- "for_external_api": Returns data URI (base64) - use when sending to external APIs
|
-> decode and store in a new random file in that folder
|
||||||
- "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
|
- URL:
|
||||||
|
-> download and store in that folder
|
||||||
|
- Local path:
|
||||||
|
-> interpret as relative to that folder; verify it exists
|
||||||
|
(no copying, as it's presumably already there).
|
||||||
|
We realpath-check so no symlink or '..' can escape the folder.
|
||||||
|
|
||||||
:param file: Data URI, URL, workspace://, or local (relative) path.
|
|
||||||
:param execution_context: ExecutionContext with user_id, graph_exec_id, workspace_id.
|
:param graph_exec_id: The unique ID of the graph execution.
|
||||||
:param return_format: What to return: "for_local_processing", "for_external_api", or "for_block_output".
|
:param file: Data URI, URL, or local (relative) path.
|
||||||
:return: The requested result based on return_format.
|
:param return_content: If True, return a data URI of the file content.
|
||||||
|
If False, return the *relative* path inside the exec_id folder.
|
||||||
|
:return: The requested result: data URI or relative path of the media.
|
||||||
"""
|
"""
|
||||||
# Extract values from execution_context
|
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
user_id = execution_context.user_id
|
|
||||||
|
|
||||||
if not graph_exec_id:
|
|
||||||
raise ValueError("execution_context.graph_exec_id is required")
|
|
||||||
if not user_id:
|
|
||||||
raise ValueError("execution_context.user_id is required")
|
|
||||||
|
|
||||||
# Create workspace_manager if we have workspace_id (with session scoping)
|
|
||||||
# Import here to avoid circular import (file.py → workspace.py → data → blocks → file.py)
|
|
||||||
from backend.util.workspace import WorkspaceManager
|
|
||||||
|
|
||||||
workspace_manager: WorkspaceManager | None = None
|
|
||||||
if execution_context.workspace_id:
|
|
||||||
workspace_manager = WorkspaceManager(
|
|
||||||
user_id, execution_context.workspace_id, execution_context.session_id
|
|
||||||
)
|
|
||||||
# Build base path
|
# Build base path
|
||||||
base_path = Path(get_exec_file_path(graph_exec_id, ""))
|
base_path = Path(get_exec_file_path(graph_exec_id, ""))
|
||||||
base_path.mkdir(parents=True, exist_ok=True)
|
base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Security fix: Add disk space limits to prevent DoS
|
# Security fix: Add disk space limits to prevent DoS
|
||||||
MAX_FILE_SIZE_BYTES = Config().max_file_size_mb * 1024 * 1024
|
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB per file
|
||||||
MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory
|
MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory
|
||||||
|
|
||||||
# Check total disk usage in base_path
|
# Check total disk usage in base_path
|
||||||
@@ -169,57 +142,9 @@ async def store_media_file(
|
|||||||
"""
|
"""
|
||||||
return str(absolute_path.relative_to(base))
|
return str(absolute_path.relative_to(base))
|
||||||
|
|
||||||
# Get cloud storage handler for checking cloud paths
|
|
||||||
cloud_storage = await get_cloud_storage_handler()
|
|
||||||
|
|
||||||
# Track if the input came from workspace (don't re-save it)
|
|
||||||
is_from_workspace = file.startswith("workspace://")
|
|
||||||
|
|
||||||
# Check if this is a workspace file reference
|
|
||||||
if is_from_workspace:
|
|
||||||
if workspace_manager is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Workspace file reference requires workspace context. "
|
|
||||||
"This file type is only available in CoPilot sessions."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse workspace reference
|
|
||||||
# workspace://abc123 - by file ID
|
|
||||||
# workspace:///path/to/file.txt - by virtual path
|
|
||||||
file_ref = file[12:] # Remove "workspace://"
|
|
||||||
|
|
||||||
if file_ref.startswith("/"):
|
|
||||||
# Path reference
|
|
||||||
workspace_content = await workspace_manager.read_file(file_ref)
|
|
||||||
file_info = await workspace_manager.get_file_info_by_path(file_ref)
|
|
||||||
filename = sanitize_filename(
|
|
||||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# ID reference
|
|
||||||
workspace_content = await workspace_manager.read_file_by_id(file_ref)
|
|
||||||
file_info = await workspace_manager.get_file_info(file_ref)
|
|
||||||
filename = sanitize_filename(
|
|
||||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
|
||||||
except OSError as e:
|
|
||||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
|
||||||
|
|
||||||
# Check file size limit
|
|
||||||
if len(workspace_content) > MAX_FILE_SIZE_BYTES:
|
|
||||||
raise ValueError(
|
|
||||||
f"File too large: {len(workspace_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Virus scan the workspace content before writing locally
|
|
||||||
await scan_content_safe(workspace_content, filename=filename)
|
|
||||||
target_path.write_bytes(workspace_content)
|
|
||||||
|
|
||||||
# Check if this is a cloud storage path
|
# Check if this is a cloud storage path
|
||||||
elif cloud_storage.is_cloud_path(file):
|
cloud_storage = await get_cloud_storage_handler()
|
||||||
|
if cloud_storage.is_cloud_path(file):
|
||||||
# Download from cloud storage and store locally
|
# Download from cloud storage and store locally
|
||||||
cloud_content = await cloud_storage.retrieve_file(
|
cloud_content = await cloud_storage.retrieve_file(
|
||||||
file, user_id=user_id, graph_exec_id=graph_exec_id
|
file, user_id=user_id, graph_exec_id=graph_exec_id
|
||||||
@@ -234,9 +159,9 @@ async def store_media_file(
|
|||||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||||
|
|
||||||
# Check file size limit
|
# Check file size limit
|
||||||
if len(cloud_content) > MAX_FILE_SIZE_BYTES:
|
if len(cloud_content) > MAX_FILE_SIZE:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE} bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Virus scan the cloud content before writing locally
|
# Virus scan the cloud content before writing locally
|
||||||
@@ -264,9 +189,9 @@ async def store_media_file(
|
|||||||
content = base64.b64decode(b64_content)
|
content = base64.b64decode(b64_content)
|
||||||
|
|
||||||
# Check file size limit
|
# Check file size limit
|
||||||
if len(content) > MAX_FILE_SIZE_BYTES:
|
if len(content) > MAX_FILE_SIZE:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE} bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Virus scan the base64 content before writing
|
# Virus scan the base64 content before writing
|
||||||
@@ -274,31 +199,23 @@ async def store_media_file(
|
|||||||
target_path.write_bytes(content)
|
target_path.write_bytes(content)
|
||||||
|
|
||||||
elif file.startswith(("http://", "https://")):
|
elif file.startswith(("http://", "https://")):
|
||||||
# URL - download first to get Content-Type header
|
# URL
|
||||||
resp = await Requests().get(file)
|
|
||||||
|
|
||||||
# Check file size limit
|
|
||||||
if len(resp.content) > MAX_FILE_SIZE_BYTES:
|
|
||||||
raise ValueError(
|
|
||||||
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract filename from URL path
|
|
||||||
parsed_url = urlparse(file)
|
parsed_url = urlparse(file)
|
||||||
filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}")
|
filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}")
|
||||||
|
|
||||||
# If filename lacks extension, add one from Content-Type header
|
|
||||||
if "." not in filename:
|
|
||||||
content_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
|
|
||||||
if content_type:
|
|
||||||
ext = _extension_from_mime(content_type)
|
|
||||||
filename = f"{filename}{ext}"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||||
|
|
||||||
|
# Download and save
|
||||||
|
resp = await Requests().get(file)
|
||||||
|
|
||||||
|
# Check file size limit
|
||||||
|
if len(resp.content) > MAX_FILE_SIZE:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
# Virus scan the downloaded content before writing
|
# Virus scan the downloaded content before writing
|
||||||
await scan_content_safe(resp.content, filename=filename)
|
await scan_content_safe(resp.content, filename=filename)
|
||||||
target_path.write_bytes(resp.content)
|
target_path.write_bytes(resp.content)
|
||||||
@@ -313,43 +230,11 @@ async def store_media_file(
|
|||||||
if not target_path.is_file():
|
if not target_path.is_file():
|
||||||
raise ValueError(f"Local file does not exist: {target_path}")
|
raise ValueError(f"Local file does not exist: {target_path}")
|
||||||
|
|
||||||
# Return based on requested format
|
# Return result
|
||||||
if return_format == "for_local_processing":
|
if return_content:
|
||||||
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
|
||||||
# Returns: relative path in exec_file directory (e.g., "image.png")
|
|
||||||
return MediaFileType(_strip_base_prefix(target_path, base_path))
|
|
||||||
|
|
||||||
elif return_format == "for_external_api":
|
|
||||||
# Use when sending content to external APIs that need base64
|
|
||||||
# Returns: data URI (e.g., "data:image/png;base64,iVBORw0...")
|
|
||||||
return MediaFileType(_file_to_data_uri(target_path))
|
return MediaFileType(_file_to_data_uri(target_path))
|
||||||
|
|
||||||
elif return_format == "for_block_output":
|
|
||||||
# Use when returning output from a block to user/next block
|
|
||||||
# Returns: workspace:// ref (CoPilot) or data URI (graph execution)
|
|
||||||
if workspace_manager is None:
|
|
||||||
# No workspace available (graph execution without CoPilot)
|
|
||||||
# Fallback to data URI so the content can still be used/displayed
|
|
||||||
return MediaFileType(_file_to_data_uri(target_path))
|
|
||||||
|
|
||||||
# Don't re-save if input was already from workspace
|
|
||||||
if is_from_workspace:
|
|
||||||
# Return original workspace reference
|
|
||||||
return MediaFileType(file)
|
|
||||||
|
|
||||||
# Save new content to workspace
|
|
||||||
content = target_path.read_bytes()
|
|
||||||
filename = target_path.name
|
|
||||||
|
|
||||||
file_record = await workspace_manager.write_file(
|
|
||||||
content=content,
|
|
||||||
filename=filename,
|
|
||||||
overwrite=True,
|
|
||||||
)
|
|
||||||
return MediaFileType(f"workspace://{file_record.id}")
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid return_format: {return_format}")
|
return MediaFileType(_strip_base_prefix(target_path, base_path))
|
||||||
|
|
||||||
|
|
||||||
def get_dir_size(path: Path) -> int:
|
def get_dir_size(path: Path) -> int:
|
||||||
|
|||||||
@@ -7,22 +7,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.util.file import store_media_file
|
from backend.util.file import store_media_file
|
||||||
from backend.util.type import MediaFileType
|
from backend.util.type import MediaFileType
|
||||||
|
|
||||||
|
|
||||||
def make_test_context(
|
|
||||||
graph_exec_id: str = "test-exec-123",
|
|
||||||
user_id: str = "test-user-123",
|
|
||||||
) -> ExecutionContext:
|
|
||||||
"""Helper to create test ExecutionContext."""
|
|
||||||
return ExecutionContext(
|
|
||||||
user_id=user_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestFileCloudIntegration:
|
class TestFileCloudIntegration:
|
||||||
"""Test cases for cloud storage integration in file utilities."""
|
"""Test cases for cloud storage integration in file utilities."""
|
||||||
|
|
||||||
@@ -82,9 +70,10 @@ class TestFileCloudIntegration:
|
|||||||
mock_path_class.side_effect = path_constructor
|
mock_path_class.side_effect = path_constructor
|
||||||
|
|
||||||
result = await store_media_file(
|
result = await store_media_file(
|
||||||
file=MediaFileType(cloud_path),
|
graph_exec_id,
|
||||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
MediaFileType(cloud_path),
|
||||||
return_format="for_local_processing",
|
"test-user-123",
|
||||||
|
return_content=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify cloud storage operations
|
# Verify cloud storage operations
|
||||||
@@ -155,9 +144,10 @@ class TestFileCloudIntegration:
|
|||||||
mock_path_obj.name = "image.png"
|
mock_path_obj.name = "image.png"
|
||||||
with patch("backend.util.file.Path", return_value=mock_path_obj):
|
with patch("backend.util.file.Path", return_value=mock_path_obj):
|
||||||
result = await store_media_file(
|
result = await store_media_file(
|
||||||
file=MediaFileType(cloud_path),
|
graph_exec_id,
|
||||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
MediaFileType(cloud_path),
|
||||||
return_format="for_external_api",
|
"test-user-123",
|
||||||
|
return_content=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify result is a data URI
|
# Verify result is a data URI
|
||||||
@@ -208,9 +198,10 @@ class TestFileCloudIntegration:
|
|||||||
mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt")
|
mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt")
|
||||||
|
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
file=MediaFileType(data_uri),
|
graph_exec_id,
|
||||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
MediaFileType(data_uri),
|
||||||
return_format="for_local_processing",
|
"test-user-123",
|
||||||
|
return_content=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify cloud handler was checked but not used for retrieval
|
# Verify cloud handler was checked but not used for retrieval
|
||||||
@@ -243,7 +234,5 @@ class TestFileCloudIntegration:
|
|||||||
FileNotFoundError, match="File not found in cloud storage"
|
FileNotFoundError, match="File not found in cloud storage"
|
||||||
):
|
):
|
||||||
await store_media_file(
|
await store_media_file(
|
||||||
file=MediaFileType(cloud_path),
|
graph_exec_id, MediaFileType(cloud_path), "test-user-123"
|
||||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,108 +0,0 @@
|
|||||||
"""
|
|
||||||
Shared GCS utilities for workspace and cloud storage backends.
|
|
||||||
|
|
||||||
This module provides common functionality for working with Google Cloud Storage,
|
|
||||||
including path parsing, client management, and signed URL generation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
from gcloud.aio import storage as async_gcs_storage
|
|
||||||
from google.cloud import storage as gcs_storage
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_gcs_path(path: str) -> tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Parse a GCS path in the format 'gcs://bucket/blob' to (bucket, blob).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: GCS path string (e.g., "gcs://my-bucket/path/to/file")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (bucket_name, blob_name)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the path format is invalid
|
|
||||||
"""
|
|
||||||
if not path.startswith("gcs://"):
|
|
||||||
raise ValueError(f"Invalid GCS path: {path}")
|
|
||||||
|
|
||||||
path_without_prefix = path[6:] # Remove "gcs://"
|
|
||||||
parts = path_without_prefix.split("/", 1)
|
|
||||||
if len(parts) != 2:
|
|
||||||
raise ValueError(f"Invalid GCS path format: {path}")
|
|
||||||
|
|
||||||
return parts[0], parts[1]
|
|
||||||
|
|
||||||
|
|
||||||
async def download_with_fresh_session(bucket: str, blob: str) -> bytes:
|
|
||||||
"""
|
|
||||||
Download file content using a fresh session.
|
|
||||||
|
|
||||||
This approach avoids event loop issues that can occur when reusing
|
|
||||||
sessions across different async contexts (e.g., in executors).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bucket: GCS bucket name
|
|
||||||
blob: Blob path within the bucket
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
File content as bytes
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If the file doesn't exist
|
|
||||||
"""
|
|
||||||
session = aiohttp.ClientSession(
|
|
||||||
connector=aiohttp.TCPConnector(limit=10, force_close=True)
|
|
||||||
)
|
|
||||||
client: async_gcs_storage.Storage | None = None
|
|
||||||
try:
|
|
||||||
client = async_gcs_storage.Storage(session=session)
|
|
||||||
content = await client.download(bucket, blob)
|
|
||||||
return content
|
|
||||||
except Exception as e:
|
|
||||||
if "404" in str(e) or "Not Found" in str(e):
|
|
||||||
raise FileNotFoundError(f"File not found: gcs://{bucket}/{blob}")
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
if client:
|
|
||||||
try:
|
|
||||||
await client.close()
|
|
||||||
except Exception:
|
|
||||||
pass # Best-effort cleanup
|
|
||||||
await session.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_signed_url(
|
|
||||||
sync_client: gcs_storage.Client,
|
|
||||||
bucket_name: str,
|
|
||||||
blob_name: str,
|
|
||||||
expires_in: int,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Generate a signed URL for temporary access to a GCS file.
|
|
||||||
|
|
||||||
Uses asyncio.to_thread() to run the sync operation without blocking.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sync_client: Sync GCS client with service account credentials
|
|
||||||
bucket_name: GCS bucket name
|
|
||||||
blob_name: Blob path within the bucket
|
|
||||||
expires_in: URL expiration time in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Signed URL string
|
|
||||||
"""
|
|
||||||
bucket = sync_client.bucket(bucket_name)
|
|
||||||
blob = bucket.blob(blob_name)
|
|
||||||
return await asyncio.to_thread(
|
|
||||||
blob.generate_signed_url,
|
|
||||||
version="v4",
|
|
||||||
expiration=datetime.now(timezone.utc) + timedelta(seconds=expires_in),
|
|
||||||
method="GET",
|
|
||||||
)
|
|
||||||
@@ -263,12 +263,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
description="The name of the Google Cloud Storage bucket for media files",
|
description="The name of the Google Cloud Storage bucket for media files",
|
||||||
)
|
)
|
||||||
|
|
||||||
workspace_storage_dir: str = Field(
|
|
||||||
default="",
|
|
||||||
description="Local directory for workspace file storage when GCS is not configured. "
|
|
||||||
"If empty, defaults to {app_data}/workspaces. Used for self-hosted deployments.",
|
|
||||||
)
|
|
||||||
|
|
||||||
reddit_user_agent: str = Field(
|
reddit_user_agent: str = Field(
|
||||||
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
|
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
|
||||||
description="The user agent for the Reddit API",
|
description="The user agent for the Reddit API",
|
||||||
@@ -395,13 +389,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
description="Maximum file size in MB for file uploads (1-1024 MB)",
|
description="Maximum file size in MB for file uploads (1-1024 MB)",
|
||||||
)
|
)
|
||||||
|
|
||||||
max_file_size_mb: int = Field(
|
|
||||||
default=100,
|
|
||||||
ge=1,
|
|
||||||
le=1024,
|
|
||||||
description="Maximum file size in MB for workspace files (1-1024 MB)",
|
|
||||||
)
|
|
||||||
|
|
||||||
# AutoMod configuration
|
# AutoMod configuration
|
||||||
automod_enabled: bool = Field(
|
automod_enabled: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
|
|||||||
@@ -140,29 +140,14 @@ async def execute_block_test(block: Block):
|
|||||||
setattr(block, mock_name, mock_obj)
|
setattr(block, mock_name, mock_obj)
|
||||||
|
|
||||||
# Populate credentials argument(s)
|
# Populate credentials argument(s)
|
||||||
# Generate IDs for execution context
|
|
||||||
graph_id = str(uuid.uuid4())
|
|
||||||
node_id = str(uuid.uuid4())
|
|
||||||
graph_exec_id = str(uuid.uuid4())
|
|
||||||
node_exec_id = str(uuid.uuid4())
|
|
||||||
user_id = str(uuid.uuid4())
|
|
||||||
graph_version = 1 # Default version for tests
|
|
||||||
|
|
||||||
extra_exec_kwargs: dict = {
|
extra_exec_kwargs: dict = {
|
||||||
"graph_id": graph_id,
|
"graph_id": str(uuid.uuid4()),
|
||||||
"node_id": node_id,
|
"node_id": str(uuid.uuid4()),
|
||||||
"graph_exec_id": graph_exec_id,
|
"graph_exec_id": str(uuid.uuid4()),
|
||||||
"node_exec_id": node_exec_id,
|
"node_exec_id": str(uuid.uuid4()),
|
||||||
"user_id": user_id,
|
"user_id": str(uuid.uuid4()),
|
||||||
"graph_version": graph_version,
|
"graph_version": 1, # Default version for tests
|
||||||
"execution_context": ExecutionContext(
|
"execution_context": ExecutionContext(),
|
||||||
user_id=user_id,
|
|
||||||
graph_id=graph_id,
|
|
||||||
graph_exec_id=graph_exec_id,
|
|
||||||
graph_version=graph_version,
|
|
||||||
node_id=node_id,
|
|
||||||
node_exec_id=node_exec_id,
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
input_model = cast(type[BlockSchema], block.input_schema)
|
input_model = cast(type[BlockSchema], block.input_schema)
|
||||||
|
|
||||||
|
|||||||
@@ -1,419 +0,0 @@
|
|||||||
"""
|
|
||||||
WorkspaceManager for managing user workspace file operations.
|
|
||||||
|
|
||||||
This module provides a high-level interface for workspace file operations,
|
|
||||||
combining the storage backend and database layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import mimetypes
|
|
||||||
import uuid
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from prisma.errors import UniqueViolationError
|
|
||||||
from prisma.models import UserWorkspaceFile
|
|
||||||
|
|
||||||
from backend.data.workspace import (
|
|
||||||
count_workspace_files,
|
|
||||||
create_workspace_file,
|
|
||||||
get_workspace_file,
|
|
||||||
get_workspace_file_by_path,
|
|
||||||
list_workspace_files,
|
|
||||||
soft_delete_workspace_file,
|
|
||||||
)
|
|
||||||
from backend.util.settings import Config
|
|
||||||
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceManager:
|
|
||||||
"""
|
|
||||||
Manages workspace file operations.
|
|
||||||
|
|
||||||
Combines storage backend operations with database record management.
|
|
||||||
Supports session-scoped file segmentation where files are stored in
|
|
||||||
session-specific virtual paths: /sessions/{session_id}/{filename}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, user_id: str, workspace_id: str, session_id: Optional[str] = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize WorkspaceManager.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The user's ID
|
|
||||||
workspace_id: The workspace ID
|
|
||||||
session_id: Optional session ID for session-scoped file access
|
|
||||||
"""
|
|
||||||
self.user_id = user_id
|
|
||||||
self.workspace_id = workspace_id
|
|
||||||
self.session_id = session_id
|
|
||||||
# Session path prefix for file isolation
|
|
||||||
self.session_path = f"/sessions/{session_id}" if session_id else ""
|
|
||||||
|
|
||||||
def _resolve_path(self, path: str) -> str:
|
|
||||||
"""
|
|
||||||
Resolve a path, defaulting to session folder if session_id is set.
|
|
||||||
|
|
||||||
Cross-session access is allowed by explicitly using /sessions/other-session-id/...
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Virtual path (e.g., "/file.txt" or "/sessions/abc123/file.txt")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Resolved path with session prefix if applicable
|
|
||||||
"""
|
|
||||||
# If path explicitly references a session folder, use it as-is
|
|
||||||
if path.startswith("/sessions/"):
|
|
||||||
return path
|
|
||||||
|
|
||||||
# If we have a session context, prepend session path
|
|
||||||
if self.session_path:
|
|
||||||
# Normalize the path
|
|
||||||
if not path.startswith("/"):
|
|
||||||
path = f"/{path}"
|
|
||||||
return f"{self.session_path}{path}"
|
|
||||||
|
|
||||||
# No session context, use path as-is
|
|
||||||
return path if path.startswith("/") else f"/{path}"
|
|
||||||
|
|
||||||
def _get_effective_path(
|
|
||||||
self, path: Optional[str], include_all_sessions: bool
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get effective path for list/count operations based on session context.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Optional path prefix to filter
|
|
||||||
include_all_sessions: If True, don't apply session scoping
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Effective path prefix for database query
|
|
||||||
"""
|
|
||||||
if include_all_sessions:
|
|
||||||
# Normalize path to ensure leading slash (stored paths are normalized)
|
|
||||||
if path is not None and not path.startswith("/"):
|
|
||||||
return f"/{path}"
|
|
||||||
return path
|
|
||||||
elif path is not None:
|
|
||||||
# Resolve the provided path with session scoping
|
|
||||||
return self._resolve_path(path)
|
|
||||||
elif self.session_path:
|
|
||||||
# Default to session folder with trailing slash to prevent prefix collisions
|
|
||||||
# e.g., "/sessions/abc" should not match "/sessions/abc123"
|
|
||||||
return self.session_path.rstrip("/") + "/"
|
|
||||||
else:
|
|
||||||
# No session context, use path as-is
|
|
||||||
return path
|
|
||||||
|
|
||||||
async def read_file(self, path: str) -> bytes:
|
|
||||||
"""
|
|
||||||
Read file from workspace by virtual path.
|
|
||||||
|
|
||||||
When session_id is set, paths are resolved relative to the session folder
|
|
||||||
unless they explicitly reference /sessions/...
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Virtual path (e.g., "/documents/report.pdf")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
File content as bytes
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If file doesn't exist
|
|
||||||
"""
|
|
||||||
resolved_path = self._resolve_path(path)
|
|
||||||
file = await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
|
||||||
if file is None:
|
|
||||||
raise FileNotFoundError(f"File not found at path: {resolved_path}")
|
|
||||||
|
|
||||||
storage = await get_workspace_storage()
|
|
||||||
return await storage.retrieve(file.storagePath)
|
|
||||||
|
|
||||||
async def read_file_by_id(self, file_id: str) -> bytes:
|
|
||||||
"""
|
|
||||||
Read file from workspace by file ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_id: The file's ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
File content as bytes
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If file doesn't exist
|
|
||||||
"""
|
|
||||||
file = await get_workspace_file(file_id, self.workspace_id)
|
|
||||||
if file is None:
|
|
||||||
raise FileNotFoundError(f"File not found: {file_id}")
|
|
||||||
|
|
||||||
storage = await get_workspace_storage()
|
|
||||||
return await storage.retrieve(file.storagePath)
|
|
||||||
|
|
||||||
async def write_file(
|
|
||||||
self,
|
|
||||||
content: bytes,
|
|
||||||
filename: str,
|
|
||||||
path: Optional[str] = None,
|
|
||||||
mime_type: Optional[str] = None,
|
|
||||||
overwrite: bool = False,
|
|
||||||
) -> UserWorkspaceFile:
|
|
||||||
"""
|
|
||||||
Write file to workspace.
|
|
||||||
|
|
||||||
When session_id is set, files are written to /sessions/{session_id}/...
|
|
||||||
by default. Use explicit /sessions/... paths for cross-session access.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: File content as bytes
|
|
||||||
filename: Filename for the file
|
|
||||||
path: Virtual path (defaults to "/{filename}", session-scoped if session_id set)
|
|
||||||
mime_type: MIME type (auto-detected if not provided)
|
|
||||||
overwrite: Whether to overwrite existing file at path
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created UserWorkspaceFile instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If file exceeds size limit or path already exists
|
|
||||||
"""
|
|
||||||
# Enforce file size limit
|
|
||||||
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
|
||||||
if len(content) > max_file_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"File too large: {len(content)} bytes exceeds "
|
|
||||||
f"{Config().max_file_size_mb}MB limit"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine path with session scoping
|
|
||||||
if path is None:
|
|
||||||
path = f"/{filename}"
|
|
||||||
elif not path.startswith("/"):
|
|
||||||
path = f"/{path}"
|
|
||||||
|
|
||||||
# Resolve path with session prefix
|
|
||||||
path = self._resolve_path(path)
|
|
||||||
|
|
||||||
# Check if file exists at path (only error for non-overwrite case)
|
|
||||||
# For overwrite=True, we let the write proceed and handle via UniqueViolationError
|
|
||||||
# This ensures the new file is written to storage BEFORE the old one is deleted,
|
|
||||||
# preventing data loss if the new write fails
|
|
||||||
if not overwrite:
|
|
||||||
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
|
||||||
if existing is not None:
|
|
||||||
raise ValueError(f"File already exists at path: {path}")
|
|
||||||
|
|
||||||
# Auto-detect MIME type if not provided
|
|
||||||
if mime_type is None:
|
|
||||||
mime_type, _ = mimetypes.guess_type(filename)
|
|
||||||
mime_type = mime_type or "application/octet-stream"
|
|
||||||
|
|
||||||
# Compute checksum
|
|
||||||
checksum = compute_file_checksum(content)
|
|
||||||
|
|
||||||
# Generate unique file ID for storage
|
|
||||||
file_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
# Store file in storage backend
|
|
||||||
storage = await get_workspace_storage()
|
|
||||||
storage_path = await storage.store(
|
|
||||||
workspace_id=self.workspace_id,
|
|
||||||
file_id=file_id,
|
|
||||||
filename=filename,
|
|
||||||
content=content,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create database record - handle race condition where another request
|
|
||||||
# created a file at the same path between our check and create
|
|
||||||
try:
|
|
||||||
file = await create_workspace_file(
|
|
||||||
workspace_id=self.workspace_id,
|
|
||||||
file_id=file_id,
|
|
||||||
name=filename,
|
|
||||||
path=path,
|
|
||||||
storage_path=storage_path,
|
|
||||||
mime_type=mime_type,
|
|
||||||
size_bytes=len(content),
|
|
||||||
checksum=checksum,
|
|
||||||
)
|
|
||||||
except UniqueViolationError:
|
|
||||||
# Race condition: another request created a file at this path
|
|
||||||
if overwrite:
|
|
||||||
# Re-fetch and delete the conflicting file, then retry
|
|
||||||
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
|
||||||
if existing:
|
|
||||||
await self.delete_file(existing.id)
|
|
||||||
# Retry the create - if this also fails, clean up storage file
|
|
||||||
try:
|
|
||||||
file = await create_workspace_file(
|
|
||||||
workspace_id=self.workspace_id,
|
|
||||||
file_id=file_id,
|
|
||||||
name=filename,
|
|
||||||
path=path,
|
|
||||||
storage_path=storage_path,
|
|
||||||
mime_type=mime_type,
|
|
||||||
size_bytes=len(content),
|
|
||||||
checksum=checksum,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
# Clean up orphaned storage file on retry failure
|
|
||||||
try:
|
|
||||||
await storage.delete(storage_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
# Clean up the orphaned storage file before raising
|
|
||||||
try:
|
|
||||||
await storage.delete(storage_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
|
||||||
raise ValueError(f"File already exists at path: {path}")
|
|
||||||
except Exception:
|
|
||||||
# Any other database error (connection, validation, etc.) - clean up storage
|
|
||||||
try:
|
|
||||||
await storage.delete(storage_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Wrote file {file.id} ({filename}) to workspace {self.workspace_id} "
|
|
||||||
f"at path {path}, size={len(content)} bytes"
|
|
||||||
)
|
|
||||||
|
|
||||||
return file
|
|
||||||
|
|
||||||
async def list_files(
|
|
||||||
self,
|
|
||||||
path: Optional[str] = None,
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
offset: int = 0,
|
|
||||||
include_all_sessions: bool = False,
|
|
||||||
) -> list[UserWorkspaceFile]:
|
|
||||||
"""
|
|
||||||
List files in workspace.
|
|
||||||
|
|
||||||
When session_id is set and include_all_sessions is False (default),
|
|
||||||
only files in the current session's folder are listed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Optional path prefix to filter (e.g., "/documents/")
|
|
||||||
limit: Maximum number of files to return
|
|
||||||
offset: Number of files to skip
|
|
||||||
include_all_sessions: If True, list files from all sessions.
|
|
||||||
If False (default), only list current session's files.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of UserWorkspaceFile instances
|
|
||||||
"""
|
|
||||||
effective_path = self._get_effective_path(path, include_all_sessions)
|
|
||||||
|
|
||||||
return await list_workspace_files(
|
|
||||||
workspace_id=self.workspace_id,
|
|
||||||
path_prefix=effective_path,
|
|
||||||
limit=limit,
|
|
||||||
offset=offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def delete_file(self, file_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Delete a file (soft-delete).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_id: The file's ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deleted, False if not found
|
|
||||||
"""
|
|
||||||
file = await get_workspace_file(file_id, self.workspace_id)
|
|
||||||
if file is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Delete from storage
|
|
||||||
storage = await get_workspace_storage()
|
|
||||||
try:
|
|
||||||
await storage.delete(file.storagePath)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to delete file from storage: {e}")
|
|
||||||
# Continue with database soft-delete even if storage delete fails
|
|
||||||
|
|
||||||
# Soft-delete database record
|
|
||||||
result = await soft_delete_workspace_file(file_id, self.workspace_id)
|
|
||||||
return result is not None
|
|
||||||
|
|
||||||
async def get_download_url(self, file_id: str, expires_in: int = 3600) -> str:
|
|
||||||
"""
|
|
||||||
Get download URL for a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_id: The file's ID
|
|
||||||
expires_in: URL expiration in seconds (default 1 hour)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Download URL (signed URL for GCS, API endpoint for local)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If file doesn't exist
|
|
||||||
"""
|
|
||||||
file = await get_workspace_file(file_id, self.workspace_id)
|
|
||||||
if file is None:
|
|
||||||
raise FileNotFoundError(f"File not found: {file_id}")
|
|
||||||
|
|
||||||
storage = await get_workspace_storage()
|
|
||||||
return await storage.get_download_url(file.storagePath, expires_in)
|
|
||||||
|
|
||||||
async def get_file_info(self, file_id: str) -> Optional[UserWorkspaceFile]:
|
|
||||||
"""
|
|
||||||
Get file metadata.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_id: The file's ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UserWorkspaceFile instance or None
|
|
||||||
"""
|
|
||||||
return await get_workspace_file(file_id, self.workspace_id)
|
|
||||||
|
|
||||||
async def get_file_info_by_path(self, path: str) -> Optional[UserWorkspaceFile]:
|
|
||||||
"""
|
|
||||||
Get file metadata by path.
|
|
||||||
|
|
||||||
When session_id is set, paths are resolved relative to the session folder
|
|
||||||
unless they explicitly reference /sessions/...
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Virtual path
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UserWorkspaceFile instance or None
|
|
||||||
"""
|
|
||||||
resolved_path = self._resolve_path(path)
|
|
||||||
return await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
|
||||||
|
|
||||||
async def get_file_count(
|
|
||||||
self,
|
|
||||||
path: Optional[str] = None,
|
|
||||||
include_all_sessions: bool = False,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Get number of files in workspace.
|
|
||||||
|
|
||||||
When session_id is set and include_all_sessions is False (default),
|
|
||||||
only counts files in the current session's folder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Optional path prefix to filter (e.g., "/documents/")
|
|
||||||
include_all_sessions: If True, count all files in workspace.
|
|
||||||
If False (default), only count current session's files.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of files
|
|
||||||
"""
|
|
||||||
effective_path = self._get_effective_path(path, include_all_sessions)
|
|
||||||
|
|
||||||
return await count_workspace_files(
|
|
||||||
self.workspace_id, path_prefix=effective_path
|
|
||||||
)
|
|
||||||
@@ -1,398 +0,0 @@
|
|||||||
"""
|
|
||||||
Workspace storage backend abstraction for supporting both cloud and local deployments.
|
|
||||||
|
|
||||||
This module provides a unified interface for storing workspace files, with implementations
|
|
||||||
for Google Cloud Storage (cloud deployments) and local filesystem (self-hosted deployments).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import logging
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import aiofiles
|
|
||||||
import aiohttp
|
|
||||||
from gcloud.aio import storage as async_gcs_storage
|
|
||||||
from google.cloud import storage as gcs_storage
|
|
||||||
|
|
||||||
from backend.util.data import get_data_path
|
|
||||||
from backend.util.gcs_utils import (
|
|
||||||
download_with_fresh_session,
|
|
||||||
generate_signed_url,
|
|
||||||
parse_gcs_path,
|
|
||||||
)
|
|
||||||
from backend.util.settings import Config
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceStorageBackend(ABC):
|
|
||||||
"""Abstract interface for workspace file storage."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def store(
|
|
||||||
self,
|
|
||||||
workspace_id: str,
|
|
||||||
file_id: str,
|
|
||||||
filename: str,
|
|
||||||
content: bytes,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Store file content, return storage path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace_id: The workspace ID
|
|
||||||
file_id: Unique file ID for storage
|
|
||||||
filename: Original filename
|
|
||||||
content: File content as bytes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Storage path string (cloud path or local path)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def retrieve(self, storage_path: str) -> bytes:
|
|
||||||
"""
|
|
||||||
Retrieve file content from storage.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage_path: The storage path returned from store()
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
File content as bytes
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def delete(self, storage_path: str) -> None:
|
|
||||||
"""
|
|
||||||
Delete file from storage.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage_path: The storage path to delete
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
|
||||||
"""
|
|
||||||
Get URL for downloading the file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage_path: The storage path
|
|
||||||
expires_in: URL expiration time in seconds (default 1 hour)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Download URL (signed URL for GCS, direct API path for local)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class GCSWorkspaceStorage(WorkspaceStorageBackend):
|
|
||||||
"""Google Cloud Storage implementation for workspace storage."""
|
|
||||||
|
|
||||||
def __init__(self, bucket_name: str):
|
|
||||||
self.bucket_name = bucket_name
|
|
||||||
self._async_client: Optional[async_gcs_storage.Storage] = None
|
|
||||||
self._sync_client: Optional[gcs_storage.Client] = None
|
|
||||||
self._session: Optional[aiohttp.ClientSession] = None
|
|
||||||
|
|
||||||
async def _get_async_client(self) -> async_gcs_storage.Storage:
|
|
||||||
"""Get or create async GCS client."""
|
|
||||||
if self._async_client is None:
|
|
||||||
self._session = aiohttp.ClientSession(
|
|
||||||
connector=aiohttp.TCPConnector(limit=100, force_close=False)
|
|
||||||
)
|
|
||||||
self._async_client = async_gcs_storage.Storage(session=self._session)
|
|
||||||
return self._async_client
|
|
||||||
|
|
||||||
def _get_sync_client(self) -> gcs_storage.Client:
|
|
||||||
"""Get or create sync GCS client (for signed URLs)."""
|
|
||||||
if self._sync_client is None:
|
|
||||||
self._sync_client = gcs_storage.Client()
|
|
||||||
return self._sync_client
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close all client connections."""
|
|
||||||
if self._async_client is not None:
|
|
||||||
try:
|
|
||||||
await self._async_client.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error closing GCS client: {e}")
|
|
||||||
self._async_client = None
|
|
||||||
|
|
||||||
if self._session is not None:
|
|
||||||
try:
|
|
||||||
await self._session.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error closing session: {e}")
|
|
||||||
self._session = None
|
|
||||||
|
|
||||||
def _build_blob_name(self, workspace_id: str, file_id: str, filename: str) -> str:
|
|
||||||
"""Build the blob path for workspace files."""
|
|
||||||
return f"workspaces/{workspace_id}/{file_id}/{filename}"
|
|
||||||
|
|
||||||
async def store(
|
|
||||||
self,
|
|
||||||
workspace_id: str,
|
|
||||||
file_id: str,
|
|
||||||
filename: str,
|
|
||||||
content: bytes,
|
|
||||||
) -> str:
|
|
||||||
"""Store file in GCS."""
|
|
||||||
client = await self._get_async_client()
|
|
||||||
blob_name = self._build_blob_name(workspace_id, file_id, filename)
|
|
||||||
|
|
||||||
# Upload with metadata
|
|
||||||
upload_time = datetime.now(timezone.utc)
|
|
||||||
await client.upload(
|
|
||||||
self.bucket_name,
|
|
||||||
blob_name,
|
|
||||||
content,
|
|
||||||
metadata={
|
|
||||||
"uploaded_at": upload_time.isoformat(),
|
|
||||||
"workspace_id": workspace_id,
|
|
||||||
"file_id": file_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return f"gcs://{self.bucket_name}/{blob_name}"
|
|
||||||
|
|
||||||
async def retrieve(self, storage_path: str) -> bytes:
|
|
||||||
"""Retrieve file from GCS."""
|
|
||||||
bucket_name, blob_name = parse_gcs_path(storage_path)
|
|
||||||
return await download_with_fresh_session(bucket_name, blob_name)
|
|
||||||
|
|
||||||
async def delete(self, storage_path: str) -> None:
|
|
||||||
"""Delete file from GCS."""
|
|
||||||
bucket_name, blob_name = parse_gcs_path(storage_path)
|
|
||||||
client = await self._get_async_client()
|
|
||||||
|
|
||||||
try:
|
|
||||||
await client.delete(bucket_name, blob_name)
|
|
||||||
except Exception as e:
|
|
||||||
if "404" not in str(e) and "Not Found" not in str(e):
|
|
||||||
raise
|
|
||||||
# File already deleted, that's fine
|
|
||||||
|
|
||||||
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
|
||||||
"""
|
|
||||||
Generate download URL for GCS file.
|
|
||||||
|
|
||||||
Attempts to generate a signed URL if running with service account credentials.
|
|
||||||
Falls back to an API proxy endpoint if signed URL generation fails
|
|
||||||
(e.g., when running locally with user OAuth credentials).
|
|
||||||
"""
|
|
||||||
bucket_name, blob_name = parse_gcs_path(storage_path)
|
|
||||||
|
|
||||||
# Extract file_id from blob_name for fallback: workspaces/{workspace_id}/{file_id}/{filename}
|
|
||||||
blob_parts = blob_name.split("/")
|
|
||||||
file_id = blob_parts[2] if len(blob_parts) >= 3 else None
|
|
||||||
|
|
||||||
# Try to generate signed URL (requires service account credentials)
|
|
||||||
try:
|
|
||||||
sync_client = self._get_sync_client()
|
|
||||||
return await generate_signed_url(
|
|
||||||
sync_client, bucket_name, blob_name, expires_in
|
|
||||||
)
|
|
||||||
except AttributeError as e:
|
|
||||||
# Signed URL generation requires service account with private key.
|
|
||||||
# When running with user OAuth credentials, fall back to API proxy.
|
|
||||||
if "private key" in str(e) and file_id:
|
|
||||||
logger.debug(
|
|
||||||
"Cannot generate signed URL (no service account credentials), "
|
|
||||||
"falling back to API proxy endpoint"
|
|
||||||
)
|
|
||||||
return f"/api/workspace/files/{file_id}/download"
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
class LocalWorkspaceStorage(WorkspaceStorageBackend):
|
|
||||||
"""Local filesystem implementation for workspace storage (self-hosted deployments)."""
|
|
||||||
|
|
||||||
def __init__(self, base_dir: Optional[str] = None):
|
|
||||||
"""
|
|
||||||
Initialize local storage backend.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
base_dir: Base directory for workspace storage.
|
|
||||||
If None, defaults to {app_data}/workspaces
|
|
||||||
"""
|
|
||||||
if base_dir:
|
|
||||||
self.base_dir = Path(base_dir)
|
|
||||||
else:
|
|
||||||
self.base_dir = Path(get_data_path()) / "workspaces"
|
|
||||||
|
|
||||||
# Ensure base directory exists
|
|
||||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
def _build_file_path(self, workspace_id: str, file_id: str, filename: str) -> Path:
|
|
||||||
"""Build the local file path with path traversal protection."""
|
|
||||||
# Import here to avoid circular import
|
|
||||||
# (file.py imports workspace.py which imports workspace_storage.py)
|
|
||||||
from backend.util.file import sanitize_filename
|
|
||||||
|
|
||||||
# Sanitize filename to prevent path traversal (removes / and \ among others)
|
|
||||||
safe_filename = sanitize_filename(filename)
|
|
||||||
file_path = (self.base_dir / workspace_id / file_id / safe_filename).resolve()
|
|
||||||
|
|
||||||
# Verify the resolved path is still under base_dir
|
|
||||||
if not file_path.is_relative_to(self.base_dir.resolve()):
|
|
||||||
raise ValueError("Invalid filename: path traversal detected")
|
|
||||||
|
|
||||||
return file_path
|
|
||||||
|
|
||||||
def _parse_storage_path(self, storage_path: str) -> Path:
|
|
||||||
"""Parse local storage path to filesystem path."""
|
|
||||||
if storage_path.startswith("local://"):
|
|
||||||
relative_path = storage_path[8:] # Remove "local://"
|
|
||||||
else:
|
|
||||||
relative_path = storage_path
|
|
||||||
|
|
||||||
full_path = (self.base_dir / relative_path).resolve()
|
|
||||||
|
|
||||||
# Security check: ensure path is under base_dir
|
|
||||||
# Use is_relative_to() for robust path containment check
|
|
||||||
# (handles case-insensitive filesystems and edge cases)
|
|
||||||
if not full_path.is_relative_to(self.base_dir.resolve()):
|
|
||||||
raise ValueError("Invalid storage path: path traversal detected")
|
|
||||||
|
|
||||||
return full_path
|
|
||||||
|
|
||||||
async def store(
|
|
||||||
self,
|
|
||||||
workspace_id: str,
|
|
||||||
file_id: str,
|
|
||||||
filename: str,
|
|
||||||
content: bytes,
|
|
||||||
) -> str:
|
|
||||||
"""Store file locally."""
|
|
||||||
file_path = self._build_file_path(workspace_id, file_id, filename)
|
|
||||||
|
|
||||||
# Create parent directories
|
|
||||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Write file asynchronously
|
|
||||||
async with aiofiles.open(file_path, "wb") as f:
|
|
||||||
await f.write(content)
|
|
||||||
|
|
||||||
# Return relative path as storage path
|
|
||||||
relative_path = file_path.relative_to(self.base_dir)
|
|
||||||
return f"local://{relative_path}"
|
|
||||||
|
|
||||||
async def retrieve(self, storage_path: str) -> bytes:
|
|
||||||
"""Retrieve file from local storage."""
|
|
||||||
file_path = self._parse_storage_path(storage_path)
|
|
||||||
|
|
||||||
if not file_path.exists():
|
|
||||||
raise FileNotFoundError(f"File not found: {storage_path}")
|
|
||||||
|
|
||||||
async with aiofiles.open(file_path, "rb") as f:
|
|
||||||
return await f.read()
|
|
||||||
|
|
||||||
async def delete(self, storage_path: str) -> None:
|
|
||||||
"""Delete file from local storage."""
|
|
||||||
file_path = self._parse_storage_path(storage_path)
|
|
||||||
|
|
||||||
if file_path.exists():
|
|
||||||
# Remove file
|
|
||||||
file_path.unlink()
|
|
||||||
|
|
||||||
# Clean up empty parent directories
|
|
||||||
parent = file_path.parent
|
|
||||||
while parent != self.base_dir:
|
|
||||||
try:
|
|
||||||
if parent.exists() and not any(parent.iterdir()):
|
|
||||||
parent.rmdir()
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
except OSError:
|
|
||||||
break
|
|
||||||
parent = parent.parent
|
|
||||||
|
|
||||||
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
|
||||||
"""
|
|
||||||
Get download URL for local file.
|
|
||||||
|
|
||||||
For local storage, this returns an API endpoint path.
|
|
||||||
The actual serving is handled by the API layer.
|
|
||||||
"""
|
|
||||||
# Parse the storage path to get the components
|
|
||||||
if storage_path.startswith("local://"):
|
|
||||||
relative_path = storage_path[8:]
|
|
||||||
else:
|
|
||||||
relative_path = storage_path
|
|
||||||
|
|
||||||
# Return the API endpoint for downloading
|
|
||||||
# The file_id is extracted from the path: {workspace_id}/{file_id}/{filename}
|
|
||||||
parts = relative_path.split("/")
|
|
||||||
if len(parts) >= 2:
|
|
||||||
file_id = parts[1] # Second component is file_id
|
|
||||||
return f"/api/workspace/files/{file_id}/download"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid storage path format: {storage_path}")
|
|
||||||
|
|
||||||
|
|
||||||
# Global storage backend instance
|
|
||||||
_workspace_storage: Optional[WorkspaceStorageBackend] = None
|
|
||||||
_storage_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
async def get_workspace_storage() -> WorkspaceStorageBackend:
|
|
||||||
"""
|
|
||||||
Get the workspace storage backend instance.
|
|
||||||
|
|
||||||
Uses GCS if media_gcs_bucket_name is configured, otherwise uses local storage.
|
|
||||||
"""
|
|
||||||
global _workspace_storage
|
|
||||||
|
|
||||||
if _workspace_storage is None:
|
|
||||||
async with _storage_lock:
|
|
||||||
if _workspace_storage is None:
|
|
||||||
config = Config()
|
|
||||||
|
|
||||||
if config.media_gcs_bucket_name:
|
|
||||||
logger.info(
|
|
||||||
f"Using GCS workspace storage: {config.media_gcs_bucket_name}"
|
|
||||||
)
|
|
||||||
_workspace_storage = GCSWorkspaceStorage(
|
|
||||||
config.media_gcs_bucket_name
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
storage_dir = (
|
|
||||||
config.workspace_storage_dir
|
|
||||||
if config.workspace_storage_dir
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Using local workspace storage: {storage_dir or 'default'}"
|
|
||||||
)
|
|
||||||
_workspace_storage = LocalWorkspaceStorage(storage_dir)
|
|
||||||
|
|
||||||
return _workspace_storage
|
|
||||||
|
|
||||||
|
|
||||||
async def shutdown_workspace_storage() -> None:
|
|
||||||
"""
|
|
||||||
Properly shutdown the global workspace storage backend.
|
|
||||||
|
|
||||||
Closes aiohttp sessions and other resources for GCS backend.
|
|
||||||
Should be called during application shutdown.
|
|
||||||
"""
|
|
||||||
global _workspace_storage
|
|
||||||
|
|
||||||
if _workspace_storage is not None:
|
|
||||||
async with _storage_lock:
|
|
||||||
if _workspace_storage is not None:
|
|
||||||
if isinstance(_workspace_storage, GCSWorkspaceStorage):
|
|
||||||
await _workspace_storage.close()
|
|
||||||
_workspace_storage = None
|
|
||||||
|
|
||||||
|
|
||||||
def compute_file_checksum(content: bytes) -> str:
|
|
||||||
"""Compute SHA256 checksum of file content."""
|
|
||||||
return hashlib.sha256(content).hexdigest()
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
-- CreateEnum
|
|
||||||
CREATE TYPE "WorkspaceFileSource" AS ENUM ('UPLOAD', 'EXECUTION', 'COPILOT', 'IMPORT');
|
|
||||||
|
|
||||||
-- CreateTable
|
|
||||||
CREATE TABLE "UserWorkspace" (
|
|
||||||
"id" TEXT NOT NULL,
|
|
||||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
|
||||||
"userId" TEXT NOT NULL,
|
|
||||||
|
|
||||||
CONSTRAINT "UserWorkspace_pkey" PRIMARY KEY ("id")
|
|
||||||
);
|
|
||||||
|
|
||||||
-- CreateTable
|
|
||||||
CREATE TABLE "UserWorkspaceFile" (
|
|
||||||
"id" TEXT NOT NULL,
|
|
||||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
|
||||||
"workspaceId" TEXT NOT NULL,
|
|
||||||
"name" TEXT NOT NULL,
|
|
||||||
"path" TEXT NOT NULL,
|
|
||||||
"storagePath" TEXT NOT NULL,
|
|
||||||
"mimeType" TEXT NOT NULL,
|
|
||||||
"sizeBytes" BIGINT NOT NULL,
|
|
||||||
"checksum" TEXT,
|
|
||||||
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
|
|
||||||
"deletedAt" TIMESTAMP(3),
|
|
||||||
"source" "WorkspaceFileSource" NOT NULL DEFAULT 'UPLOAD',
|
|
||||||
"sourceExecId" TEXT,
|
|
||||||
"sourceSessionId" TEXT,
|
|
||||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
|
||||||
|
|
||||||
CONSTRAINT "UserWorkspaceFile_pkey" PRIMARY KEY ("id")
|
|
||||||
);
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE UNIQUE INDEX "UserWorkspace_userId_key" ON "UserWorkspace"("userId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "UserWorkspace_userId_idx" ON "UserWorkspace"("userId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "UserWorkspaceFile_workspaceId_isDeleted_idx" ON "UserWorkspaceFile"("workspaceId", "isDeleted");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE UNIQUE INDEX "UserWorkspaceFile_workspaceId_path_key" ON "UserWorkspaceFile"("workspaceId", "path");
|
|
||||||
|
|
||||||
-- AddForeignKey
|
|
||||||
ALTER TABLE "UserWorkspace" ADD CONSTRAINT "UserWorkspace_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
|
||||||
|
|
||||||
-- AddForeignKey
|
|
||||||
ALTER TABLE "UserWorkspaceFile" ADD CONSTRAINT "UserWorkspaceFile_workspaceId_fkey" FOREIGN KEY ("workspaceId") REFERENCES "UserWorkspace"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
/*
|
|
||||||
Warnings:
|
|
||||||
|
|
||||||
- You are about to drop the column `source` on the `UserWorkspaceFile` table. All the data in the column will be lost.
|
|
||||||
- You are about to drop the column `sourceExecId` on the `UserWorkspaceFile` table. All the data in the column will be lost.
|
|
||||||
- You are about to drop the column `sourceSessionId` on the `UserWorkspaceFile` table. All the data in the column will be lost.
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
-- AlterTable
|
|
||||||
ALTER TABLE "UserWorkspaceFile" DROP COLUMN "source",
|
|
||||||
DROP COLUMN "sourceExecId",
|
|
||||||
DROP COLUMN "sourceSessionId";
|
|
||||||
|
|
||||||
-- DropEnum
|
|
||||||
DROP TYPE "WorkspaceFileSource";
|
|
||||||
@@ -63,7 +63,6 @@ model User {
|
|||||||
IntegrationWebhooks IntegrationWebhook[]
|
IntegrationWebhooks IntegrationWebhook[]
|
||||||
NotificationBatches UserNotificationBatch[]
|
NotificationBatches UserNotificationBatch[]
|
||||||
PendingHumanReviews PendingHumanReview[]
|
PendingHumanReviews PendingHumanReview[]
|
||||||
Workspace UserWorkspace?
|
|
||||||
|
|
||||||
// OAuth Provider relations
|
// OAuth Provider relations
|
||||||
OAuthApplications OAuthApplication[]
|
OAuthApplications OAuthApplication[]
|
||||||
@@ -138,53 +137,6 @@ model CoPilotUnderstanding {
|
|||||||
@@index([userId])
|
@@index([userId])
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////
|
|
||||||
////////////////////////////////////////////////////////////
|
|
||||||
//////////////// USER WORKSPACE TABLES /////////////////
|
|
||||||
////////////////////////////////////////////////////////////
|
|
||||||
////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
// User's persistent file storage workspace
|
|
||||||
model UserWorkspace {
|
|
||||||
id String @id @default(uuid())
|
|
||||||
createdAt DateTime @default(now())
|
|
||||||
updatedAt DateTime @updatedAt
|
|
||||||
|
|
||||||
userId String @unique
|
|
||||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
|
||||||
|
|
||||||
Files UserWorkspaceFile[]
|
|
||||||
|
|
||||||
@@index([userId])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Individual files in a user's workspace
|
|
||||||
model UserWorkspaceFile {
|
|
||||||
id String @id @default(uuid())
|
|
||||||
createdAt DateTime @default(now())
|
|
||||||
updatedAt DateTime @updatedAt
|
|
||||||
|
|
||||||
workspaceId String
|
|
||||||
Workspace UserWorkspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
|
|
||||||
|
|
||||||
// File metadata
|
|
||||||
name String // User-visible filename
|
|
||||||
path String // Virtual path (e.g., "/documents/report.pdf")
|
|
||||||
storagePath String // Actual GCS or local storage path
|
|
||||||
mimeType String
|
|
||||||
sizeBytes BigInt
|
|
||||||
checksum String? // SHA256 for integrity
|
|
||||||
|
|
||||||
// File state
|
|
||||||
isDeleted Boolean @default(false)
|
|
||||||
deletedAt DateTime?
|
|
||||||
|
|
||||||
metadata Json @default("{}")
|
|
||||||
|
|
||||||
@@unique([workspaceId, path])
|
|
||||||
@@index([workspaceId, isDeleted])
|
|
||||||
}
|
|
||||||
|
|
||||||
model BuilderSearchHistory {
|
model BuilderSearchHistory {
|
||||||
id String @id @default(uuid())
|
id String @id @default(uuid())
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
|
|||||||
@@ -34,6 +34,3 @@ NEXT_PUBLIC_PREVIEW_STEALING_DEV=
|
|||||||
# PostHog Analytics
|
# PostHog Analytics
|
||||||
NEXT_PUBLIC_POSTHOG_KEY=
|
NEXT_PUBLIC_POSTHOG_KEY=
|
||||||
NEXT_PUBLIC_POSTHOG_HOST=https://eu.i.posthog.com
|
NEXT_PUBLIC_POSTHOG_HOST=https://eu.i.posthog.com
|
||||||
|
|
||||||
# OpenAI (for voice transcription)
|
|
||||||
OPENAI_API_KEY=
|
|
||||||
|
|||||||
@@ -1,76 +0,0 @@
|
|||||||
# CLAUDE.md - Frontend
|
|
||||||
|
|
||||||
This file provides guidance to Claude Code when working with the frontend.
|
|
||||||
|
|
||||||
## Essential Commands
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Install dependencies
|
|
||||||
pnpm i
|
|
||||||
|
|
||||||
# Generate API client from OpenAPI spec
|
|
||||||
pnpm generate:api
|
|
||||||
|
|
||||||
# Start development server
|
|
||||||
pnpm dev
|
|
||||||
|
|
||||||
# Run E2E tests
|
|
||||||
pnpm test
|
|
||||||
|
|
||||||
# Run Storybook for component development
|
|
||||||
pnpm storybook
|
|
||||||
|
|
||||||
# Build production
|
|
||||||
pnpm build
|
|
||||||
|
|
||||||
# Format and lint
|
|
||||||
pnpm format
|
|
||||||
|
|
||||||
# Type checking
|
|
||||||
pnpm types
|
|
||||||
```
|
|
||||||
|
|
||||||
### Code Style
|
|
||||||
|
|
||||||
- Fully capitalize acronyms in symbols, e.g. `graphID`, `useBackendAPI`
|
|
||||||
- Use function declarations (not arrow functions) for components/handlers
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
|
||||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
|
||||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
|
||||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
|
||||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
|
||||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
|
||||||
- **Icons**: Phosphor Icons only
|
|
||||||
- **Feature Flags**: LaunchDarkly integration
|
|
||||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
|
||||||
- **Testing**: Playwright for E2E, Storybook for component development
|
|
||||||
|
|
||||||
## Environment Configuration
|
|
||||||
|
|
||||||
`.env.default` (defaults) → `.env` (user overrides)
|
|
||||||
|
|
||||||
## Feature Development
|
|
||||||
|
|
||||||
See @CONTRIBUTING.md for complete patterns. Quick reference:
|
|
||||||
|
|
||||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
|
||||||
- Extract component logic into custom hooks grouped by concern, not by component. Each hook should represent a cohesive domain of functionality (e.g., useSearch, useFilters, usePagination) rather than bundling all state into one useComponentState hook.
|
|
||||||
- Put each hook in its own `.ts` file
|
|
||||||
- Put sub-components in local `components/` folder
|
|
||||||
- Component props should be `type Props = { ... }` (not exported) unless it needs to be used outside the component
|
|
||||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
|
||||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
|
||||||
- Never use `src/components/__legacy__/*`
|
|
||||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
|
||||||
- Regenerate with `pnpm generate:api`
|
|
||||||
- Pattern: `use{Method}{Version}{OperationName}`
|
|
||||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
|
||||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
|
||||||
6. **Code conventions**:
|
|
||||||
- Use function declarations (not arrow functions) for components/handlers
|
|
||||||
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
|
|
||||||
- Do not type hook returns, let Typescript infer as much as possible
|
|
||||||
- Never type with `any` unless a variable/attribute can ACTUALLY be of any type
|
|
||||||
@@ -1,10 +1,9 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
||||||
|
import { getOnboardingStatus, resolveResponse } from "@/app/api/helpers";
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
import { resolveResponse, getOnboardingStatus } from "@/app/api/helpers";
|
|
||||||
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
|
|
||||||
export default function OnboardingPage() {
|
export default function OnboardingPage() {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
@@ -13,12 +12,10 @@ export default function OnboardingPage() {
|
|||||||
async function redirectToStep() {
|
async function redirectToStep() {
|
||||||
try {
|
try {
|
||||||
// Check if onboarding is enabled (also gets chat flag for redirect)
|
// Check if onboarding is enabled (also gets chat flag for redirect)
|
||||||
const { shouldShowOnboarding, isChatEnabled } =
|
const { shouldShowOnboarding } = await getOnboardingStatus();
|
||||||
await getOnboardingStatus();
|
|
||||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
|
||||||
|
|
||||||
if (!shouldShowOnboarding) {
|
if (!shouldShowOnboarding) {
|
||||||
router.replace(homepageRoute);
|
router.replace("/");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -26,7 +23,7 @@ export default function OnboardingPage() {
|
|||||||
|
|
||||||
// Handle completed onboarding
|
// Handle completed onboarding
|
||||||
if (onboarding.completedSteps.includes("GET_RESULTS")) {
|
if (onboarding.completedSteps.includes("GET_RESULTS")) {
|
||||||
router.replace(homepageRoute);
|
router.replace("/");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import BackendAPI from "@/lib/autogpt-server-api";
|
|
||||||
import { NextResponse } from "next/server";
|
|
||||||
import { revalidatePath } from "next/cache";
|
|
||||||
import { getOnboardingStatus } from "@/app/api/helpers";
|
import { getOnboardingStatus } from "@/app/api/helpers";
|
||||||
|
import BackendAPI from "@/lib/autogpt-server-api";
|
||||||
|
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||||
|
import { revalidatePath } from "next/cache";
|
||||||
|
import { NextResponse } from "next/server";
|
||||||
|
|
||||||
// Handle the callback to complete the user session login
|
// Handle the callback to complete the user session login
|
||||||
export async function GET(request: Request) {
|
export async function GET(request: Request) {
|
||||||
@@ -27,13 +26,12 @@ export async function GET(request: Request) {
|
|||||||
await api.createUser();
|
await api.createUser();
|
||||||
|
|
||||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||||
const { shouldShowOnboarding, isChatEnabled } =
|
const { shouldShowOnboarding } = await getOnboardingStatus();
|
||||||
await getOnboardingStatus();
|
|
||||||
if (shouldShowOnboarding) {
|
if (shouldShowOnboarding) {
|
||||||
next = "/onboarding";
|
next = "/onboarding";
|
||||||
revalidatePath("/onboarding", "layout");
|
revalidatePath("/onboarding", "layout");
|
||||||
} else {
|
} else {
|
||||||
next = getHomepageRoute(isChatEnabled);
|
next = "/";
|
||||||
revalidatePath(next, "layout");
|
revalidatePath(next, "layout");
|
||||||
}
|
}
|
||||||
} catch (createUserError) {
|
} catch (createUserError) {
|
||||||
|
|||||||
@@ -1,6 +1,13 @@
|
|||||||
import type { ReactNode } from "react";
|
"use client";
|
||||||
|
import { FeatureFlagPage } from "@/services/feature-flags/FeatureFlagPage";
|
||||||
|
import { Flag } from "@/services/feature-flags/use-get-flag";
|
||||||
|
import { type ReactNode } from "react";
|
||||||
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
|
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
|
||||||
|
|
||||||
export default function CopilotLayout({ children }: { children: ReactNode }) {
|
export default function CopilotLayout({ children }: { children: ReactNode }) {
|
||||||
return <CopilotShell>{children}</CopilotShell>;
|
return (
|
||||||
|
<FeatureFlagPage flag={Flag.CHAT} whenDisabled="/library">
|
||||||
|
<CopilotShell>{children}</CopilotShell>
|
||||||
|
</FeatureFlagPage>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,14 +14,8 @@ export default function CopilotPage() {
|
|||||||
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||||
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||||
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||||
const {
|
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } =
|
||||||
greetingName,
|
state;
|
||||||
quickActions,
|
|
||||||
isLoading,
|
|
||||||
hasSession,
|
|
||||||
initialPrompt,
|
|
||||||
isReady,
|
|
||||||
} = state;
|
|
||||||
const {
|
const {
|
||||||
handleQuickAction,
|
handleQuickAction,
|
||||||
startChatWithPrompt,
|
startChatWithPrompt,
|
||||||
@@ -29,8 +23,6 @@ export default function CopilotPage() {
|
|||||||
handleStreamingChange,
|
handleStreamingChange,
|
||||||
} = handlers;
|
} = handlers;
|
||||||
|
|
||||||
if (!isReady) return null;
|
|
||||||
|
|
||||||
if (hasSession) {
|
if (hasSession) {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full flex-col">
|
<div className="flex h-full flex-col">
|
||||||
|
|||||||
@@ -3,18 +3,11 @@ import {
|
|||||||
postV2CreateSession,
|
postV2CreateSession,
|
||||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||||
import {
|
|
||||||
Flag,
|
|
||||||
type FlagValues,
|
|
||||||
useGetFlag,
|
|
||||||
} from "@/services/feature-flags/use-get-flag";
|
|
||||||
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
import { useCopilotStore } from "./copilot-page-store";
|
import { useCopilotStore } from "./copilot-page-store";
|
||||||
@@ -33,22 +26,6 @@ export function useCopilotPage() {
|
|||||||
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
||||||
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
||||||
|
|
||||||
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
|
|
||||||
useEffect(() => {
|
|
||||||
if (isLoggedIn) {
|
|
||||||
completeStep("VISIT_COPILOT");
|
|
||||||
}
|
|
||||||
}, [completeStep, isLoggedIn]);
|
|
||||||
|
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
|
||||||
const flags = useFlags<FlagValues>();
|
|
||||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
|
||||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
|
||||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
|
||||||
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
|
||||||
const isFlagReady =
|
|
||||||
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
|
||||||
|
|
||||||
const greetingName = getGreetingName(user);
|
const greetingName = getGreetingName(user);
|
||||||
const quickActions = getQuickActions();
|
const quickActions = getQuickActions();
|
||||||
|
|
||||||
@@ -58,11 +35,8 @@ export function useCopilotPage() {
|
|||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!isFlagReady) return;
|
if (isLoggedIn) completeStep("VISIT_COPILOT");
|
||||||
if (isChatEnabled === false) {
|
}, [completeStep, isLoggedIn]);
|
||||||
router.replace(homepageRoute);
|
|
||||||
}
|
|
||||||
}, [homepageRoute, isChatEnabled, isFlagReady, router]);
|
|
||||||
|
|
||||||
async function startChatWithPrompt(prompt: string) {
|
async function startChatWithPrompt(prompt: string) {
|
||||||
if (!prompt?.trim()) return;
|
if (!prompt?.trim()) return;
|
||||||
@@ -116,7 +90,6 @@ export function useCopilotPage() {
|
|||||||
isLoading: isUserLoading,
|
isLoading: isUserLoading,
|
||||||
hasSession,
|
hasSession,
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
|
|
||||||
},
|
},
|
||||||
handlers: {
|
handlers: {
|
||||||
handleQuickAction,
|
handleQuickAction,
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
|
||||||
import { useSearchParams } from "next/navigation";
|
import { useSearchParams } from "next/navigation";
|
||||||
import { Suspense } from "react";
|
import { Suspense } from "react";
|
||||||
import { getErrorDetails } from "./helpers";
|
import { getErrorDetails } from "./helpers";
|
||||||
@@ -11,8 +9,6 @@ function ErrorPageContent() {
|
|||||||
const searchParams = useSearchParams();
|
const searchParams = useSearchParams();
|
||||||
const errorMessage = searchParams.get("message");
|
const errorMessage = searchParams.get("message");
|
||||||
const errorDetails = getErrorDetails(errorMessage);
|
const errorDetails = getErrorDetails(errorMessage);
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
|
||||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
|
||||||
|
|
||||||
function handleRetry() {
|
function handleRetry() {
|
||||||
// Auth-related errors should redirect to login
|
// Auth-related errors should redirect to login
|
||||||
@@ -30,7 +26,7 @@ function ErrorPageContent() {
|
|||||||
}, 2000);
|
}, 2000);
|
||||||
} else {
|
} else {
|
||||||
// For server/network errors, go to home
|
// For server/network errors, go to home
|
||||||
window.location.href = homepageRoute;
|
window.location.href = "/";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"use server";
|
"use server";
|
||||||
|
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import BackendAPI from "@/lib/autogpt-server-api";
|
import BackendAPI from "@/lib/autogpt-server-api";
|
||||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||||
import { loginFormSchema } from "@/types/auth";
|
import { loginFormSchema } from "@/types/auth";
|
||||||
@@ -38,10 +37,8 @@ export async function login(email: string, password: string) {
|
|||||||
await api.createUser();
|
await api.createUser();
|
||||||
|
|
||||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||||
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
|
const { shouldShowOnboarding } = await getOnboardingStatus();
|
||||||
const next = shouldShowOnboarding
|
const next = shouldShowOnboarding ? "/onboarding" : "/";
|
||||||
? "/onboarding"
|
|
||||||
: getHomepageRoute(isChatEnabled);
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
success: true,
|
success: true,
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
|
||||||
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
||||||
import { zodResolver } from "@hookform/resolvers/zod";
|
import { zodResolver } from "@hookform/resolvers/zod";
|
||||||
import { useRouter, useSearchParams } from "next/navigation";
|
import { useRouter, useSearchParams } from "next/navigation";
|
||||||
@@ -22,17 +20,15 @@ export function useLoginPage() {
|
|||||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||||
const isCloudEnv = environment.isCloud();
|
const isCloudEnv = environment.isCloud();
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
|
||||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
|
||||||
|
|
||||||
// Get redirect destination from 'next' query parameter
|
// Get redirect destination from 'next' query parameter
|
||||||
const nextUrl = searchParams.get("next");
|
const nextUrl = searchParams.get("next");
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isLoggedIn && !isLoggingIn) {
|
if (isLoggedIn && !isLoggingIn) {
|
||||||
router.push(nextUrl || homepageRoute);
|
router.push(nextUrl || "/");
|
||||||
}
|
}
|
||||||
}, [homepageRoute, isLoggedIn, isLoggingIn, nextUrl, router]);
|
}, [isLoggedIn, isLoggingIn, nextUrl, router]);
|
||||||
|
|
||||||
const form = useForm<z.infer<typeof loginFormSchema>>({
|
const form = useForm<z.infer<typeof loginFormSchema>>({
|
||||||
resolver: zodResolver(loginFormSchema),
|
resolver: zodResolver(loginFormSchema),
|
||||||
@@ -98,7 +94,7 @@ export function useLoginPage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prefer URL's next parameter, then use backend-determined route
|
// Prefer URL's next parameter, then use backend-determined route
|
||||||
router.replace(nextUrl || result.next || homepageRoute);
|
router.replace(nextUrl || result.next || "/");
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
toast({
|
toast({
|
||||||
title:
|
title:
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"use server";
|
"use server";
|
||||||
|
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||||
import { signupFormSchema } from "@/types/auth";
|
import { signupFormSchema } from "@/types/auth";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
@@ -59,10 +58,8 @@ export async function signup(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||||
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
|
const { shouldShowOnboarding } = await getOnboardingStatus();
|
||||||
const next = shouldShowOnboarding
|
const next = shouldShowOnboarding ? "/onboarding" : "/";
|
||||||
? "/onboarding"
|
|
||||||
: getHomepageRoute(isChatEnabled);
|
|
||||||
|
|
||||||
return { success: true, next };
|
return { success: true, next };
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
|
||||||
import { LoginProvider, signupFormSchema } from "@/types/auth";
|
import { LoginProvider, signupFormSchema } from "@/types/auth";
|
||||||
import { zodResolver } from "@hookform/resolvers/zod";
|
import { zodResolver } from "@hookform/resolvers/zod";
|
||||||
import { useRouter, useSearchParams } from "next/navigation";
|
import { useRouter, useSearchParams } from "next/navigation";
|
||||||
@@ -22,17 +20,15 @@ export function useSignupPage() {
|
|||||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||||
const isCloudEnv = environment.isCloud();
|
const isCloudEnv = environment.isCloud();
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
|
||||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
|
||||||
|
|
||||||
// Get redirect destination from 'next' query parameter
|
// Get redirect destination from 'next' query parameter
|
||||||
const nextUrl = searchParams.get("next");
|
const nextUrl = searchParams.get("next");
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isLoggedIn && !isSigningUp) {
|
if (isLoggedIn && !isSigningUp) {
|
||||||
router.push(nextUrl || homepageRoute);
|
router.push(nextUrl || "/");
|
||||||
}
|
}
|
||||||
}, [homepageRoute, isLoggedIn, isSigningUp, nextUrl, router]);
|
}, [isLoggedIn, isSigningUp, nextUrl, router]);
|
||||||
|
|
||||||
const form = useForm<z.infer<typeof signupFormSchema>>({
|
const form = useForm<z.infer<typeof signupFormSchema>>({
|
||||||
resolver: zodResolver(signupFormSchema),
|
resolver: zodResolver(signupFormSchema),
|
||||||
@@ -133,7 +129,7 @@ export function useSignupPage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prefer the URL's next parameter, then result.next (for onboarding), then default
|
// Prefer the URL's next parameter, then result.next (for onboarding), then default
|
||||||
const redirectTo = nextUrl || result.next || homepageRoute;
|
const redirectTo = nextUrl || result.next || "/";
|
||||||
router.replace(redirectTo);
|
router.replace(redirectTo);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
setIsLoading(false);
|
setIsLoading(false);
|
||||||
|
|||||||
@@ -181,6 +181,5 @@ export async function getOnboardingStatus() {
|
|||||||
const isCompleted = onboarding.completedSteps.includes("CONGRATS");
|
const isCompleted = onboarding.completedSteps.includes("CONGRATS");
|
||||||
return {
|
return {
|
||||||
shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted,
|
shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted,
|
||||||
isChatEnabled: status.is_chat_enabled,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5912,40 +5912,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/api/workspace/files/{file_id}/download": {
|
|
||||||
"get": {
|
|
||||||
"tags": ["workspace"],
|
|
||||||
"summary": "Download file by ID",
|
|
||||||
"description": "Download a file by its ID.\n\nReturns the file content directly or redirects to a signed URL for GCS.",
|
|
||||||
"operationId": "getWorkspaceDownload file by id",
|
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "file_id",
|
|
||||||
"in": "path",
|
|
||||||
"required": true,
|
|
||||||
"schema": { "type": "string", "title": "File Id" }
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": { "application/json": { "schema": {} } }
|
|
||||||
},
|
|
||||||
"401": {
|
|
||||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/health": {
|
"/health": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": ["health"],
|
"tags": ["health"],
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
ApiError,
|
ApiError,
|
||||||
getServerAuthToken,
|
|
||||||
makeAuthenticatedFileUpload,
|
makeAuthenticatedFileUpload,
|
||||||
makeAuthenticatedRequest,
|
makeAuthenticatedRequest,
|
||||||
} from "@/lib/autogpt-server-api/helpers";
|
} from "@/lib/autogpt-server-api/helpers";
|
||||||
@@ -16,69 +15,6 @@ function buildBackendUrl(path: string[], queryString: string): string {
|
|||||||
return `${environment.getAGPTServerBaseUrl()}/${backendPath}${queryString}`;
|
return `${environment.getAGPTServerBaseUrl()}/${backendPath}${queryString}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if this is a workspace file download request that needs binary response handling.
|
|
||||||
*/
|
|
||||||
function isWorkspaceDownloadRequest(path: string[]): boolean {
|
|
||||||
// Match pattern: api/workspace/files/{id}/download (5 segments)
|
|
||||||
return (
|
|
||||||
path.length == 5 &&
|
|
||||||
path[0] === "api" &&
|
|
||||||
path[1] === "workspace" &&
|
|
||||||
path[2] === "files" &&
|
|
||||||
path[path.length - 1] === "download"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handle workspace file download requests with proper binary response streaming.
|
|
||||||
*/
|
|
||||||
async function handleWorkspaceDownload(
|
|
||||||
req: NextRequest,
|
|
||||||
backendUrl: string,
|
|
||||||
): Promise<NextResponse> {
|
|
||||||
const token = await getServerAuthToken();
|
|
||||||
|
|
||||||
const headers: Record<string, string> = {};
|
|
||||||
if (token && token !== "no-token-found") {
|
|
||||||
headers["Authorization"] = `Bearer ${token}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
const response = await fetch(backendUrl, {
|
|
||||||
method: "GET",
|
|
||||||
headers,
|
|
||||||
redirect: "follow", // Follow redirects to signed URLs
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
return NextResponse.json(
|
|
||||||
{ error: `Failed to download file: ${response.statusText}` },
|
|
||||||
{ status: response.status },
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the content type from the backend response
|
|
||||||
const contentType =
|
|
||||||
response.headers.get("Content-Type") || "application/octet-stream";
|
|
||||||
const contentDisposition = response.headers.get("Content-Disposition");
|
|
||||||
|
|
||||||
// Stream the response body
|
|
||||||
const responseHeaders: Record<string, string> = {
|
|
||||||
"Content-Type": contentType,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (contentDisposition) {
|
|
||||||
responseHeaders["Content-Disposition"] = contentDisposition;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the binary content
|
|
||||||
const arrayBuffer = await response.arrayBuffer();
|
|
||||||
return new NextResponse(arrayBuffer, {
|
|
||||||
status: 200,
|
|
||||||
headers: responseHeaders,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
async function handleJsonRequest(
|
async function handleJsonRequest(
|
||||||
req: NextRequest,
|
req: NextRequest,
|
||||||
method: string,
|
method: string,
|
||||||
@@ -244,11 +180,6 @@ async function handler(
|
|||||||
};
|
};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Handle workspace file downloads separately (binary response)
|
|
||||||
if (method === "GET" && isWorkspaceDownloadRequest(path)) {
|
|
||||||
return await handleWorkspaceDownload(req, backendUrl);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (method === "GET" || method === "DELETE") {
|
if (method === "GET" || method === "DELETE") {
|
||||||
responseBody = await handleGetDeleteRequest(method, backendUrl, req);
|
responseBody = await handleGetDeleteRequest(method, backendUrl, req);
|
||||||
} else if (contentType?.includes("application/json")) {
|
} else if (contentType?.includes("application/json")) {
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
|
||||||
import { NextRequest, NextResponse } from "next/server";
|
|
||||||
|
|
||||||
const WHISPER_API_URL = "https://api.openai.com/v1/audio/transcriptions";
|
|
||||||
const MAX_FILE_SIZE = 25 * 1024 * 1024; // 25MB - Whisper's limit
|
|
||||||
|
|
||||||
function getExtensionFromMimeType(mimeType: string): string {
|
|
||||||
const subtype = mimeType.split("/")[1]?.split(";")[0];
|
|
||||||
return subtype || "webm";
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function POST(request: NextRequest) {
|
|
||||||
const token = await getServerAuthToken();
|
|
||||||
|
|
||||||
if (!token || token === "no-token-found") {
|
|
||||||
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
|
|
||||||
}
|
|
||||||
|
|
||||||
const apiKey = process.env.OPENAI_API_KEY;
|
|
||||||
|
|
||||||
if (!apiKey) {
|
|
||||||
return NextResponse.json(
|
|
||||||
{ error: "OpenAI API key not configured" },
|
|
||||||
{ status: 401 },
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const formData = await request.formData();
|
|
||||||
const audioFile = formData.get("audio");
|
|
||||||
|
|
||||||
if (!audioFile || !(audioFile instanceof Blob)) {
|
|
||||||
return NextResponse.json(
|
|
||||||
{ error: "No audio file provided" },
|
|
||||||
{ status: 400 },
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (audioFile.size > MAX_FILE_SIZE) {
|
|
||||||
return NextResponse.json(
|
|
||||||
{ error: "File too large. Maximum size is 25MB." },
|
|
||||||
{ status: 413 },
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
const ext = getExtensionFromMimeType(audioFile.type);
|
|
||||||
const whisperFormData = new FormData();
|
|
||||||
whisperFormData.append("file", audioFile, `recording.${ext}`);
|
|
||||||
whisperFormData.append("model", "whisper-1");
|
|
||||||
|
|
||||||
const response = await fetch(WHISPER_API_URL, {
|
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
Authorization: `Bearer ${apiKey}`,
|
|
||||||
},
|
|
||||||
body: whisperFormData,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
const errorData = await response.json().catch(() => ({}));
|
|
||||||
console.error("Whisper API error:", errorData);
|
|
||||||
return NextResponse.json(
|
|
||||||
{ error: errorData.error?.message || "Transcription failed" },
|
|
||||||
{ status: response.status },
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
const result = await response.json();
|
|
||||||
return NextResponse.json({ text: result.text });
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Transcription error:", error);
|
|
||||||
return NextResponse.json(
|
|
||||||
{ error: "Failed to process audio" },
|
|
||||||
{ status: 500 },
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,27 +1,14 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
import { FeatureFlagRedirect } from "@/services/feature-flags/FeatureFlagRedirect";
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
import { Flag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { useRouter } from "next/navigation";
|
|
||||||
import { useEffect } from "react";
|
|
||||||
|
|
||||||
export default function Page() {
|
export default function Page() {
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
return (
|
||||||
const router = useRouter();
|
<FeatureFlagRedirect
|
||||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
flag={Flag.CHAT}
|
||||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
whenEnabled="/copilot"
|
||||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
whenDisabled="/library"
|
||||||
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
/>
|
||||||
const isFlagReady =
|
|
||||||
!isLaunchDarklyConfigured || typeof isChatEnabled === "boolean";
|
|
||||||
|
|
||||||
useEffect(
|
|
||||||
function redirectToHomepage() {
|
|
||||||
if (!isFlagReady) return;
|
|
||||||
router.replace(homepageRoute);
|
|
||||||
},
|
|
||||||
[homepageRoute, isFlagReady, router],
|
|
||||||
);
|
);
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,7 @@
|
|||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import {
|
import { ArrowUpIcon, StopIcon } from "@phosphor-icons/react";
|
||||||
ArrowUpIcon,
|
|
||||||
CircleNotchIcon,
|
|
||||||
MicrophoneIcon,
|
|
||||||
StopIcon,
|
|
||||||
} from "@phosphor-icons/react";
|
|
||||||
import { RecordingIndicator } from "./components/RecordingIndicator";
|
|
||||||
import { useChatInput } from "./useChatInput";
|
import { useChatInput } from "./useChatInput";
|
||||||
import { useVoiceRecording } from "./useVoiceRecording";
|
|
||||||
|
|
||||||
export interface Props {
|
export interface Props {
|
||||||
onSend: (message: string) => void;
|
onSend: (message: string) => void;
|
||||||
@@ -28,36 +21,13 @@ export function ChatInput({
|
|||||||
className,
|
className,
|
||||||
}: Props) {
|
}: Props) {
|
||||||
const inputId = "chat-input";
|
const inputId = "chat-input";
|
||||||
const {
|
const { value, handleKeyDown, handleSubmit, handleChange, hasMultipleLines } =
|
||||||
value,
|
useChatInput({
|
||||||
setValue,
|
onSend,
|
||||||
handleKeyDown: baseHandleKeyDown,
|
disabled: disabled || isStreaming,
|
||||||
handleSubmit,
|
maxRows: 4,
|
||||||
handleChange,
|
inputId,
|
||||||
hasMultipleLines,
|
});
|
||||||
} = useChatInput({
|
|
||||||
onSend,
|
|
||||||
disabled: disabled || isStreaming,
|
|
||||||
maxRows: 4,
|
|
||||||
inputId,
|
|
||||||
});
|
|
||||||
|
|
||||||
const {
|
|
||||||
isRecording,
|
|
||||||
isTranscribing,
|
|
||||||
elapsedTime,
|
|
||||||
toggleRecording,
|
|
||||||
handleKeyDown,
|
|
||||||
showMicButton,
|
|
||||||
isInputDisabled,
|
|
||||||
audioStream,
|
|
||||||
} = useVoiceRecording({
|
|
||||||
setValue,
|
|
||||||
disabled: disabled || isStreaming,
|
|
||||||
isStreaming,
|
|
||||||
value,
|
|
||||||
baseHandleKeyDown,
|
|
||||||
});
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<form onSubmit={handleSubmit} className={cn("relative flex-1", className)}>
|
<form onSubmit={handleSubmit} className={cn("relative flex-1", className)}>
|
||||||
@@ -65,11 +35,8 @@ export function ChatInput({
|
|||||||
<div
|
<div
|
||||||
id={`${inputId}-wrapper`}
|
id={`${inputId}-wrapper`}
|
||||||
className={cn(
|
className={cn(
|
||||||
"relative overflow-hidden border bg-white shadow-sm",
|
"relative overflow-hidden border border-neutral-200 bg-white shadow-sm",
|
||||||
"focus-within:ring-1",
|
"focus-within:border-zinc-400 focus-within:ring-1 focus-within:ring-zinc-400",
|
||||||
isRecording
|
|
||||||
? "border-red-400 focus-within:border-red-400 focus-within:ring-red-400"
|
|
||||||
: "border-neutral-200 focus-within:border-zinc-400 focus-within:ring-zinc-400",
|
|
||||||
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
|
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
@@ -79,94 +46,48 @@ export function ChatInput({
|
|||||||
value={value}
|
value={value}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
onKeyDown={handleKeyDown}
|
onKeyDown={handleKeyDown}
|
||||||
placeholder={
|
placeholder={placeholder}
|
||||||
isTranscribing
|
disabled={disabled || isStreaming}
|
||||||
? "Transcribing..."
|
|
||||||
: isRecording
|
|
||||||
? ""
|
|
||||||
: placeholder
|
|
||||||
}
|
|
||||||
disabled={isInputDisabled}
|
|
||||||
rows={1}
|
rows={1}
|
||||||
className={cn(
|
className={cn(
|
||||||
"w-full resize-none overflow-y-auto border-0 bg-transparent text-[1rem] leading-6 text-black",
|
"w-full resize-none overflow-y-auto border-0 bg-transparent text-[1rem] leading-6 text-black",
|
||||||
"placeholder:text-zinc-400",
|
"placeholder:text-zinc-400",
|
||||||
"focus:outline-none focus:ring-0",
|
"focus:outline-none focus:ring-0",
|
||||||
"disabled:text-zinc-500",
|
"disabled:text-zinc-500",
|
||||||
hasMultipleLines
|
hasMultipleLines ? "pb-6 pl-4 pr-4 pt-2" : "pb-4 pl-4 pr-14 pt-4",
|
||||||
? "pb-6 pl-4 pr-4 pt-2"
|
|
||||||
: showMicButton
|
|
||||||
? "pb-4 pl-14 pr-14 pt-4"
|
|
||||||
: "pb-4 pl-4 pr-14 pt-4",
|
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
{isRecording && !value && (
|
|
||||||
<div className="pointer-events-none absolute inset-0 flex items-center justify-center">
|
|
||||||
<RecordingIndicator
|
|
||||||
elapsedTime={elapsedTime}
|
|
||||||
audioStream={audioStream}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
<span id="chat-input-hint" className="sr-only">
|
<span id="chat-input-hint" className="sr-only">
|
||||||
Press Enter to send, Shift+Enter for new line, Space to record voice
|
Press Enter to send, Shift+Enter for new line
|
||||||
</span>
|
</span>
|
||||||
|
|
||||||
{showMicButton && (
|
{isStreaming ? (
|
||||||
<div className="absolute bottom-[7px] left-2 flex items-center gap-1">
|
<Button
|
||||||
<Button
|
type="button"
|
||||||
type="button"
|
variant="icon"
|
||||||
variant="icon"
|
size="icon"
|
||||||
size="icon"
|
aria-label="Stop generating"
|
||||||
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
onClick={onStop}
|
||||||
onClick={toggleRecording}
|
className="absolute bottom-[7px] right-2 border-red-600 bg-red-600 text-white hover:border-red-800 hover:bg-red-800"
|
||||||
disabled={disabled || isTranscribing}
|
>
|
||||||
className={cn(
|
<StopIcon className="h-4 w-4" weight="bold" />
|
||||||
isRecording
|
</Button>
|
||||||
? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600"
|
) : (
|
||||||
: isTranscribing
|
<Button
|
||||||
? "border-zinc-300 bg-zinc-100 text-zinc-400"
|
type="submit"
|
||||||
: "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
|
variant="icon"
|
||||||
)}
|
size="icon"
|
||||||
>
|
aria-label="Send message"
|
||||||
{isTranscribing ? (
|
className={cn(
|
||||||
<CircleNotchIcon className="h-4 w-4 animate-spin" />
|
"absolute bottom-[7px] right-2 border-zinc-800 bg-zinc-800 text-white hover:border-zinc-900 hover:bg-zinc-900",
|
||||||
) : (
|
(disabled || !value.trim()) && "opacity-20",
|
||||||
<MicrophoneIcon className="h-4 w-4" weight="bold" />
|
)}
|
||||||
)}
|
disabled={disabled || !value.trim()}
|
||||||
</Button>
|
>
|
||||||
</div>
|
<ArrowUpIcon className="h-4 w-4" weight="bold" />
|
||||||
|
</Button>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<div className="absolute bottom-[7px] right-2 flex items-center gap-1">
|
|
||||||
{isStreaming ? (
|
|
||||||
<Button
|
|
||||||
type="button"
|
|
||||||
variant="icon"
|
|
||||||
size="icon"
|
|
||||||
aria-label="Stop generating"
|
|
||||||
onClick={onStop}
|
|
||||||
className="border-red-600 bg-red-600 text-white hover:border-red-800 hover:bg-red-800"
|
|
||||||
>
|
|
||||||
<StopIcon className="h-4 w-4" weight="bold" />
|
|
||||||
</Button>
|
|
||||||
) : (
|
|
||||||
<Button
|
|
||||||
type="submit"
|
|
||||||
variant="icon"
|
|
||||||
size="icon"
|
|
||||||
aria-label="Send message"
|
|
||||||
className={cn(
|
|
||||||
"border-zinc-800 bg-zinc-800 text-white hover:border-zinc-900 hover:bg-zinc-900",
|
|
||||||
(disabled || !value.trim() || isRecording) && "opacity-20",
|
|
||||||
)}
|
|
||||||
disabled={disabled || !value.trim() || isRecording}
|
|
||||||
>
|
|
||||||
<ArrowUpIcon className="h-4 w-4" weight="bold" />
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
</form>
|
</form>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,142 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { useEffect, useRef, useState } from "react";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
stream: MediaStream | null;
|
|
||||||
barCount?: number;
|
|
||||||
barWidth?: number;
|
|
||||||
barGap?: number;
|
|
||||||
barColor?: string;
|
|
||||||
minBarHeight?: number;
|
|
||||||
maxBarHeight?: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function AudioWaveform({
|
|
||||||
stream,
|
|
||||||
barCount = 24,
|
|
||||||
barWidth = 3,
|
|
||||||
barGap = 2,
|
|
||||||
barColor = "#ef4444", // red-500
|
|
||||||
minBarHeight = 4,
|
|
||||||
maxBarHeight = 32,
|
|
||||||
}: Props) {
|
|
||||||
const [bars, setBars] = useState<number[]>(() =>
|
|
||||||
Array(barCount).fill(minBarHeight),
|
|
||||||
);
|
|
||||||
const analyserRef = useRef<AnalyserNode | null>(null);
|
|
||||||
const audioContextRef = useRef<AudioContext | null>(null);
|
|
||||||
const sourceRef = useRef<MediaStreamAudioSourceNode | null>(null);
|
|
||||||
const animationRef = useRef<number | null>(null);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (!stream) {
|
|
||||||
setBars(Array(barCount).fill(minBarHeight));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create audio context and analyser
|
|
||||||
const audioContext = new AudioContext();
|
|
||||||
const analyser = audioContext.createAnalyser();
|
|
||||||
analyser.fftSize = 512;
|
|
||||||
analyser.smoothingTimeConstant = 0.8;
|
|
||||||
|
|
||||||
// Connect the stream to the analyser
|
|
||||||
const source = audioContext.createMediaStreamSource(stream);
|
|
||||||
source.connect(analyser);
|
|
||||||
|
|
||||||
audioContextRef.current = audioContext;
|
|
||||||
analyserRef.current = analyser;
|
|
||||||
sourceRef.current = source;
|
|
||||||
|
|
||||||
const timeData = new Uint8Array(analyser.frequencyBinCount);
|
|
||||||
|
|
||||||
const updateBars = () => {
|
|
||||||
if (!analyserRef.current) return;
|
|
||||||
|
|
||||||
analyserRef.current.getByteTimeDomainData(timeData);
|
|
||||||
|
|
||||||
// Distribute time-domain data across bars
|
|
||||||
// This shows waveform amplitude, making all bars respond to audio
|
|
||||||
const newBars: number[] = [];
|
|
||||||
const samplesPerBar = timeData.length / barCount;
|
|
||||||
|
|
||||||
for (let i = 0; i < barCount; i++) {
|
|
||||||
// Sample waveform data for this bar
|
|
||||||
let maxAmplitude = 0;
|
|
||||||
const startIdx = Math.floor(i * samplesPerBar);
|
|
||||||
const endIdx = Math.floor((i + 1) * samplesPerBar);
|
|
||||||
|
|
||||||
for (let j = startIdx; j < endIdx && j < timeData.length; j++) {
|
|
||||||
// Convert to amplitude (distance from center 128)
|
|
||||||
const amplitude = Math.abs(timeData[j] - 128);
|
|
||||||
maxAmplitude = Math.max(maxAmplitude, amplitude);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Map amplitude (0-128) to bar height
|
|
||||||
const normalized = (maxAmplitude / 128) * 255;
|
|
||||||
const height =
|
|
||||||
minBarHeight + (normalized / 255) * (maxBarHeight - minBarHeight);
|
|
||||||
newBars.push(height);
|
|
||||||
}
|
|
||||||
|
|
||||||
setBars(newBars);
|
|
||||||
animationRef.current = requestAnimationFrame(updateBars);
|
|
||||||
};
|
|
||||||
|
|
||||||
updateBars();
|
|
||||||
|
|
||||||
return () => {
|
|
||||||
if (animationRef.current) {
|
|
||||||
cancelAnimationFrame(animationRef.current);
|
|
||||||
}
|
|
||||||
if (sourceRef.current) {
|
|
||||||
sourceRef.current.disconnect();
|
|
||||||
}
|
|
||||||
if (audioContextRef.current) {
|
|
||||||
audioContextRef.current.close();
|
|
||||||
}
|
|
||||||
analyserRef.current = null;
|
|
||||||
audioContextRef.current = null;
|
|
||||||
sourceRef.current = null;
|
|
||||||
};
|
|
||||||
}, [stream, barCount, minBarHeight, maxBarHeight]);
|
|
||||||
|
|
||||||
const totalWidth = barCount * barWidth + (barCount - 1) * barGap;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
className="flex items-center justify-center"
|
|
||||||
style={{
|
|
||||||
width: totalWidth,
|
|
||||||
height: maxBarHeight,
|
|
||||||
gap: barGap,
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{bars.map((height, i) => {
|
|
||||||
const barHeight = Math.max(minBarHeight, height);
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
key={i}
|
|
||||||
className="relative"
|
|
||||||
style={{
|
|
||||||
width: barWidth,
|
|
||||||
height: maxBarHeight,
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<div
|
|
||||||
className="absolute left-0 rounded-full transition-[height] duration-75"
|
|
||||||
style={{
|
|
||||||
width: barWidth,
|
|
||||||
height: barHeight,
|
|
||||||
top: "50%",
|
|
||||||
transform: "translateY(-50%)",
|
|
||||||
backgroundColor: barColor,
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
})}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
import { formatElapsedTime } from "../helpers";
|
|
||||||
import { AudioWaveform } from "./AudioWaveform";
|
|
||||||
|
|
||||||
type Props = {
|
|
||||||
elapsedTime: number;
|
|
||||||
audioStream: MediaStream | null;
|
|
||||||
};
|
|
||||||
|
|
||||||
export function RecordingIndicator({ elapsedTime, audioStream }: Props) {
|
|
||||||
return (
|
|
||||||
<div className="flex items-center gap-3">
|
|
||||||
<AudioWaveform
|
|
||||||
stream={audioStream}
|
|
||||||
barCount={20}
|
|
||||||
barWidth={3}
|
|
||||||
barGap={2}
|
|
||||||
barColor="#ef4444"
|
|
||||||
minBarHeight={4}
|
|
||||||
maxBarHeight={24}
|
|
||||||
/>
|
|
||||||
<span className="min-w-[3ch] text-sm font-medium text-red-500">
|
|
||||||
{formatElapsedTime(elapsedTime)}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
export function formatElapsedTime(ms: number): string {
|
|
||||||
const seconds = Math.floor(ms / 1000);
|
|
||||||
const minutes = Math.floor(seconds / 60);
|
|
||||||
const remainingSeconds = seconds % 60;
|
|
||||||
return `${minutes}:${remainingSeconds.toString().padStart(2, "0")}`;
|
|
||||||
}
|
|
||||||
@@ -6,7 +6,7 @@ import {
|
|||||||
useState,
|
useState,
|
||||||
} from "react";
|
} from "react";
|
||||||
|
|
||||||
interface Args {
|
interface UseChatInputArgs {
|
||||||
onSend: (message: string) => void;
|
onSend: (message: string) => void;
|
||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
maxRows?: number;
|
maxRows?: number;
|
||||||
@@ -18,7 +18,7 @@ export function useChatInput({
|
|||||||
disabled = false,
|
disabled = false,
|
||||||
maxRows = 5,
|
maxRows = 5,
|
||||||
inputId = "chat-input",
|
inputId = "chat-input",
|
||||||
}: Args) {
|
}: UseChatInputArgs) {
|
||||||
const [value, setValue] = useState("");
|
const [value, setValue] = useState("");
|
||||||
const [hasMultipleLines, setHasMultipleLines] = useState(false);
|
const [hasMultipleLines, setHasMultipleLines] = useState(false);
|
||||||
|
|
||||||
|
|||||||
@@ -1,240 +0,0 @@
|
|||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
|
||||||
import React, {
|
|
||||||
KeyboardEvent,
|
|
||||||
useCallback,
|
|
||||||
useEffect,
|
|
||||||
useRef,
|
|
||||||
useState,
|
|
||||||
} from "react";
|
|
||||||
|
|
||||||
const MAX_RECORDING_DURATION = 2 * 60 * 1000; // 2 minutes in ms
|
|
||||||
|
|
||||||
interface Args {
|
|
||||||
setValue: React.Dispatch<React.SetStateAction<string>>;
|
|
||||||
disabled?: boolean;
|
|
||||||
isStreaming?: boolean;
|
|
||||||
value: string;
|
|
||||||
baseHandleKeyDown: (event: KeyboardEvent<HTMLTextAreaElement>) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useVoiceRecording({
|
|
||||||
setValue,
|
|
||||||
disabled = false,
|
|
||||||
isStreaming = false,
|
|
||||||
value,
|
|
||||||
baseHandleKeyDown,
|
|
||||||
}: Args) {
|
|
||||||
const [isRecording, setIsRecording] = useState(false);
|
|
||||||
const [isTranscribing, setIsTranscribing] = useState(false);
|
|
||||||
const [error, setError] = useState<string | null>(null);
|
|
||||||
const [elapsedTime, setElapsedTime] = useState(0);
|
|
||||||
|
|
||||||
const mediaRecorderRef = useRef<MediaRecorder | null>(null);
|
|
||||||
const chunksRef = useRef<Blob[]>([]);
|
|
||||||
const timerRef = useRef<NodeJS.Timeout | null>(null);
|
|
||||||
const startTimeRef = useRef<number>(0);
|
|
||||||
const streamRef = useRef<MediaStream | null>(null);
|
|
||||||
const isRecordingRef = useRef(false);
|
|
||||||
|
|
||||||
const isSupported =
|
|
||||||
typeof window !== "undefined" &&
|
|
||||||
!!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
|
|
||||||
|
|
||||||
const clearTimer = useCallback(() => {
|
|
||||||
if (timerRef.current) {
|
|
||||||
clearInterval(timerRef.current);
|
|
||||||
timerRef.current = null;
|
|
||||||
}
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const cleanup = useCallback(() => {
|
|
||||||
clearTimer();
|
|
||||||
if (streamRef.current) {
|
|
||||||
streamRef.current.getTracks().forEach((track) => track.stop());
|
|
||||||
streamRef.current = null;
|
|
||||||
}
|
|
||||||
mediaRecorderRef.current = null;
|
|
||||||
chunksRef.current = [];
|
|
||||||
setElapsedTime(0);
|
|
||||||
}, [clearTimer]);
|
|
||||||
|
|
||||||
const handleTranscription = useCallback(
|
|
||||||
(text: string) => {
|
|
||||||
setValue((prev) => {
|
|
||||||
const trimmedPrev = prev.trim();
|
|
||||||
if (trimmedPrev) {
|
|
||||||
return `${trimmedPrev} ${text}`;
|
|
||||||
}
|
|
||||||
return text;
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[setValue],
|
|
||||||
);
|
|
||||||
|
|
||||||
const transcribeAudio = useCallback(
|
|
||||||
async (audioBlob: Blob) => {
|
|
||||||
setIsTranscribing(true);
|
|
||||||
setError(null);
|
|
||||||
|
|
||||||
try {
|
|
||||||
const formData = new FormData();
|
|
||||||
formData.append("audio", audioBlob);
|
|
||||||
|
|
||||||
const response = await fetch("/api/transcribe", {
|
|
||||||
method: "POST",
|
|
||||||
body: formData,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
const data = await response.json().catch(() => ({}));
|
|
||||||
throw new Error(data.error || "Transcription failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
const data = await response.json();
|
|
||||||
if (data.text) {
|
|
||||||
handleTranscription(data.text);
|
|
||||||
}
|
|
||||||
} catch (err) {
|
|
||||||
const message =
|
|
||||||
err instanceof Error ? err.message : "Transcription failed";
|
|
||||||
setError(message);
|
|
||||||
console.error("Transcription error:", err);
|
|
||||||
} finally {
|
|
||||||
setIsTranscribing(false);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[handleTranscription],
|
|
||||||
);
|
|
||||||
|
|
||||||
const stopRecording = useCallback(() => {
|
|
||||||
if (mediaRecorderRef.current && isRecordingRef.current) {
|
|
||||||
mediaRecorderRef.current.stop();
|
|
||||||
isRecordingRef.current = false;
|
|
||||||
setIsRecording(false);
|
|
||||||
clearTimer();
|
|
||||||
}
|
|
||||||
}, [clearTimer]);
|
|
||||||
|
|
||||||
const startRecording = useCallback(async () => {
|
|
||||||
if (disabled || isRecordingRef.current || isTranscribing) return;
|
|
||||||
|
|
||||||
setError(null);
|
|
||||||
chunksRef.current = [];
|
|
||||||
|
|
||||||
try {
|
|
||||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
|
||||||
streamRef.current = stream;
|
|
||||||
|
|
||||||
const mediaRecorder = new MediaRecorder(stream, {
|
|
||||||
mimeType: MediaRecorder.isTypeSupported("audio/webm")
|
|
||||||
? "audio/webm"
|
|
||||||
: "audio/mp4",
|
|
||||||
});
|
|
||||||
|
|
||||||
mediaRecorderRef.current = mediaRecorder;
|
|
||||||
|
|
||||||
mediaRecorder.ondataavailable = (event) => {
|
|
||||||
if (event.data.size > 0) {
|
|
||||||
chunksRef.current.push(event.data);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
mediaRecorder.onstop = async () => {
|
|
||||||
const audioBlob = new Blob(chunksRef.current, {
|
|
||||||
type: mediaRecorder.mimeType,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Cleanup stream
|
|
||||||
if (streamRef.current) {
|
|
||||||
streamRef.current.getTracks().forEach((track) => track.stop());
|
|
||||||
streamRef.current = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (audioBlob.size > 0) {
|
|
||||||
await transcribeAudio(audioBlob);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
mediaRecorder.start(1000); // Collect data every second
|
|
||||||
isRecordingRef.current = true;
|
|
||||||
setIsRecording(true);
|
|
||||||
startTimeRef.current = Date.now();
|
|
||||||
|
|
||||||
// Start elapsed time timer
|
|
||||||
timerRef.current = setInterval(() => {
|
|
||||||
const elapsed = Date.now() - startTimeRef.current;
|
|
||||||
setElapsedTime(elapsed);
|
|
||||||
|
|
||||||
// Auto-stop at max duration
|
|
||||||
if (elapsed >= MAX_RECORDING_DURATION) {
|
|
||||||
stopRecording();
|
|
||||||
}
|
|
||||||
}, 100);
|
|
||||||
} catch (err) {
|
|
||||||
console.error("Failed to start recording:", err);
|
|
||||||
if (err instanceof DOMException && err.name === "NotAllowedError") {
|
|
||||||
setError("Microphone permission denied");
|
|
||||||
} else {
|
|
||||||
setError("Failed to access microphone");
|
|
||||||
}
|
|
||||||
cleanup();
|
|
||||||
}
|
|
||||||
}, [disabled, isTranscribing, stopRecording, transcribeAudio, cleanup]);
|
|
||||||
|
|
||||||
const toggleRecording = useCallback(() => {
|
|
||||||
if (isRecording) {
|
|
||||||
stopRecording();
|
|
||||||
} else {
|
|
||||||
startRecording();
|
|
||||||
}
|
|
||||||
}, [isRecording, startRecording, stopRecording]);
|
|
||||||
|
|
||||||
const { toast } = useToast();
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (error) {
|
|
||||||
toast({
|
|
||||||
title: "Voice recording failed",
|
|
||||||
description: error,
|
|
||||||
variant: "destructive",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [error, toast]);
|
|
||||||
|
|
||||||
const handleKeyDown = useCallback(
|
|
||||||
(event: KeyboardEvent<HTMLTextAreaElement>) => {
|
|
||||||
if (event.key === " " && !value.trim() && !isTranscribing) {
|
|
||||||
event.preventDefault();
|
|
||||||
toggleRecording();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
baseHandleKeyDown(event);
|
|
||||||
},
|
|
||||||
[value, isTranscribing, toggleRecording, baseHandleKeyDown],
|
|
||||||
);
|
|
||||||
|
|
||||||
const showMicButton = isSupported && !isStreaming;
|
|
||||||
const isInputDisabled = disabled || isStreaming || isTranscribing;
|
|
||||||
|
|
||||||
// Cleanup on unmount
|
|
||||||
useEffect(() => {
|
|
||||||
return () => {
|
|
||||||
cleanup();
|
|
||||||
};
|
|
||||||
}, [cleanup]);
|
|
||||||
|
|
||||||
return {
|
|
||||||
isRecording,
|
|
||||||
isTranscribing,
|
|
||||||
error,
|
|
||||||
elapsedTime,
|
|
||||||
startRecording,
|
|
||||||
stopRecording,
|
|
||||||
toggleRecording,
|
|
||||||
isSupported,
|
|
||||||
handleKeyDown,
|
|
||||||
showMicButton,
|
|
||||||
isInputDisabled,
|
|
||||||
audioStream: streamRef.current,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { EyeSlash } from "@phosphor-icons/react";
|
|
||||||
import React from "react";
|
import React from "react";
|
||||||
import ReactMarkdown from "react-markdown";
|
import ReactMarkdown from "react-markdown";
|
||||||
import remarkGfm from "remark-gfm";
|
import remarkGfm from "remark-gfm";
|
||||||
@@ -31,88 +29,12 @@ interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {
|
|||||||
type?: string;
|
type?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Converts a workspace:// URL to a proxy URL that routes through Next.js to the backend.
|
|
||||||
* workspace://abc123 -> /api/proxy/api/workspace/files/abc123/download
|
|
||||||
*
|
|
||||||
* Uses the generated API URL helper and routes through the Next.js proxy
|
|
||||||
* which handles authentication and proper backend routing.
|
|
||||||
*/
|
|
||||||
/**
|
|
||||||
* URL transformer for ReactMarkdown.
|
|
||||||
* Converts workspace:// URLs to proxy URLs that route through Next.js to the backend.
|
|
||||||
* workspace://abc123 -> /api/proxy/api/workspace/files/abc123/download
|
|
||||||
*
|
|
||||||
* This is needed because ReactMarkdown sanitizes URLs and only allows
|
|
||||||
* http, https, mailto, and tel protocols by default.
|
|
||||||
*/
|
|
||||||
function resolveWorkspaceUrl(src: string): string {
|
|
||||||
if (src.startsWith("workspace://")) {
|
|
||||||
const fileId = src.replace("workspace://", "");
|
|
||||||
// Use the generated API URL helper to get the correct path
|
|
||||||
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
|
|
||||||
// Route through the Next.js proxy (same pattern as customMutator for client-side)
|
|
||||||
return `/api/proxy${apiPath}`;
|
|
||||||
}
|
|
||||||
return src;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if the image URL is a workspace file (AI cannot see these yet).
|
|
||||||
* After URL transformation, workspace files have URLs like /api/proxy/api/workspace/files/...
|
|
||||||
*/
|
|
||||||
function isWorkspaceImage(src: string | undefined): boolean {
|
|
||||||
return src?.includes("/workspace/files/") ?? false;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Custom image component that shows an indicator when the AI cannot see the image.
|
|
||||||
* Note: src is already transformed by urlTransform, so workspace:// is now /api/workspace/...
|
|
||||||
*/
|
|
||||||
function MarkdownImage(props: Record<string, unknown>) {
|
|
||||||
const src = props.src as string | undefined;
|
|
||||||
const alt = props.alt as string | undefined;
|
|
||||||
|
|
||||||
const aiCannotSee = isWorkspaceImage(src);
|
|
||||||
|
|
||||||
// If no src, show a placeholder
|
|
||||||
if (!src) {
|
|
||||||
return (
|
|
||||||
<span className="my-2 inline-block rounded border border-amber-200 bg-amber-50 px-2 py-1 text-sm text-amber-700">
|
|
||||||
[Image: {alt || "missing src"}]
|
|
||||||
</span>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<span className="relative my-2 inline-block">
|
|
||||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
|
||||||
<img
|
|
||||||
src={src}
|
|
||||||
alt={alt || "Image"}
|
|
||||||
className="h-auto max-w-full rounded-md border border-zinc-200"
|
|
||||||
loading="lazy"
|
|
||||||
/>
|
|
||||||
{aiCannotSee && (
|
|
||||||
<span
|
|
||||||
className="absolute bottom-2 right-2 flex items-center gap-1 rounded bg-black/70 px-2 py-1 text-xs text-white"
|
|
||||||
title="The AI cannot see this image"
|
|
||||||
>
|
|
||||||
<EyeSlash size={14} />
|
|
||||||
<span>AI cannot see this image</span>
|
|
||||||
</span>
|
|
||||||
)}
|
|
||||||
</span>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
||||||
return (
|
return (
|
||||||
<div className={cn("markdown-content", className)}>
|
<div className={cn("markdown-content", className)}>
|
||||||
<ReactMarkdown
|
<ReactMarkdown
|
||||||
skipHtml={true}
|
skipHtml={true}
|
||||||
remarkPlugins={[remarkGfm]}
|
remarkPlugins={[remarkGfm]}
|
||||||
urlTransform={resolveWorkspaceUrl}
|
|
||||||
components={{
|
components={{
|
||||||
code: ({ children, className, ...props }: CodeProps) => {
|
code: ({ children, className, ...props }: CodeProps) => {
|
||||||
const isInline = !className?.includes("language-");
|
const isInline = !className?.includes("language-");
|
||||||
@@ -284,9 +206,6 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
|||||||
{children}
|
{children}
|
||||||
</td>
|
</td>
|
||||||
),
|
),
|
||||||
img: ({ src, alt, ...props }) => (
|
|
||||||
<MarkdownImage src={src} alt={alt} {...props} />
|
|
||||||
),
|
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{content}
|
{content}
|
||||||
|
|||||||
@@ -37,87 +37,6 @@ export function getErrorMessage(result: unknown): string {
|
|||||||
return "An error occurred";
|
return "An error occurred";
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a value is a workspace file reference.
|
|
||||||
*/
|
|
||||||
function isWorkspaceRef(value: unknown): value is string {
|
|
||||||
return typeof value === "string" && value.startsWith("workspace://");
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a workspace reference appears to be an image based on common patterns.
|
|
||||||
* Since workspace refs don't have extensions, we check the context or assume image
|
|
||||||
* for certain block types.
|
|
||||||
*
|
|
||||||
* TODO: Replace keyword matching with MIME type encoded in workspace ref.
|
|
||||||
* e.g., workspace://abc123#image/png or workspace://abc123#video/mp4
|
|
||||||
* This would let frontend render correctly without fragile keyword matching.
|
|
||||||
*/
|
|
||||||
function isLikelyImageRef(value: string, outputKey?: string): boolean {
|
|
||||||
if (!isWorkspaceRef(value)) return false;
|
|
||||||
|
|
||||||
// Check output key name for video-related hints (these are NOT images)
|
|
||||||
const videoKeywords = ["video", "mp4", "mov", "avi", "webm", "movie", "clip"];
|
|
||||||
if (outputKey) {
|
|
||||||
const lowerKey = outputKey.toLowerCase();
|
|
||||||
if (videoKeywords.some((kw) => lowerKey.includes(kw))) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check output key name for image-related hints
|
|
||||||
const imageKeywords = [
|
|
||||||
"image",
|
|
||||||
"img",
|
|
||||||
"photo",
|
|
||||||
"picture",
|
|
||||||
"thumbnail",
|
|
||||||
"avatar",
|
|
||||||
"icon",
|
|
||||||
"screenshot",
|
|
||||||
];
|
|
||||||
if (outputKey) {
|
|
||||||
const lowerKey = outputKey.toLowerCase();
|
|
||||||
if (imageKeywords.some((kw) => lowerKey.includes(kw))) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default to treating workspace refs as potential images
|
|
||||||
// since that's the most common case for generated content
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Format a single output value, converting workspace refs to markdown images.
|
|
||||||
*/
|
|
||||||
function formatOutputValue(value: unknown, outputKey?: string): string {
|
|
||||||
if (isWorkspaceRef(value) && isLikelyImageRef(value, outputKey)) {
|
|
||||||
// Format as markdown image
|
|
||||||
return ``;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (typeof value === "string") {
|
|
||||||
// Check for data URIs (images)
|
|
||||||
if (value.startsWith("data:image/")) {
|
|
||||||
return ``;
|
|
||||||
}
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Array.isArray(value)) {
|
|
||||||
return value
|
|
||||||
.map((item, idx) => formatOutputValue(item, `${outputKey}_${idx}`))
|
|
||||||
.join("\n\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (typeof value === "object" && value !== null) {
|
|
||||||
return JSON.stringify(value, null, 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
return String(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
function getToolCompletionPhrase(toolName: string): string {
|
function getToolCompletionPhrase(toolName: string): string {
|
||||||
const toolCompletionPhrases: Record<string, string> = {
|
const toolCompletionPhrases: Record<string, string> = {
|
||||||
add_understanding: "Updated your business information",
|
add_understanding: "Updated your business information",
|
||||||
@@ -208,26 +127,10 @@ export function formatToolResponse(result: unknown, toolName: string): string {
|
|||||||
|
|
||||||
case "block_output":
|
case "block_output":
|
||||||
const blockName = (response.block_name as string) || "Block";
|
const blockName = (response.block_name as string) || "Block";
|
||||||
const outputs = response.outputs as Record<string, unknown[]> | undefined;
|
const outputs = response.outputs as Record<string, unknown> | undefined;
|
||||||
if (outputs && Object.keys(outputs).length > 0) {
|
if (outputs && Object.keys(outputs).length > 0) {
|
||||||
const formattedOutputs: string[] = [];
|
const outputKeys = Object.keys(outputs);
|
||||||
|
return `${blockName} executed successfully. Outputs: ${outputKeys.join(", ")}`;
|
||||||
for (const [key, values] of Object.entries(outputs)) {
|
|
||||||
if (!Array.isArray(values) || values.length === 0) continue;
|
|
||||||
|
|
||||||
// Format each value in the output array
|
|
||||||
for (const value of values) {
|
|
||||||
const formatted = formatOutputValue(value, key);
|
|
||||||
if (formatted) {
|
|
||||||
formattedOutputs.push(formatted);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (formattedOutputs.length > 0) {
|
|
||||||
return `${blockName} executed successfully.\n\n${formattedOutputs.join("\n\n")}`;
|
|
||||||
}
|
|
||||||
return `${blockName} executed successfully.`;
|
|
||||||
}
|
}
|
||||||
return `${blockName} executed successfully.`;
|
return `${blockName} executed successfully.`;
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { IconLaptop } from "@/components/__legacy__/ui/icons";
|
import { IconLaptop } from "@/components/__legacy__/ui/icons";
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { ListChecksIcon } from "@phosphor-icons/react/dist/ssr";
|
import { ListChecksIcon } from "@phosphor-icons/react/dist/ssr";
|
||||||
@@ -24,11 +23,11 @@ interface Props {
|
|||||||
export function NavbarLink({ name, href }: Props) {
|
export function NavbarLink({ name, href }: Props) {
|
||||||
const pathname = usePathname();
|
const pathname = usePathname();
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
const expectedHomeRoute = isChatEnabled ? "/copilot" : "/library";
|
||||||
|
|
||||||
const isActive =
|
const isActive =
|
||||||
href === homepageRoute
|
href === expectedHomeRoute
|
||||||
? pathname === "/" || pathname.startsWith(homepageRoute)
|
? pathname === "/" || pathname.startsWith(expectedHomeRoute)
|
||||||
: pathname.includes(href);
|
: pathname.includes(href);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ export default function useAgentGraph(
|
|||||||
>(null);
|
>(null);
|
||||||
const [xyNodes, setXYNodes] = useState<CustomNode[]>([]);
|
const [xyNodes, setXYNodes] = useState<CustomNode[]>([]);
|
||||||
const [xyEdges, setXYEdges] = useState<CustomEdge[]>([]);
|
const [xyEdges, setXYEdges] = useState<CustomEdge[]>([]);
|
||||||
const betaBlocks = useGetFlag(Flag.BETA_BLOCKS);
|
const betaBlocks = useGetFlag(Flag.BETA_BLOCKS) as string[];
|
||||||
|
|
||||||
// Filter blocks based on beta flags
|
// Filter blocks based on beta flags
|
||||||
const availableBlocks = useMemo(() => {
|
const availableBlocks = useMemo(() => {
|
||||||
|
|||||||
@@ -516,7 +516,7 @@ export type GraphValidationErrorResponse = {
|
|||||||
|
|
||||||
/* *** LIBRARY *** */
|
/* *** LIBRARY *** */
|
||||||
|
|
||||||
/* Mirror of backend/api/features/library/model.py:LibraryAgent */
|
/* Mirror of backend/server/v2/library/model.py:LibraryAgent */
|
||||||
export type LibraryAgent = {
|
export type LibraryAgent = {
|
||||||
id: LibraryAgentID;
|
id: LibraryAgentID;
|
||||||
graph_id: GraphID;
|
graph_id: GraphID;
|
||||||
@@ -616,7 +616,7 @@ export enum LibraryAgentSortEnum {
|
|||||||
|
|
||||||
/* *** CREDENTIALS *** */
|
/* *** CREDENTIALS *** */
|
||||||
|
|
||||||
/* Mirror of backend/api/features/integrations/router.py:CredentialsMetaResponse */
|
/* Mirror of backend/server/integrations/router.py:CredentialsMetaResponse */
|
||||||
export type CredentialsMetaResponse = {
|
export type CredentialsMetaResponse = {
|
||||||
id: string;
|
id: string;
|
||||||
provider: CredentialsProviderName;
|
provider: CredentialsProviderName;
|
||||||
@@ -628,13 +628,13 @@ export type CredentialsMetaResponse = {
|
|||||||
is_system?: boolean;
|
is_system?: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Mirror of backend/api/features/integrations/router.py:CredentialsDeletionResponse */
|
/* Mirror of backend/server/integrations/router.py:CredentialsDeletionResponse */
|
||||||
export type CredentialsDeleteResponse = {
|
export type CredentialsDeleteResponse = {
|
||||||
deleted: true;
|
deleted: true;
|
||||||
revoked: boolean | null;
|
revoked: boolean | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Mirror of backend/api/features/integrations/router.py:CredentialsDeletionNeedsConfirmationResponse */
|
/* Mirror of backend/server/integrations/router.py:CredentialsDeletionNeedsConfirmationResponse */
|
||||||
export type CredentialsDeleteNeedConfirmationResponse = {
|
export type CredentialsDeleteNeedConfirmationResponse = {
|
||||||
deleted: false;
|
deleted: false;
|
||||||
need_confirmation: true;
|
need_confirmation: true;
|
||||||
@@ -888,7 +888,7 @@ export type Schedule = {
|
|||||||
|
|
||||||
export type ScheduleID = Brand<string, "ScheduleID">;
|
export type ScheduleID = Brand<string, "ScheduleID">;
|
||||||
|
|
||||||
/* Mirror of backend/api/features/v1.py:ScheduleCreationRequest */
|
/* Mirror of backend/server/routers/v1.py:ScheduleCreationRequest */
|
||||||
export type ScheduleCreatable = {
|
export type ScheduleCreatable = {
|
||||||
graph_id: GraphID;
|
graph_id: GraphID;
|
||||||
graph_version: number;
|
graph_version: number;
|
||||||
|
|||||||
@@ -11,10 +11,3 @@ export const API_KEY_HEADER_NAME = "X-API-Key";
|
|||||||
|
|
||||||
// Layout
|
// Layout
|
||||||
export const NAVBAR_HEIGHT_PX = 60;
|
export const NAVBAR_HEIGHT_PX = 60;
|
||||||
|
|
||||||
// Routes
|
|
||||||
export function getHomepageRoute(isChatEnabled?: boolean | null): string {
|
|
||||||
if (isChatEnabled === true) return "/copilot";
|
|
||||||
if (isChatEnabled === false) return "/library";
|
|
||||||
return "/";
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
import { Key, storage } from "@/services/storage/local-storage";
|
import { Key, storage } from "@/services/storage/local-storage";
|
||||||
import { type CookieOptions } from "@supabase/ssr";
|
import { type CookieOptions } from "@supabase/ssr";
|
||||||
@@ -71,7 +70,7 @@ export function getRedirectPath(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (isAdminPage(path) && userRole !== "admin") {
|
if (isAdminPage(path) && userRole !== "admin") {
|
||||||
return getHomepageRoute();
|
return "/";
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
import { createServerClient } from "@supabase/ssr";
|
import { createServerClient } from "@supabase/ssr";
|
||||||
import { NextResponse, type NextRequest } from "next/server";
|
import { NextResponse, type NextRequest } from "next/server";
|
||||||
@@ -67,7 +66,7 @@ export async function updateSession(request: NextRequest) {
|
|||||||
|
|
||||||
// 2. Check if user is authenticated but lacks admin role when accessing admin pages
|
// 2. Check if user is authenticated but lacks admin role when accessing admin pages
|
||||||
if (user && userRole !== "admin" && isAdminPage(pathname)) {
|
if (user && userRole !== "admin" && isAdminPage(pathname)) {
|
||||||
url.pathname = getHomepageRoute();
|
url.pathname = "/";
|
||||||
return NextResponse.redirect(url);
|
return NextResponse.redirect(url);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,9 +23,7 @@ import {
|
|||||||
WebSocketNotification,
|
WebSocketNotification,
|
||||||
} from "@/lib/autogpt-server-api";
|
} from "@/lib/autogpt-server-api";
|
||||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||||
import { getHomepageRoute } from "@/lib/constants";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { usePathname, useRouter } from "next/navigation";
|
import { usePathname, useRouter } from "next/navigation";
|
||||||
import {
|
import {
|
||||||
@@ -104,8 +102,6 @@ export default function OnboardingProvider({
|
|||||||
const pathname = usePathname();
|
const pathname = usePathname();
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const { isLoggedIn } = useSupabase();
|
const { isLoggedIn } = useSupabase();
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
|
||||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
|
||||||
|
|
||||||
useOnboardingTimezoneDetection();
|
useOnboardingTimezoneDetection();
|
||||||
|
|
||||||
@@ -150,7 +146,7 @@ export default function OnboardingProvider({
|
|||||||
if (isOnOnboardingRoute) {
|
if (isOnOnboardingRoute) {
|
||||||
const enabled = await resolveResponse(getV1IsOnboardingEnabled());
|
const enabled = await resolveResponse(getV1IsOnboardingEnabled());
|
||||||
if (!enabled) {
|
if (!enabled) {
|
||||||
router.push(homepageRoute);
|
router.push("/");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -162,7 +158,7 @@ export default function OnboardingProvider({
|
|||||||
isOnOnboardingRoute &&
|
isOnOnboardingRoute &&
|
||||||
shouldRedirectFromOnboarding(onboarding.completedSteps, pathname)
|
shouldRedirectFromOnboarding(onboarding.completedSteps, pathname)
|
||||||
) {
|
) {
|
||||||
router.push(homepageRoute);
|
router.push("/");
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Failed to initialize onboarding:", error);
|
console.error("Failed to initialize onboarding:", error);
|
||||||
@@ -177,7 +173,7 @@ export default function OnboardingProvider({
|
|||||||
}
|
}
|
||||||
|
|
||||||
initializeOnboarding();
|
initializeOnboarding();
|
||||||
}, [api, homepageRoute, isOnOnboardingRoute, router, isLoggedIn, pathname]);
|
}, [api, isOnOnboardingRoute, router, isLoggedIn, pathname]);
|
||||||
|
|
||||||
const handleOnboardingNotification = useCallback(
|
const handleOnboardingNotification = useCallback(
|
||||||
(notification: WebSocketNotification) => {
|
(notification: WebSocketNotification) => {
|
||||||
|
|||||||
@@ -83,6 +83,10 @@ function getPostHogCredentials() {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getLaunchDarklyClientId() {
|
||||||
|
return process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||||
|
}
|
||||||
|
|
||||||
function isProductionBuild() {
|
function isProductionBuild() {
|
||||||
return process.env.NODE_ENV === "production";
|
return process.env.NODE_ENV === "production";
|
||||||
}
|
}
|
||||||
@@ -120,7 +124,10 @@ function isVercelPreview() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function areFeatureFlagsEnabled() {
|
function areFeatureFlagsEnabled() {
|
||||||
return process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "enabled";
|
return (
|
||||||
|
process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true" &&
|
||||||
|
Boolean(process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function isPostHogEnabled() {
|
function isPostHogEnabled() {
|
||||||
@@ -143,6 +150,7 @@ export const environment = {
|
|||||||
getSupabaseAnonKey,
|
getSupabaseAnonKey,
|
||||||
getPreviewStealingDev,
|
getPreviewStealingDev,
|
||||||
getPostHogCredentials,
|
getPostHogCredentials,
|
||||||
|
getLaunchDarklyClientId,
|
||||||
// Assertions
|
// Assertions
|
||||||
isServerSide,
|
isServerSide,
|
||||||
isClientSide,
|
isClientSide,
|
||||||
|
|||||||
@@ -0,0 +1,59 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
|
import { useLDClient } from "launchdarkly-react-client-sdk";
|
||||||
|
import { useRouter } from "next/navigation";
|
||||||
|
import { ReactNode, useEffect, useState } from "react";
|
||||||
|
import { environment } from "../environment";
|
||||||
|
import { Flag, useGetFlag } from "./use-get-flag";
|
||||||
|
|
||||||
|
interface FeatureFlagRedirectProps {
|
||||||
|
flag: Flag;
|
||||||
|
whenDisabled: string;
|
||||||
|
children: ReactNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function FeatureFlagPage({
|
||||||
|
flag,
|
||||||
|
whenDisabled,
|
||||||
|
children,
|
||||||
|
}: FeatureFlagRedirectProps) {
|
||||||
|
const [isLoading, setIsLoading] = useState(true);
|
||||||
|
const router = useRouter();
|
||||||
|
const flagValue = useGetFlag(flag);
|
||||||
|
const ldClient = useLDClient();
|
||||||
|
const ldEnabled = environment.areFeatureFlagsEnabled();
|
||||||
|
const ldReady = Boolean(ldClient);
|
||||||
|
const flagEnabled = Boolean(flagValue);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const initialize = async () => {
|
||||||
|
if (!ldEnabled) {
|
||||||
|
router.replace(whenDisabled);
|
||||||
|
setIsLoading(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for LaunchDarkly to initialize when enabled to prevent race conditions
|
||||||
|
if (ldEnabled && !ldReady) return;
|
||||||
|
|
||||||
|
try {
|
||||||
|
await ldClient?.waitForInitialization();
|
||||||
|
if (!flagEnabled) router.replace(whenDisabled);
|
||||||
|
} catch (error) {
|
||||||
|
console.error(error);
|
||||||
|
router.replace(whenDisabled);
|
||||||
|
} finally {
|
||||||
|
setIsLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
initialize();
|
||||||
|
}, [ldReady, flagEnabled]);
|
||||||
|
|
||||||
|
return isLoading || !flagEnabled ? (
|
||||||
|
<LoadingSpinner size="large" cover />
|
||||||
|
) : (
|
||||||
|
<>{children}</>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
|
import { useLDClient } from "launchdarkly-react-client-sdk";
|
||||||
|
import { useRouter } from "next/navigation";
|
||||||
|
import { useEffect } from "react";
|
||||||
|
import { environment } from "../environment";
|
||||||
|
import { Flag, useGetFlag } from "./use-get-flag";
|
||||||
|
|
||||||
|
interface FeatureFlagRedirectProps {
|
||||||
|
flag: Flag;
|
||||||
|
whenEnabled: string;
|
||||||
|
whenDisabled: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function FeatureFlagRedirect({
|
||||||
|
flag,
|
||||||
|
whenEnabled,
|
||||||
|
whenDisabled,
|
||||||
|
}: FeatureFlagRedirectProps) {
|
||||||
|
const router = useRouter();
|
||||||
|
const flagValue = useGetFlag(flag);
|
||||||
|
const ldEnabled = environment.areFeatureFlagsEnabled();
|
||||||
|
const ldClient = useLDClient();
|
||||||
|
const ldReady = Boolean(ldClient);
|
||||||
|
const flagEnabled = Boolean(flagValue);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const initialize = async () => {
|
||||||
|
if (!ldEnabled) {
|
||||||
|
router.replace(whenDisabled);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for LaunchDarkly to initialize when enabled to prevent race conditions
|
||||||
|
if (ldEnabled && !ldReady) return;
|
||||||
|
|
||||||
|
try {
|
||||||
|
await ldClient?.waitForInitialization();
|
||||||
|
router.replace(flagEnabled ? whenEnabled : whenDisabled);
|
||||||
|
} catch (error) {
|
||||||
|
console.error(error);
|
||||||
|
router.replace(whenDisabled);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
initialize();
|
||||||
|
}, [ldReady, flagEnabled]);
|
||||||
|
|
||||||
|
return <LoadingSpinner size="large" cover />;
|
||||||
|
}
|
||||||
@@ -7,14 +7,12 @@ import type { ReactNode } from "react";
|
|||||||
import { useMemo } from "react";
|
import { useMemo } from "react";
|
||||||
import { environment } from "../environment";
|
import { environment } from "../environment";
|
||||||
|
|
||||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
|
||||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
|
||||||
const LAUNCHDARKLY_INIT_TIMEOUT_MS = 5000;
|
const LAUNCHDARKLY_INIT_TIMEOUT_MS = 5000;
|
||||||
|
|
||||||
export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
|
export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
|
||||||
const { user, isUserLoading } = useSupabase();
|
const { user, isUserLoading } = useSupabase();
|
||||||
const isCloud = environment.isCloud();
|
const envEnabled = environment.areFeatureFlagsEnabled();
|
||||||
const isLaunchDarklyConfigured = isCloud && envEnabled && clientId;
|
const clientId = environment.getLaunchDarklyClientId();
|
||||||
|
|
||||||
const context = useMemo(() => {
|
const context = useMemo(() => {
|
||||||
if (isUserLoading || !user) {
|
if (isUserLoading || !user) {
|
||||||
@@ -36,7 +34,7 @@ export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
|
|||||||
};
|
};
|
||||||
}, [user, isUserLoading]);
|
}, [user, isUserLoading]);
|
||||||
|
|
||||||
if (!isLaunchDarklyConfigured) {
|
if (!envEnabled) {
|
||||||
return <>{children}</>;
|
return <>{children}</>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,7 +42,7 @@ export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
|
|||||||
<LDProvider
|
<LDProvider
|
||||||
// Add this key prop. It will be 'anonymous' when logged out,
|
// Add this key prop. It will be 'anonymous' when logged out,
|
||||||
key={context.key}
|
key={context.key}
|
||||||
clientSideID={clientId}
|
clientSideID={clientId ?? ""}
|
||||||
context={context}
|
context={context}
|
||||||
timeout={LAUNCHDARKLY_INIT_TIMEOUT_MS}
|
timeout={LAUNCHDARKLY_INIT_TIMEOUT_MS}
|
||||||
reactOptions={{ useCamelCaseFlagKeys: false }}
|
reactOptions={{ useCamelCaseFlagKeys: false }}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { DEFAULT_SEARCH_TERMS } from "@/app/(platform)/marketplace/components/HeroSection/helpers";
|
import { DEFAULT_SEARCH_TERMS } from "@/app/(platform)/marketplace/components/HeroSection/helpers";
|
||||||
|
import { environment } from "@/services/environment";
|
||||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||||
|
|
||||||
export enum Flag {
|
export enum Flag {
|
||||||
@@ -18,24 +19,9 @@ export enum Flag {
|
|||||||
CHAT = "chat",
|
CHAT = "chat",
|
||||||
}
|
}
|
||||||
|
|
||||||
export type FlagValues = {
|
|
||||||
[Flag.BETA_BLOCKS]: string[];
|
|
||||||
[Flag.NEW_BLOCK_MENU]: boolean;
|
|
||||||
[Flag.NEW_AGENT_RUNS]: boolean;
|
|
||||||
[Flag.GRAPH_SEARCH]: boolean;
|
|
||||||
[Flag.ENABLE_ENHANCED_OUTPUT_HANDLING]: boolean;
|
|
||||||
[Flag.NEW_FLOW_EDITOR]: boolean;
|
|
||||||
[Flag.BUILDER_VIEW_SWITCH]: boolean;
|
|
||||||
[Flag.SHARE_EXECUTION_RESULTS]: boolean;
|
|
||||||
[Flag.AGENT_FAVORITING]: boolean;
|
|
||||||
[Flag.MARKETPLACE_SEARCH_TERMS]: string[];
|
|
||||||
[Flag.ENABLE_PLATFORM_PAYMENT]: boolean;
|
|
||||||
[Flag.CHAT]: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
const isPwMockEnabled = process.env.NEXT_PUBLIC_PW_TEST === "true";
|
const isPwMockEnabled = process.env.NEXT_PUBLIC_PW_TEST === "true";
|
||||||
|
|
||||||
const mockFlags = {
|
const defaultFlags = {
|
||||||
[Flag.BETA_BLOCKS]: [],
|
[Flag.BETA_BLOCKS]: [],
|
||||||
[Flag.NEW_BLOCK_MENU]: false,
|
[Flag.NEW_BLOCK_MENU]: false,
|
||||||
[Flag.NEW_AGENT_RUNS]: false,
|
[Flag.NEW_AGENT_RUNS]: false,
|
||||||
@@ -50,17 +36,16 @@ const mockFlags = {
|
|||||||
[Flag.CHAT]: false,
|
[Flag.CHAT]: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] | null {
|
type FlagValues = typeof defaultFlags;
|
||||||
|
|
||||||
|
export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] {
|
||||||
const currentFlags = useFlags<FlagValues>();
|
const currentFlags = useFlags<FlagValues>();
|
||||||
const flagValue = currentFlags[flag];
|
const flagValue = currentFlags[flag];
|
||||||
|
const areFlagsEnabled = environment.areFeatureFlagsEnabled();
|
||||||
|
|
||||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
if (!areFlagsEnabled || isPwMockEnabled) {
|
||||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
return defaultFlags[flag];
|
||||||
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
|
||||||
|
|
||||||
if (!isLaunchDarklyConfigured || isPwMockEnabled) {
|
|
||||||
return mockFlags[flag];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return flagValue ?? mockFlags[flag];
|
return flagValue ?? defaultFlags[flag];
|
||||||
}
|
}
|
||||||
|
|||||||
1
classic/frontend/.gitignore
vendored
1
classic/frontend/.gitignore
vendored
@@ -8,6 +8,7 @@
|
|||||||
.buildlog/
|
.buildlog/
|
||||||
.history
|
.history
|
||||||
.svn/
|
.svn/
|
||||||
|
.next/
|
||||||
migrate_working_dir/
|
migrate_working_dir/
|
||||||
|
|
||||||
# IntelliJ related
|
# IntelliJ related
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
|||||||
| [Block Installation](block-integrations/basic.md#block-installation) | Given a code string, this block allows the verification and installation of a block code into the system |
|
| [Block Installation](block-integrations/basic.md#block-installation) | Given a code string, this block allows the verification and installation of a block code into the system |
|
||||||
| [Concatenate Lists](block-integrations/basic.md#concatenate-lists) | Concatenates multiple lists into a single list |
|
| [Concatenate Lists](block-integrations/basic.md#concatenate-lists) | Concatenates multiple lists into a single list |
|
||||||
| [Dictionary Is Empty](block-integrations/basic.md#dictionary-is-empty) | Checks if a dictionary is empty |
|
| [Dictionary Is Empty](block-integrations/basic.md#dictionary-is-empty) | Checks if a dictionary is empty |
|
||||||
| [File Store](block-integrations/basic.md#file-store) | Downloads and stores a file from a URL, data URI, or local path |
|
| [File Store](block-integrations/basic.md#file-store) | Stores the input file in the temporary directory |
|
||||||
| [Find In Dictionary](block-integrations/basic.md#find-in-dictionary) | A block that looks up a value in a dictionary, list, or object by key or index and returns the corresponding value |
|
| [Find In Dictionary](block-integrations/basic.md#find-in-dictionary) | A block that looks up a value in a dictionary, list, or object by key or index and returns the corresponding value |
|
||||||
| [Find In List](block-integrations/basic.md#find-in-list) | Finds the index of the value in the list |
|
| [Find In List](block-integrations/basic.md#find-in-list) | Finds the index of the value in the list |
|
||||||
| [Get All Memories](block-integrations/basic.md#get-all-memories) | Retrieve all memories from Mem0 with optional conversation filtering |
|
| [Get All Memories](block-integrations/basic.md#get-all-memories) | Retrieve all memories from Mem0 with optional conversation filtering |
|
||||||
|
|||||||
@@ -709,7 +709,7 @@ This is useful for conditional logic where you need to verify if data was return
|
|||||||
## File Store
|
## File Store
|
||||||
|
|
||||||
### What it is
|
### What it is
|
||||||
Downloads and stores a file from a URL, data URI, or local path. Use this to fetch images, documents, or other files for processing. In CoPilot: saves to workspace (use list_workspace_files to see it). In graphs: outputs a data URI to pass to other blocks.
|
Stores the input file in the temporary directory.
|
||||||
|
|
||||||
### How it works
|
### How it works
|
||||||
<!-- MANUAL: how_it_works -->
|
<!-- MANUAL: how_it_works -->
|
||||||
@@ -722,15 +722,15 @@ The block outputs a file path that other blocks can use to access the stored fil
|
|||||||
|
|
||||||
| Input | Description | Type | Required |
|
| Input | Description | Type | Required |
|
||||||
|-------|-------------|------|----------|
|
|-------|-------------|------|----------|
|
||||||
| file_in | The file to download and store. Can be a URL (https://...), data URI, or local path. | str (file) | Yes |
|
| file_in | The file to store in the temporary directory, it can be a URL, data URI, or local path. | str (file) | Yes |
|
||||||
| base_64 | Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks). | bool | No |
|
| base_64 | Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks). | bool | No |
|
||||||
|
|
||||||
### Outputs
|
### Outputs
|
||||||
|
|
||||||
| Output | Description | Type |
|
| Output | Description | Type |
|
||||||
|--------|-------------|------|
|
|--------|-------------|------|
|
||||||
| error | Error message if the operation failed | str |
|
| error | Error message if the operation failed | str |
|
||||||
| file_out | Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks. | str (file) |
|
| file_out | The relative path to the stored file in the temporary directory. | str (file) |
|
||||||
|
|
||||||
### Possible use case
|
### Possible use case
|
||||||
<!-- MANUAL: use_case -->
|
<!-- MANUAL: use_case -->
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ Block to attach an audio file to a video file using moviepy.
|
|||||||
<!-- MANUAL: how_it_works -->
|
<!-- MANUAL: how_it_works -->
|
||||||
This block combines a video file with an audio file using the moviepy library. The audio track is attached to the video, optionally with volume adjustment via the volume parameter (1.0 = original volume).
|
This block combines a video file with an audio file using the moviepy library. The audio track is attached to the video, optionally with volume adjustment via the volume parameter (1.0 = original volume).
|
||||||
|
|
||||||
Input files can be URLs, data URIs, or local paths. The output format is automatically determined: `workspace://` URLs in CoPilot, data URIs in graph executions.
|
Input files can be URLs, data URIs, or local paths. The output can be returned as either a file path or base64 data URI.
|
||||||
<!-- END MANUAL -->
|
<!-- END MANUAL -->
|
||||||
|
|
||||||
### Inputs
|
### Inputs
|
||||||
@@ -22,6 +22,7 @@ Input files can be URLs, data URIs, or local paths. The output format is automat
|
|||||||
| video_in | Video input (URL, data URI, or local path). | str (file) | Yes |
|
| video_in | Video input (URL, data URI, or local path). | str (file) | Yes |
|
||||||
| audio_in | Audio input (URL, data URI, or local path). | str (file) | Yes |
|
| audio_in | Audio input (URL, data URI, or local path). | str (file) | Yes |
|
||||||
| volume | Volume scale for the newly attached audio track (1.0 = original). | float | No |
|
| volume | Volume scale for the newly attached audio track (1.0 = original). | float | No |
|
||||||
|
| output_return_type | Return the final output as a relative path or base64 data URI. | "file_path" \| "data_uri" | No |
|
||||||
|
|
||||||
### Outputs
|
### Outputs
|
||||||
|
|
||||||
@@ -50,7 +51,7 @@ Block to loop a video to a given duration or number of repeats.
|
|||||||
<!-- MANUAL: how_it_works -->
|
<!-- MANUAL: how_it_works -->
|
||||||
This block extends a video by repeating it to reach a target duration or number of loops. Set duration to specify the total length in seconds, or use n_loops to repeat the video a specific number of times.
|
This block extends a video by repeating it to reach a target duration or number of loops. Set duration to specify the total length in seconds, or use n_loops to repeat the video a specific number of times.
|
||||||
|
|
||||||
The looped video is seamlessly concatenated. The output format is automatically determined: `workspace://` URLs in CoPilot, data URIs in graph executions.
|
The looped video is seamlessly concatenated and can be output as a file path or base64 data URI.
|
||||||
<!-- END MANUAL -->
|
<!-- END MANUAL -->
|
||||||
|
|
||||||
### Inputs
|
### Inputs
|
||||||
@@ -60,6 +61,7 @@ The looped video is seamlessly concatenated. The output format is automatically
|
|||||||
| video_in | The input video (can be a URL, data URI, or local path). | str (file) | Yes |
|
| video_in | The input video (can be a URL, data URI, or local path). | str (file) | Yes |
|
||||||
| duration | Target duration (in seconds) to loop the video to. If omitted, defaults to no looping. | float | No |
|
| duration | Target duration (in seconds) to loop the video to. If omitted, defaults to no looping. | float | No |
|
||||||
| n_loops | Number of times to repeat the video. If omitted, defaults to 1 (no repeat). | int | No |
|
| n_loops | Number of times to repeat the video. If omitted, defaults to 1 (no repeat). | int | No |
|
||||||
|
| output_return_type | How to return the output video. Either a relative path or base64 data URI. | "file_path" \| "data_uri" | No |
|
||||||
|
|
||||||
### Outputs
|
### Outputs
|
||||||
|
|
||||||
|
|||||||
@@ -277,50 +277,6 @@ async def run(
|
|||||||
token = credentials.api_key.get_secret_value()
|
token = credentials.api_key.get_secret_value()
|
||||||
```
|
```
|
||||||
|
|
||||||
### Handling Files
|
|
||||||
|
|
||||||
When your block works with files (images, videos, documents), use `store_media_file()`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.util.file import store_media_file
|
|
||||||
from backend.util.type import MediaFileType
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# PROCESSING: Need local file path for tools like ffmpeg, MoviePy, PIL
|
|
||||||
local_path = await store_media_file(
|
|
||||||
file=input_data.video,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
# EXTERNAL API: Need base64 content for APIs like Replicate, OpenAI
|
|
||||||
image_b64 = await store_media_file(
|
|
||||||
file=input_data.image,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_external_api",
|
|
||||||
)
|
|
||||||
|
|
||||||
# OUTPUT: Return to user/next block (auto-adapts to context)
|
|
||||||
result = await store_media_file(
|
|
||||||
file=generated_url,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output", # workspace:// in CoPilot, data URI in graphs
|
|
||||||
)
|
|
||||||
yield "image_url", result
|
|
||||||
```
|
|
||||||
|
|
||||||
**Return format options:**
|
|
||||||
- `"for_local_processing"` - Local file path for processing tools
|
|
||||||
- `"for_external_api"` - Data URI for external APIs needing base64
|
|
||||||
- `"for_block_output"` - **Always use for outputs** - automatically picks best format
|
|
||||||
|
|
||||||
## Testing Your Block
|
## Testing Your Block
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ This document focuses on the **API Integration OAuth flow** used for connecting
|
|||||||
### 2. Backend API Trust Boundary
|
### 2. Backend API Trust Boundary
|
||||||
- **Location**: Server-side FastAPI application
|
- **Location**: Server-side FastAPI application
|
||||||
- **Components**:
|
- **Components**:
|
||||||
- Integration router (`/backend/backend/api/features/integrations/router.py`)
|
- Integration router (`/backend/backend/server/integrations/router.py`)
|
||||||
- OAuth handlers (`/backend/backend/integrations/oauth/`)
|
- OAuth handlers (`/backend/backend/integrations/oauth/`)
|
||||||
- Credentials store (`/backend/backend/integrations/credentials_store.py`)
|
- Credentials store (`/backend/backend/integrations/credentials_store.py`)
|
||||||
- **Trust Level**: Trusted - server-controlled environment
|
- **Trust Level**: Trusted - server-controlled environment
|
||||||
|
|||||||
@@ -111,71 +111,6 @@ Follow these steps to create and test a new block:
|
|||||||
- `graph_exec_id`: The ID of the execution of the agent. This changes every time the agent has a new "run"
|
- `graph_exec_id`: The ID of the execution of the agent. This changes every time the agent has a new "run"
|
||||||
- `node_exec_id`: The ID of the execution of the node. This changes every time the node is executed
|
- `node_exec_id`: The ID of the execution of the node. This changes every time the node is executed
|
||||||
- `node_id`: The ID of the node that is being executed. It changes every version of the graph, but not every time the node is executed.
|
- `node_id`: The ID of the node that is being executed. It changes every version of the graph, but not every time the node is executed.
|
||||||
- `execution_context`: An `ExecutionContext` object containing user_id, graph_exec_id, workspace_id, and session_id. Required for file handling.
|
|
||||||
|
|
||||||
### Handling Files in Blocks
|
|
||||||
|
|
||||||
When your block needs to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. This function handles downloading, validation, virus scanning, and storage.
|
|
||||||
|
|
||||||
**Import:**
|
|
||||||
```python
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.util.file import store_media_file
|
|
||||||
from backend.util.type import MediaFileType
|
|
||||||
```
|
|
||||||
|
|
||||||
**The `return_format` parameter determines what you get back:**
|
|
||||||
|
|
||||||
| Format | Use When | Returns |
|
|
||||||
|--------|----------|---------|
|
|
||||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
|
||||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
|
||||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
|
||||||
|
|
||||||
**Examples:**
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# PROCESSING: Need to work with file locally (ffmpeg, MoviePy, PIL)
|
|
||||||
local_path = await store_media_file(
|
|
||||||
file=input_data.video,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
# local_path = "video.mp4" - use with Path, ffmpeg, subprocess, etc.
|
|
||||||
full_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
|
||||||
|
|
||||||
# EXTERNAL API: Need to send content to an API like Replicate
|
|
||||||
image_b64 = await store_media_file(
|
|
||||||
file=input_data.image,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_external_api",
|
|
||||||
)
|
|
||||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to external API
|
|
||||||
|
|
||||||
# OUTPUT: Returning result from block to user/next block
|
|
||||||
result_url = await store_media_file(
|
|
||||||
file=generated_image_url,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
yield "image_url", result_url
|
|
||||||
# In CoPilot: result_url = "workspace://abc123" (persistent, context-efficient)
|
|
||||||
# In graphs: result_url = "data:image/png;base64,..." (for next block/display)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Key points:**
|
|
||||||
|
|
||||||
- `for_block_output` is the **only** format that auto-adapts to execution context
|
|
||||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
|
||||||
- Never manually check for `workspace_id` - let `for_block_output` handle the logic
|
|
||||||
- The function handles URLs, data URIs, `workspace://` references, and local paths as input
|
|
||||||
|
|
||||||
### Field Types
|
### Field Types
|
||||||
|
|
||||||
|
|||||||
@@ -246,7 +246,7 @@ If you encounter any issues, verify that:
|
|||||||
```bash
|
```bash
|
||||||
ollama pull llama3.2
|
ollama pull llama3.2
|
||||||
```
|
```
|
||||||
- If using a custom model, ensure it's added to the model list in `backend/api/model.py`
|
- If using a custom model, ensure it's added to the model list in `backend/server/model.py`
|
||||||
|
|
||||||
#### Docker Issues
|
#### Docker Issues
|
||||||
- Ensure Docker daemon is running:
|
- Ensure Docker daemon is running:
|
||||||
|
|||||||
Reference in New Issue
Block a user