mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-10 23:05:17 -05:00
Compare commits
7 Commits
pwuts/open
...
swiftyos/d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a86d750cf5 | ||
|
|
13bd648731 | ||
|
|
3d7ee7cc29 | ||
|
|
1ea52934cd | ||
|
|
7b6db6e260 | ||
|
|
2c9563353e | ||
|
|
fb2a70e2d8 |
@@ -1,36 +0,0 @@
|
||||
{
|
||||
"worktreeCopyPatterns": [
|
||||
".env*",
|
||||
".vscode/**",
|
||||
".auth/**",
|
||||
".claude/**",
|
||||
"autogpt_platform/.env*",
|
||||
"autogpt_platform/backend/.env*",
|
||||
"autogpt_platform/frontend/.env*",
|
||||
"autogpt_platform/frontend/.auth/**",
|
||||
"autogpt_platform/db/docker/.env*"
|
||||
],
|
||||
"worktreeCopyIgnores": [
|
||||
"**/node_modules/**",
|
||||
"**/dist/**",
|
||||
"**/.git/**",
|
||||
"**/Thumbs.db",
|
||||
"**/.DS_Store",
|
||||
"**/.next/**",
|
||||
"**/__pycache__/**",
|
||||
"**/.ruff_cache/**",
|
||||
"**/.pytest_cache/**",
|
||||
"**/*.pyc",
|
||||
"**/playwright-report/**",
|
||||
"**/logs/**",
|
||||
"**/site/**"
|
||||
],
|
||||
"worktreePathTemplate": "$BASE_PATH.worktree",
|
||||
"postCreateCmd": [
|
||||
"cd autogpt_platform/autogpt_libs && poetry install",
|
||||
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
||||
"cd autogpt_platform/frontend && pnpm install"
|
||||
],
|
||||
"terminalCommand": "code .",
|
||||
"deleteBranchWithWorktree": false
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,125 +0,0 @@
|
||||
---
|
||||
name: vercel-react-best-practices
|
||||
description: React and Next.js performance optimization guidelines from Vercel Engineering. This skill should be used when writing, reviewing, or refactoring React/Next.js code to ensure optimal performance patterns. Triggers on tasks involving React components, Next.js pages, data fetching, bundle optimization, or performance improvements.
|
||||
license: MIT
|
||||
metadata:
|
||||
author: vercel
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Vercel React Best Practices
|
||||
|
||||
Comprehensive performance optimization guide for React and Next.js applications, maintained by Vercel. Contains 45 rules across 8 categories, prioritized by impact to guide automated refactoring and code generation.
|
||||
|
||||
## When to Apply
|
||||
|
||||
Reference these guidelines when:
|
||||
- Writing new React components or Next.js pages
|
||||
- Implementing data fetching (client or server-side)
|
||||
- Reviewing code for performance issues
|
||||
- Refactoring existing React/Next.js code
|
||||
- Optimizing bundle size or load times
|
||||
|
||||
## Rule Categories by Priority
|
||||
|
||||
| Priority | Category | Impact | Prefix |
|
||||
|----------|----------|--------|--------|
|
||||
| 1 | Eliminating Waterfalls | CRITICAL | `async-` |
|
||||
| 2 | Bundle Size Optimization | CRITICAL | `bundle-` |
|
||||
| 3 | Server-Side Performance | HIGH | `server-` |
|
||||
| 4 | Client-Side Data Fetching | MEDIUM-HIGH | `client-` |
|
||||
| 5 | Re-render Optimization | MEDIUM | `rerender-` |
|
||||
| 6 | Rendering Performance | MEDIUM | `rendering-` |
|
||||
| 7 | JavaScript Performance | LOW-MEDIUM | `js-` |
|
||||
| 8 | Advanced Patterns | LOW | `advanced-` |
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### 1. Eliminating Waterfalls (CRITICAL)
|
||||
|
||||
- `async-defer-await` - Move await into branches where actually used
|
||||
- `async-parallel` - Use Promise.all() for independent operations
|
||||
- `async-dependencies` - Use better-all for partial dependencies
|
||||
- `async-api-routes` - Start promises early, await late in API routes
|
||||
- `async-suspense-boundaries` - Use Suspense to stream content
|
||||
|
||||
### 2. Bundle Size Optimization (CRITICAL)
|
||||
|
||||
- `bundle-barrel-imports` - Import directly, avoid barrel files
|
||||
- `bundle-dynamic-imports` - Use next/dynamic for heavy components
|
||||
- `bundle-defer-third-party` - Load analytics/logging after hydration
|
||||
- `bundle-conditional` - Load modules only when feature is activated
|
||||
- `bundle-preload` - Preload on hover/focus for perceived speed
|
||||
|
||||
### 3. Server-Side Performance (HIGH)
|
||||
|
||||
- `server-cache-react` - Use React.cache() for per-request deduplication
|
||||
- `server-cache-lru` - Use LRU cache for cross-request caching
|
||||
- `server-serialization` - Minimize data passed to client components
|
||||
- `server-parallel-fetching` - Restructure components to parallelize fetches
|
||||
- `server-after-nonblocking` - Use after() for non-blocking operations
|
||||
|
||||
### 4. Client-Side Data Fetching (MEDIUM-HIGH)
|
||||
|
||||
- `client-swr-dedup` - Use SWR for automatic request deduplication
|
||||
- `client-event-listeners` - Deduplicate global event listeners
|
||||
|
||||
### 5. Re-render Optimization (MEDIUM)
|
||||
|
||||
- `rerender-defer-reads` - Don't subscribe to state only used in callbacks
|
||||
- `rerender-memo` - Extract expensive work into memoized components
|
||||
- `rerender-dependencies` - Use primitive dependencies in effects
|
||||
- `rerender-derived-state` - Subscribe to derived booleans, not raw values
|
||||
- `rerender-functional-setstate` - Use functional setState for stable callbacks
|
||||
- `rerender-lazy-state-init` - Pass function to useState for expensive values
|
||||
- `rerender-transitions` - Use startTransition for non-urgent updates
|
||||
|
||||
### 6. Rendering Performance (MEDIUM)
|
||||
|
||||
- `rendering-animate-svg-wrapper` - Animate div wrapper, not SVG element
|
||||
- `rendering-content-visibility` - Use content-visibility for long lists
|
||||
- `rendering-hoist-jsx` - Extract static JSX outside components
|
||||
- `rendering-svg-precision` - Reduce SVG coordinate precision
|
||||
- `rendering-hydration-no-flicker` - Use inline script for client-only data
|
||||
- `rendering-activity` - Use Activity component for show/hide
|
||||
- `rendering-conditional-render` - Use ternary, not && for conditionals
|
||||
|
||||
### 7. JavaScript Performance (LOW-MEDIUM)
|
||||
|
||||
- `js-batch-dom-css` - Group CSS changes via classes or cssText
|
||||
- `js-index-maps` - Build Map for repeated lookups
|
||||
- `js-cache-property-access` - Cache object properties in loops
|
||||
- `js-cache-function-results` - Cache function results in module-level Map
|
||||
- `js-cache-storage` - Cache localStorage/sessionStorage reads
|
||||
- `js-combine-iterations` - Combine multiple filter/map into one loop
|
||||
- `js-length-check-first` - Check array length before expensive comparison
|
||||
- `js-early-exit` - Return early from functions
|
||||
- `js-hoist-regexp` - Hoist RegExp creation outside loops
|
||||
- `js-min-max-loop` - Use loop for min/max instead of sort
|
||||
- `js-set-map-lookups` - Use Set/Map for O(1) lookups
|
||||
- `js-tosorted-immutable` - Use toSorted() for immutability
|
||||
|
||||
### 8. Advanced Patterns (LOW)
|
||||
|
||||
- `advanced-event-handler-refs` - Store event handlers in refs
|
||||
- `advanced-use-latest` - useLatest for stable callback refs
|
||||
|
||||
## How to Use
|
||||
|
||||
Read individual rule files for detailed explanations and code examples:
|
||||
|
||||
```
|
||||
rules/async-parallel.md
|
||||
rules/bundle-barrel-imports.md
|
||||
rules/_sections.md
|
||||
```
|
||||
|
||||
Each rule file contains:
|
||||
- Brief explanation of why it matters
|
||||
- Incorrect code example with explanation
|
||||
- Correct code example with explanation
|
||||
- Additional context and references
|
||||
|
||||
## Full Compiled Document
|
||||
|
||||
For the complete guide with all rules expanded: `AGENTS.md`
|
||||
@@ -1,55 +0,0 @@
|
||||
---
|
||||
title: Store Event Handlers in Refs
|
||||
impact: LOW
|
||||
impactDescription: stable subscriptions
|
||||
tags: advanced, hooks, refs, event-handlers, optimization
|
||||
---
|
||||
|
||||
## Store Event Handlers in Refs
|
||||
|
||||
Store callbacks in refs when used in effects that shouldn't re-subscribe on callback changes.
|
||||
|
||||
**Incorrect (re-subscribes on every render):**
|
||||
|
||||
```tsx
|
||||
function useWindowEvent(event: string, handler: () => void) {
|
||||
useEffect(() => {
|
||||
window.addEventListener(event, handler)
|
||||
return () => window.removeEventListener(event, handler)
|
||||
}, [event, handler])
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (stable subscription):**
|
||||
|
||||
```tsx
|
||||
function useWindowEvent(event: string, handler: () => void) {
|
||||
const handlerRef = useRef(handler)
|
||||
useEffect(() => {
|
||||
handlerRef.current = handler
|
||||
}, [handler])
|
||||
|
||||
useEffect(() => {
|
||||
const listener = () => handlerRef.current()
|
||||
window.addEventListener(event, listener)
|
||||
return () => window.removeEventListener(event, listener)
|
||||
}, [event])
|
||||
}
|
||||
```
|
||||
|
||||
**Alternative: use `useEffectEvent` if you're on latest React:**
|
||||
|
||||
```tsx
|
||||
import { useEffectEvent } from 'react'
|
||||
|
||||
function useWindowEvent(event: string, handler: () => void) {
|
||||
const onEvent = useEffectEvent(handler)
|
||||
|
||||
useEffect(() => {
|
||||
window.addEventListener(event, onEvent)
|
||||
return () => window.removeEventListener(event, onEvent)
|
||||
}, [event])
|
||||
}
|
||||
```
|
||||
|
||||
`useEffectEvent` provides a cleaner API for the same pattern: it creates a stable function reference that always calls the latest version of the handler.
|
||||
@@ -1,49 +0,0 @@
|
||||
---
|
||||
title: useLatest for Stable Callback Refs
|
||||
impact: LOW
|
||||
impactDescription: prevents effect re-runs
|
||||
tags: advanced, hooks, useLatest, refs, optimization
|
||||
---
|
||||
|
||||
## useLatest for Stable Callback Refs
|
||||
|
||||
Access latest values in callbacks without adding them to dependency arrays. Prevents effect re-runs while avoiding stale closures.
|
||||
|
||||
**Implementation:**
|
||||
|
||||
```typescript
|
||||
function useLatest<T>(value: T) {
|
||||
const ref = useRef(value)
|
||||
useEffect(() => {
|
||||
ref.current = value
|
||||
}, [value])
|
||||
return ref
|
||||
}
|
||||
```
|
||||
|
||||
**Incorrect (effect re-runs on every callback change):**
|
||||
|
||||
```tsx
|
||||
function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
|
||||
const [query, setQuery] = useState('')
|
||||
|
||||
useEffect(() => {
|
||||
const timeout = setTimeout(() => onSearch(query), 300)
|
||||
return () => clearTimeout(timeout)
|
||||
}, [query, onSearch])
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (stable effect, fresh callback):**
|
||||
|
||||
```tsx
|
||||
function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
|
||||
const [query, setQuery] = useState('')
|
||||
const onSearchRef = useLatest(onSearch)
|
||||
|
||||
useEffect(() => {
|
||||
const timeout = setTimeout(() => onSearchRef.current(query), 300)
|
||||
return () => clearTimeout(timeout)
|
||||
}, [query])
|
||||
}
|
||||
```
|
||||
@@ -1,38 +0,0 @@
|
||||
---
|
||||
title: Prevent Waterfall Chains in API Routes
|
||||
impact: CRITICAL
|
||||
impactDescription: 2-10× improvement
|
||||
tags: api-routes, server-actions, waterfalls, parallelization
|
||||
---
|
||||
|
||||
## Prevent Waterfall Chains in API Routes
|
||||
|
||||
In API routes and Server Actions, start independent operations immediately, even if you don't await them yet.
|
||||
|
||||
**Incorrect (config waits for auth, data waits for both):**
|
||||
|
||||
```typescript
|
||||
export async function GET(request: Request) {
|
||||
const session = await auth()
|
||||
const config = await fetchConfig()
|
||||
const data = await fetchData(session.user.id)
|
||||
return Response.json({ data, config })
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (auth and config start immediately):**
|
||||
|
||||
```typescript
|
||||
export async function GET(request: Request) {
|
||||
const sessionPromise = auth()
|
||||
const configPromise = fetchConfig()
|
||||
const session = await sessionPromise
|
||||
const [config, data] = await Promise.all([
|
||||
configPromise,
|
||||
fetchData(session.user.id)
|
||||
])
|
||||
return Response.json({ data, config })
|
||||
}
|
||||
```
|
||||
|
||||
For operations with more complex dependency chains, use `better-all` to automatically maximize parallelism (see Dependency-Based Parallelization).
|
||||
@@ -1,80 +0,0 @@
|
||||
---
|
||||
title: Defer Await Until Needed
|
||||
impact: HIGH
|
||||
impactDescription: avoids blocking unused code paths
|
||||
tags: async, await, conditional, optimization
|
||||
---
|
||||
|
||||
## Defer Await Until Needed
|
||||
|
||||
Move `await` operations into the branches where they're actually used to avoid blocking code paths that don't need them.
|
||||
|
||||
**Incorrect (blocks both branches):**
|
||||
|
||||
```typescript
|
||||
async function handleRequest(userId: string, skipProcessing: boolean) {
|
||||
const userData = await fetchUserData(userId)
|
||||
|
||||
if (skipProcessing) {
|
||||
// Returns immediately but still waited for userData
|
||||
return { skipped: true }
|
||||
}
|
||||
|
||||
// Only this branch uses userData
|
||||
return processUserData(userData)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (only blocks when needed):**
|
||||
|
||||
```typescript
|
||||
async function handleRequest(userId: string, skipProcessing: boolean) {
|
||||
if (skipProcessing) {
|
||||
// Returns immediately without waiting
|
||||
return { skipped: true }
|
||||
}
|
||||
|
||||
// Fetch only when needed
|
||||
const userData = await fetchUserData(userId)
|
||||
return processUserData(userData)
|
||||
}
|
||||
```
|
||||
|
||||
**Another example (early return optimization):**
|
||||
|
||||
```typescript
|
||||
// Incorrect: always fetches permissions
|
||||
async function updateResource(resourceId: string, userId: string) {
|
||||
const permissions = await fetchPermissions(userId)
|
||||
const resource = await getResource(resourceId)
|
||||
|
||||
if (!resource) {
|
||||
return { error: 'Not found' }
|
||||
}
|
||||
|
||||
if (!permissions.canEdit) {
|
||||
return { error: 'Forbidden' }
|
||||
}
|
||||
|
||||
return await updateResourceData(resource, permissions)
|
||||
}
|
||||
|
||||
// Correct: fetches only when needed
|
||||
async function updateResource(resourceId: string, userId: string) {
|
||||
const resource = await getResource(resourceId)
|
||||
|
||||
if (!resource) {
|
||||
return { error: 'Not found' }
|
||||
}
|
||||
|
||||
const permissions = await fetchPermissions(userId)
|
||||
|
||||
if (!permissions.canEdit) {
|
||||
return { error: 'Forbidden' }
|
||||
}
|
||||
|
||||
return await updateResourceData(resource, permissions)
|
||||
}
|
||||
```
|
||||
|
||||
This optimization is especially valuable when the skipped branch is frequently taken, or when the deferred operation is expensive.
|
||||
@@ -1,36 +0,0 @@
|
||||
---
|
||||
title: Dependency-Based Parallelization
|
||||
impact: CRITICAL
|
||||
impactDescription: 2-10× improvement
|
||||
tags: async, parallelization, dependencies, better-all
|
||||
---
|
||||
|
||||
## Dependency-Based Parallelization
|
||||
|
||||
For operations with partial dependencies, use `better-all` to maximize parallelism. It automatically starts each task at the earliest possible moment.
|
||||
|
||||
**Incorrect (profile waits for config unnecessarily):**
|
||||
|
||||
```typescript
|
||||
const [user, config] = await Promise.all([
|
||||
fetchUser(),
|
||||
fetchConfig()
|
||||
])
|
||||
const profile = await fetchProfile(user.id)
|
||||
```
|
||||
|
||||
**Correct (config and profile run in parallel):**
|
||||
|
||||
```typescript
|
||||
import { all } from 'better-all'
|
||||
|
||||
const { user, config, profile } = await all({
|
||||
async user() { return fetchUser() },
|
||||
async config() { return fetchConfig() },
|
||||
async profile() {
|
||||
return fetchProfile((await this.$.user).id)
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
Reference: [https://github.com/shuding/better-all](https://github.com/shuding/better-all)
|
||||
@@ -1,28 +0,0 @@
|
||||
---
|
||||
title: Promise.all() for Independent Operations
|
||||
impact: CRITICAL
|
||||
impactDescription: 2-10× improvement
|
||||
tags: async, parallelization, promises, waterfalls
|
||||
---
|
||||
|
||||
## Promise.all() for Independent Operations
|
||||
|
||||
When async operations have no interdependencies, execute them concurrently using `Promise.all()`.
|
||||
|
||||
**Incorrect (sequential execution, 3 round trips):**
|
||||
|
||||
```typescript
|
||||
const user = await fetchUser()
|
||||
const posts = await fetchPosts()
|
||||
const comments = await fetchComments()
|
||||
```
|
||||
|
||||
**Correct (parallel execution, 1 round trip):**
|
||||
|
||||
```typescript
|
||||
const [user, posts, comments] = await Promise.all([
|
||||
fetchUser(),
|
||||
fetchPosts(),
|
||||
fetchComments()
|
||||
])
|
||||
```
|
||||
@@ -1,99 +0,0 @@
|
||||
---
|
||||
title: Strategic Suspense Boundaries
|
||||
impact: HIGH
|
||||
impactDescription: faster initial paint
|
||||
tags: async, suspense, streaming, layout-shift
|
||||
---
|
||||
|
||||
## Strategic Suspense Boundaries
|
||||
|
||||
Instead of awaiting data in async components before returning JSX, use Suspense boundaries to show the wrapper UI faster while data loads.
|
||||
|
||||
**Incorrect (wrapper blocked by data fetching):**
|
||||
|
||||
```tsx
|
||||
async function Page() {
|
||||
const data = await fetchData() // Blocks entire page
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div>Sidebar</div>
|
||||
<div>Header</div>
|
||||
<div>
|
||||
<DataDisplay data={data} />
|
||||
</div>
|
||||
<div>Footer</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
The entire layout waits for data even though only the middle section needs it.
|
||||
|
||||
**Correct (wrapper shows immediately, data streams in):**
|
||||
|
||||
```tsx
|
||||
function Page() {
|
||||
return (
|
||||
<div>
|
||||
<div>Sidebar</div>
|
||||
<div>Header</div>
|
||||
<div>
|
||||
<Suspense fallback={<Skeleton />}>
|
||||
<DataDisplay />
|
||||
</Suspense>
|
||||
</div>
|
||||
<div>Footer</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
async function DataDisplay() {
|
||||
const data = await fetchData() // Only blocks this component
|
||||
return <div>{data.content}</div>
|
||||
}
|
||||
```
|
||||
|
||||
Sidebar, Header, and Footer render immediately. Only DataDisplay waits for data.
|
||||
|
||||
**Alternative (share promise across components):**
|
||||
|
||||
```tsx
|
||||
function Page() {
|
||||
// Start fetch immediately, but don't await
|
||||
const dataPromise = fetchData()
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div>Sidebar</div>
|
||||
<div>Header</div>
|
||||
<Suspense fallback={<Skeleton />}>
|
||||
<DataDisplay dataPromise={dataPromise} />
|
||||
<DataSummary dataPromise={dataPromise} />
|
||||
</Suspense>
|
||||
<div>Footer</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function DataDisplay({ dataPromise }: { dataPromise: Promise<Data> }) {
|
||||
const data = use(dataPromise) // Unwraps the promise
|
||||
return <div>{data.content}</div>
|
||||
}
|
||||
|
||||
function DataSummary({ dataPromise }: { dataPromise: Promise<Data> }) {
|
||||
const data = use(dataPromise) // Reuses the same promise
|
||||
return <div>{data.summary}</div>
|
||||
}
|
||||
```
|
||||
|
||||
Both components share the same promise, so only one fetch occurs. Layout renders immediately while both components wait together.
|
||||
|
||||
**When NOT to use this pattern:**
|
||||
|
||||
- Critical data needed for layout decisions (affects positioning)
|
||||
- SEO-critical content above the fold
|
||||
- Small, fast queries where suspense overhead isn't worth it
|
||||
- When you want to avoid layout shift (loading → content jump)
|
||||
|
||||
**Trade-off:** Faster initial paint vs potential layout shift. Choose based on your UX priorities.
|
||||
@@ -1,59 +0,0 @@
|
||||
---
|
||||
title: Avoid Barrel File Imports
|
||||
impact: CRITICAL
|
||||
impactDescription: 200-800ms import cost, slow builds
|
||||
tags: bundle, imports, tree-shaking, barrel-files, performance
|
||||
---
|
||||
|
||||
## Avoid Barrel File Imports
|
||||
|
||||
Import directly from source files instead of barrel files to avoid loading thousands of unused modules. **Barrel files** are entry points that re-export multiple modules (e.g., `index.js` that does `export * from './module'`).
|
||||
|
||||
Popular icon and component libraries can have **up to 10,000 re-exports** in their entry file. For many React packages, **it takes 200-800ms just to import them**, affecting both development speed and production cold starts.
|
||||
|
||||
**Why tree-shaking doesn't help:** When a library is marked as external (not bundled), the bundler can't optimize it. If you bundle it to enable tree-shaking, builds become substantially slower analyzing the entire module graph.
|
||||
|
||||
**Incorrect (imports entire library):**
|
||||
|
||||
```tsx
|
||||
import { Check, X, Menu } from 'lucide-react'
|
||||
// Loads 1,583 modules, takes ~2.8s extra in dev
|
||||
// Runtime cost: 200-800ms on every cold start
|
||||
|
||||
import { Button, TextField } from '@mui/material'
|
||||
// Loads 2,225 modules, takes ~4.2s extra in dev
|
||||
```
|
||||
|
||||
**Correct (imports only what you need):**
|
||||
|
||||
```tsx
|
||||
import Check from 'lucide-react/dist/esm/icons/check'
|
||||
import X from 'lucide-react/dist/esm/icons/x'
|
||||
import Menu from 'lucide-react/dist/esm/icons/menu'
|
||||
// Loads only 3 modules (~2KB vs ~1MB)
|
||||
|
||||
import Button from '@mui/material/Button'
|
||||
import TextField from '@mui/material/TextField'
|
||||
// Loads only what you use
|
||||
```
|
||||
|
||||
**Alternative (Next.js 13.5+):**
|
||||
|
||||
```js
|
||||
// next.config.js - use optimizePackageImports
|
||||
module.exports = {
|
||||
experimental: {
|
||||
optimizePackageImports: ['lucide-react', '@mui/material']
|
||||
}
|
||||
}
|
||||
|
||||
// Then you can keep the ergonomic barrel imports:
|
||||
import { Check, X, Menu } from 'lucide-react'
|
||||
// Automatically transformed to direct imports at build time
|
||||
```
|
||||
|
||||
Direct imports provide 15-70% faster dev boot, 28% faster builds, 40% faster cold starts, and significantly faster HMR.
|
||||
|
||||
Libraries commonly affected: `lucide-react`, `@mui/material`, `@mui/icons-material`, `@tabler/icons-react`, `react-icons`, `@headlessui/react`, `@radix-ui/react-*`, `lodash`, `ramda`, `date-fns`, `rxjs`, `react-use`.
|
||||
|
||||
Reference: [How we optimized package imports in Next.js](https://vercel.com/blog/how-we-optimized-package-imports-in-next-js)
|
||||
@@ -1,31 +0,0 @@
|
||||
---
|
||||
title: Conditional Module Loading
|
||||
impact: HIGH
|
||||
impactDescription: loads large data only when needed
|
||||
tags: bundle, conditional-loading, lazy-loading
|
||||
---
|
||||
|
||||
## Conditional Module Loading
|
||||
|
||||
Load large data or modules only when a feature is activated.
|
||||
|
||||
**Example (lazy-load animation frames):**
|
||||
|
||||
```tsx
|
||||
function AnimationPlayer({ enabled }: { enabled: boolean }) {
|
||||
const [frames, setFrames] = useState<Frame[] | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (enabled && !frames && typeof window !== 'undefined') {
|
||||
import('./animation-frames.js')
|
||||
.then(mod => setFrames(mod.frames))
|
||||
.catch(() => setEnabled(false))
|
||||
}
|
||||
}, [enabled, frames])
|
||||
|
||||
if (!frames) return <Skeleton />
|
||||
return <Canvas frames={frames} />
|
||||
}
|
||||
```
|
||||
|
||||
The `typeof window !== 'undefined'` check prevents bundling this module for SSR, optimizing server bundle size and build speed.
|
||||
@@ -1,49 +0,0 @@
|
||||
---
|
||||
title: Defer Non-Critical Third-Party Libraries
|
||||
impact: MEDIUM
|
||||
impactDescription: loads after hydration
|
||||
tags: bundle, third-party, analytics, defer
|
||||
---
|
||||
|
||||
## Defer Non-Critical Third-Party Libraries
|
||||
|
||||
Analytics, logging, and error tracking don't block user interaction. Load them after hydration.
|
||||
|
||||
**Incorrect (blocks initial bundle):**
|
||||
|
||||
```tsx
|
||||
import { Analytics } from '@vercel/analytics/react'
|
||||
|
||||
export default function RootLayout({ children }) {
|
||||
return (
|
||||
<html>
|
||||
<body>
|
||||
{children}
|
||||
<Analytics />
|
||||
</body>
|
||||
</html>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (loads after hydration):**
|
||||
|
||||
```tsx
|
||||
import dynamic from 'next/dynamic'
|
||||
|
||||
const Analytics = dynamic(
|
||||
() => import('@vercel/analytics/react').then(m => m.Analytics),
|
||||
{ ssr: false }
|
||||
)
|
||||
|
||||
export default function RootLayout({ children }) {
|
||||
return (
|
||||
<html>
|
||||
<body>
|
||||
{children}
|
||||
<Analytics />
|
||||
</body>
|
||||
</html>
|
||||
)
|
||||
}
|
||||
```
|
||||
@@ -1,35 +0,0 @@
|
||||
---
|
||||
title: Dynamic Imports for Heavy Components
|
||||
impact: CRITICAL
|
||||
impactDescription: directly affects TTI and LCP
|
||||
tags: bundle, dynamic-import, code-splitting, next-dynamic
|
||||
---
|
||||
|
||||
## Dynamic Imports for Heavy Components
|
||||
|
||||
Use `next/dynamic` to lazy-load large components not needed on initial render.
|
||||
|
||||
**Incorrect (Monaco bundles with main chunk ~300KB):**
|
||||
|
||||
```tsx
|
||||
import { MonacoEditor } from './monaco-editor'
|
||||
|
||||
function CodePanel({ code }: { code: string }) {
|
||||
return <MonacoEditor value={code} />
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (Monaco loads on demand):**
|
||||
|
||||
```tsx
|
||||
import dynamic from 'next/dynamic'
|
||||
|
||||
const MonacoEditor = dynamic(
|
||||
() => import('./monaco-editor').then(m => m.MonacoEditor),
|
||||
{ ssr: false }
|
||||
)
|
||||
|
||||
function CodePanel({ code }: { code: string }) {
|
||||
return <MonacoEditor value={code} />
|
||||
}
|
||||
```
|
||||
@@ -1,50 +0,0 @@
|
||||
---
|
||||
title: Preload Based on User Intent
|
||||
impact: MEDIUM
|
||||
impactDescription: reduces perceived latency
|
||||
tags: bundle, preload, user-intent, hover
|
||||
---
|
||||
|
||||
## Preload Based on User Intent
|
||||
|
||||
Preload heavy bundles before they're needed to reduce perceived latency.
|
||||
|
||||
**Example (preload on hover/focus):**
|
||||
|
||||
```tsx
|
||||
function EditorButton({ onClick }: { onClick: () => void }) {
|
||||
const preload = () => {
|
||||
if (typeof window !== 'undefined') {
|
||||
void import('./monaco-editor')
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<button
|
||||
onMouseEnter={preload}
|
||||
onFocus={preload}
|
||||
onClick={onClick}
|
||||
>
|
||||
Open Editor
|
||||
</button>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Example (preload when feature flag is enabled):**
|
||||
|
||||
```tsx
|
||||
function FlagsProvider({ children, flags }: Props) {
|
||||
useEffect(() => {
|
||||
if (flags.editorEnabled && typeof window !== 'undefined') {
|
||||
void import('./monaco-editor').then(mod => mod.init())
|
||||
}
|
||||
}, [flags.editorEnabled])
|
||||
|
||||
return <FlagsContext.Provider value={flags}>
|
||||
{children}
|
||||
</FlagsContext.Provider>
|
||||
}
|
||||
```
|
||||
|
||||
The `typeof window !== 'undefined'` check prevents bundling preloaded modules for SSR, optimizing server bundle size and build speed.
|
||||
@@ -1,74 +0,0 @@
|
||||
---
|
||||
title: Deduplicate Global Event Listeners
|
||||
impact: LOW
|
||||
impactDescription: single listener for N components
|
||||
tags: client, swr, event-listeners, subscription
|
||||
---
|
||||
|
||||
## Deduplicate Global Event Listeners
|
||||
|
||||
Use `useSWRSubscription()` to share global event listeners across component instances.
|
||||
|
||||
**Incorrect (N instances = N listeners):**
|
||||
|
||||
```tsx
|
||||
function useKeyboardShortcut(key: string, callback: () => void) {
|
||||
useEffect(() => {
|
||||
const handler = (e: KeyboardEvent) => {
|
||||
if (e.metaKey && e.key === key) {
|
||||
callback()
|
||||
}
|
||||
}
|
||||
window.addEventListener('keydown', handler)
|
||||
return () => window.removeEventListener('keydown', handler)
|
||||
}, [key, callback])
|
||||
}
|
||||
```
|
||||
|
||||
When using the `useKeyboardShortcut` hook multiple times, each instance will register a new listener.
|
||||
|
||||
**Correct (N instances = 1 listener):**
|
||||
|
||||
```tsx
|
||||
import useSWRSubscription from 'swr/subscription'
|
||||
|
||||
// Module-level Map to track callbacks per key
|
||||
const keyCallbacks = new Map<string, Set<() => void>>()
|
||||
|
||||
function useKeyboardShortcut(key: string, callback: () => void) {
|
||||
// Register this callback in the Map
|
||||
useEffect(() => {
|
||||
if (!keyCallbacks.has(key)) {
|
||||
keyCallbacks.set(key, new Set())
|
||||
}
|
||||
keyCallbacks.get(key)!.add(callback)
|
||||
|
||||
return () => {
|
||||
const set = keyCallbacks.get(key)
|
||||
if (set) {
|
||||
set.delete(callback)
|
||||
if (set.size === 0) {
|
||||
keyCallbacks.delete(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}, [key, callback])
|
||||
|
||||
useSWRSubscription('global-keydown', () => {
|
||||
const handler = (e: KeyboardEvent) => {
|
||||
if (e.metaKey && keyCallbacks.has(e.key)) {
|
||||
keyCallbacks.get(e.key)!.forEach(cb => cb())
|
||||
}
|
||||
}
|
||||
window.addEventListener('keydown', handler)
|
||||
return () => window.removeEventListener('keydown', handler)
|
||||
})
|
||||
}
|
||||
|
||||
function Profile() {
|
||||
// Multiple shortcuts will share the same listener
|
||||
useKeyboardShortcut('p', () => { /* ... */ })
|
||||
useKeyboardShortcut('k', () => { /* ... */ })
|
||||
// ...
|
||||
}
|
||||
```
|
||||
@@ -1,56 +0,0 @@
|
||||
---
|
||||
title: Use SWR for Automatic Deduplication
|
||||
impact: MEDIUM-HIGH
|
||||
impactDescription: automatic deduplication
|
||||
tags: client, swr, deduplication, data-fetching
|
||||
---
|
||||
|
||||
## Use SWR for Automatic Deduplication
|
||||
|
||||
SWR enables request deduplication, caching, and revalidation across component instances.
|
||||
|
||||
**Incorrect (no deduplication, each instance fetches):**
|
||||
|
||||
```tsx
|
||||
function UserList() {
|
||||
const [users, setUsers] = useState([])
|
||||
useEffect(() => {
|
||||
fetch('/api/users')
|
||||
.then(r => r.json())
|
||||
.then(setUsers)
|
||||
}, [])
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (multiple instances share one request):**
|
||||
|
||||
```tsx
|
||||
import useSWR from 'swr'
|
||||
|
||||
function UserList() {
|
||||
const { data: users } = useSWR('/api/users', fetcher)
|
||||
}
|
||||
```
|
||||
|
||||
**For immutable data:**
|
||||
|
||||
```tsx
|
||||
import { useImmutableSWR } from '@/lib/swr'
|
||||
|
||||
function StaticContent() {
|
||||
const { data } = useImmutableSWR('/api/config', fetcher)
|
||||
}
|
||||
```
|
||||
|
||||
**For mutations:**
|
||||
|
||||
```tsx
|
||||
import { useSWRMutation } from 'swr/mutation'
|
||||
|
||||
function UpdateButton() {
|
||||
const { trigger } = useSWRMutation('/api/user', updateUser)
|
||||
return <button onClick={() => trigger()}>Update</button>
|
||||
}
|
||||
```
|
||||
|
||||
Reference: [https://swr.vercel.app](https://swr.vercel.app)
|
||||
@@ -1,82 +0,0 @@
|
||||
---
|
||||
title: Batch DOM CSS Changes
|
||||
impact: MEDIUM
|
||||
impactDescription: reduces reflows/repaints
|
||||
tags: javascript, dom, css, performance, reflow
|
||||
---
|
||||
|
||||
## Batch DOM CSS Changes
|
||||
|
||||
Avoid changing styles one property at a time. Group multiple CSS changes together via classes or `cssText` to minimize browser reflows.
|
||||
|
||||
**Incorrect (multiple reflows):**
|
||||
|
||||
```typescript
|
||||
function updateElementStyles(element: HTMLElement) {
|
||||
// Each line triggers a reflow
|
||||
element.style.width = '100px'
|
||||
element.style.height = '200px'
|
||||
element.style.backgroundColor = 'blue'
|
||||
element.style.border = '1px solid black'
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (add class - single reflow):**
|
||||
|
||||
```typescript
|
||||
// CSS file
|
||||
.highlighted-box {
|
||||
width: 100px;
|
||||
height: 200px;
|
||||
background-color: blue;
|
||||
border: 1px solid black;
|
||||
}
|
||||
|
||||
// JavaScript
|
||||
function updateElementStyles(element: HTMLElement) {
|
||||
element.classList.add('highlighted-box')
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (change cssText - single reflow):**
|
||||
|
||||
```typescript
|
||||
function updateElementStyles(element: HTMLElement) {
|
||||
element.style.cssText = `
|
||||
width: 100px;
|
||||
height: 200px;
|
||||
background-color: blue;
|
||||
border: 1px solid black;
|
||||
`
|
||||
}
|
||||
```
|
||||
|
||||
**React example:**
|
||||
|
||||
```tsx
|
||||
// Incorrect: changing styles one by one
|
||||
function Box({ isHighlighted }: { isHighlighted: boolean }) {
|
||||
const ref = useRef<HTMLDivElement>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (ref.current && isHighlighted) {
|
||||
ref.current.style.width = '100px'
|
||||
ref.current.style.height = '200px'
|
||||
ref.current.style.backgroundColor = 'blue'
|
||||
}
|
||||
}, [isHighlighted])
|
||||
|
||||
return <div ref={ref}>Content</div>
|
||||
}
|
||||
|
||||
// Correct: toggle class
|
||||
function Box({ isHighlighted }: { isHighlighted: boolean }) {
|
||||
return (
|
||||
<div className={isHighlighted ? 'highlighted-box' : ''}>
|
||||
Content
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
Prefer CSS classes over inline styles when possible. Classes are cached by the browser and provide better separation of concerns.
|
||||
@@ -1,80 +0,0 @@
|
||||
---
|
||||
title: Cache Repeated Function Calls
|
||||
impact: MEDIUM
|
||||
impactDescription: avoid redundant computation
|
||||
tags: javascript, cache, memoization, performance
|
||||
---
|
||||
|
||||
## Cache Repeated Function Calls
|
||||
|
||||
Use a module-level Map to cache function results when the same function is called repeatedly with the same inputs during render.
|
||||
|
||||
**Incorrect (redundant computation):**
|
||||
|
||||
```typescript
|
||||
function ProjectList({ projects }: { projects: Project[] }) {
|
||||
return (
|
||||
<div>
|
||||
{projects.map(project => {
|
||||
// slugify() called 100+ times for same project names
|
||||
const slug = slugify(project.name)
|
||||
|
||||
return <ProjectCard key={project.id} slug={slug} />
|
||||
})}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (cached results):**
|
||||
|
||||
```typescript
|
||||
// Module-level cache
|
||||
const slugifyCache = new Map<string, string>()
|
||||
|
||||
function cachedSlugify(text: string): string {
|
||||
if (slugifyCache.has(text)) {
|
||||
return slugifyCache.get(text)!
|
||||
}
|
||||
const result = slugify(text)
|
||||
slugifyCache.set(text, result)
|
||||
return result
|
||||
}
|
||||
|
||||
function ProjectList({ projects }: { projects: Project[] }) {
|
||||
return (
|
||||
<div>
|
||||
{projects.map(project => {
|
||||
// Computed only once per unique project name
|
||||
const slug = cachedSlugify(project.name)
|
||||
|
||||
return <ProjectCard key={project.id} slug={slug} />
|
||||
})}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Simpler pattern for single-value functions:**
|
||||
|
||||
```typescript
|
||||
let isLoggedInCache: boolean | null = null
|
||||
|
||||
function isLoggedIn(): boolean {
|
||||
if (isLoggedInCache !== null) {
|
||||
return isLoggedInCache
|
||||
}
|
||||
|
||||
isLoggedInCache = document.cookie.includes('auth=')
|
||||
return isLoggedInCache
|
||||
}
|
||||
|
||||
// Clear cache when auth changes
|
||||
function onAuthChange() {
|
||||
isLoggedInCache = null
|
||||
}
|
||||
```
|
||||
|
||||
Use a Map (not a hook) so it works everywhere: utilities, event handlers, not just React components.
|
||||
|
||||
Reference: [How we made the Vercel Dashboard twice as fast](https://vercel.com/blog/how-we-made-the-vercel-dashboard-twice-as-fast)
|
||||
@@ -1,28 +0,0 @@
|
||||
---
|
||||
title: Cache Property Access in Loops
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: reduces lookups
|
||||
tags: javascript, loops, optimization, caching
|
||||
---
|
||||
|
||||
## Cache Property Access in Loops
|
||||
|
||||
Cache object property lookups in hot paths.
|
||||
|
||||
**Incorrect (3 lookups × N iterations):**
|
||||
|
||||
```typescript
|
||||
for (let i = 0; i < arr.length; i++) {
|
||||
process(obj.config.settings.value)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (1 lookup total):**
|
||||
|
||||
```typescript
|
||||
const value = obj.config.settings.value
|
||||
const len = arr.length
|
||||
for (let i = 0; i < len; i++) {
|
||||
process(value)
|
||||
}
|
||||
```
|
||||
@@ -1,70 +0,0 @@
|
||||
---
|
||||
title: Cache Storage API Calls
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: reduces expensive I/O
|
||||
tags: javascript, localStorage, storage, caching, performance
|
||||
---
|
||||
|
||||
## Cache Storage API Calls
|
||||
|
||||
`localStorage`, `sessionStorage`, and `document.cookie` are synchronous and expensive. Cache reads in memory.
|
||||
|
||||
**Incorrect (reads storage on every call):**
|
||||
|
||||
```typescript
|
||||
function getTheme() {
|
||||
return localStorage.getItem('theme') ?? 'light'
|
||||
}
|
||||
// Called 10 times = 10 storage reads
|
||||
```
|
||||
|
||||
**Correct (Map cache):**
|
||||
|
||||
```typescript
|
||||
const storageCache = new Map<string, string | null>()
|
||||
|
||||
function getLocalStorage(key: string) {
|
||||
if (!storageCache.has(key)) {
|
||||
storageCache.set(key, localStorage.getItem(key))
|
||||
}
|
||||
return storageCache.get(key)
|
||||
}
|
||||
|
||||
function setLocalStorage(key: string, value: string) {
|
||||
localStorage.setItem(key, value)
|
||||
storageCache.set(key, value) // keep cache in sync
|
||||
}
|
||||
```
|
||||
|
||||
Use a Map (not a hook) so it works everywhere: utilities, event handlers, not just React components.
|
||||
|
||||
**Cookie caching:**
|
||||
|
||||
```typescript
|
||||
let cookieCache: Record<string, string> | null = null
|
||||
|
||||
function getCookie(name: string) {
|
||||
if (!cookieCache) {
|
||||
cookieCache = Object.fromEntries(
|
||||
document.cookie.split('; ').map(c => c.split('='))
|
||||
)
|
||||
}
|
||||
return cookieCache[name]
|
||||
}
|
||||
```
|
||||
|
||||
**Important (invalidate on external changes):**
|
||||
|
||||
If storage can change externally (another tab, server-set cookies), invalidate cache:
|
||||
|
||||
```typescript
|
||||
window.addEventListener('storage', (e) => {
|
||||
if (e.key) storageCache.delete(e.key)
|
||||
})
|
||||
|
||||
document.addEventListener('visibilitychange', () => {
|
||||
if (document.visibilityState === 'visible') {
|
||||
storageCache.clear()
|
||||
}
|
||||
})
|
||||
```
|
||||
@@ -1,32 +0,0 @@
|
||||
---
|
||||
title: Combine Multiple Array Iterations
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: reduces iterations
|
||||
tags: javascript, arrays, loops, performance
|
||||
---
|
||||
|
||||
## Combine Multiple Array Iterations
|
||||
|
||||
Multiple `.filter()` or `.map()` calls iterate the array multiple times. Combine into one loop.
|
||||
|
||||
**Incorrect (3 iterations):**
|
||||
|
||||
```typescript
|
||||
const admins = users.filter(u => u.isAdmin)
|
||||
const testers = users.filter(u => u.isTester)
|
||||
const inactive = users.filter(u => !u.isActive)
|
||||
```
|
||||
|
||||
**Correct (1 iteration):**
|
||||
|
||||
```typescript
|
||||
const admins: User[] = []
|
||||
const testers: User[] = []
|
||||
const inactive: User[] = []
|
||||
|
||||
for (const user of users) {
|
||||
if (user.isAdmin) admins.push(user)
|
||||
if (user.isTester) testers.push(user)
|
||||
if (!user.isActive) inactive.push(user)
|
||||
}
|
||||
```
|
||||
@@ -1,50 +0,0 @@
|
||||
---
|
||||
title: Early Return from Functions
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: avoids unnecessary computation
|
||||
tags: javascript, functions, optimization, early-return
|
||||
---
|
||||
|
||||
## Early Return from Functions
|
||||
|
||||
Return early when result is determined to skip unnecessary processing.
|
||||
|
||||
**Incorrect (processes all items even after finding answer):**
|
||||
|
||||
```typescript
|
||||
function validateUsers(users: User[]) {
|
||||
let hasError = false
|
||||
let errorMessage = ''
|
||||
|
||||
for (const user of users) {
|
||||
if (!user.email) {
|
||||
hasError = true
|
||||
errorMessage = 'Email required'
|
||||
}
|
||||
if (!user.name) {
|
||||
hasError = true
|
||||
errorMessage = 'Name required'
|
||||
}
|
||||
// Continues checking all users even after error found
|
||||
}
|
||||
|
||||
return hasError ? { valid: false, error: errorMessage } : { valid: true }
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (returns immediately on first error):**
|
||||
|
||||
```typescript
|
||||
function validateUsers(users: User[]) {
|
||||
for (const user of users) {
|
||||
if (!user.email) {
|
||||
return { valid: false, error: 'Email required' }
|
||||
}
|
||||
if (!user.name) {
|
||||
return { valid: false, error: 'Name required' }
|
||||
}
|
||||
}
|
||||
|
||||
return { valid: true }
|
||||
}
|
||||
```
|
||||
@@ -1,45 +0,0 @@
|
||||
---
|
||||
title: Hoist RegExp Creation
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: avoids recreation
|
||||
tags: javascript, regexp, optimization, memoization
|
||||
---
|
||||
|
||||
## Hoist RegExp Creation
|
||||
|
||||
Don't create RegExp inside render. Hoist to module scope or memoize with `useMemo()`.
|
||||
|
||||
**Incorrect (new RegExp every render):**
|
||||
|
||||
```tsx
|
||||
function Highlighter({ text, query }: Props) {
|
||||
const regex = new RegExp(`(${query})`, 'gi')
|
||||
const parts = text.split(regex)
|
||||
return <>{parts.map((part, i) => ...)}</>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (memoize or hoist):**
|
||||
|
||||
```tsx
|
||||
const EMAIL_REGEX = /^[^\s@]+@[^\s@]+\.[^\s@]+$/
|
||||
|
||||
function Highlighter({ text, query }: Props) {
|
||||
const regex = useMemo(
|
||||
() => new RegExp(`(${escapeRegex(query)})`, 'gi'),
|
||||
[query]
|
||||
)
|
||||
const parts = text.split(regex)
|
||||
return <>{parts.map((part, i) => ...)}</>
|
||||
}
|
||||
```
|
||||
|
||||
**Warning (global regex has mutable state):**
|
||||
|
||||
Global regex (`/g`) has mutable `lastIndex` state:
|
||||
|
||||
```typescript
|
||||
const regex = /foo/g
|
||||
regex.test('foo') // true, lastIndex = 3
|
||||
regex.test('foo') // false, lastIndex = 0
|
||||
```
|
||||
@@ -1,37 +0,0 @@
|
||||
---
|
||||
title: Build Index Maps for Repeated Lookups
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: 1M ops to 2K ops
|
||||
tags: javascript, map, indexing, optimization, performance
|
||||
---
|
||||
|
||||
## Build Index Maps for Repeated Lookups
|
||||
|
||||
Multiple `.find()` calls by the same key should use a Map.
|
||||
|
||||
**Incorrect (O(n) per lookup):**
|
||||
|
||||
```typescript
|
||||
function processOrders(orders: Order[], users: User[]) {
|
||||
return orders.map(order => ({
|
||||
...order,
|
||||
user: users.find(u => u.id === order.userId)
|
||||
}))
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (O(1) per lookup):**
|
||||
|
||||
```typescript
|
||||
function processOrders(orders: Order[], users: User[]) {
|
||||
const userById = new Map(users.map(u => [u.id, u]))
|
||||
|
||||
return orders.map(order => ({
|
||||
...order,
|
||||
user: userById.get(order.userId)
|
||||
}))
|
||||
}
|
||||
```
|
||||
|
||||
Build map once (O(n)), then all lookups are O(1).
|
||||
For 1000 orders × 1000 users: 1M ops → 2K ops.
|
||||
@@ -1,49 +0,0 @@
|
||||
---
|
||||
title: Early Length Check for Array Comparisons
|
||||
impact: MEDIUM-HIGH
|
||||
impactDescription: avoids expensive operations when lengths differ
|
||||
tags: javascript, arrays, performance, optimization, comparison
|
||||
---
|
||||
|
||||
## Early Length Check for Array Comparisons
|
||||
|
||||
When comparing arrays with expensive operations (sorting, deep equality, serialization), check lengths first. If lengths differ, the arrays cannot be equal.
|
||||
|
||||
In real-world applications, this optimization is especially valuable when the comparison runs in hot paths (event handlers, render loops).
|
||||
|
||||
**Incorrect (always runs expensive comparison):**
|
||||
|
||||
```typescript
|
||||
function hasChanges(current: string[], original: string[]) {
|
||||
// Always sorts and joins, even when lengths differ
|
||||
return current.sort().join() !== original.sort().join()
|
||||
}
|
||||
```
|
||||
|
||||
Two O(n log n) sorts run even when `current.length` is 5 and `original.length` is 100. There is also overhead of joining the arrays and comparing the strings.
|
||||
|
||||
**Correct (O(1) length check first):**
|
||||
|
||||
```typescript
|
||||
function hasChanges(current: string[], original: string[]) {
|
||||
// Early return if lengths differ
|
||||
if (current.length !== original.length) {
|
||||
return true
|
||||
}
|
||||
// Only sort/join when lengths match
|
||||
const currentSorted = current.toSorted()
|
||||
const originalSorted = original.toSorted()
|
||||
for (let i = 0; i < currentSorted.length; i++) {
|
||||
if (currentSorted[i] !== originalSorted[i]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
```
|
||||
|
||||
This new approach is more efficient because:
|
||||
- It avoids the overhead of sorting and joining the arrays when lengths differ
|
||||
- It avoids consuming memory for the joined strings (especially important for large arrays)
|
||||
- It avoids mutating the original arrays
|
||||
- It returns early when a difference is found
|
||||
@@ -1,82 +0,0 @@
|
||||
---
|
||||
title: Use Loop for Min/Max Instead of Sort
|
||||
impact: LOW
|
||||
impactDescription: O(n) instead of O(n log n)
|
||||
tags: javascript, arrays, performance, sorting, algorithms
|
||||
---
|
||||
|
||||
## Use Loop for Min/Max Instead of Sort
|
||||
|
||||
Finding the smallest or largest element only requires a single pass through the array. Sorting is wasteful and slower.
|
||||
|
||||
**Incorrect (O(n log n) - sort to find latest):**
|
||||
|
||||
```typescript
|
||||
interface Project {
|
||||
id: string
|
||||
name: string
|
||||
updatedAt: number
|
||||
}
|
||||
|
||||
function getLatestProject(projects: Project[]) {
|
||||
const sorted = [...projects].sort((a, b) => b.updatedAt - a.updatedAt)
|
||||
return sorted[0]
|
||||
}
|
||||
```
|
||||
|
||||
Sorts the entire array just to find the maximum value.
|
||||
|
||||
**Incorrect (O(n log n) - sort for oldest and newest):**
|
||||
|
||||
```typescript
|
||||
function getOldestAndNewest(projects: Project[]) {
|
||||
const sorted = [...projects].sort((a, b) => a.updatedAt - b.updatedAt)
|
||||
return { oldest: sorted[0], newest: sorted[sorted.length - 1] }
|
||||
}
|
||||
```
|
||||
|
||||
Still sorts unnecessarily when only min/max are needed.
|
||||
|
||||
**Correct (O(n) - single loop):**
|
||||
|
||||
```typescript
|
||||
function getLatestProject(projects: Project[]) {
|
||||
if (projects.length === 0) return null
|
||||
|
||||
let latest = projects[0]
|
||||
|
||||
for (let i = 1; i < projects.length; i++) {
|
||||
if (projects[i].updatedAt > latest.updatedAt) {
|
||||
latest = projects[i]
|
||||
}
|
||||
}
|
||||
|
||||
return latest
|
||||
}
|
||||
|
||||
function getOldestAndNewest(projects: Project[]) {
|
||||
if (projects.length === 0) return { oldest: null, newest: null }
|
||||
|
||||
let oldest = projects[0]
|
||||
let newest = projects[0]
|
||||
|
||||
for (let i = 1; i < projects.length; i++) {
|
||||
if (projects[i].updatedAt < oldest.updatedAt) oldest = projects[i]
|
||||
if (projects[i].updatedAt > newest.updatedAt) newest = projects[i]
|
||||
}
|
||||
|
||||
return { oldest, newest }
|
||||
}
|
||||
```
|
||||
|
||||
Single pass through the array, no copying, no sorting.
|
||||
|
||||
**Alternative (Math.min/Math.max for small arrays):**
|
||||
|
||||
```typescript
|
||||
const numbers = [5, 2, 8, 1, 9]
|
||||
const min = Math.min(...numbers)
|
||||
const max = Math.max(...numbers)
|
||||
```
|
||||
|
||||
This works for small arrays but can be slower for very large arrays due to spread operator limitations. Use the loop approach for reliability.
|
||||
@@ -1,24 +0,0 @@
|
||||
---
|
||||
title: Use Set/Map for O(1) Lookups
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: O(n) to O(1)
|
||||
tags: javascript, set, map, data-structures, performance
|
||||
---
|
||||
|
||||
## Use Set/Map for O(1) Lookups
|
||||
|
||||
Convert arrays to Set/Map for repeated membership checks.
|
||||
|
||||
**Incorrect (O(n) per check):**
|
||||
|
||||
```typescript
|
||||
const allowedIds = ['a', 'b', 'c', ...]
|
||||
items.filter(item => allowedIds.includes(item.id))
|
||||
```
|
||||
|
||||
**Correct (O(1) per check):**
|
||||
|
||||
```typescript
|
||||
const allowedIds = new Set(['a', 'b', 'c', ...])
|
||||
items.filter(item => allowedIds.has(item.id))
|
||||
```
|
||||
@@ -1,57 +0,0 @@
|
||||
---
|
||||
title: Use toSorted() Instead of sort() for Immutability
|
||||
impact: MEDIUM-HIGH
|
||||
impactDescription: prevents mutation bugs in React state
|
||||
tags: javascript, arrays, immutability, react, state, mutation
|
||||
---
|
||||
|
||||
## Use toSorted() Instead of sort() for Immutability
|
||||
|
||||
`.sort()` mutates the array in place, which can cause bugs with React state and props. Use `.toSorted()` to create a new sorted array without mutation.
|
||||
|
||||
**Incorrect (mutates original array):**
|
||||
|
||||
```typescript
|
||||
function UserList({ users }: { users: User[] }) {
|
||||
// Mutates the users prop array!
|
||||
const sorted = useMemo(
|
||||
() => users.sort((a, b) => a.name.localeCompare(b.name)),
|
||||
[users]
|
||||
)
|
||||
return <div>{sorted.map(renderUser)}</div>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (creates new array):**
|
||||
|
||||
```typescript
|
||||
function UserList({ users }: { users: User[] }) {
|
||||
// Creates new sorted array, original unchanged
|
||||
const sorted = useMemo(
|
||||
() => users.toSorted((a, b) => a.name.localeCompare(b.name)),
|
||||
[users]
|
||||
)
|
||||
return <div>{sorted.map(renderUser)}</div>
|
||||
}
|
||||
```
|
||||
|
||||
**Why this matters in React:**
|
||||
|
||||
1. Props/state mutations break React's immutability model - React expects props and state to be treated as read-only
|
||||
2. Causes stale closure bugs - Mutating arrays inside closures (callbacks, effects) can lead to unexpected behavior
|
||||
|
||||
**Browser support (fallback for older browsers):**
|
||||
|
||||
`.toSorted()` is available in all modern browsers (Chrome 110+, Safari 16+, Firefox 115+, Node.js 20+). For older environments, use spread operator:
|
||||
|
||||
```typescript
|
||||
// Fallback for older browsers
|
||||
const sorted = [...items].sort((a, b) => a.value - b.value)
|
||||
```
|
||||
|
||||
**Other immutable array methods:**
|
||||
|
||||
- `.toSorted()` - immutable sort
|
||||
- `.toReversed()` - immutable reverse
|
||||
- `.toSpliced()` - immutable splice
|
||||
- `.with()` - immutable element replacement
|
||||
@@ -1,26 +0,0 @@
|
||||
---
|
||||
title: Use Activity Component for Show/Hide
|
||||
impact: MEDIUM
|
||||
impactDescription: preserves state/DOM
|
||||
tags: rendering, activity, visibility, state-preservation
|
||||
---
|
||||
|
||||
## Use Activity Component for Show/Hide
|
||||
|
||||
Use React's `<Activity>` to preserve state/DOM for expensive components that frequently toggle visibility.
|
||||
|
||||
**Usage:**
|
||||
|
||||
```tsx
|
||||
import { Activity } from 'react'
|
||||
|
||||
function Dropdown({ isOpen }: Props) {
|
||||
return (
|
||||
<Activity mode={isOpen ? 'visible' : 'hidden'}>
|
||||
<ExpensiveMenu />
|
||||
</Activity>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
Avoids expensive re-renders and state loss.
|
||||
@@ -1,47 +0,0 @@
|
||||
---
|
||||
title: Animate SVG Wrapper Instead of SVG Element
|
||||
impact: LOW
|
||||
impactDescription: enables hardware acceleration
|
||||
tags: rendering, svg, css, animation, performance
|
||||
---
|
||||
|
||||
## Animate SVG Wrapper Instead of SVG Element
|
||||
|
||||
Many browsers don't have hardware acceleration for CSS3 animations on SVG elements. Wrap SVG in a `<div>` and animate the wrapper instead.
|
||||
|
||||
**Incorrect (animating SVG directly - no hardware acceleration):**
|
||||
|
||||
```tsx
|
||||
function LoadingSpinner() {
|
||||
return (
|
||||
<svg
|
||||
className="animate-spin"
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<circle cx="12" cy="12" r="10" stroke="currentColor" />
|
||||
</svg>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (animating wrapper div - hardware accelerated):**
|
||||
|
||||
```tsx
|
||||
function LoadingSpinner() {
|
||||
return (
|
||||
<div className="animate-spin">
|
||||
<svg
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<circle cx="12" cy="12" r="10" stroke="currentColor" />
|
||||
</svg>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
This applies to all CSS transforms and transitions (`transform`, `opacity`, `translate`, `scale`, `rotate`). The wrapper div allows browsers to use GPU acceleration for smoother animations.
|
||||
@@ -1,40 +0,0 @@
|
||||
---
|
||||
title: Use Explicit Conditional Rendering
|
||||
impact: LOW
|
||||
impactDescription: prevents rendering 0 or NaN
|
||||
tags: rendering, conditional, jsx, falsy-values
|
||||
---
|
||||
|
||||
## Use Explicit Conditional Rendering
|
||||
|
||||
Use explicit ternary operators (`? :`) instead of `&&` for conditional rendering when the condition can be `0`, `NaN`, or other falsy values that render.
|
||||
|
||||
**Incorrect (renders "0" when count is 0):**
|
||||
|
||||
```tsx
|
||||
function Badge({ count }: { count: number }) {
|
||||
return (
|
||||
<div>
|
||||
{count && <span className="badge">{count}</span>}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// When count = 0, renders: <div>0</div>
|
||||
// When count = 5, renders: <div><span class="badge">5</span></div>
|
||||
```
|
||||
|
||||
**Correct (renders nothing when count is 0):**
|
||||
|
||||
```tsx
|
||||
function Badge({ count }: { count: number }) {
|
||||
return (
|
||||
<div>
|
||||
{count > 0 ? <span className="badge">{count}</span> : null}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// When count = 0, renders: <div></div>
|
||||
// When count = 5, renders: <div><span class="badge">5</span></div>
|
||||
```
|
||||
@@ -1,38 +0,0 @@
|
||||
---
|
||||
title: CSS content-visibility for Long Lists
|
||||
impact: HIGH
|
||||
impactDescription: faster initial render
|
||||
tags: rendering, css, content-visibility, long-lists
|
||||
---
|
||||
|
||||
## CSS content-visibility for Long Lists
|
||||
|
||||
Apply `content-visibility: auto` to defer off-screen rendering.
|
||||
|
||||
**CSS:**
|
||||
|
||||
```css
|
||||
.message-item {
|
||||
content-visibility: auto;
|
||||
contain-intrinsic-size: 0 80px;
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
```tsx
|
||||
function MessageList({ messages }: { messages: Message[] }) {
|
||||
return (
|
||||
<div className="overflow-y-auto h-screen">
|
||||
{messages.map(msg => (
|
||||
<div key={msg.id} className="message-item">
|
||||
<Avatar user={msg.author} />
|
||||
<div>{msg.content}</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
For 1000 messages, browser skips layout/paint for ~990 off-screen items (10× faster initial render).
|
||||
@@ -1,46 +0,0 @@
|
||||
---
|
||||
title: Hoist Static JSX Elements
|
||||
impact: LOW
|
||||
impactDescription: avoids re-creation
|
||||
tags: rendering, jsx, static, optimization
|
||||
---
|
||||
|
||||
## Hoist Static JSX Elements
|
||||
|
||||
Extract static JSX outside components to avoid re-creation.
|
||||
|
||||
**Incorrect (recreates element every render):**
|
||||
|
||||
```tsx
|
||||
function LoadingSkeleton() {
|
||||
return <div className="animate-pulse h-20 bg-gray-200" />
|
||||
}
|
||||
|
||||
function Container() {
|
||||
return (
|
||||
<div>
|
||||
{loading && <LoadingSkeleton />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (reuses same element):**
|
||||
|
||||
```tsx
|
||||
const loadingSkeleton = (
|
||||
<div className="animate-pulse h-20 bg-gray-200" />
|
||||
)
|
||||
|
||||
function Container() {
|
||||
return (
|
||||
<div>
|
||||
{loading && loadingSkeleton}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
This is especially helpful for large and static SVG nodes, which can be expensive to recreate on every render.
|
||||
|
||||
**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, the compiler automatically hoists static JSX elements and optimizes component re-renders, making manual hoisting unnecessary.
|
||||
@@ -1,82 +0,0 @@
|
||||
---
|
||||
title: Prevent Hydration Mismatch Without Flickering
|
||||
impact: MEDIUM
|
||||
impactDescription: avoids visual flicker and hydration errors
|
||||
tags: rendering, ssr, hydration, localStorage, flicker
|
||||
---
|
||||
|
||||
## Prevent Hydration Mismatch Without Flickering
|
||||
|
||||
When rendering content that depends on client-side storage (localStorage, cookies), avoid both SSR breakage and post-hydration flickering by injecting a synchronous script that updates the DOM before React hydrates.
|
||||
|
||||
**Incorrect (breaks SSR):**
|
||||
|
||||
```tsx
|
||||
function ThemeWrapper({ children }: { children: ReactNode }) {
|
||||
// localStorage is not available on server - throws error
|
||||
const theme = localStorage.getItem('theme') || 'light'
|
||||
|
||||
return (
|
||||
<div className={theme}>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
Server-side rendering will fail because `localStorage` is undefined.
|
||||
|
||||
**Incorrect (visual flickering):**
|
||||
|
||||
```tsx
|
||||
function ThemeWrapper({ children }: { children: ReactNode }) {
|
||||
const [theme, setTheme] = useState('light')
|
||||
|
||||
useEffect(() => {
|
||||
// Runs after hydration - causes visible flash
|
||||
const stored = localStorage.getItem('theme')
|
||||
if (stored) {
|
||||
setTheme(stored)
|
||||
}
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div className={theme}>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
Component first renders with default value (`light`), then updates after hydration, causing a visible flash of incorrect content.
|
||||
|
||||
**Correct (no flicker, no hydration mismatch):**
|
||||
|
||||
```tsx
|
||||
function ThemeWrapper({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<>
|
||||
<div id="theme-wrapper">
|
||||
{children}
|
||||
</div>
|
||||
<script
|
||||
dangerouslySetInnerHTML={{
|
||||
__html: `
|
||||
(function() {
|
||||
try {
|
||||
var theme = localStorage.getItem('theme') || 'light';
|
||||
var el = document.getElementById('theme-wrapper');
|
||||
if (el) el.className = theme;
|
||||
} catch (e) {}
|
||||
})();
|
||||
`,
|
||||
}}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
The inline script executes synchronously before showing the element, ensuring the DOM already has the correct value. No flickering, no hydration mismatch.
|
||||
|
||||
This pattern is especially useful for theme toggles, user preferences, authentication states, and any client-only data that should render immediately without flashing default values.
|
||||
@@ -1,28 +0,0 @@
|
||||
---
|
||||
title: Optimize SVG Precision
|
||||
impact: LOW
|
||||
impactDescription: reduces file size
|
||||
tags: rendering, svg, optimization, svgo
|
||||
---
|
||||
|
||||
## Optimize SVG Precision
|
||||
|
||||
Reduce SVG coordinate precision to decrease file size. The optimal precision depends on the viewBox size, but in general reducing precision should be considered.
|
||||
|
||||
**Incorrect (excessive precision):**
|
||||
|
||||
```svg
|
||||
<path d="M 10.293847 20.847362 L 30.938472 40.192837" />
|
||||
```
|
||||
|
||||
**Correct (1 decimal place):**
|
||||
|
||||
```svg
|
||||
<path d="M 10.3 20.8 L 30.9 40.2" />
|
||||
```
|
||||
|
||||
**Automate with SVGO:**
|
||||
|
||||
```bash
|
||||
npx svgo --precision=1 --multipass icon.svg
|
||||
```
|
||||
@@ -1,39 +0,0 @@
|
||||
---
|
||||
title: Defer State Reads to Usage Point
|
||||
impact: MEDIUM
|
||||
impactDescription: avoids unnecessary subscriptions
|
||||
tags: rerender, searchParams, localStorage, optimization
|
||||
---
|
||||
|
||||
## Defer State Reads to Usage Point
|
||||
|
||||
Don't subscribe to dynamic state (searchParams, localStorage) if you only read it inside callbacks.
|
||||
|
||||
**Incorrect (subscribes to all searchParams changes):**
|
||||
|
||||
```tsx
|
||||
function ShareButton({ chatId }: { chatId: string }) {
|
||||
const searchParams = useSearchParams()
|
||||
|
||||
const handleShare = () => {
|
||||
const ref = searchParams.get('ref')
|
||||
shareChat(chatId, { ref })
|
||||
}
|
||||
|
||||
return <button onClick={handleShare}>Share</button>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (reads on demand, no subscription):**
|
||||
|
||||
```tsx
|
||||
function ShareButton({ chatId }: { chatId: string }) {
|
||||
const handleShare = () => {
|
||||
const params = new URLSearchParams(window.location.search)
|
||||
const ref = params.get('ref')
|
||||
shareChat(chatId, { ref })
|
||||
}
|
||||
|
||||
return <button onClick={handleShare}>Share</button>
|
||||
}
|
||||
```
|
||||
@@ -1,45 +0,0 @@
|
||||
---
|
||||
title: Narrow Effect Dependencies
|
||||
impact: LOW
|
||||
impactDescription: minimizes effect re-runs
|
||||
tags: rerender, useEffect, dependencies, optimization
|
||||
---
|
||||
|
||||
## Narrow Effect Dependencies
|
||||
|
||||
Specify primitive dependencies instead of objects to minimize effect re-runs.
|
||||
|
||||
**Incorrect (re-runs on any user field change):**
|
||||
|
||||
```tsx
|
||||
useEffect(() => {
|
||||
console.log(user.id)
|
||||
}, [user])
|
||||
```
|
||||
|
||||
**Correct (re-runs only when id changes):**
|
||||
|
||||
```tsx
|
||||
useEffect(() => {
|
||||
console.log(user.id)
|
||||
}, [user.id])
|
||||
```
|
||||
|
||||
**For derived state, compute outside effect:**
|
||||
|
||||
```tsx
|
||||
// Incorrect: runs on width=767, 766, 765...
|
||||
useEffect(() => {
|
||||
if (width < 768) {
|
||||
enableMobileMode()
|
||||
}
|
||||
}, [width])
|
||||
|
||||
// Correct: runs only on boolean transition
|
||||
const isMobile = width < 768
|
||||
useEffect(() => {
|
||||
if (isMobile) {
|
||||
enableMobileMode()
|
||||
}
|
||||
}, [isMobile])
|
||||
```
|
||||
@@ -1,29 +0,0 @@
|
||||
---
|
||||
title: Subscribe to Derived State
|
||||
impact: MEDIUM
|
||||
impactDescription: reduces re-render frequency
|
||||
tags: rerender, derived-state, media-query, optimization
|
||||
---
|
||||
|
||||
## Subscribe to Derived State
|
||||
|
||||
Subscribe to derived boolean state instead of continuous values to reduce re-render frequency.
|
||||
|
||||
**Incorrect (re-renders on every pixel change):**
|
||||
|
||||
```tsx
|
||||
function Sidebar() {
|
||||
const width = useWindowWidth() // updates continuously
|
||||
const isMobile = width < 768
|
||||
return <nav className={isMobile ? 'mobile' : 'desktop'}>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (re-renders only when boolean changes):**
|
||||
|
||||
```tsx
|
||||
function Sidebar() {
|
||||
const isMobile = useMediaQuery('(max-width: 767px)')
|
||||
return <nav className={isMobile ? 'mobile' : 'desktop'}>
|
||||
}
|
||||
```
|
||||
@@ -1,74 +0,0 @@
|
||||
---
|
||||
title: Use Functional setState Updates
|
||||
impact: MEDIUM
|
||||
impactDescription: prevents stale closures and unnecessary callback recreations
|
||||
tags: react, hooks, useState, useCallback, callbacks, closures
|
||||
---
|
||||
|
||||
## Use Functional setState Updates
|
||||
|
||||
When updating state based on the current state value, use the functional update form of setState instead of directly referencing the state variable. This prevents stale closures, eliminates unnecessary dependencies, and creates stable callback references.
|
||||
|
||||
**Incorrect (requires state as dependency):**
|
||||
|
||||
```tsx
|
||||
function TodoList() {
|
||||
const [items, setItems] = useState(initialItems)
|
||||
|
||||
// Callback must depend on items, recreated on every items change
|
||||
const addItems = useCallback((newItems: Item[]) => {
|
||||
setItems([...items, ...newItems])
|
||||
}, [items]) // ❌ items dependency causes recreations
|
||||
|
||||
// Risk of stale closure if dependency is forgotten
|
||||
const removeItem = useCallback((id: string) => {
|
||||
setItems(items.filter(item => item.id !== id))
|
||||
}, []) // ❌ Missing items dependency - will use stale items!
|
||||
|
||||
return <ItemsEditor items={items} onAdd={addItems} onRemove={removeItem} />
|
||||
}
|
||||
```
|
||||
|
||||
The first callback is recreated every time `items` changes, which can cause child components to re-render unnecessarily. The second callback has a stale closure bug—it will always reference the initial `items` value.
|
||||
|
||||
**Correct (stable callbacks, no stale closures):**
|
||||
|
||||
```tsx
|
||||
function TodoList() {
|
||||
const [items, setItems] = useState(initialItems)
|
||||
|
||||
// Stable callback, never recreated
|
||||
const addItems = useCallback((newItems: Item[]) => {
|
||||
setItems(curr => [...curr, ...newItems])
|
||||
}, []) // ✅ No dependencies needed
|
||||
|
||||
// Always uses latest state, no stale closure risk
|
||||
const removeItem = useCallback((id: string) => {
|
||||
setItems(curr => curr.filter(item => item.id !== id))
|
||||
}, []) // ✅ Safe and stable
|
||||
|
||||
return <ItemsEditor items={items} onAdd={addItems} onRemove={removeItem} />
|
||||
}
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
|
||||
1. **Stable callback references** - Callbacks don't need to be recreated when state changes
|
||||
2. **No stale closures** - Always operates on the latest state value
|
||||
3. **Fewer dependencies** - Simplifies dependency arrays and reduces memory leaks
|
||||
4. **Prevents bugs** - Eliminates the most common source of React closure bugs
|
||||
|
||||
**When to use functional updates:**
|
||||
|
||||
- Any setState that depends on the current state value
|
||||
- Inside useCallback/useMemo when state is needed
|
||||
- Event handlers that reference state
|
||||
- Async operations that update state
|
||||
|
||||
**When direct updates are fine:**
|
||||
|
||||
- Setting state to a static value: `setCount(0)`
|
||||
- Setting state from props/arguments only: `setName(newName)`
|
||||
- State doesn't depend on previous value
|
||||
|
||||
**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, the compiler can automatically optimize some cases, but functional updates are still recommended for correctness and to prevent stale closure bugs.
|
||||
@@ -1,58 +0,0 @@
|
||||
---
|
||||
title: Use Lazy State Initialization
|
||||
impact: MEDIUM
|
||||
impactDescription: wasted computation on every render
|
||||
tags: react, hooks, useState, performance, initialization
|
||||
---
|
||||
|
||||
## Use Lazy State Initialization
|
||||
|
||||
Pass a function to `useState` for expensive initial values. Without the function form, the initializer runs on every render even though the value is only used once.
|
||||
|
||||
**Incorrect (runs on every render):**
|
||||
|
||||
```tsx
|
||||
function FilteredList({ items }: { items: Item[] }) {
|
||||
// buildSearchIndex() runs on EVERY render, even after initialization
|
||||
const [searchIndex, setSearchIndex] = useState(buildSearchIndex(items))
|
||||
const [query, setQuery] = useState('')
|
||||
|
||||
// When query changes, buildSearchIndex runs again unnecessarily
|
||||
return <SearchResults index={searchIndex} query={query} />
|
||||
}
|
||||
|
||||
function UserProfile() {
|
||||
// JSON.parse runs on every render
|
||||
const [settings, setSettings] = useState(
|
||||
JSON.parse(localStorage.getItem('settings') || '{}')
|
||||
)
|
||||
|
||||
return <SettingsForm settings={settings} onChange={setSettings} />
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (runs only once):**
|
||||
|
||||
```tsx
|
||||
function FilteredList({ items }: { items: Item[] }) {
|
||||
// buildSearchIndex() runs ONLY on initial render
|
||||
const [searchIndex, setSearchIndex] = useState(() => buildSearchIndex(items))
|
||||
const [query, setQuery] = useState('')
|
||||
|
||||
return <SearchResults index={searchIndex} query={query} />
|
||||
}
|
||||
|
||||
function UserProfile() {
|
||||
// JSON.parse runs only on initial render
|
||||
const [settings, setSettings] = useState(() => {
|
||||
const stored = localStorage.getItem('settings')
|
||||
return stored ? JSON.parse(stored) : {}
|
||||
})
|
||||
|
||||
return <SettingsForm settings={settings} onChange={setSettings} />
|
||||
}
|
||||
```
|
||||
|
||||
Use lazy initialization when computing initial values from localStorage/sessionStorage, building data structures (indexes, maps), reading from the DOM, or performing heavy transformations.
|
||||
|
||||
For simple primitives (`useState(0)`), direct references (`useState(props.value)`), or cheap literals (`useState({})`), the function form is unnecessary.
|
||||
@@ -1,44 +0,0 @@
|
||||
---
|
||||
title: Extract to Memoized Components
|
||||
impact: MEDIUM
|
||||
impactDescription: enables early returns
|
||||
tags: rerender, memo, useMemo, optimization
|
||||
---
|
||||
|
||||
## Extract to Memoized Components
|
||||
|
||||
Extract expensive work into memoized components to enable early returns before computation.
|
||||
|
||||
**Incorrect (computes avatar even when loading):**
|
||||
|
||||
```tsx
|
||||
function Profile({ user, loading }: Props) {
|
||||
const avatar = useMemo(() => {
|
||||
const id = computeAvatarId(user)
|
||||
return <Avatar id={id} />
|
||||
}, [user])
|
||||
|
||||
if (loading) return <Skeleton />
|
||||
return <div>{avatar}</div>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (skips computation when loading):**
|
||||
|
||||
```tsx
|
||||
const UserAvatar = memo(function UserAvatar({ user }: { user: User }) {
|
||||
const id = useMemo(() => computeAvatarId(user), [user])
|
||||
return <Avatar id={id} />
|
||||
})
|
||||
|
||||
function Profile({ user, loading }: Props) {
|
||||
if (loading) return <Skeleton />
|
||||
return (
|
||||
<div>
|
||||
<UserAvatar user={user} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, manual memoization with `memo()` and `useMemo()` is not necessary. The compiler automatically optimizes re-renders.
|
||||
@@ -1,40 +0,0 @@
|
||||
---
|
||||
title: Use Transitions for Non-Urgent Updates
|
||||
impact: MEDIUM
|
||||
impactDescription: maintains UI responsiveness
|
||||
tags: rerender, transitions, startTransition, performance
|
||||
---
|
||||
|
||||
## Use Transitions for Non-Urgent Updates
|
||||
|
||||
Mark frequent, non-urgent state updates as transitions to maintain UI responsiveness.
|
||||
|
||||
**Incorrect (blocks UI on every scroll):**
|
||||
|
||||
```tsx
|
||||
function ScrollTracker() {
|
||||
const [scrollY, setScrollY] = useState(0)
|
||||
useEffect(() => {
|
||||
const handler = () => setScrollY(window.scrollY)
|
||||
window.addEventListener('scroll', handler, { passive: true })
|
||||
return () => window.removeEventListener('scroll', handler)
|
||||
}, [])
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (non-blocking updates):**
|
||||
|
||||
```tsx
|
||||
import { startTransition } from 'react'
|
||||
|
||||
function ScrollTracker() {
|
||||
const [scrollY, setScrollY] = useState(0)
|
||||
useEffect(() => {
|
||||
const handler = () => {
|
||||
startTransition(() => setScrollY(window.scrollY))
|
||||
}
|
||||
window.addEventListener('scroll', handler, { passive: true })
|
||||
return () => window.removeEventListener('scroll', handler)
|
||||
}, [])
|
||||
}
|
||||
```
|
||||
@@ -1,73 +0,0 @@
|
||||
---
|
||||
title: Use after() for Non-Blocking Operations
|
||||
impact: MEDIUM
|
||||
impactDescription: faster response times
|
||||
tags: server, async, logging, analytics, side-effects
|
||||
---
|
||||
|
||||
## Use after() for Non-Blocking Operations
|
||||
|
||||
Use Next.js's `after()` to schedule work that should execute after a response is sent. This prevents logging, analytics, and other side effects from blocking the response.
|
||||
|
||||
**Incorrect (blocks response):**
|
||||
|
||||
```tsx
|
||||
import { logUserAction } from '@/app/utils'
|
||||
|
||||
export async function POST(request: Request) {
|
||||
// Perform mutation
|
||||
await updateDatabase(request)
|
||||
|
||||
// Logging blocks the response
|
||||
const userAgent = request.headers.get('user-agent') || 'unknown'
|
||||
await logUserAction({ userAgent })
|
||||
|
||||
return new Response(JSON.stringify({ status: 'success' }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (non-blocking):**
|
||||
|
||||
```tsx
|
||||
import { after } from 'next/server'
|
||||
import { headers, cookies } from 'next/headers'
|
||||
import { logUserAction } from '@/app/utils'
|
||||
|
||||
export async function POST(request: Request) {
|
||||
// Perform mutation
|
||||
await updateDatabase(request)
|
||||
|
||||
// Log after response is sent
|
||||
after(async () => {
|
||||
const userAgent = (await headers()).get('user-agent') || 'unknown'
|
||||
const sessionCookie = (await cookies()).get('session-id')?.value || 'anonymous'
|
||||
|
||||
logUserAction({ sessionCookie, userAgent })
|
||||
})
|
||||
|
||||
return new Response(JSON.stringify({ status: 'success' }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
The response is sent immediately while logging happens in the background.
|
||||
|
||||
**Common use cases:**
|
||||
|
||||
- Analytics tracking
|
||||
- Audit logging
|
||||
- Sending notifications
|
||||
- Cache invalidation
|
||||
- Cleanup tasks
|
||||
|
||||
**Important notes:**
|
||||
|
||||
- `after()` runs even if the response fails or redirects
|
||||
- Works in Server Actions, Route Handlers, and Server Components
|
||||
|
||||
Reference: [https://nextjs.org/docs/app/api-reference/functions/after](https://nextjs.org/docs/app/api-reference/functions/after)
|
||||
@@ -1,41 +0,0 @@
|
||||
---
|
||||
title: Cross-Request LRU Caching
|
||||
impact: HIGH
|
||||
impactDescription: caches across requests
|
||||
tags: server, cache, lru, cross-request
|
||||
---
|
||||
|
||||
## Cross-Request LRU Caching
|
||||
|
||||
`React.cache()` only works within one request. For data shared across sequential requests (user clicks button A then button B), use an LRU cache.
|
||||
|
||||
**Implementation:**
|
||||
|
||||
```typescript
|
||||
import { LRUCache } from 'lru-cache'
|
||||
|
||||
const cache = new LRUCache<string, any>({
|
||||
max: 1000,
|
||||
ttl: 5 * 60 * 1000 // 5 minutes
|
||||
})
|
||||
|
||||
export async function getUser(id: string) {
|
||||
const cached = cache.get(id)
|
||||
if (cached) return cached
|
||||
|
||||
const user = await db.user.findUnique({ where: { id } })
|
||||
cache.set(id, user)
|
||||
return user
|
||||
}
|
||||
|
||||
// Request 1: DB query, result cached
|
||||
// Request 2: cache hit, no DB query
|
||||
```
|
||||
|
||||
Use when sequential user actions hit multiple endpoints needing the same data within seconds.
|
||||
|
||||
**With Vercel's [Fluid Compute](https://vercel.com/docs/fluid-compute):** LRU caching is especially effective because multiple concurrent requests can share the same function instance and cache. This means the cache persists across requests without needing external storage like Redis.
|
||||
|
||||
**In traditional serverless:** Each invocation runs in isolation, so consider Redis for cross-process caching.
|
||||
|
||||
Reference: [https://github.com/isaacs/node-lru-cache](https://github.com/isaacs/node-lru-cache)
|
||||
@@ -1,26 +0,0 @@
|
||||
---
|
||||
title: Per-Request Deduplication with React.cache()
|
||||
impact: MEDIUM
|
||||
impactDescription: deduplicates within request
|
||||
tags: server, cache, react-cache, deduplication
|
||||
---
|
||||
|
||||
## Per-Request Deduplication with React.cache()
|
||||
|
||||
Use `React.cache()` for server-side request deduplication. Authentication and database queries benefit most.
|
||||
|
||||
**Usage:**
|
||||
|
||||
```typescript
|
||||
import { cache } from 'react'
|
||||
|
||||
export const getCurrentUser = cache(async () => {
|
||||
const session = await auth()
|
||||
if (!session?.user?.id) return null
|
||||
return await db.user.findUnique({
|
||||
where: { id: session.user.id }
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
Within a single request, multiple calls to `getCurrentUser()` execute the query only once.
|
||||
@@ -1,79 +0,0 @@
|
||||
---
|
||||
title: Parallel Data Fetching with Component Composition
|
||||
impact: CRITICAL
|
||||
impactDescription: eliminates server-side waterfalls
|
||||
tags: server, rsc, parallel-fetching, composition
|
||||
---
|
||||
|
||||
## Parallel Data Fetching with Component Composition
|
||||
|
||||
React Server Components execute sequentially within a tree. Restructure with composition to parallelize data fetching.
|
||||
|
||||
**Incorrect (Sidebar waits for Page's fetch to complete):**
|
||||
|
||||
```tsx
|
||||
export default async function Page() {
|
||||
const header = await fetchHeader()
|
||||
return (
|
||||
<div>
|
||||
<div>{header}</div>
|
||||
<Sidebar />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
async function Sidebar() {
|
||||
const items = await fetchSidebarItems()
|
||||
return <nav>{items.map(renderItem)}</nav>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (both fetch simultaneously):**
|
||||
|
||||
```tsx
|
||||
async function Header() {
|
||||
const data = await fetchHeader()
|
||||
return <div>{data}</div>
|
||||
}
|
||||
|
||||
async function Sidebar() {
|
||||
const items = await fetchSidebarItems()
|
||||
return <nav>{items.map(renderItem)}</nav>
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<div>
|
||||
<Header />
|
||||
<Sidebar />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Alternative with children prop:**
|
||||
|
||||
```tsx
|
||||
async function Layout({ children }: { children: ReactNode }) {
|
||||
const header = await fetchHeader()
|
||||
return (
|
||||
<div>
|
||||
<div>{header}</div>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
async function Sidebar() {
|
||||
const items = await fetchSidebarItems()
|
||||
return <nav>{items.map(renderItem)}</nav>
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<Layout>
|
||||
<Sidebar />
|
||||
</Layout>
|
||||
)
|
||||
}
|
||||
```
|
||||
@@ -1,38 +0,0 @@
|
||||
---
|
||||
title: Minimize Serialization at RSC Boundaries
|
||||
impact: HIGH
|
||||
impactDescription: reduces data transfer size
|
||||
tags: server, rsc, serialization, props
|
||||
---
|
||||
|
||||
## Minimize Serialization at RSC Boundaries
|
||||
|
||||
The React Server/Client boundary serializes all object properties into strings and embeds them in the HTML response and subsequent RSC requests. This serialized data directly impacts page weight and load time, so **size matters a lot**. Only pass fields that the client actually uses.
|
||||
|
||||
**Incorrect (serializes all 50 fields):**
|
||||
|
||||
```tsx
|
||||
async function Page() {
|
||||
const user = await fetchUser() // 50 fields
|
||||
return <Profile user={user} />
|
||||
}
|
||||
|
||||
'use client'
|
||||
function Profile({ user }: { user: User }) {
|
||||
return <div>{user.name}</div> // uses 1 field
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (serializes only 1 field):**
|
||||
|
||||
```tsx
|
||||
async function Page() {
|
||||
const user = await fetchUser()
|
||||
return <Profile name={user.name} />
|
||||
}
|
||||
|
||||
'use client'
|
||||
function Profile({ name }: { name: string }) {
|
||||
return <div>{name}</div>
|
||||
}
|
||||
```
|
||||
@@ -1,9 +1,6 @@
|
||||
# Ignore everything by default, selectively add things to context
|
||||
*
|
||||
|
||||
# Documentation (for embeddings/search)
|
||||
!docs/
|
||||
|
||||
# Platform - Libs
|
||||
!autogpt_platform/autogpt_libs/autogpt_libs/
|
||||
!autogpt_platform/autogpt_libs/pyproject.toml
|
||||
@@ -19,7 +16,6 @@
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
|
||||
6
.github/copilot-instructions.md
vendored
6
.github/copilot-instructions.md
vendored
@@ -160,7 +160,7 @@ pnpm storybook # Start component development server
|
||||
|
||||
**Backend Entry Points:**
|
||||
|
||||
- `backend/backend/api/rest_api.py` - FastAPI application setup
|
||||
- `backend/backend/server/server.py` - FastAPI application setup
|
||||
- `backend/backend/data/` - Database models and user management
|
||||
- `backend/blocks/` - Agent execution blocks and logic
|
||||
|
||||
@@ -219,7 +219,7 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
|
||||
### API Development
|
||||
|
||||
1. Update routes in `/backend/backend/api/features/`
|
||||
1. Update routes in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside route files
|
||||
4. For `data/*.py` changes, validate user ID checks
|
||||
@@ -285,7 +285,7 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
|
||||
### Security Guidelines
|
||||
|
||||
**Cache Protection Middleware** (`/backend/backend/api/middleware/security.py`):
|
||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
||||
|
||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||
|
||||
2
.github/workflows/classic-frontend-ci.yml
vendored
2
.github/workflows/classic-frontend-ci.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
|
||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||
if: github.event_name == 'push'
|
||||
uses: peter-evans/create-pull-request@v8
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
add-paths: classic/frontend/build/web
|
||||
base: ${{ github.ref_name }}
|
||||
|
||||
@@ -42,7 +42,7 @@ jobs:
|
||||
|
||||
- name: Get CI failure details
|
||||
id: failure_details
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const run = await github.rest.actions.getWorkflowRun({
|
||||
@@ -93,5 +93,5 @@ jobs:
|
||||
|
||||
Error logs:
|
||||
${{ toJSON(fromJSON(steps.failure_details.outputs.result).errorLogs) }}
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: "--allowedTools 'Edit,MultiEdit,Write,Read,Glob,Grep,LS,Bash(git:*),Bash(bun:*),Bash(npm:*),Bash(npx:*),Bash(gh:*)'"
|
||||
|
||||
15
.github/workflows/claude-dependabot.yml
vendored
15
.github/workflows/claude-dependabot.yml
vendored
@@ -7,7 +7,7 @@
|
||||
# - Provide actionable recommendations for the development team
|
||||
#
|
||||
# Triggered on: Dependabot PRs (opened, synchronize)
|
||||
# Requirements: CLAUDE_CODE_OAUTH_TOKEN secret must be configured
|
||||
# Requirements: ANTHROPIC_API_KEY secret must be configured
|
||||
|
||||
name: Claude Dependabot PR Review
|
||||
|
||||
@@ -41,7 +41,7 @@ jobs:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -74,11 +74,11 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
@@ -91,7 +91,7 @@ jobs:
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
@@ -124,7 +124,7 @@ jobs:
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
@@ -308,8 +308,7 @@ jobs:
|
||||
id: claude_review
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
allowed_bots: "dependabot[bot]"
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||
prompt: |
|
||||
|
||||
12
.github/workflows/claude.yml
vendored
12
.github/workflows/claude.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -90,11 +90,11 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
@@ -107,7 +107,7 @@ jobs:
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
@@ -140,7 +140,7 @@ jobs:
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
@@ -323,7 +323,7 @@ jobs:
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*), Bash(gh pr edit:*)"
|
||||
--model opus
|
||||
|
||||
20
.github/workflows/copilot-setup-steps.yml
vendored
20
.github/workflows/copilot-setup-steps.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -72,11 +72,11 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
@@ -89,7 +89,7 @@ jobs:
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
@@ -108,16 +108,6 @@ jobs:
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
# Remove large unused tools to free disk space for Docker builds
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
@@ -132,7 +122,7 @@ jobs:
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
|
||||
78
.github/workflows/docs-block-sync.yml
vendored
78
.github/workflows/docs-block-sync.yml
vendored
@@ -1,78 +0,0 @@
|
||||
name: Block Documentation Sync Check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- "autogpt_platform/backend/backend/blocks/**"
|
||||
- "docs/integrations/**"
|
||||
- "autogpt_platform/backend/scripts/generate_block_docs.py"
|
||||
- ".github/workflows/docs-block-sync.yml"
|
||||
pull_request:
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- "autogpt_platform/backend/backend/blocks/**"
|
||||
- "docs/integrations/**"
|
||||
- "autogpt_platform/backend/scripts/generate_block_docs.py"
|
||||
- ".github/workflows/docs-block-sync.yml"
|
||||
|
||||
jobs:
|
||||
check-docs-sync:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
restore-keys: |
|
||||
poetry-${{ runner.os }}-
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry install --only main
|
||||
poetry run prisma generate
|
||||
|
||||
- name: Check block documentation is in sync
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
echo "Checking if block documentation is in sync with code..."
|
||||
poetry run python scripts/generate_block_docs.py --check
|
||||
|
||||
- name: Show diff if out of sync
|
||||
if: failure()
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
echo "::error::Block documentation is out of sync with code!"
|
||||
echo ""
|
||||
echo "To fix this, run the following command locally:"
|
||||
echo " cd autogpt_platform/backend && poetry run python scripts/generate_block_docs.py"
|
||||
echo ""
|
||||
echo "Then commit the updated documentation files."
|
||||
echo ""
|
||||
echo "Regenerating docs to show diff..."
|
||||
poetry run python scripts/generate_block_docs.py
|
||||
echo ""
|
||||
echo "Changes detected:"
|
||||
git diff ../../docs/integrations/ || true
|
||||
95
.github/workflows/docs-claude-review.yml
vendored
95
.github/workflows/docs-claude-review.yml
vendored
@@ -1,95 +0,0 @@
|
||||
name: Claude Block Docs Review
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
paths:
|
||||
- "docs/integrations/**"
|
||||
- "autogpt_platform/backend/backend/blocks/**"
|
||||
|
||||
jobs:
|
||||
claude-review:
|
||||
# Only run for PRs from members/collaborators
|
||||
if: |
|
||||
github.event.pull_request.author_association == 'OWNER' ||
|
||||
github.event.pull_request.author_association == 'MEMBER' ||
|
||||
github.event.pull_request.author_association == 'COLLABORATOR'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
restore-keys: |
|
||||
poetry-${{ runner.os }}-
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry install --only main
|
||||
poetry run prisma generate
|
||||
|
||||
- name: Run Claude Code Review
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
claude_args: |
|
||||
--allowedTools "Read,Glob,Grep,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*)"
|
||||
prompt: |
|
||||
You are reviewing a PR that modifies block documentation or block code for AutoGPT.
|
||||
|
||||
## Your Task
|
||||
Review the changes in this PR and provide constructive feedback. Focus on:
|
||||
|
||||
1. **Documentation Accuracy**: For any block code changes, verify that:
|
||||
- Input/output tables in docs match the actual block schemas
|
||||
- Description text accurately reflects what the block does
|
||||
- Any new blocks have corresponding documentation
|
||||
|
||||
2. **Manual Content Quality**: Check manual sections (marked with `<!-- MANUAL: -->` markers):
|
||||
- "How it works" sections should have clear technical explanations
|
||||
- "Possible use case" sections should have practical, real-world examples
|
||||
- Content should be helpful for users trying to understand the blocks
|
||||
|
||||
3. **Template Compliance**: Ensure docs follow the standard template:
|
||||
- What it is (brief intro)
|
||||
- What it does (description)
|
||||
- How it works (technical explanation)
|
||||
- Inputs table
|
||||
- Outputs table
|
||||
- Possible use case
|
||||
|
||||
4. **Cross-references**: Check that links and anchors are correct
|
||||
|
||||
## Review Process
|
||||
1. First, get the PR diff to see what changed: `gh pr diff ${{ github.event.pull_request.number }}`
|
||||
2. Read any modified block files to understand the implementation
|
||||
3. Read corresponding documentation files to verify accuracy
|
||||
4. Provide your feedback as a PR comment
|
||||
|
||||
Be constructive and specific. If everything looks good, say so!
|
||||
If there are issues, explain what's wrong and suggest how to fix it.
|
||||
194
.github/workflows/docs-enhance.yml
vendored
194
.github/workflows/docs-enhance.yml
vendored
@@ -1,194 +0,0 @@
|
||||
name: Enhance Block Documentation
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
block_pattern:
|
||||
description: 'Block file pattern to enhance (e.g., "google/*.md" or "*" for all blocks)'
|
||||
required: true
|
||||
default: '*'
|
||||
type: string
|
||||
dry_run:
|
||||
description: 'Dry run mode - show proposed changes without committing'
|
||||
type: boolean
|
||||
default: true
|
||||
max_blocks:
|
||||
description: 'Maximum number of blocks to process (0 for unlimited)'
|
||||
type: number
|
||||
default: 10
|
||||
|
||||
jobs:
|
||||
enhance-docs:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
restore-keys: |
|
||||
poetry-${{ runner.os }}-
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry install --only main
|
||||
poetry run prisma generate
|
||||
|
||||
- name: Run Claude Enhancement
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
claude_args: |
|
||||
--allowedTools "Read,Edit,Glob,Grep,Write,Bash(git:*),Bash(gh:*),Bash(find:*),Bash(ls:*)"
|
||||
prompt: |
|
||||
You are enhancing block documentation for AutoGPT. Your task is to improve the MANUAL sections
|
||||
of block documentation files by reading the actual block implementations and writing helpful content.
|
||||
|
||||
## Configuration
|
||||
- Block pattern: ${{ inputs.block_pattern }}
|
||||
- Dry run: ${{ inputs.dry_run }}
|
||||
- Max blocks to process: ${{ inputs.max_blocks }}
|
||||
|
||||
## Your Task
|
||||
|
||||
1. **Find Documentation Files**
|
||||
Find block documentation files matching the pattern in `docs/integrations/`
|
||||
Pattern: ${{ inputs.block_pattern }}
|
||||
|
||||
Use: `find docs/integrations -name "*.md" -type f`
|
||||
|
||||
2. **For Each Documentation File** (up to ${{ inputs.max_blocks }} files):
|
||||
|
||||
a. Read the documentation file
|
||||
|
||||
b. Identify which block(s) it documents (look for the block class name)
|
||||
|
||||
c. Find and read the corresponding block implementation in `autogpt_platform/backend/backend/blocks/`
|
||||
|
||||
d. Improve the MANUAL sections:
|
||||
|
||||
**"How it works" section** (within `<!-- MANUAL: how_it_works -->` markers):
|
||||
- Explain the technical flow of the block
|
||||
- Describe what APIs or services it connects to
|
||||
- Note any important configuration or prerequisites
|
||||
- Keep it concise but informative (2-4 paragraphs)
|
||||
|
||||
**"Possible use case" section** (within `<!-- MANUAL: use_case -->` markers):
|
||||
- Provide 2-3 practical, real-world examples
|
||||
- Make them specific and actionable
|
||||
- Show how this block could be used in an automation workflow
|
||||
|
||||
3. **Important Rules**
|
||||
- ONLY modify content within `<!-- MANUAL: -->` and `<!-- END MANUAL -->` markers
|
||||
- Do NOT modify auto-generated sections (inputs/outputs tables, descriptions)
|
||||
- Keep content accurate based on the actual block implementation
|
||||
- Write for users who may not be technical experts
|
||||
|
||||
4. **Output**
|
||||
${{ inputs.dry_run == true && 'DRY RUN MODE: Show proposed changes for each file but do NOT actually edit the files. Describe what you would change.' || 'LIVE MODE: Actually edit the files to improve the documentation.' }}
|
||||
|
||||
## Example Improvements
|
||||
|
||||
**Before (How it works):**
|
||||
```
|
||||
_Add technical explanation here._
|
||||
```
|
||||
|
||||
**After (How it works):**
|
||||
```
|
||||
This block connects to the GitHub API to retrieve issue information. When executed,
|
||||
it authenticates using your GitHub credentials and fetches issue details including
|
||||
title, body, labels, and assignees.
|
||||
|
||||
The block requires a valid GitHub OAuth connection with repository access permissions.
|
||||
It supports both public and private repositories you have access to.
|
||||
```
|
||||
|
||||
**Before (Possible use case):**
|
||||
```
|
||||
_Add practical use case examples here._
|
||||
```
|
||||
|
||||
**After (Possible use case):**
|
||||
```
|
||||
**Customer Support Automation**: Monitor a GitHub repository for new issues with
|
||||
the "bug" label, then automatically create a ticket in your support system and
|
||||
notify the on-call engineer via Slack.
|
||||
|
||||
**Release Notes Generation**: When a new release is published, gather all closed
|
||||
issues since the last release and generate a summary for your changelog.
|
||||
```
|
||||
|
||||
Begin by finding and listing the documentation files to process.
|
||||
|
||||
- name: Create PR with enhanced documentation
|
||||
if: ${{ inputs.dry_run == false }}
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
# Check if there are changes
|
||||
if git diff --quiet docs/integrations/; then
|
||||
echo "No changes to commit"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Configure git
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
# Create branch and commit
|
||||
BRANCH_NAME="docs/enhance-blocks-$(date +%Y%m%d-%H%M%S)"
|
||||
git checkout -b "$BRANCH_NAME"
|
||||
git add docs/integrations/
|
||||
git commit -m "docs: enhance block documentation with LLM-generated content
|
||||
|
||||
Pattern: ${{ inputs.block_pattern }}
|
||||
Max blocks: ${{ inputs.max_blocks }}
|
||||
|
||||
🤖 Generated with [Claude Code](https://claude.com/claude-code)
|
||||
|
||||
Co-Authored-By: Claude <noreply@anthropic.com>"
|
||||
|
||||
# Push and create PR
|
||||
git push -u origin "$BRANCH_NAME"
|
||||
gh pr create \
|
||||
--title "docs: LLM-enhanced block documentation" \
|
||||
--body "## Summary
|
||||
This PR contains LLM-enhanced documentation for block files matching pattern: \`${{ inputs.block_pattern }}\`
|
||||
|
||||
The following manual sections were improved:
|
||||
- **How it works**: Technical explanations based on block implementations
|
||||
- **Possible use case**: Practical, real-world examples
|
||||
|
||||
## Review Checklist
|
||||
- [ ] Content is accurate based on block implementations
|
||||
- [ ] Examples are practical and helpful
|
||||
- [ ] No auto-generated sections were modified
|
||||
|
||||
---
|
||||
🤖 Generated with [Claude Code](https://claude.com/claude-code)" \
|
||||
--base dev
|
||||
6
.github/workflows/platform-backend-ci.yml
vendored
6
.github/workflows/platform-backend-ci.yml
vendored
@@ -88,7 +88,7 @@ jobs:
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -134,7 +134,7 @@ jobs:
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
run: poetry run prisma generate
|
||||
|
||||
- id: supabase
|
||||
name: Start Supabase
|
||||
@@ -176,7 +176,7 @@ jobs:
|
||||
}
|
||||
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate deploy
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: Check comment permissions and deployment status
|
||||
id: check_status
|
||||
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const commentBody = context.payload.comment.body.trim();
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
|
||||
- name: Post permission denied comment
|
||||
if: steps.check_status.outputs.permission_denied == 'true'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
- name: Get PR details for deployment
|
||||
id: pr_details
|
||||
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const pr = await github.rest.pulls.get({
|
||||
@@ -98,7 +98,7 @@ jobs:
|
||||
|
||||
- name: Post deploy success comment
|
||||
if: steps.check_status.outputs.should_deploy == 'true'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
@@ -126,7 +126,7 @@ jobs:
|
||||
|
||||
- name: Post undeploy success comment
|
||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
@@ -139,7 +139,7 @@ jobs:
|
||||
- name: Check deployment status on PR close
|
||||
id: check_pr_close
|
||||
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const comments = await github.rest.issues.listComments({
|
||||
@@ -187,7 +187,7 @@ jobs:
|
||||
github.event_name == 'pull_request' &&
|
||||
github.event.action == 'closed' &&
|
||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
|
||||
97
.github/workflows/platform-frontend-ci.yml
vendored
97
.github/workflows/platform-frontend-ci.yml
vendored
@@ -11,7 +11,6 @@ on:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event_name == 'merge_group' && format('merge-queue-{0}', github.ref) || format('{0}-{1}', github.ref, github.event.pull_request.number || github.sha) }}
|
||||
@@ -27,22 +26,13 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
components-changed: ${{ steps.filter.outputs.components }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check for component changes
|
||||
uses: dorny/paths-filter@v3
|
||||
id: filter
|
||||
with:
|
||||
filters: |
|
||||
components:
|
||||
- 'autogpt_platform/frontend/src/components/**'
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -54,7 +44,7 @@ jobs:
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
@@ -74,7 +64,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -82,7 +72,7 @@ jobs:
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
@@ -99,11 +89,8 @@ jobs:
|
||||
chromatic:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
# Disabled: to re-enable, remove 'false &&' from the condition below
|
||||
if: >-
|
||||
false
|
||||
&& (github.ref == 'refs/heads/dev' || github.base_ref == 'dev')
|
||||
&& needs.setup.outputs.components-changed == 'true'
|
||||
# Only run on dev branch pushes or PRs targeting dev
|
||||
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -112,7 +99,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -120,7 +107,7 @@ jobs:
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
@@ -140,7 +127,7 @@ jobs:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
exitOnceUploaded: true
|
||||
|
||||
e2e_test:
|
||||
test:
|
||||
runs-on: big-boi
|
||||
needs: setup
|
||||
strategy:
|
||||
@@ -153,7 +140,7 @@ jobs:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -164,19 +151,11 @@ jobs:
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Cache Docker layers
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /tmp/.buildx-cache
|
||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||
@@ -231,7 +210,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
@@ -247,62 +226,14 @@ jobs:
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
- name: Upload Playwright artifacts
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Upload Playwright test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.yml logs
|
||||
|
||||
integration_test:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Generate API client
|
||||
run: pnpm generate:api
|
||||
|
||||
- name: Run Integration Tests
|
||||
run: pnpm test:unit
|
||||
|
||||
12
.github/workflows/platform-fullstack-ci.yml
vendored
12
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
@@ -56,7 +56,7 @@ jobs:
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
types:
|
||||
runs-on: big-boi
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
@@ -85,10 +85,10 @@ jobs:
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
|
||||
@@ -11,7 +11,7 @@ jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v10
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
# operations-per-run: 5000
|
||||
stale-issue-message: >
|
||||
|
||||
2
.github/workflows/repo-pr-label.yml
vendored
2
.github/workflows/repo-pr-label.yml
vendored
@@ -61,6 +61,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v6
|
||||
- uses: actions/labeler@v5
|
||||
with:
|
||||
sync-labels: true
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -178,6 +178,4 @@ autogpt_platform/backend/settings.py
|
||||
*.ign.*
|
||||
.test-contents
|
||||
.claude/settings.local.json
|
||||
CLAUDE.local.md
|
||||
/autogpt_platform/backend/logs
|
||||
.next
|
||||
46
AGENTS.md
46
AGENTS.md
@@ -16,34 +16,6 @@ See `docs/content/platform/getting-started.md` for setup instructions.
|
||||
- Format Python code with `poetry run format`.
|
||||
- Format frontend code using `pnpm format`.
|
||||
|
||||
## Frontend guidelines:
|
||||
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Add `usePageName.ts` hook for logic
|
||||
- Put sub-components in local `components/` folder
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
|
||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
||||
- Use function declarations for components, arrow functions only for callbacks
|
||||
- No barrel files or `index.ts` re-exports
|
||||
- Avoid comments at all times unless the code is very complex
|
||||
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
|
||||
- Do not type hook returns, let Typescript infer as much as possible
|
||||
- Never type with `any`, if not types available use `unknown`
|
||||
|
||||
## Testing
|
||||
|
||||
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
|
||||
@@ -51,8 +23,22 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
Always run the relevant linters and tests before committing.
|
||||
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
||||
Types: - feat - fix - refactor - ci - dx (developer experience)
|
||||
Scopes: - platform - platform/library - platform/marketplace - backend - backend/executor - frontend - frontend/library - frontend/marketplace - blocks
|
||||
Types:
|
||||
- feat
|
||||
- fix
|
||||
- refactor
|
||||
- ci
|
||||
- dx (developer experience)
|
||||
Scopes:
|
||||
- platform
|
||||
- platform/library
|
||||
- platform/marketplace
|
||||
- backend
|
||||
- backend/executor
|
||||
- frontend
|
||||
- frontend/library
|
||||
- frontend/marketplace
|
||||
- blocks
|
||||
|
||||
## Pull requests
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ Before proceeding with the installation, ensure your system meets the following
|
||||
### Updated Setup Instructions:
|
||||
We've moved to a fully maintained and regularly updated documentation site.
|
||||
|
||||
👉 [Follow the official self-hosting guide here](https://agpt.co/docs/platform/getting-started/getting-started)
|
||||
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
|
||||
|
||||
|
||||
This tutorial assumes you have Docker, VSCode, git and npm installed.
|
||||
|
||||
@@ -6,30 +6,152 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
- **Backend** (`/backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`/frontend`): Next.js React application
|
||||
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
|
||||
|
||||
## Component Documentation
|
||||
## Essential Commands
|
||||
|
||||
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
|
||||
### Backend Development
|
||||
|
||||
## Key Concepts
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd backend && poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend server
|
||||
poetry run serve
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in TESTING.md
|
||||
|
||||
#### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
### Frontend Development
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd frontend && pnpm i
|
||||
|
||||
# Generate API client from OpenAPI spec
|
||||
pnpm generate:api
|
||||
|
||||
# Start development server
|
||||
pnpm dev
|
||||
|
||||
# Run E2E tests
|
||||
pnpm test
|
||||
|
||||
# Run Storybook for component development
|
||||
pnpm storybook
|
||||
|
||||
# Build production
|
||||
pnpm build
|
||||
|
||||
# Format and lint
|
||||
pnpm format
|
||||
|
||||
# Type checking
|
||||
pnpm types
|
||||
```
|
||||
|
||||
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
|
||||
|
||||
**Key Frontend Conventions:**
|
||||
|
||||
- Separate render logic from data/behavior in components
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Only use Phosphor Icons
|
||||
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Backend Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
### Frontend Architecture
|
||||
|
||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||
- **Icons**: Phosphor Icons only
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||
- **Testing**: Playwright for E2E, Storybook for component development
|
||||
|
||||
### Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Testing Approach
|
||||
|
||||
- Backend uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Frontend uses Playwright for E2E tests
|
||||
- Component testing via Storybook
|
||||
|
||||
### Database Schema
|
||||
|
||||
Key models (defined in `/backend/schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
|
||||
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
|
||||
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
@@ -45,12 +167,75 @@ AutoGPT Platform is a monorepo containing:
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Common Development Tasks
|
||||
|
||||
**Adding a new block:**
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
**Modifying the API:**
|
||||
|
||||
1. Update route in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
**Frontend feature development:**
|
||||
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Add `usePageName.ts` hook for logic
|
||||
- Put sub-components in local `components/` folder
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
|
||||
### Security Implementation
|
||||
|
||||
**Cache Protection Middleware:**
|
||||
|
||||
- Located in `/backend/backend/server/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Create the PR aginst the `dev` branch of the repository.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
|
||||
- Use conventional commit messages (see below)/
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
@@ -6,14 +6,12 @@ start-core:
|
||||
|
||||
# Stop core services
|
||||
stop-core:
|
||||
docker compose stop
|
||||
docker compose stop deps
|
||||
|
||||
reset-db:
|
||||
docker compose stop db
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
@@ -35,7 +33,6 @@ init-env:
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
@@ -47,7 +44,7 @@ test-data:
|
||||
cd backend && poetry run python test/test_data_creator.py
|
||||
|
||||
load-store-agents:
|
||||
cd backend && poetry run load-store-agents
|
||||
cd backend && poetry run python test/load_store_agents.py
|
||||
|
||||
help:
|
||||
@echo "Usage: make <target>"
|
||||
@@ -61,4 +58,4 @@ help:
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " test-data - Run the test data creator"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
@@ -57,9 +57,6 @@ class APIKeySmith:
|
||||
|
||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||
"""Migrate a legacy hash to secure hash format."""
|
||||
if not raw_key.startswith(self.PREFIX):
|
||||
raise ValueError("Key without 'agpt_' prefix would fail validation")
|
||||
|
||||
salt = self._generate_salt()
|
||||
hash = self._hash_key_with_salt(raw_key, salt)
|
||||
return hash, salt.hex()
|
||||
|
||||
@@ -1,25 +1,29 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from .jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def add_auth_responses_to_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Patch a FastAPI instance's `openapi()` method to add 401 responses
|
||||
Set up custom OpenAPI schema generation that adds 401 responses
|
||||
to all authenticated endpoints.
|
||||
|
||||
This is needed when using HTTPBearer with auto_error=False to get proper
|
||||
401 responses instead of 403, but FastAPI only automatically adds security
|
||||
responses when auto_error=True.
|
||||
"""
|
||||
# Wrap current method to allow stacking OpenAPI schema modifiers like this
|
||||
wrapped_openapi = app.openapi
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = wrapped_openapi()
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add 401 response to all endpoints that have security requirements
|
||||
for path, methods in openapi_schema["paths"].items():
|
||||
|
||||
1854
autogpt_platform/autogpt_libs/poetry.lock
generated
1854
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^46.0"
|
||||
cryptography = "^45.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.128.0"
|
||||
google-cloud-logging = "^3.13.0"
|
||||
launchdarkly-server-sdk = "^9.14.1"
|
||||
pydantic = "^2.12.5"
|
||||
pydantic-settings = "^2.12.0"
|
||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.27.2"
|
||||
uvicorn = "^0.40.0"
|
||||
supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pyright = "^1.1.408"
|
||||
pyright = "^1.1.404"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.3.0"
|
||||
pytest-mock = "^3.15.1"
|
||||
pytest-cov = "^7.0.0"
|
||||
ruff = "^0.15.0"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pytest-cov = "^6.2.1"
|
||||
ruff = "^0.12.11"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -58,13 +58,6 @@ V0_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Langfuse Prompt Management
|
||||
# Used for managing the CoPilot system prompt externally
|
||||
# Get credentials from https://cloud.langfuse.com or your self-hosted instance
|
||||
LANGFUSE_PUBLIC_KEY=
|
||||
LANGFUSE_SECRET_KEY=
|
||||
LANGFUSE_HOST=https://cloud.langfuse.com
|
||||
|
||||
# OAuth Credentials
|
||||
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
|
||||
# e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
@@ -152,7 +145,6 @@ REPLICATE_API_KEY=
|
||||
REVID_API_KEY=
|
||||
SCREENSHOTONE_API_KEY=
|
||||
UNREAL_SPEECH_API_KEY=
|
||||
ELEVENLABS_API_KEY=
|
||||
|
||||
# Data & Search Services
|
||||
E2B_API_KEY=
|
||||
@@ -179,10 +171,5 @@ AYRSHARE_JWT_KEY=
|
||||
SMARTLEAD_API_KEY=
|
||||
ZEROBOUNCE_API_KEY=
|
||||
|
||||
# PostHog Analytics
|
||||
# Get API key from https://posthog.com - Project Settings > Project API Key
|
||||
POSTHOG_API_KEY=
|
||||
POSTHOG_HOST=https://eu.i.posthog.com
|
||||
|
||||
# Other Services
|
||||
AUTOMOD_API_KEY=
|
||||
|
||||
4
autogpt_platform/backend/.gitignore
vendored
4
autogpt_platform/backend/.gitignore
vendored
@@ -18,7 +18,3 @@ load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
migrations/*/rollback*.sql
|
||||
|
||||
# Workspace files
|
||||
workspaces/
|
||||
|
||||
@@ -1,170 +0,0 @@
|
||||
# CLAUDE.md - Backend
|
||||
|
||||
This file provides guidance to Claude Code when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@@ -48,8 +48,7 @@ RUN poetry install --no-ansi --no-root
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
@@ -62,12 +61,10 @@ ENV POETRY_HOME=/opt/poetry \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
||||
# Install Python without upgrading system-managed packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
imagemagick \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy only necessary files from builder
|
||||
@@ -102,7 +99,6 @@ COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migration
|
||||
FROM server_dependencies AS server
|
||||
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
COPY docs /app/docs
|
||||
RUN poetry install --no-ansi --only-root
|
||||
|
||||
ENV PORT=8000
|
||||
|
||||
@@ -108,7 +108,7 @@ import fastapi.testclient
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.api.features.myroute import router
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
@@ -138,7 +138,7 @@ If the test doesn't need the `user_id` specifically, mocking is not necessary as
|
||||
|
||||
#### Using Global Auth Fixtures
|
||||
|
||||
Two global auth fixtures are provided by `backend/api/conftest.py`:
|
||||
Two global auth fixtures are provided by `backend/server/conftest.py`:
|
||||
|
||||
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
|
||||
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
|
||||
@@ -149,7 +149,7 @@ These provide the easiest way to set up authentication mocking in test modules:
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from backend.api.features.myroute import router
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
|
||||
from .v1.routes import v1_router
|
||||
|
||||
external_api = FastAPI(
|
||||
title="AutoGPT External API",
|
||||
description="External API for AutoGPT integrations",
|
||||
docs_url="/docs",
|
||||
version="1.0",
|
||||
)
|
||||
|
||||
external_api.add_middleware(SecurityHeadersMiddleware)
|
||||
external_api.include_router(v1_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
external_api,
|
||||
service_name="external-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=True,
|
||||
)
|
||||
@@ -1,107 +0,0 @@
|
||||
from fastapi import HTTPException, Security, status
|
||||
from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer
|
||||
from prisma.enums import APIKeyPermission
|
||||
|
||||
from backend.data.auth.api_key import APIKeyInfo, validate_api_key
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.auth.oauth import (
|
||||
InvalidClientError,
|
||||
InvalidTokenError,
|
||||
OAuthAccessTokenInfo,
|
||||
validate_access_token,
|
||||
)
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
bearer_auth = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
|
||||
"""Middleware for API key authentication only"""
|
||||
if api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API key"
|
||||
)
|
||||
|
||||
api_key_obj = await validate_api_key(api_key)
|
||||
|
||||
if not api_key_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||
)
|
||||
|
||||
return api_key_obj
|
||||
|
||||
|
||||
async def require_access_token(
|
||||
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||
) -> OAuthAccessTokenInfo:
|
||||
"""Middleware for OAuth access token authentication only"""
|
||||
if bearer is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing Authorization header",
|
||||
)
|
||||
|
||||
try:
|
||||
token_info, _ = await validate_access_token(bearer.credentials)
|
||||
except (InvalidClientError, InvalidTokenError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||
|
||||
return token_info
|
||||
|
||||
|
||||
async def require_auth(
|
||||
api_key: str | None = Security(api_key_header),
|
||||
bearer: HTTPAuthorizationCredentials | None = Security(bearer_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
"""
|
||||
Unified authentication middleware supporting both API keys and OAuth tokens.
|
||||
|
||||
Supports two authentication methods, which are checked in order:
|
||||
1. X-API-Key header (existing API key authentication)
|
||||
2. Authorization: Bearer <token> header (OAuth access token)
|
||||
|
||||
Returns:
|
||||
APIAuthorizationInfo: base class of both APIKeyInfo and OAuthAccessTokenInfo.
|
||||
"""
|
||||
# Try API key first
|
||||
if api_key is not None:
|
||||
api_key_info = await validate_api_key(api_key)
|
||||
if api_key_info:
|
||||
return api_key_info
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||
)
|
||||
|
||||
# Try OAuth bearer token
|
||||
if bearer is not None:
|
||||
try:
|
||||
token_info, _ = await validate_access_token(bearer.credentials)
|
||||
return token_info
|
||||
except (InvalidClientError, InvalidTokenError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||
|
||||
# No credentials provided
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing authentication. Provide API key or access token.",
|
||||
)
|
||||
|
||||
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
"""
|
||||
Dependency function for checking specific permissions
|
||||
(works with API keys and OAuth tokens)
|
||||
"""
|
||||
|
||||
async def check_permission(
|
||||
auth: APIAuthorizationInfo = Security(require_auth),
|
||||
) -> APIAuthorizationInfo:
|
||||
if permission not in auth.scopes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing required permission: {permission.value}",
|
||||
)
|
||||
return auth
|
||||
|
||||
return check_permission
|
||||
@@ -1,340 +0,0 @@
|
||||
"""Tests for analytics API endpoints."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from .analytics import router as analytics_router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(analytics_router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module."""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_metric endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_metric_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw metric logging."""
|
||||
mock_result = Mock(id="metric-123-uuid")
|
||||
mock_log_metric = mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "page_load_time",
|
||||
"metric_value": 2.5,
|
||||
"data_string": "/dashboard",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "metric-123-uuid"
|
||||
|
||||
mock_log_metric.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
metric_name="page_load_time",
|
||||
metric_value=2.5,
|
||||
data_string="/dashboard",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"metric_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_metric_success",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metric_value,metric_name,data_string,test_id",
|
||||
[
|
||||
(100, "api_calls_count", "external_api", "integer_value"),
|
||||
(0, "error_count", "no_errors", "zero_value"),
|
||||
(-5.2, "temperature_delta", "cooling", "negative_value"),
|
||||
(1.23456789, "precision_test", "float_precision", "float_precision"),
|
||||
(999999999, "large_number", "max_value", "large_number"),
|
||||
(0.0000001, "tiny_number", "min_value", "tiny_number"),
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_various_values(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
metric_value: float,
|
||||
metric_name: str,
|
||||
data_string: str,
|
||||
test_id: str,
|
||||
) -> None:
|
||||
"""Test raw metric logging with various metric values."""
|
||||
mock_result = Mock(id=f"metric-{test_id}-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": metric_name,
|
||||
"metric_value": metric_value,
|
||||
"data_string": data_string,
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Failed for {test_id}: {response.text}"
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"metric_id": response.json(), "test_case": test_id},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
f"analytics_metric_{test_id}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"metric_name": "test"}, "Field required"),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": "not_a_number", "data_string": "x"},
|
||||
"Input should be a valid number",
|
||||
),
|
||||
(
|
||||
{"metric_name": "", "metric_value": 1.0, "data_string": "test"},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
(
|
||||
{"metric_name": "test", "metric_value": 1.0, "data_string": ""},
|
||||
"String should have at least 1 character",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_metric_value_and_data_string",
|
||||
"invalid_metric_value_type",
|
||||
"empty_metric_name",
|
||||
"empty_data_string",
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid metric requests."""
|
||||
response = client.post("/log_raw_metric", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_metric_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database connection failed"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"metric_name": "test_metric",
|
||||
"metric_value": 1.0,
|
||||
"data_string": "test",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_metric", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Database connection failed" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# /log_raw_analytics endpoint tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_log_raw_analytics_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test successful raw analytics logging."""
|
||||
mock_result = Mock(id="analytics-789-uuid")
|
||||
mock_log_analytics = mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "user_action",
|
||||
"data": {
|
||||
"action": "button_click",
|
||||
"button_id": "submit_form",
|
||||
"timestamp": "2023-01-01T00:00:00Z",
|
||||
"metadata": {"form_type": "registration", "fields_filled": 5},
|
||||
},
|
||||
"data_index": "button_click_submit_form",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200, f"Unexpected response: {response.text}"
|
||||
assert response.json() == "analytics-789-uuid"
|
||||
|
||||
mock_log_analytics.assert_called_once_with(
|
||||
test_user_id,
|
||||
"user_action",
|
||||
request_data["data"],
|
||||
"button_click_submit_form",
|
||||
)
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps({"analytics_id": response.json()}, indent=2, sort_keys=True),
|
||||
"analytics_log_analytics_success",
|
||||
)
|
||||
|
||||
|
||||
def test_log_raw_analytics_complex_data(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test raw analytics logging with complex nested data structures."""
|
||||
mock_result = Mock(id="analytics-complex-uuid")
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "agent_execution",
|
||||
"data": {
|
||||
"agent_id": "agent_123",
|
||||
"execution_id": "exec_456",
|
||||
"status": "completed",
|
||||
"duration_ms": 3500,
|
||||
"nodes_executed": 15,
|
||||
"blocks_used": [
|
||||
{"block_id": "llm_block", "count": 3},
|
||||
{"block_id": "http_block", "count": 5},
|
||||
{"block_id": "code_block", "count": 2},
|
||||
],
|
||||
"errors": [],
|
||||
"metadata": {
|
||||
"trigger": "manual",
|
||||
"user_tier": "premium",
|
||||
"environment": "production",
|
||||
},
|
||||
},
|
||||
"data_index": "agent_123_exec_456",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(
|
||||
{"analytics_id": response.json(), "logged_data": request_data["data"]},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
"analytics_log_analytics_complex_data",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,expected_error",
|
||||
[
|
||||
({}, "Field required"),
|
||||
({"type": "test"}, "Field required"),
|
||||
(
|
||||
{"type": "test", "data": "not_a_dict", "data_index": "test"},
|
||||
"Input should be a valid dictionary",
|
||||
),
|
||||
({"type": "test", "data": {"key": "value"}}, "Field required"),
|
||||
],
|
||||
ids=[
|
||||
"empty_request",
|
||||
"missing_data_and_data_index",
|
||||
"invalid_data_type",
|
||||
"missing_data_index",
|
||||
],
|
||||
)
|
||||
def test_log_raw_analytics_validation_errors(
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test validation errors for invalid analytics requests."""
|
||||
response = client.post("/log_raw_analytics", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()
|
||||
assert "detail" in error_detail, f"Missing 'detail' in error: {error_detail}"
|
||||
|
||||
error_text = json.dumps(error_detail)
|
||||
assert (
|
||||
expected_error in error_text
|
||||
), f"Expected '{expected_error}' in error response: {error_text}"
|
||||
|
||||
|
||||
def test_log_raw_analytics_service_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
"""Test error handling when analytics service fails."""
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_analytics",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Analytics DB unreachable"),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"type": "test_event",
|
||||
"data": {"key": "value"},
|
||||
"data_index": "test_index",
|
||||
}
|
||||
|
||||
response = client.post("/log_raw_analytics", json=request_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_detail = response.json()["detail"]
|
||||
assert "Analytics DB unreachable" in error_detail["message"]
|
||||
assert "hint" in error_detail
|
||||
@@ -1,689 +0,0 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Sequence
|
||||
|
||||
import prisma
|
||||
|
||||
import backend.api.features.library.db as library_db
|
||||
import backend.api.features.library.model as library_model
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.data.block
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
|
||||
from .model import (
|
||||
BlockCategoryResponse,
|
||||
BlockResponse,
|
||||
BlockType,
|
||||
CountResponse,
|
||||
FilterType,
|
||||
Provider,
|
||||
ProviderResponse,
|
||||
SearchEntry,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
||||
|
||||
MAX_LIBRARY_AGENT_RESULTS = 100
|
||||
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
||||
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
|
||||
|
||||
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ScoredItem:
|
||||
item: SearchResultItem
|
||||
filter_type: FilterType
|
||||
score: float
|
||||
sort_key: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SearchCacheEntry:
|
||||
items: list[SearchResultItem]
|
||||
total_items: dict[FilterType, int]
|
||||
|
||||
|
||||
def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse]:
|
||||
categories: dict[BlockCategory, BlockCategoryResponse] = {}
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: AnyBlockSchema = block_type()
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
# Skip blocks that don't have categories (all should have at least one)
|
||||
if not block.categories:
|
||||
continue
|
||||
|
||||
# Add block to the categories
|
||||
for category in block.categories:
|
||||
if category not in categories:
|
||||
categories[category] = BlockCategoryResponse(
|
||||
name=category.name.lower(),
|
||||
total_blocks=0,
|
||||
blocks=[],
|
||||
)
|
||||
|
||||
categories[category].total_blocks += 1
|
||||
|
||||
# Append if the category has less than the specified number of blocks
|
||||
if len(categories[category].blocks) < category_blocks:
|
||||
categories[category].blocks.append(block.get_info())
|
||||
|
||||
# Sort categories by name
|
||||
return sorted(categories.values(), key=lambda x: x.name)
|
||||
|
||||
|
||||
def get_blocks(
|
||||
*,
|
||||
category: str | None = None,
|
||||
type: BlockType | None = None,
|
||||
provider: ProviderName | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> BlockResponse:
|
||||
"""
|
||||
Get blocks based on either category, type or provider.
|
||||
Providing nothing fetches all block types.
|
||||
"""
|
||||
# Only one of category, type, or provider can be specified
|
||||
if (category and type) or (category and provider) or (type and provider):
|
||||
raise ValueError("Only one of category, type, or provider can be specified")
|
||||
|
||||
blocks: list[AnyBlockSchema] = []
|
||||
skip = (page - 1) * page_size
|
||||
take = page_size
|
||||
total = 0
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: AnyBlockSchema = block_type()
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
# Skip blocks that don't match the category
|
||||
if category and category not in {c.name.lower() for c in block.categories}:
|
||||
continue
|
||||
# Skip blocks that don't match the type
|
||||
if (
|
||||
(type == "input" and block.block_type.value != "Input")
|
||||
or (type == "output" and block.block_type.value != "Output")
|
||||
or (type == "action" and block.block_type.value in ("Input", "Output"))
|
||||
):
|
||||
continue
|
||||
# Skip blocks that don't match the provider
|
||||
if provider:
|
||||
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
||||
if not any(provider in info.provider for info in credentials_info):
|
||||
continue
|
||||
|
||||
total += 1
|
||||
if skip > 0:
|
||||
skip -= 1
|
||||
continue
|
||||
if take > 0:
|
||||
take -= 1
|
||||
blocks.append(block)
|
||||
|
||||
return BlockResponse(
|
||||
blocks=[b.get_info() for b in blocks],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_block_by_id(block_id: str) -> BlockInfo | None:
|
||||
"""
|
||||
Get a specific block by its ID.
|
||||
"""
|
||||
for block_type in load_all_blocks().values():
|
||||
block: AnyBlockSchema = block_type()
|
||||
if block.id == block_id:
|
||||
return block.get_info()
|
||||
return None
|
||||
|
||||
|
||||
async def update_search(user_id: str, search: SearchEntry) -> str:
|
||||
"""
|
||||
Upsert a search request for the user and return the search ID.
|
||||
"""
|
||||
if search.search_id:
|
||||
# Update existing search
|
||||
await prisma.models.BuilderSearchHistory.prisma().update(
|
||||
where={
|
||||
"id": search.search_id,
|
||||
},
|
||||
data={
|
||||
"searchQuery": search.search_query or "",
|
||||
"filter": search.filter or [], # type: ignore
|
||||
"byCreator": search.by_creator or [],
|
||||
},
|
||||
)
|
||||
return search.search_id
|
||||
else:
|
||||
# Create new search
|
||||
new_search = await prisma.models.BuilderSearchHistory.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"searchQuery": search.search_query or "",
|
||||
"filter": search.filter or [], # type: ignore
|
||||
"byCreator": search.by_creator or [],
|
||||
}
|
||||
)
|
||||
return new_search.id
|
||||
|
||||
|
||||
async def get_recent_searches(user_id: str, limit: int = 5) -> list[SearchEntry]:
|
||||
"""
|
||||
Get the user's most recent search requests.
|
||||
"""
|
||||
searches = await prisma.models.BuilderSearchHistory.prisma().find_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
},
|
||||
order={
|
||||
"updatedAt": "desc",
|
||||
},
|
||||
take=limit,
|
||||
)
|
||||
return [
|
||||
SearchEntry(
|
||||
search_query=s.searchQuery,
|
||||
filter=s.filter, # type: ignore
|
||||
by_creator=s.byCreator,
|
||||
search_id=s.id,
|
||||
)
|
||||
for s in searches
|
||||
]
|
||||
|
||||
|
||||
async def get_sorted_search_results(
|
||||
*,
|
||||
user_id: str,
|
||||
search_query: str | None,
|
||||
filters: Sequence[FilterType],
|
||||
by_creator: Sequence[str] | None = None,
|
||||
) -> _SearchCacheEntry:
|
||||
normalized_filters: tuple[FilterType, ...] = tuple(sorted(set(filters or [])))
|
||||
normalized_creators: tuple[str, ...] = tuple(sorted(set(by_creator or [])))
|
||||
return await _build_cached_search_results(
|
||||
user_id=user_id,
|
||||
search_query=search_query or "",
|
||||
filters=normalized_filters,
|
||||
by_creator=normalized_creators,
|
||||
)
|
||||
|
||||
|
||||
@cached(ttl_seconds=300, shared_cache=True)
|
||||
async def _build_cached_search_results(
|
||||
user_id: str,
|
||||
search_query: str,
|
||||
filters: tuple[FilterType, ...],
|
||||
by_creator: tuple[str, ...],
|
||||
) -> _SearchCacheEntry:
|
||||
normalized_query = (search_query or "").strip().lower()
|
||||
|
||||
include_blocks = "blocks" in filters
|
||||
include_integrations = "integrations" in filters
|
||||
include_library_agents = "my_agents" in filters
|
||||
include_marketplace_agents = "marketplace_agents" in filters
|
||||
|
||||
scored_items: list[_ScoredItem] = []
|
||||
total_items: dict[FilterType, int] = {
|
||||
"blocks": 0,
|
||||
"integrations": 0,
|
||||
"marketplace_agents": 0,
|
||||
"my_agents": 0,
|
||||
}
|
||||
|
||||
block_results, block_total, integration_total = _collect_block_results(
|
||||
normalized_query=normalized_query,
|
||||
include_blocks=include_blocks,
|
||||
include_integrations=include_integrations,
|
||||
)
|
||||
scored_items.extend(block_results)
|
||||
total_items["blocks"] = block_total
|
||||
total_items["integrations"] = integration_total
|
||||
|
||||
if include_library_agents:
|
||||
library_response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query or None,
|
||||
page=1,
|
||||
page_size=MAX_LIBRARY_AGENT_RESULTS,
|
||||
)
|
||||
total_items["my_agents"] = library_response.pagination.total_items
|
||||
scored_items.extend(
|
||||
_build_library_items(
|
||||
agents=library_response.agents,
|
||||
normalized_query=normalized_query,
|
||||
)
|
||||
)
|
||||
|
||||
if include_marketplace_agents:
|
||||
marketplace_response = await store_db.get_store_agents(
|
||||
creators=list(by_creator) or None,
|
||||
search_query=search_query or None,
|
||||
page=1,
|
||||
page_size=MAX_MARKETPLACE_AGENT_RESULTS,
|
||||
)
|
||||
total_items["marketplace_agents"] = marketplace_response.pagination.total_items
|
||||
scored_items.extend(
|
||||
_build_marketplace_items(
|
||||
agents=marketplace_response.agents,
|
||||
normalized_query=normalized_query,
|
||||
)
|
||||
)
|
||||
|
||||
sorted_items = sorted(
|
||||
scored_items,
|
||||
key=lambda entry: (-entry.score, entry.sort_key, entry.filter_type),
|
||||
)
|
||||
|
||||
return _SearchCacheEntry(
|
||||
items=[entry.item for entry in sorted_items],
|
||||
total_items=total_items,
|
||||
)
|
||||
|
||||
|
||||
def _collect_block_results(
|
||||
*,
|
||||
normalized_query: str,
|
||||
include_blocks: bool,
|
||||
include_integrations: bool,
|
||||
) -> tuple[list[_ScoredItem], int, int]:
|
||||
results: list[_ScoredItem] = []
|
||||
block_count = 0
|
||||
integration_count = 0
|
||||
|
||||
if not include_blocks and not include_integrations:
|
||||
return results, block_count, integration_count
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: AnyBlockSchema = block_type()
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
block_info = block.get_info()
|
||||
credentials = list(block.input_schema.get_credentials_fields().values())
|
||||
is_integration = len(credentials) > 0
|
||||
|
||||
if is_integration and not include_integrations:
|
||||
continue
|
||||
if not is_integration and not include_blocks:
|
||||
continue
|
||||
|
||||
score = _score_block(block, block_info, normalized_query)
|
||||
if not _should_include_item(score, normalized_query):
|
||||
continue
|
||||
|
||||
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||
if is_integration:
|
||||
integration_count += 1
|
||||
else:
|
||||
block_count += 1
|
||||
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=block_info,
|
||||
filter_type=filter_type,
|
||||
score=score,
|
||||
sort_key=_get_item_name(block_info),
|
||||
)
|
||||
)
|
||||
|
||||
return results, block_count, integration_count
|
||||
|
||||
|
||||
def _build_library_items(
|
||||
*,
|
||||
agents: list[library_model.LibraryAgent],
|
||||
normalized_query: str,
|
||||
) -> list[_ScoredItem]:
|
||||
results: list[_ScoredItem] = []
|
||||
|
||||
for agent in agents:
|
||||
score = _score_library_agent(agent, normalized_query)
|
||||
if not _should_include_item(score, normalized_query):
|
||||
continue
|
||||
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=agent,
|
||||
filter_type="my_agents",
|
||||
score=score,
|
||||
sort_key=_get_item_name(agent),
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _build_marketplace_items(
|
||||
*,
|
||||
agents: list[store_model.StoreAgent],
|
||||
normalized_query: str,
|
||||
) -> list[_ScoredItem]:
|
||||
results: list[_ScoredItem] = []
|
||||
|
||||
for agent in agents:
|
||||
score = _score_store_agent(agent, normalized_query)
|
||||
if not _should_include_item(score, normalized_query):
|
||||
continue
|
||||
|
||||
results.append(
|
||||
_ScoredItem(
|
||||
item=agent,
|
||||
filter_type="marketplace_agents",
|
||||
score=score,
|
||||
sort_key=_get_item_name(agent),
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_providers(
|
||||
query: str = "",
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> ProviderResponse:
|
||||
providers = []
|
||||
query = query.lower()
|
||||
|
||||
skip = (page - 1) * page_size
|
||||
take = page_size
|
||||
|
||||
all_providers = _get_all_providers()
|
||||
|
||||
for provider in all_providers.values():
|
||||
if (
|
||||
query not in provider.name.value.lower()
|
||||
and query not in provider.description.lower()
|
||||
):
|
||||
continue
|
||||
if skip > 0:
|
||||
skip -= 1
|
||||
continue
|
||||
if take > 0:
|
||||
take -= 1
|
||||
providers.append(provider)
|
||||
|
||||
total = len(all_providers)
|
||||
|
||||
return ProviderResponse(
|
||||
providers=providers,
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def get_counts(user_id: str) -> CountResponse:
|
||||
my_agents = await prisma.models.LibraryAgent.prisma().count(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
}
|
||||
)
|
||||
counts = await _get_static_counts()
|
||||
return CountResponse(
|
||||
my_agents=my_agents,
|
||||
**counts,
|
||||
)
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
async def _get_static_counts():
|
||||
"""
|
||||
Get counts of blocks, integrations, and marketplace agents.
|
||||
This is cached to avoid unnecessary database queries and calculations.
|
||||
"""
|
||||
all_blocks = 0
|
||||
input_blocks = 0
|
||||
action_blocks = 0
|
||||
output_blocks = 0
|
||||
integrations = 0
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: AnyBlockSchema = block_type()
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
all_blocks += 1
|
||||
|
||||
if block.block_type.value == "Input":
|
||||
input_blocks += 1
|
||||
elif block.block_type.value == "Output":
|
||||
output_blocks += 1
|
||||
else:
|
||||
action_blocks += 1
|
||||
|
||||
credentials = list(block.input_schema.get_credentials_fields().values())
|
||||
if len(credentials) > 0:
|
||||
integrations += 1
|
||||
|
||||
marketplace_agents = await prisma.models.StoreAgent.prisma().count()
|
||||
|
||||
return {
|
||||
"all_blocks": all_blocks,
|
||||
"input_blocks": input_blocks,
|
||||
"action_blocks": action_blocks,
|
||||
"output_blocks": output_blocks,
|
||||
"integrations": integrations,
|
||||
"marketplace_agents": marketplace_agents,
|
||||
}
|
||||
|
||||
|
||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||
for field in schema_cls.model_fields.values():
|
||||
if field.annotation == LlmModel:
|
||||
# Check if query matches any value in llm_models
|
||||
if any(query in name for name in llm_models):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _score_block(
|
||||
block: AnyBlockSchema,
|
||||
block_info: BlockInfo,
|
||||
normalized_query: str,
|
||||
) -> float:
|
||||
if not normalized_query:
|
||||
return 0.0
|
||||
|
||||
name = block_info.name.lower()
|
||||
description = block_info.description.lower()
|
||||
score = _score_primary_fields(name, description, normalized_query)
|
||||
|
||||
category_text = " ".join(
|
||||
category.get("category", "").lower() for category in block_info.categories
|
||||
)
|
||||
score += _score_additional_field(category_text, normalized_query, 12, 6)
|
||||
|
||||
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
||||
provider_names = [
|
||||
provider.value.lower()
|
||||
for info in credentials_info
|
||||
for provider in info.provider
|
||||
]
|
||||
provider_text = " ".join(provider_names)
|
||||
score += _score_additional_field(provider_text, normalized_query, 15, 6)
|
||||
|
||||
if _matches_llm_model(block.input_schema, normalized_query):
|
||||
score += 20
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def _score_library_agent(
|
||||
agent: library_model.LibraryAgent,
|
||||
normalized_query: str,
|
||||
) -> float:
|
||||
if not normalized_query:
|
||||
return 0.0
|
||||
|
||||
name = agent.name.lower()
|
||||
description = (agent.description or "").lower()
|
||||
instructions = (agent.instructions or "").lower()
|
||||
|
||||
score = _score_primary_fields(name, description, normalized_query)
|
||||
score += _score_additional_field(instructions, normalized_query, 15, 6)
|
||||
score += _score_additional_field(
|
||||
agent.creator_name.lower(), normalized_query, 10, 5
|
||||
)
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def _score_store_agent(
|
||||
agent: store_model.StoreAgent,
|
||||
normalized_query: str,
|
||||
) -> float:
|
||||
if not normalized_query:
|
||||
return 0.0
|
||||
|
||||
name = agent.agent_name.lower()
|
||||
description = agent.description.lower()
|
||||
sub_heading = agent.sub_heading.lower()
|
||||
|
||||
score = _score_primary_fields(name, description, normalized_query)
|
||||
score += _score_additional_field(sub_heading, normalized_query, 12, 6)
|
||||
score += _score_additional_field(agent.creator.lower(), normalized_query, 10, 5)
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def _score_primary_fields(name: str, description: str, query: str) -> float:
|
||||
score = 0.0
|
||||
if name == query:
|
||||
score += 120
|
||||
elif name.startswith(query):
|
||||
score += 90
|
||||
elif query in name:
|
||||
score += 60
|
||||
|
||||
score += SequenceMatcher(None, name, query).ratio() * 50
|
||||
if description:
|
||||
if query in description:
|
||||
score += 30
|
||||
score += SequenceMatcher(None, description, query).ratio() * 25
|
||||
return score
|
||||
|
||||
|
||||
def _score_additional_field(
|
||||
value: str,
|
||||
query: str,
|
||||
contains_weight: float,
|
||||
similarity_weight: float,
|
||||
) -> float:
|
||||
if not value or not query:
|
||||
return 0.0
|
||||
|
||||
score = 0.0
|
||||
if query in value:
|
||||
score += contains_weight
|
||||
score += SequenceMatcher(None, value, query).ratio() * similarity_weight
|
||||
return score
|
||||
|
||||
|
||||
def _should_include_item(score: float, normalized_query: str) -> bool:
|
||||
if not normalized_query:
|
||||
return True
|
||||
return score >= MIN_SCORE_FOR_FILTERED_RESULTS
|
||||
|
||||
|
||||
def _get_item_name(item: SearchResultItem) -> str:
|
||||
if isinstance(item, BlockInfo):
|
||||
return item.name.lower()
|
||||
if isinstance(item, library_model.LibraryAgent):
|
||||
return item.name.lower()
|
||||
return item.agent_name.lower()
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
def _get_all_providers() -> dict[ProviderName, Provider]:
|
||||
providers: dict[ProviderName, Provider] = {}
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: AnyBlockSchema = block_type()
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
||||
for info in credentials_info:
|
||||
for provider in info.provider: # provider is a ProviderName enum member
|
||||
if provider in providers:
|
||||
providers[provider].integration_count += 1
|
||||
else:
|
||||
providers[provider] = Provider(
|
||||
name=provider, description="", integration_count=1
|
||||
)
|
||||
return providers
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||
suggested_blocks = []
|
||||
# Sum the number of executions for each block type
|
||||
# Prisma cannot group by nested relations, so we do a raw query
|
||||
# Calculate the cutoff timestamp
|
||||
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
|
||||
|
||||
results = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
agent_node."agentBlockId" AS block_id,
|
||||
COUNT(execution.id) AS execution_count
|
||||
FROM {schema_prefix}"AgentNodeExecution" execution
|
||||
JOIN {schema_prefix}"AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
|
||||
WHERE execution."endedTime" >= $1::timestamp
|
||||
GROUP BY agent_node."agentBlockId"
|
||||
ORDER BY execution_count DESC;
|
||||
""",
|
||||
timestamp_threshold,
|
||||
)
|
||||
|
||||
# Get the top blocks based on execution count
|
||||
# But ignore Input and Output blocks
|
||||
blocks: list[tuple[BlockInfo, int]] = []
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: AnyBlockSchema = block_type()
|
||||
if block.disabled or block.block_type in (
|
||||
backend.data.block.BlockType.INPUT,
|
||||
backend.data.block.BlockType.OUTPUT,
|
||||
backend.data.block.BlockType.AGENT,
|
||||
):
|
||||
continue
|
||||
# Find the execution count for this block
|
||||
execution_count = next(
|
||||
(row["execution_count"] for row in results if row["block_id"] == block.id),
|
||||
0,
|
||||
)
|
||||
blocks.append((block.get_info(), execution_count))
|
||||
# Sort blocks by execution count
|
||||
blocks.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
suggested_blocks = [block[0] for block in blocks]
|
||||
|
||||
# Return the top blocks
|
||||
return suggested_blocks[:count]
|
||||
@@ -1,939 +0,0 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import logging
|
||||
import uuid as uuid_module
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.completion_handler import (
|
||||
process_operation_failure,
|
||||
process_operation_success,
|
||||
)
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.executor.utils import enqueue_copilot_task
|
||||
from backend.copilot.model import (
|
||||
ChatSession,
|
||||
create_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
)
|
||||
from backend.copilot.response_model import StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.models import (
|
||||
AgentDetailsResponse,
|
||||
AgentOutputResponse,
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AgentsFoundResponse,
|
||||
BlockListResponse,
|
||||
BlockOutputResponse,
|
||||
ClarificationNeededResponse,
|
||||
DocPageResponse,
|
||||
DocSearchResultsResponse,
|
||||
ErrorResponse,
|
||||
ExecutionStartedResponse,
|
||||
InputValidationErrorResponse,
|
||||
NeedLoginResponse,
|
||||
NoResultsResponse,
|
||||
OperationInProgressResponse,
|
||||
OperationPendingResponse,
|
||||
OperationStartedResponse,
|
||||
SetupRequirementsResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _validate_and_get_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> ChatSession:
|
||||
"""Validate session exists and belongs to user."""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
return session
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
tags=["chat"],
|
||||
)
|
||||
|
||||
# ========== Request/Response Models ==========
|
||||
|
||||
|
||||
class StreamChatRequest(BaseModel):
|
||||
"""Request model for streaming chat with optional context."""
|
||||
|
||||
message: str
|
||||
is_user_message: bool = True
|
||||
context: dict[str, str] | None = None # {url: str, content: str}
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
"""Response model containing information on a newly created chat session."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
user_id: str | None
|
||||
|
||||
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
"""Information about an active stream for reconnection."""
|
||||
|
||||
task_id: str
|
||||
last_message_id: str # Redis Stream message ID for resumption
|
||||
operation_id: str # Operation ID for completion tracking
|
||||
tool_name: str # Name of the tool being executed
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
"""Response model providing complete details for a chat session, including messages."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
"""Response model for a session summary (without messages)."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
title: str | None = None
|
||||
|
||||
|
||||
class ListSessionsResponse(BaseModel):
|
||||
"""Response model for listing chat sessions."""
|
||||
|
||||
sessions: list[SessionSummaryResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class OperationCompleteRequest(BaseModel):
|
||||
"""Request model for external completion webhook."""
|
||||
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def list_sessions(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: int = Query(default=50, ge=1, le=100),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> ListSessionsResponse:
|
||||
"""
|
||||
List chat sessions for the authenticated user.
|
||||
|
||||
Returns a paginated list of chat sessions belonging to the current user,
|
||||
ordered by most recently updated.
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user's ID.
|
||||
limit: Maximum number of sessions to return (1-100).
|
||||
offset: Number of sessions to skip for pagination.
|
||||
|
||||
Returns:
|
||||
ListSessionsResponse: List of session summaries and total count.
|
||||
"""
|
||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||
|
||||
return ListSessionsResponse(
|
||||
sessions=[
|
||||
SessionSummaryResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=session.title,
|
||||
)
|
||||
for session in sessions
|
||||
],
|
||||
total=total_count,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions",
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
|
||||
Initiates a new chat session for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id: The authenticated user ID parsed from the JWT (required).
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
|
||||
)
|
||||
|
||||
session = await create_chat_session(user_id)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
If there's an active stream for this session, returns the task_id for reconnection.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
||||
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
active_stream_info = None
|
||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
if active_task:
|
||||
# Filter out the in-progress assistant message from the session response.
|
||||
# The client will receive the complete assistant response through the SSE
|
||||
# stream replay instead, preventing duplicate content.
|
||||
if messages and messages[-1].get("role") == "assistant":
|
||||
messages = messages[:-1]
|
||||
|
||||
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
||||
# Since we filtered out the cached assistant message, the client needs
|
||||
# the full stream to reconstruct the response.
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
task_id=active_task.task_id,
|
||||
last_message_id="0-0",
|
||||
operation_id=active_task.operation_id,
|
||||
tool_name=active_task.tool_name,
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat_post(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream chat responses for a session (POST with context support).
|
||||
|
||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
||||
- Text fragments as they are generated
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
The AI generation runs in a background task that continues even if the client disconnects.
|
||||
All chunks are written to Redis for reconnection support. If the client disconnects,
|
||||
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
||||
containing the task_id for reconnection.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
stream_start_time = time.perf_counter()
|
||||
log_meta = {"component": "ChatStream", "session_id": session_id}
|
||||
if user_id:
|
||||
log_meta["user_id"] = user_id
|
||||
|
||||
logger.info(
|
||||
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
||||
f"user={user_id}, message_len={len(request.message)}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
|
||||
_session = await _validate_and_get_session(session_id, user_id) # noqa: F841
|
||||
logger.info(
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
task_id = str(uuid_module.uuid4())
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
log_meta["task_id"] = task_id
|
||||
|
||||
task_create_start = time.perf_counter()
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
||||
tool_name="chat",
|
||||
operation_id=operation_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Enqueue the task to RabbitMQ for processing by the CoPilot executor
|
||||
await enqueue_copilot_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
)
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
|
||||
# SSE endpoint that subscribes to the task's stream
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import time as time_module
|
||||
|
||||
event_gen_start = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
||||
f"user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
subscriber_queue = None
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
try:
|
||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Get all messages from the beginning
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
logger.info(
|
||||
"[TIMING] Starting to read from subscriber_queue",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
chunks_yielded += 1
|
||||
|
||||
if not first_chunk_yielded:
|
||||
first_chunk_yielded = True
|
||||
elapsed = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
|
||||
f"type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
"elapsed_ms": elapsed * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
|
||||
f"n_chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunks_yielded": chunks_yielded,
|
||||
"total_time_ms": total_time * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
|
||||
except GeneratorExit:
|
||||
logger.info(
|
||||
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunks_yielded": chunks_yielded,
|
||||
"reason": "client_disconnect",
|
||||
}
|
||||
},
|
||||
)
|
||||
pass # Client disconnected - background task continues
|
||||
except Exception as e:
|
||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||
logger.error(
|
||||
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
|
||||
extra={
|
||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||
},
|
||||
)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
task_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time * 1000,
|
||||
"chunks_yielded": chunks_yielded,
|
||||
}
|
||||
},
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def resume_session_stream(
|
||||
session_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Resume an active stream for a session.
|
||||
|
||||
Called by the AI SDK's ``useChat(resume: true)`` on page load.
|
||||
Checks for an active (in-progress) task on the session and either replays
|
||||
the full SSE stream or returns 204 No Content if nothing is running.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier.
|
||||
user_id: Optional authenticated user ID.
|
||||
|
||||
Returns:
|
||||
StreamingResponse (SSE) when an active stream exists,
|
||||
or 204 No Content when there is nothing to resume.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
active_task, _last_id = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
|
||||
if not active_task:
|
||||
return Response(status_code=204)
|
||||
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=active_task.task_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Full replay so useChat rebuilds the message
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
return Response(status_code=204)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Resume stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except GeneratorExit:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
||||
finally:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
active_task.task_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(
|
||||
"Resume stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"n_chunks": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/assign-user",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=200,
|
||||
)
|
||||
async def session_assign_user(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> dict:
|
||||
"""
|
||||
Assign an authenticated user to a chat session.
|
||||
|
||||
Used (typically post-login) to claim an existing anonymous session as the current authenticated user.
|
||||
|
||||
Args:
|
||||
session_id: The identifier for the (previously anonymous) session.
|
||||
user_id: The authenticated user's ID to associate with the session.
|
||||
|
||||
Returns:
|
||||
dict: Status of the assignment.
|
||||
|
||||
"""
|
||||
await chat_service.assign_user_to_session(session_id, user_id)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Task Streaming (SSE Reconnection) ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/stream",
|
||||
)
|
||||
async def stream_task(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
last_message_id: str = Query(
|
||||
default="0-0",
|
||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Reconnect to a long-running task's SSE stream.
|
||||
|
||||
When a long-running operation (like agent generation) starts, the client
|
||||
receives a task_id. If the connection drops, the client can reconnect
|
||||
using this endpoint to resume receiving updates.
|
||||
|
||||
Args:
|
||||
task_id: The task ID from the operation_started response.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
||||
"""
|
||||
# Check task existence and expiry before subscribing
|
||||
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
||||
|
||||
if error_code == "TASK_EXPIRED":
|
||||
raise HTTPException(
|
||||
status_code=410,
|
||||
detail={
|
||||
"code": "TASK_EXPIRED",
|
||||
"message": "This operation has expired. Please try again.",
|
||||
},
|
||||
)
|
||||
|
||||
if error_code == "TASK_NOT_FOUND":
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found.",
|
||||
},
|
||||
)
|
||||
|
||||
# Validate ownership if task has an owner
|
||||
if task and task.user_id and user_id != task.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"code": "ACCESS_DENIED",
|
||||
"message": "You do not have access to this task.",
|
||||
},
|
||||
)
|
||||
|
||||
# Get subscriber queue from stream registry
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found or access denied.",
|
||||
},
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import asyncio
|
||||
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for next chunk with timeout for heartbeats
|
||||
chunk = await asyncio.wait_for(
|
||||
subscriber_queue.get(), timeout=heartbeat_interval
|
||||
)
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
) -> dict:
|
||||
"""
|
||||
Get the status of a long-running task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to check.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
|
||||
Returns:
|
||||
dict: Task status including task_id, status, tool_name, and operation_id.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
"""
|
||||
task = await stream_registry.get_task(task_id)
|
||||
|
||||
if task is None:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
# Validate ownership - if task has an owner, requester must match
|
||||
if task.user_id and user_id != task.user_id:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
return {
|
||||
"task_id": task.task_id,
|
||||
"session_id": task.session_id,
|
||||
"status": task.status,
|
||||
"tool_name": task.tool_name,
|
||||
"operation_id": task.operation_id,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ========== External Completion Webhook ==========
|
||||
|
||||
|
||||
@router.post(
|
||||
"/operations/{operation_id}/complete",
|
||||
status_code=200,
|
||||
)
|
||||
async def complete_operation(
|
||||
operation_id: str,
|
||||
request: OperationCompleteRequest,
|
||||
x_api_key: str | None = Header(default=None),
|
||||
) -> dict:
|
||||
"""
|
||||
External completion webhook for long-running operations.
|
||||
|
||||
Called by Agent Generator (or other services) when an operation completes.
|
||||
This triggers the stream registry to publish completion and continue LLM generation.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID to complete.
|
||||
request: Completion payload with success status and result/error.
|
||||
x_api_key: Internal API key for authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status of the completion.
|
||||
|
||||
Raises:
|
||||
HTTPException: If API key is invalid or operation not found.
|
||||
"""
|
||||
# Validate internal API key - reject if not configured or invalid
|
||||
if not config.internal_api_key:
|
||||
logger.error(
|
||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Webhook not available: internal API key not configured",
|
||||
)
|
||||
if x_api_key != config.internal_api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
# Find task by operation_id
|
||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Operation {operation_id} not found",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received completion webhook for operation {operation_id} "
|
||||
f"(task_id={task.task_id}, success={request.success})"
|
||||
)
|
||||
|
||||
if request.success:
|
||||
await process_operation_success(task, request.result)
|
||||
else:
|
||||
await process_operation_failure(task, request.error)
|
||||
|
||||
return {"status": "ok", "task_id": task.task_id}
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
@router.get("/config/ttl", status_code=200)
|
||||
async def get_ttl_config() -> dict:
|
||||
"""
|
||||
Get the stream TTL configuration.
|
||||
|
||||
Returns the Time-To-Live settings for chat streams, which determines
|
||||
how long clients can reconnect to an active stream.
|
||||
|
||||
Returns:
|
||||
dict: TTL configuration with seconds and milliseconds values.
|
||||
"""
|
||||
return {
|
||||
"stream_ttl_seconds": config.stream_ttl,
|
||||
"stream_ttl_ms": config.stream_ttl * 1000,
|
||||
}
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
|
||||
|
||||
@router.get("/health", status_code=200)
|
||||
async def health_check() -> dict:
|
||||
"""
|
||||
Health check endpoint for the chat service.
|
||||
|
||||
Performs a full cycle test of session creation and retrieval. Should always return healthy
|
||||
if the service and data layer are operational.
|
||||
|
||||
Returns:
|
||||
dict: A status dictionary indicating health, service name, and API version.
|
||||
|
||||
"""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Ensure health check user exists (required for FK constraint)
|
||||
health_check_user_id = "health-check-user"
|
||||
await get_or_create_user(
|
||||
{
|
||||
"sub": health_check_user_id,
|
||||
"email": "health-check@system.local",
|
||||
"user_metadata": {"name": "Health Check User"},
|
||||
}
|
||||
)
|
||||
|
||||
# Create and retrieve session to verify full data layer
|
||||
session = await create_chat_session(health_check_user_id)
|
||||
await get_chat_session(session.session_id, health_check_user_id)
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "chat",
|
||||
"version": "0.1.0",
|
||||
}
|
||||
|
||||
|
||||
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
|
||||
|
||||
ToolResponseUnion = (
|
||||
AgentsFoundResponse
|
||||
| NoResultsResponse
|
||||
| AgentDetailsResponse
|
||||
| SetupRequirementsResponse
|
||||
| ExecutionStartedResponse
|
||||
| NeedLoginResponse
|
||||
| ErrorResponse
|
||||
| InputValidationErrorResponse
|
||||
| AgentOutputResponse
|
||||
| UnderstandingUpdatedResponse
|
||||
| AgentPreviewResponse
|
||||
| AgentSavedResponse
|
||||
| ClarificationNeededResponse
|
||||
| BlockListResponse
|
||||
| BlockOutputResponse
|
||||
| DocSearchResultsResponse
|
||||
| DocPageResponse
|
||||
| OperationStartedResponse
|
||||
| OperationPendingResponse
|
||||
| OperationInProgressResponse
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/schema/tool-responses",
|
||||
response_model=ToolResponseUnion,
|
||||
include_in_schema=True,
|
||||
summary="[Dummy] Tool response type export for codegen",
|
||||
description="This endpoint is not meant to be called. It exists solely to "
|
||||
"expose tool response models in the OpenAPI schema for frontend codegen.",
|
||||
)
|
||||
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
|
||||
"""Never called at runtime. Exists only so Orval generates TS types."""
|
||||
raise HTTPException(status_code=501, detail="Schema-only endpoint")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,353 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, List
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, HTTPException, Query, Security, status
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
ExecutionStatus,
|
||||
get_graph_execution_meta,
|
||||
)
|
||||
from backend.data.graph import get_graph_settings
|
||||
from backend.data.human_review import (
|
||||
create_auto_approval_record,
|
||||
get_pending_reviews_for_execution,
|
||||
get_pending_reviews_for_user,
|
||||
get_reviews_by_node_exec_ids,
|
||||
has_pending_reviews_for_graph_exec,
|
||||
process_all_reviews_for_execution,
|
||||
)
|
||||
from backend.data.model import USER_TIMEZONE_NOT_SET
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor.utils import add_graph_execution
|
||||
|
||||
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
tags=["v2", "executions", "review"],
|
||||
dependencies=[Security(autogpt_auth_lib.requires_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/pending",
|
||||
summary="Get Pending Reviews",
|
||||
response_model=List[PendingHumanReviewModel],
|
||||
responses={
|
||||
200: {"description": "List of pending reviews"},
|
||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
||||
},
|
||||
)
|
||||
async def list_pending_reviews(
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
page: int = Query(1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = Query(25, ge=1, le=100, description="Number of reviews per page"),
|
||||
) -> List[PendingHumanReviewModel]:
|
||||
"""Get all pending reviews for the current user.
|
||||
|
||||
Retrieves all reviews with status "WAITING" that belong to the authenticated user.
|
||||
Results are ordered by creation time (newest first).
|
||||
|
||||
Args:
|
||||
user_id: Authenticated user ID from security dependency
|
||||
|
||||
Returns:
|
||||
List of pending review objects with status converted to typed literals
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails or database error occurs
|
||||
|
||||
Note:
|
||||
Reviews with invalid status values are logged as warnings but excluded
|
||||
from results rather than failing the entire request.
|
||||
"""
|
||||
|
||||
return await get_pending_reviews_for_user(user_id, page, page_size)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/execution/{graph_exec_id}",
|
||||
summary="Get Pending Reviews for Execution",
|
||||
response_model=List[PendingHumanReviewModel],
|
||||
responses={
|
||||
200: {"description": "List of pending reviews for the execution"},
|
||||
404: {"description": "Graph execution not found"},
|
||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
||||
},
|
||||
)
|
||||
async def list_pending_reviews_for_execution(
|
||||
graph_exec_id: str,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> List[PendingHumanReviewModel]:
|
||||
"""Get all pending reviews for a specific graph execution.
|
||||
|
||||
Retrieves all reviews with status "WAITING" for the specified graph execution
|
||||
that belong to the authenticated user. Results are ordered by creation time
|
||||
(oldest first) to preserve review order within the execution.
|
||||
|
||||
Args:
|
||||
graph_exec_id: ID of the graph execution to get reviews for
|
||||
user_id: Authenticated user ID from security dependency
|
||||
|
||||
Returns:
|
||||
List of pending review objects for the specified execution
|
||||
|
||||
Raises:
|
||||
HTTPException:
|
||||
- 404: If the graph execution doesn't exist or isn't owned by this user
|
||||
- 500: If authentication fails or database error occurs
|
||||
|
||||
Note:
|
||||
Only returns reviews owned by the authenticated user for security.
|
||||
Reviews with invalid status are excluded with warning logs.
|
||||
"""
|
||||
|
||||
# Verify user owns the graph execution before returning reviews
|
||||
graph_exec = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
return await get_pending_reviews_for_execution(graph_exec_id, user_id)
|
||||
|
||||
|
||||
@router.post("/action", response_model=ReviewResponse)
|
||||
async def process_review_action(
|
||||
request: ReviewRequest,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> ReviewResponse:
|
||||
"""Process reviews with approve or reject actions."""
|
||||
|
||||
# Collect all node exec IDs from the request
|
||||
all_request_node_ids = {review.node_exec_id for review in request.reviews}
|
||||
|
||||
if not all_request_node_ids:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one review must be provided",
|
||||
)
|
||||
|
||||
# Batch fetch all requested reviews (regardless of status for idempotent handling)
|
||||
reviews_map = await get_reviews_by_node_exec_ids(
|
||||
list(all_request_node_ids), user_id
|
||||
)
|
||||
|
||||
# Validate all reviews were found (must exist, any status is OK for now)
|
||||
missing_ids = all_request_node_ids - set(reviews_map.keys())
|
||||
if missing_ids:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Review(s) not found: {', '.join(missing_ids)}",
|
||||
)
|
||||
|
||||
# Validate all reviews belong to the same execution
|
||||
graph_exec_ids = {review.graph_exec_id for review in reviews_map.values()}
|
||||
if len(graph_exec_ids) > 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="All reviews in a single request must belong to the same execution.",
|
||||
)
|
||||
|
||||
graph_exec_id = next(iter(graph_exec_ids))
|
||||
|
||||
# Validate execution status before processing reviews
|
||||
graph_exec_meta = await get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
|
||||
if not graph_exec_meta:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Graph execution #{graph_exec_id} not found",
|
||||
)
|
||||
|
||||
# Only allow processing reviews if execution is paused for review
|
||||
# or incomplete (partial execution with some reviews already processed)
|
||||
if graph_exec_meta.status not in (
|
||||
ExecutionStatus.REVIEW,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Cannot process reviews while execution status is {graph_exec_meta.status}. "
|
||||
f"Reviews can only be processed when execution is paused (REVIEW status). "
|
||||
f"Current status: {graph_exec_meta.status}",
|
||||
)
|
||||
|
||||
# Build review decisions map and track which reviews requested auto-approval
|
||||
# Auto-approved reviews use original data (no modifications allowed)
|
||||
review_decisions = {}
|
||||
auto_approve_requests = {} # Map node_exec_id -> auto_approve_future flag
|
||||
|
||||
for review in request.reviews:
|
||||
review_status = (
|
||||
ReviewStatus.APPROVED if review.approved else ReviewStatus.REJECTED
|
||||
)
|
||||
# If this review requested auto-approval, don't allow data modifications
|
||||
reviewed_data = None if review.auto_approve_future else review.reviewed_data
|
||||
review_decisions[review.node_exec_id] = (
|
||||
review_status,
|
||||
reviewed_data,
|
||||
review.message,
|
||||
)
|
||||
auto_approve_requests[review.node_exec_id] = review.auto_approve_future
|
||||
|
||||
# Process all reviews
|
||||
updated_reviews = await process_all_reviews_for_execution(
|
||||
user_id=user_id,
|
||||
review_decisions=review_decisions,
|
||||
)
|
||||
|
||||
# Create auto-approval records for approved reviews that requested it
|
||||
# Deduplicate by node_id to avoid race conditions when multiple reviews
|
||||
# for the same node are processed in parallel
|
||||
async def create_auto_approval_for_node(
|
||||
node_id: str, review_result
|
||||
) -> tuple[str, bool]:
|
||||
"""
|
||||
Create auto-approval record for a node.
|
||||
Returns (node_id, success) tuple for tracking failures.
|
||||
"""
|
||||
try:
|
||||
await create_auto_approval_record(
|
||||
user_id=user_id,
|
||||
graph_exec_id=review_result.graph_exec_id,
|
||||
graph_id=review_result.graph_id,
|
||||
graph_version=review_result.graph_version,
|
||||
node_id=node_id,
|
||||
payload=review_result.payload,
|
||||
)
|
||||
return (node_id, True)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create auto-approval record for node {node_id}",
|
||||
exc_info=e,
|
||||
)
|
||||
return (node_id, False)
|
||||
|
||||
# Collect node_exec_ids that need auto-approval
|
||||
node_exec_ids_needing_auto_approval = [
|
||||
node_exec_id
|
||||
for node_exec_id, review_result in updated_reviews.items()
|
||||
if review_result.status == ReviewStatus.APPROVED
|
||||
and auto_approve_requests.get(node_exec_id, False)
|
||||
]
|
||||
|
||||
# Batch-fetch node executions to get node_ids
|
||||
nodes_needing_auto_approval: dict[str, Any] = {}
|
||||
if node_exec_ids_needing_auto_approval:
|
||||
from backend.data.execution import get_node_executions
|
||||
|
||||
node_execs = await get_node_executions(
|
||||
graph_exec_id=graph_exec_id, include_exec_data=False
|
||||
)
|
||||
node_exec_map = {node_exec.node_exec_id: node_exec for node_exec in node_execs}
|
||||
|
||||
for node_exec_id in node_exec_ids_needing_auto_approval:
|
||||
node_exec = node_exec_map.get(node_exec_id)
|
||||
if node_exec:
|
||||
review_result = updated_reviews[node_exec_id]
|
||||
# Use the first approved review for this node (deduplicate by node_id)
|
||||
if node_exec.node_id not in nodes_needing_auto_approval:
|
||||
nodes_needing_auto_approval[node_exec.node_id] = review_result
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to create auto-approval record for {node_exec_id}: "
|
||||
f"Node execution not found. This may indicate a race condition "
|
||||
f"or data inconsistency."
|
||||
)
|
||||
|
||||
# Execute all auto-approval creations in parallel (deduplicated by node_id)
|
||||
auto_approval_results = await asyncio.gather(
|
||||
*[
|
||||
create_auto_approval_for_node(node_id, review_result)
|
||||
for node_id, review_result in nodes_needing_auto_approval.items()
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Count auto-approval failures
|
||||
auto_approval_failed_count = 0
|
||||
for result in auto_approval_results:
|
||||
if isinstance(result, Exception):
|
||||
# Unexpected exception during auto-approval creation
|
||||
auto_approval_failed_count += 1
|
||||
logger.error(
|
||||
f"Unexpected exception during auto-approval creation: {result}"
|
||||
)
|
||||
elif isinstance(result, tuple) and len(result) == 2 and not result[1]:
|
||||
# Auto-approval creation failed (returned False)
|
||||
auto_approval_failed_count += 1
|
||||
|
||||
# Count results
|
||||
approved_count = sum(
|
||||
1
|
||||
for review in updated_reviews.values()
|
||||
if review.status == ReviewStatus.APPROVED
|
||||
)
|
||||
rejected_count = sum(
|
||||
1
|
||||
for review in updated_reviews.values()
|
||||
if review.status == ReviewStatus.REJECTED
|
||||
)
|
||||
|
||||
# Resume execution only if ALL pending reviews for this execution have been processed
|
||||
if updated_reviews:
|
||||
still_has_pending = await has_pending_reviews_for_graph_exec(graph_exec_id)
|
||||
|
||||
if not still_has_pending:
|
||||
# Get the graph_id from any processed review
|
||||
first_review = next(iter(updated_reviews.values()))
|
||||
|
||||
try:
|
||||
# Fetch user and settings to build complete execution context
|
||||
user = await get_user_by_id(user_id)
|
||||
settings = await get_graph_settings(
|
||||
user_id=user_id, graph_id=first_review.graph_id
|
||||
)
|
||||
|
||||
# Preserve user's timezone preference when resuming execution
|
||||
user_timezone = (
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||
user_timezone=user_timezone,
|
||||
)
|
||||
|
||||
await add_graph_execution(
|
||||
graph_id=first_review.graph_id,
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
logger.info(f"Resumed execution {graph_exec_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to resume execution {graph_exec_id}: {str(e)}")
|
||||
|
||||
# Build error message if auto-approvals failed
|
||||
error_message = None
|
||||
if auto_approval_failed_count > 0:
|
||||
error_message = (
|
||||
f"{auto_approval_failed_count} auto-approval setting(s) could not be saved. "
|
||||
f"You may need to manually approve these reviews in future executions."
|
||||
)
|
||||
|
||||
return ReviewResponse(
|
||||
approved_count=approved_count,
|
||||
rejected_count=rejected_count,
|
||||
failed_count=auto_approval_failed_count,
|
||||
error=error_message,
|
||||
)
|
||||
@@ -1,199 +0,0 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
from fastapi.responses import Response
|
||||
from prisma.enums import OnboardingStep
|
||||
|
||||
from backend.data.onboarding import complete_onboarding_step
|
||||
|
||||
from .. import db as library_db
|
||||
from .. import model as library_model
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/agents",
|
||||
tags=["library", "private"],
|
||||
dependencies=[Security(autogpt_auth_lib.requires_user)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
summary="List Library Agents",
|
||||
response_model=library_model.LibraryAgentResponse,
|
||||
)
|
||||
async def list_library_agents(
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
search_term: Optional[str] = Query(
|
||||
None, description="Search term to filter agents"
|
||||
),
|
||||
sort_by: library_model.LibraryAgentSort = Query(
|
||||
library_model.LibraryAgentSort.UPDATED_AT,
|
||||
description="Criteria to sort results by",
|
||||
),
|
||||
page: int = Query(
|
||||
1,
|
||||
ge=1,
|
||||
description="Page number to retrieve (must be >= 1)",
|
||||
),
|
||||
page_size: int = Query(
|
||||
15,
|
||||
ge=1,
|
||||
description="Number of agents per page (must be >= 1)",
|
||||
),
|
||||
) -> library_model.LibraryAgentResponse:
|
||||
"""
|
||||
Get all agents in the user's library (both created and saved).
|
||||
"""
|
||||
return await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_term,
|
||||
sort_by=sort_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/favorites",
|
||||
summary="List Favorite Library Agents",
|
||||
)
|
||||
async def list_favorite_library_agents(
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
page: int = Query(
|
||||
1,
|
||||
ge=1,
|
||||
description="Page number to retrieve (must be >= 1)",
|
||||
),
|
||||
page_size: int = Query(
|
||||
15,
|
||||
ge=1,
|
||||
description="Number of agents per page (must be >= 1)",
|
||||
),
|
||||
) -> library_model.LibraryAgentResponse:
|
||||
"""
|
||||
Get all favorite agents in the user's library.
|
||||
"""
|
||||
return await library_db.list_favorite_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{library_agent_id}", summary="Get Library Agent")
|
||||
async def get_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
return await library_db.get_library_agent(id=library_agent_id, user_id=user_id)
|
||||
|
||||
|
||||
@router.get("/by-graph/{graph_id}")
|
||||
async def get_library_agent_by_graph_id(
|
||||
graph_id: str,
|
||||
version: Optional[int] = Query(default=None),
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
library_agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id, graph_id, version
|
||||
)
|
||||
if not library_agent:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Library agent for graph #{graph_id} and user #{user_id} not found",
|
||||
)
|
||||
return library_agent
|
||||
|
||||
|
||||
@router.get(
|
||||
"/marketplace/{store_listing_version_id}",
|
||||
summary="Get Agent By Store ID",
|
||||
tags=["store", "library"],
|
||||
response_model=library_model.LibraryAgent | None,
|
||||
)
|
||||
async def get_library_agent_by_store_listing_version_id(
|
||||
store_listing_version_id: str,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryAgent | None:
|
||||
"""
|
||||
Get Library Agent from Store Listing Version ID.
|
||||
"""
|
||||
return await library_db.get_library_agent_by_store_version_id(
|
||||
store_listing_version_id, user_id
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
summary="Add Marketplace Agent",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def add_marketplace_agent_to_library(
|
||||
store_listing_version_id: str = Body(embed=True),
|
||||
source: Literal["onboarding", "marketplace"] = Body(
|
||||
default="marketplace", embed=True
|
||||
),
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Add an agent from the marketplace to the user's library.
|
||||
"""
|
||||
agent = await library_db.add_store_agent_to_library(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if source != "onboarding":
|
||||
await complete_onboarding_step(user_id, OnboardingStep.MARKETPLACE_ADD_AGENT)
|
||||
return agent
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{library_agent_id}",
|
||||
summary="Update Library Agent",
|
||||
)
|
||||
async def update_library_agent(
|
||||
library_agent_id: str,
|
||||
payload: library_model.LibraryAgentUpdateRequest,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Update the library agent with the given fields.
|
||||
"""
|
||||
return await library_db.update_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
auto_update_version=payload.auto_update_version,
|
||||
graph_version=payload.graph_version,
|
||||
is_favorite=payload.is_favorite,
|
||||
is_archived=payload.is_archived,
|
||||
settings=payload.settings,
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{library_agent_id}",
|
||||
summary="Delete Library Agent",
|
||||
)
|
||||
async def delete_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> Response:
|
||||
"""
|
||||
Soft-delete the specified library agent.
|
||||
"""
|
||||
await library_db.delete_library_agent(
|
||||
library_agent_id=library_agent_id, user_id=user_id
|
||||
)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")
|
||||
async def fork_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
return await library_db.fork_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -1,833 +0,0 @@
|
||||
"""
|
||||
OAuth 2.0 Provider Endpoints
|
||||
|
||||
Implements OAuth 2.0 Authorization Code flow with PKCE support.
|
||||
|
||||
Flow:
|
||||
1. User clicks "Login with AutoGPT" in 3rd party app
|
||||
2. App redirects user to /auth/authorize with client_id, redirect_uri, scope, state
|
||||
3. User sees consent screen (if not already logged in, redirects to login first)
|
||||
4. User approves → backend creates authorization code
|
||||
5. User redirected back to app with code
|
||||
6. App exchanges code for access/refresh tokens at /api/oauth/token
|
||||
7. App uses access token to call external API endpoints
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from autogpt_libs.auth import get_user_id
|
||||
from fastapi import APIRouter, Body, HTTPException, Security, UploadFile, status
|
||||
from gcloud.aio import storage as async_storage
|
||||
from PIL import Image
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.auth.oauth import (
|
||||
InvalidClientError,
|
||||
InvalidGrantError,
|
||||
OAuthApplicationInfo,
|
||||
TokenIntrospectionResult,
|
||||
consume_authorization_code,
|
||||
create_access_token,
|
||||
create_authorization_code,
|
||||
create_refresh_token,
|
||||
get_oauth_application,
|
||||
get_oauth_application_by_id,
|
||||
introspect_token,
|
||||
list_user_oauth_applications,
|
||||
refresh_tokens,
|
||||
revoke_access_token,
|
||||
revoke_refresh_token,
|
||||
update_oauth_application,
|
||||
validate_client_credentials,
|
||||
validate_redirect_uri,
|
||||
validate_scopes,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""OAuth 2.0 token response"""
|
||||
|
||||
token_type: Literal["Bearer"] = "Bearer"
|
||||
access_token: str
|
||||
access_token_expires_at: datetime
|
||||
refresh_token: str
|
||||
refresh_token_expires_at: datetime
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""OAuth 2.0 error response"""
|
||||
|
||||
error: str
|
||||
error_description: Optional[str] = None
|
||||
|
||||
|
||||
class OAuthApplicationPublicInfo(BaseModel):
|
||||
"""Public information about an OAuth application (for consent screen)"""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
logo_url: Optional[str] = None
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Application Info Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/app/{client_id}",
|
||||
responses={
|
||||
404: {"description": "Application not found or disabled"},
|
||||
},
|
||||
)
|
||||
async def get_oauth_app_info(
|
||||
client_id: str, user_id: str = Security(get_user_id)
|
||||
) -> OAuthApplicationPublicInfo:
|
||||
"""
|
||||
Get public information about an OAuth application.
|
||||
|
||||
This endpoint is used by the consent screen to display application details
|
||||
to the user before they authorize access.
|
||||
|
||||
Returns:
|
||||
- name: Application name
|
||||
- description: Application description (if provided)
|
||||
- scopes: List of scopes the application is allowed to request
|
||||
"""
|
||||
app = await get_oauth_application(client_id)
|
||||
if not app or not app.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found",
|
||||
)
|
||||
|
||||
return OAuthApplicationPublicInfo(
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
logo_url=app.logo_url,
|
||||
scopes=[s.value for s in app.scopes],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AuthorizeRequest(BaseModel):
|
||||
"""OAuth 2.0 authorization request"""
|
||||
|
||||
client_id: str = Field(description="Client identifier")
|
||||
redirect_uri: str = Field(description="Redirect URI")
|
||||
scopes: list[str] = Field(description="List of scopes")
|
||||
state: str = Field(description="Anti-CSRF token from client")
|
||||
response_type: str = Field(
|
||||
default="code", description="Must be 'code' for authorization code flow"
|
||||
)
|
||||
code_challenge: str = Field(description="PKCE code challenge (required)")
|
||||
code_challenge_method: Literal["S256", "plain"] = Field(
|
||||
default="S256", description="PKCE code challenge method (S256 recommended)"
|
||||
)
|
||||
|
||||
|
||||
class AuthorizeResponse(BaseModel):
|
||||
"""OAuth 2.0 authorization response with redirect URL"""
|
||||
|
||||
redirect_url: str = Field(description="URL to redirect the user to")
|
||||
|
||||
|
||||
@router.post("/authorize")
|
||||
async def authorize(
|
||||
request: AuthorizeRequest = Body(),
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> AuthorizeResponse:
|
||||
"""
|
||||
OAuth 2.0 Authorization Endpoint
|
||||
|
||||
User must be logged in (authenticated with Supabase JWT).
|
||||
This endpoint creates an authorization code and returns a redirect URL.
|
||||
|
||||
PKCE (Proof Key for Code Exchange) is REQUIRED for all authorization requests.
|
||||
|
||||
The frontend consent screen should call this endpoint after the user approves,
|
||||
then redirect the user to the returned `redirect_url`.
|
||||
|
||||
Request Body:
|
||||
- client_id: The OAuth application's client ID
|
||||
- redirect_uri: Where to redirect after authorization (must match registered URI)
|
||||
- scopes: List of permissions (e.g., "EXECUTE_GRAPH READ_GRAPH")
|
||||
- state: Anti-CSRF token provided by client (will be returned in redirect)
|
||||
- response_type: Must be "code" (for authorization code flow)
|
||||
- code_challenge: PKCE code challenge (required)
|
||||
- code_challenge_method: "S256" (recommended) or "plain"
|
||||
|
||||
Returns:
|
||||
- redirect_url: The URL to redirect the user to (includes authorization code)
|
||||
|
||||
Error cases return a redirect_url with error parameters, or raise HTTPException
|
||||
for critical errors (like invalid redirect_uri).
|
||||
"""
|
||||
try:
|
||||
# Validate response_type
|
||||
if request.response_type != "code":
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"unsupported_response_type",
|
||||
"Only 'code' response type is supported",
|
||||
)
|
||||
|
||||
# Get application
|
||||
app = await get_oauth_application(request.client_id)
|
||||
if not app:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_client",
|
||||
"Unknown client_id",
|
||||
)
|
||||
|
||||
if not app.is_active:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_client",
|
||||
"Application is not active",
|
||||
)
|
||||
|
||||
# Validate redirect URI
|
||||
if not validate_redirect_uri(app, request.redirect_uri):
|
||||
# For invalid redirect_uri, we can't redirect safely
|
||||
# Must return error instead
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"Invalid redirect_uri. "
|
||||
f"Must be one of: {', '.join(app.redirect_uris)}"
|
||||
),
|
||||
)
|
||||
|
||||
# Parse and validate scopes
|
||||
try:
|
||||
requested_scopes = [APIKeyPermission(s.strip()) for s in request.scopes]
|
||||
except ValueError as e:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
f"Invalid scope: {e}",
|
||||
)
|
||||
|
||||
if not requested_scopes:
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
"At least one scope is required",
|
||||
)
|
||||
|
||||
if not validate_scopes(app, requested_scopes):
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"invalid_scope",
|
||||
"Application is not authorized for all requested scopes. "
|
||||
f"Allowed: {', '.join(s.value for s in app.scopes)}",
|
||||
)
|
||||
|
||||
# Create authorization code
|
||||
auth_code = await create_authorization_code(
|
||||
application_id=app.id,
|
||||
user_id=user_id,
|
||||
scopes=requested_scopes,
|
||||
redirect_uri=request.redirect_uri,
|
||||
code_challenge=request.code_challenge,
|
||||
code_challenge_method=request.code_challenge_method,
|
||||
)
|
||||
|
||||
# Build redirect URL with authorization code
|
||||
params = {
|
||||
"code": auth_code.code,
|
||||
"state": request.state,
|
||||
}
|
||||
redirect_url = f"{request.redirect_uri}?{urlencode(params)}"
|
||||
|
||||
logger.info(
|
||||
f"Authorization code issued for user #{user_id} "
|
||||
f"and app {app.name} (#{app.id})"
|
||||
)
|
||||
|
||||
return AuthorizeResponse(redirect_url=redirect_url)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in authorization endpoint: {e}", exc_info=True)
|
||||
return _error_redirect_url(
|
||||
request.redirect_uri,
|
||||
request.state,
|
||||
"server_error",
|
||||
"An unexpected error occurred",
|
||||
)
|
||||
|
||||
|
||||
def _error_redirect_url(
|
||||
redirect_uri: str,
|
||||
state: str,
|
||||
error: str,
|
||||
error_description: Optional[str] = None,
|
||||
) -> AuthorizeResponse:
|
||||
"""Helper to build redirect URL with OAuth error parameters"""
|
||||
params = {
|
||||
"error": error,
|
||||
"state": state,
|
||||
}
|
||||
if error_description:
|
||||
params["error_description"] = error_description
|
||||
|
||||
redirect_url = f"{redirect_uri}?{urlencode(params)}"
|
||||
return AuthorizeResponse(redirect_url=redirect_url)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenRequestByCode(BaseModel):
|
||||
grant_type: Literal["authorization_code"]
|
||||
code: str = Field(description="Authorization code")
|
||||
redirect_uri: str = Field(
|
||||
description="Redirect URI (must match authorization request)"
|
||||
)
|
||||
client_id: str
|
||||
client_secret: str
|
||||
code_verifier: str = Field(description="PKCE code verifier")
|
||||
|
||||
|
||||
class TokenRequestByRefreshToken(BaseModel):
|
||||
grant_type: Literal["refresh_token"]
|
||||
refresh_token: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
|
||||
|
||||
@router.post("/token")
|
||||
async def token(
|
||||
request: TokenRequestByCode | TokenRequestByRefreshToken = Body(),
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
OAuth 2.0 Token Endpoint
|
||||
|
||||
Exchanges authorization code or refresh token for access token.
|
||||
|
||||
Grant Types:
|
||||
1. authorization_code: Exchange authorization code for tokens
|
||||
- Required: grant_type, code, redirect_uri, client_id, client_secret
|
||||
- Optional: code_verifier (required if PKCE was used)
|
||||
|
||||
2. refresh_token: Exchange refresh token for new access token
|
||||
- Required: grant_type, refresh_token, client_id, client_secret
|
||||
|
||||
Returns:
|
||||
- access_token: Bearer token for API access (1 hour TTL)
|
||||
- token_type: "Bearer"
|
||||
- expires_in: Seconds until access token expires
|
||||
- refresh_token: Token for refreshing access (30 days TTL)
|
||||
- scopes: List of scopes
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
app = await validate_client_credentials(
|
||||
request.client_id, request.client_secret
|
||||
)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Handle authorization_code grant
|
||||
if request.grant_type == "authorization_code":
|
||||
# Consume authorization code
|
||||
try:
|
||||
user_id, scopes = await consume_authorization_code(
|
||||
code=request.code,
|
||||
application_id=app.id,
|
||||
redirect_uri=request.redirect_uri,
|
||||
code_verifier=request.code_verifier,
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Create access and refresh tokens
|
||||
access_token = await create_access_token(app.id, user_id, scopes)
|
||||
refresh_token = await create_refresh_token(app.id, user_id, scopes)
|
||||
|
||||
logger.info(
|
||||
f"Access token issued for user #{user_id} and app {app.name} (#{app.id})"
|
||||
"via authorization code"
|
||||
)
|
||||
|
||||
if not access_token.token or not refresh_token.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate tokens",
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
token_type="Bearer",
|
||||
access_token=access_token.token.get_secret_value(),
|
||||
access_token_expires_at=access_token.expires_at,
|
||||
refresh_token=refresh_token.token.get_secret_value(),
|
||||
refresh_token_expires_at=refresh_token.expires_at,
|
||||
scopes=list(s.value for s in scopes),
|
||||
)
|
||||
|
||||
# Handle refresh_token grant
|
||||
elif request.grant_type == "refresh_token":
|
||||
# Refresh access token
|
||||
try:
|
||||
new_access_token, new_refresh_token = await refresh_tokens(
|
||||
request.refresh_token, app.id
|
||||
)
|
||||
except InvalidGrantError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Tokens refreshed for user #{new_access_token.user_id} "
|
||||
f"by app {app.name} (#{app.id})"
|
||||
)
|
||||
|
||||
if not new_access_token.token or not new_refresh_token.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate tokens",
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
token_type="Bearer",
|
||||
access_token=new_access_token.token.get_secret_value(),
|
||||
access_token_expires_at=new_access_token.expires_at,
|
||||
refresh_token=new_refresh_token.token.get_secret_value(),
|
||||
refresh_token_expires_at=new_refresh_token.expires_at,
|
||||
scopes=list(s.value for s in new_access_token.scopes),
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported grant_type: {request.grant_type}. "
|
||||
"Must be 'authorization_code' or 'refresh_token'",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Introspection Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/introspect")
|
||||
async def introspect(
|
||||
token: str = Body(description="Token to introspect"),
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||
),
|
||||
client_id: str = Body(description="Client identifier"),
|
||||
client_secret: str = Body(description="Client secret"),
|
||||
) -> TokenIntrospectionResult:
|
||||
"""
|
||||
OAuth 2.0 Token Introspection Endpoint (RFC 7662)
|
||||
|
||||
Allows clients to check if a token is valid and get its metadata.
|
||||
|
||||
Returns:
|
||||
- active: Whether the token is currently active
|
||||
- scopes: List of authorized scopes (if active)
|
||||
- client_id: The client the token was issued to (if active)
|
||||
- user_id: The user the token represents (if active)
|
||||
- exp: Expiration timestamp (if active)
|
||||
- token_type: "access_token" or "refresh_token" (if active)
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
await validate_client_credentials(client_id, client_secret)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Introspect the token
|
||||
return await introspect_token(token, token_type_hint)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Revocation Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post("/revoke")
|
||||
async def revoke(
|
||||
token: str = Body(description="Token to revoke"),
|
||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Body(
|
||||
None, description="Hint about token type ('access_token' or 'refresh_token')"
|
||||
),
|
||||
client_id: str = Body(description="Client identifier"),
|
||||
client_secret: str = Body(description="Client secret"),
|
||||
):
|
||||
"""
|
||||
OAuth 2.0 Token Revocation Endpoint (RFC 7009)
|
||||
|
||||
Allows clients to revoke an access or refresh token.
|
||||
|
||||
Note: Revoking a refresh token does NOT revoke associated access tokens.
|
||||
Revoking an access token does NOT revoke the associated refresh token.
|
||||
"""
|
||||
# Validate client credentials
|
||||
try:
|
||||
app = await validate_client_credentials(client_id, client_secret)
|
||||
except InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Try to revoke as access token first
|
||||
# Note: We pass app.id to ensure the token belongs to the authenticated app
|
||||
if token_type_hint != "refresh_token":
|
||||
revoked = await revoke_access_token(token, app.id)
|
||||
if revoked:
|
||||
logger.info(
|
||||
f"Access token revoked for app {app.name} (#{app.id}); "
|
||||
f"user #{revoked.user_id}"
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
# Try to revoke as refresh token
|
||||
revoked = await revoke_refresh_token(token, app.id)
|
||||
if revoked:
|
||||
logger.info(
|
||||
f"Refresh token revoked for app {app.name} (#{app.id}); "
|
||||
f"user #{revoked.user_id}"
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
# Per RFC 7009, revocation endpoint returns 200 even if token not found
|
||||
# or if token belongs to a different application.
|
||||
# This prevents token scanning attacks.
|
||||
logger.warning(f"Unsuccessful token revocation attempt by app {app.name} #{app.id}")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Application Management Endpoints (for app owners)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get("/apps/mine")
|
||||
async def list_my_oauth_apps(
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> list[OAuthApplicationInfo]:
|
||||
"""
|
||||
List all OAuth applications owned by the current user.
|
||||
|
||||
Returns a list of OAuth applications with their details including:
|
||||
- id, name, description, logo_url
|
||||
- client_id (public identifier)
|
||||
- redirect_uris, grant_types, scopes
|
||||
- is_active status
|
||||
- created_at, updated_at timestamps
|
||||
|
||||
Note: client_secret is never returned for security reasons.
|
||||
"""
|
||||
return await list_user_oauth_applications(user_id)
|
||||
|
||||
|
||||
@router.patch("/apps/{app_id}/status")
|
||||
async def update_app_status(
|
||||
app_id: str,
|
||||
user_id: str = Security(get_user_id),
|
||||
is_active: bool = Body(description="Whether the app should be active", embed=True),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Enable or disable an OAuth application.
|
||||
|
||||
Only the application owner can update the status.
|
||||
When disabled, the application cannot be used for new authorizations
|
||||
and existing access tokens will fail validation.
|
||||
|
||||
Returns the updated application info.
|
||||
"""
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
action = "enabled" if is_active else "disabled"
|
||||
logger.info(f"OAuth app {updated_app.name} (#{app_id}) {action} by user #{user_id}")
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
class UpdateAppLogoRequest(BaseModel):
|
||||
logo_url: str = Field(description="URL of the uploaded logo image")
|
||||
|
||||
|
||||
@router.patch("/apps/{app_id}/logo")
|
||||
async def update_app_logo(
|
||||
app_id: str,
|
||||
request: UpdateAppLogoRequest = Body(),
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Update the logo URL for an OAuth application.
|
||||
|
||||
Only the application owner can update the logo.
|
||||
The logo should be uploaded first using the media upload endpoint,
|
||||
then this endpoint is called with the resulting URL.
|
||||
|
||||
Logo requirements:
|
||||
- Must be square (1:1 aspect ratio)
|
||||
- Minimum 512x512 pixels
|
||||
- Maximum 2048x2048 pixels
|
||||
|
||||
Returns the updated application info.
|
||||
"""
|
||||
if (
|
||||
not (app := await get_oauth_application_by_id(app_id))
|
||||
or app.owner_id != user_id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth App not found",
|
||||
)
|
||||
|
||||
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||
await _delete_app_current_logo_file(app)
|
||||
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
logo_url=request.logo_url,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth app {updated_app.name} (#{app_id}) logo updated by user #{user_id}"
|
||||
)
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
# Logo upload constraints
|
||||
LOGO_MIN_SIZE = 512
|
||||
LOGO_MAX_SIZE = 2048
|
||||
LOGO_ALLOWED_TYPES = {"image/jpeg", "image/png", "image/webp"}
|
||||
LOGO_MAX_FILE_SIZE = 3 * 1024 * 1024 # 3MB
|
||||
|
||||
|
||||
@router.post("/apps/{app_id}/logo/upload")
|
||||
async def upload_app_logo(
|
||||
app_id: str,
|
||||
file: UploadFile,
|
||||
user_id: str = Security(get_user_id),
|
||||
) -> OAuthApplicationInfo:
|
||||
"""
|
||||
Upload a logo image for an OAuth application.
|
||||
|
||||
Requirements:
|
||||
- Image must be square (1:1 aspect ratio)
|
||||
- Minimum 512x512 pixels
|
||||
- Maximum 2048x2048 pixels
|
||||
- Allowed formats: JPEG, PNG, WebP
|
||||
- Maximum file size: 3MB
|
||||
|
||||
The image is uploaded to cloud storage and the app's logoUrl is updated.
|
||||
Returns the updated application info.
|
||||
"""
|
||||
# Verify ownership to reduce vulnerability to DoS(torage) or DoM(oney) attacks
|
||||
if (
|
||||
not (app := await get_oauth_application_by_id(app_id))
|
||||
or app.owner_id != user_id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth App not found",
|
||||
)
|
||||
|
||||
# Check GCS configuration
|
||||
if not settings.config.media_gcs_bucket_name:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Media storage is not configured",
|
||||
)
|
||||
|
||||
# Validate content type
|
||||
content_type = file.content_type
|
||||
if content_type not in LOGO_ALLOWED_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid file type. Allowed: JPEG, PNG, WebP. Got: {content_type}",
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
file_bytes = await file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading logo file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to read uploaded file",
|
||||
)
|
||||
|
||||
# Check file size
|
||||
if len(file_bytes) > LOGO_MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
"File too large. "
|
||||
f"Maximum size is {LOGO_MAX_FILE_SIZE // 1024 // 1024}MB"
|
||||
),
|
||||
)
|
||||
|
||||
# Validate image dimensions
|
||||
try:
|
||||
image = Image.open(io.BytesIO(file_bytes))
|
||||
width, height = image.size
|
||||
|
||||
if width != height:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo must be square. Got {width}x{height}",
|
||||
)
|
||||
|
||||
if width < LOGO_MIN_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo too small. Minimum {LOGO_MIN_SIZE}x{LOGO_MIN_SIZE}. "
|
||||
f"Got {width}x{height}",
|
||||
)
|
||||
|
||||
if width > LOGO_MAX_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Logo too large. Maximum {LOGO_MAX_SIZE}x{LOGO_MAX_SIZE}. "
|
||||
f"Got {width}x{height}",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating logo image: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid image file",
|
||||
)
|
||||
|
||||
# Scan for viruses
|
||||
filename = file.filename or "logo"
|
||||
await scan_content_safe(file_bytes, filename=filename)
|
||||
|
||||
# Generate unique filename
|
||||
file_ext = os.path.splitext(filename)[1].lower() or ".png"
|
||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||
storage_path = f"oauth-apps/{app_id}/logo/{unique_filename}"
|
||||
|
||||
# Upload to GCS
|
||||
try:
|
||||
async with async_storage.Storage() as async_client:
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
|
||||
await async_client.upload(
|
||||
bucket_name, storage_path, file_bytes, content_type=content_type
|
||||
)
|
||||
|
||||
logo_url = f"https://storage.googleapis.com/{bucket_name}/{storage_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading logo to GCS: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to upload logo",
|
||||
)
|
||||
|
||||
# Delete the current app logo file (if any and it's in our cloud storage)
|
||||
await _delete_app_current_logo_file(app)
|
||||
|
||||
# Update the app with the new logo URL
|
||||
updated_app = await update_oauth_application(
|
||||
app_id=app_id,
|
||||
owner_id=user_id,
|
||||
logo_url=logo_url,
|
||||
)
|
||||
|
||||
if not updated_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Application not found or you don't have permission to update it",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth app {updated_app.name} (#{app_id}) logo uploaded by user #{user_id}"
|
||||
)
|
||||
|
||||
return updated_app
|
||||
|
||||
|
||||
async def _delete_app_current_logo_file(app: OAuthApplicationInfo):
|
||||
"""
|
||||
Delete the current logo file for the given app, if there is one in our cloud storage
|
||||
"""
|
||||
bucket_name = settings.config.media_gcs_bucket_name
|
||||
storage_base_url = f"https://storage.googleapis.com/{bucket_name}/"
|
||||
|
||||
if app.logo_url and app.logo_url.startswith(storage_base_url):
|
||||
# Parse blob path from URL: https://storage.googleapis.com/{bucket}/{path}
|
||||
old_path = app.logo_url.replace(storage_base_url, "")
|
||||
try:
|
||||
async with async_storage.Storage() as async_client:
|
||||
await async_client.delete(bucket_name, old_path)
|
||||
logger.info(f"Deleted old logo for OAuth app #{app.id}: {old_path}")
|
||||
except Exception as e:
|
||||
# Log but don't fail - the new logo was uploaded successfully
|
||||
logger.warning(
|
||||
f"Failed to delete old logo for OAuth app #{app.id}: {e}", exc_info=e
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,621 +0,0 @@
|
||||
"""
|
||||
Content Type Handlers for Unified Embeddings
|
||||
|
||||
Pluggable system for different content sources (store agents, blocks, docs).
|
||||
Each handler knows how to fetch and process its content type for embedding.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.data.db import query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentItem:
|
||||
"""Represents a piece of content to be embedded."""
|
||||
|
||||
content_id: str # Unique identifier (DB ID or file path)
|
||||
content_type: ContentType
|
||||
searchable_text: str # Combined text for embedding
|
||||
metadata: dict[str, Any] # Content-specific metadata
|
||||
user_id: str | None = None # For user-scoped content
|
||||
|
||||
|
||||
class ContentHandler(ABC):
|
||||
"""Base handler for fetching and processing content for embeddings."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def content_type(self) -> ContentType:
|
||||
"""The ContentType this handler manages."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""
|
||||
Fetch items that don't have embeddings yet.
|
||||
|
||||
Args:
|
||||
batch_size: Maximum number of items to return
|
||||
|
||||
Returns:
|
||||
List of ContentItem objects ready for embedding
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get statistics about embedding coverage.
|
||||
|
||||
Returns:
|
||||
Dict with keys: total, with_embeddings, without_embeddings
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StoreAgentHandler(ContentHandler):
|
||||
"""Handler for marketplace store agent listings."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.STORE_AGENT
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch approved store listings without embeddings."""
|
||||
from backend.api.features.store.embeddings import build_searchable_text
|
||||
|
||||
missing = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
slv.id,
|
||||
slv.name,
|
||||
slv.description,
|
||||
slv."subHeading",
|
||||
slv.categories
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce
|
||||
ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND uce."contentId" IS NULL
|
||||
LIMIT $1
|
||||
""",
|
||||
batch_size,
|
||||
)
|
||||
|
||||
return [
|
||||
ContentItem(
|
||||
content_id=row["id"],
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
searchable_text=build_searchable_text(
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
sub_heading=row["subHeading"],
|
||||
categories=row["categories"] or [],
|
||||
),
|
||||
metadata={
|
||||
"name": row["name"],
|
||||
"categories": row["categories"] or [],
|
||||
},
|
||||
user_id=None, # Store agents are public
|
||||
)
|
||||
for row in missing
|
||||
]
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about store agent embedding coverage."""
|
||||
# Count approved versions
|
||||
approved_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
AND "isDeleted" = false
|
||||
"""
|
||||
)
|
||||
total_approved = approved_result[0]["count"] if approved_result else 0
|
||||
|
||||
# Count versions with embeddings
|
||||
embedded_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
"""
|
||||
)
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_approved,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_approved - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
class BlockHandler(ContentHandler):
|
||||
"""Handler for block definitions (Python classes)."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.BLOCK
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch blocks without embeddings."""
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
# Get all available blocks
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Check which ones have embeddings
|
||||
if not all_blocks:
|
||||
return []
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
|
||||
# Query for existing embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
existing_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT "contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*block_ids,
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
missing_blocks = [
|
||||
(block_id, block_cls)
|
||||
for block_id, block_cls in all_blocks.items()
|
||||
if block_id not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem
|
||||
items = []
|
||||
for block_id, block_cls in missing_blocks[:batch_size]:
|
||||
try:
|
||||
block_instance = block_cls()
|
||||
|
||||
# Skip disabled blocks - they shouldn't be indexed
|
||||
if block_instance.disabled:
|
||||
continue
|
||||
|
||||
# Build searchable text from block metadata
|
||||
parts = []
|
||||
if hasattr(block_instance, "name") and block_instance.name:
|
||||
parts.append(block_instance.name)
|
||||
if (
|
||||
hasattr(block_instance, "description")
|
||||
and block_instance.description
|
||||
):
|
||||
parts.append(block_instance.description)
|
||||
if hasattr(block_instance, "categories") and block_instance.categories:
|
||||
# Convert BlockCategory enum to strings
|
||||
parts.append(
|
||||
" ".join(str(cat.value) for cat in block_instance.categories)
|
||||
)
|
||||
|
||||
# Add input/output schema info
|
||||
if hasattr(block_instance, "input_schema"):
|
||||
schema = block_instance.input_schema
|
||||
if hasattr(schema, "model_json_schema"):
|
||||
schema_dict = schema.model_json_schema()
|
||||
if "properties" in schema_dict:
|
||||
for prop_name, prop_info in schema_dict[
|
||||
"properties"
|
||||
].items():
|
||||
if "description" in prop_info:
|
||||
parts.append(
|
||||
f"{prop_name}: {prop_info['description']}"
|
||||
)
|
||||
|
||||
searchable_text = " ".join(parts)
|
||||
|
||||
# Convert categories set of enums to list of strings for JSON serialization
|
||||
categories = getattr(block_instance, "categories", set())
|
||||
categories_list = (
|
||||
[cat.value for cat in categories] if categories else []
|
||||
)
|
||||
|
||||
items.append(
|
||||
ContentItem(
|
||||
content_id=block_id,
|
||||
content_type=ContentType.BLOCK,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"name": getattr(block_instance, "name", ""),
|
||||
"categories": categories_list,
|
||||
},
|
||||
user_id=None, # Blocks are public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process block {block_id}: {e}")
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about block embedding coverage."""
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Filter out disabled blocks - they're not indexed
|
||||
enabled_block_ids = [
|
||||
block_id
|
||||
for block_id, block_cls in all_blocks.items()
|
||||
if not block_cls().disabled
|
||||
]
|
||||
total_blocks = len(enabled_block_ids)
|
||||
|
||||
if total_blocks == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
block_ids = enabled_block_ids
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
|
||||
embedded_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*block_ids,
|
||||
)
|
||||
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_blocks,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_blocks - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarkdownSection:
|
||||
"""Represents a section of a markdown document."""
|
||||
|
||||
title: str # Section heading text
|
||||
content: str # Section content (including the heading line)
|
||||
level: int # Heading level (1 for #, 2 for ##, etc.)
|
||||
index: int # Section index within the document
|
||||
|
||||
|
||||
class DocumentationHandler(ContentHandler):
|
||||
"""Handler for documentation files (.md/.mdx).
|
||||
|
||||
Chunks documents by markdown headings to create multiple embeddings per file.
|
||||
Each section (## heading) becomes a separate embedding for better retrieval.
|
||||
"""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.DOCUMENTATION
|
||||
|
||||
def _get_docs_root(self) -> Path:
|
||||
"""Get the documentation root directory."""
|
||||
# content_handlers.py is at: backend/backend/api/features/store/content_handlers.py
|
||||
# Need to go up to project root then into docs/
|
||||
# In container: /app/autogpt_platform/backend/backend/api/features/store -> /app/docs
|
||||
# In development: /repo/autogpt_platform/backend/backend/api/features/store -> /repo/docs
|
||||
this_file = Path(
|
||||
__file__
|
||||
) # .../backend/backend/api/features/store/content_handlers.py
|
||||
project_root = (
|
||||
this_file.parent.parent.parent.parent.parent.parent.parent
|
||||
) # -> /app or /repo
|
||||
docs_root = project_root / "docs"
|
||||
return docs_root
|
||||
|
||||
def _extract_doc_title(self, file_path: Path) -> str:
|
||||
"""Extract the document title from a markdown file."""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
lines = content.split("\n")
|
||||
|
||||
# Try to extract title from first # heading
|
||||
for line in lines:
|
||||
if line.startswith("# "):
|
||||
return line[2:].strip()
|
||||
|
||||
# If no title found, use filename
|
||||
return file_path.stem.replace("-", " ").replace("_", " ").title()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read title from {file_path}: {e}")
|
||||
return file_path.stem.replace("-", " ").replace("_", " ").title()
|
||||
|
||||
def _chunk_markdown_by_headings(
|
||||
self, file_path: Path, min_heading_level: int = 2
|
||||
) -> list[MarkdownSection]:
|
||||
"""
|
||||
Split a markdown file into sections based on headings.
|
||||
|
||||
Args:
|
||||
file_path: Path to the markdown file
|
||||
min_heading_level: Minimum heading level to split on (default: 2 for ##)
|
||||
|
||||
Returns:
|
||||
List of MarkdownSection objects, one per section.
|
||||
If no headings found, returns a single section with all content.
|
||||
"""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {file_path}: {e}")
|
||||
return []
|
||||
|
||||
lines = content.split("\n")
|
||||
sections: list[MarkdownSection] = []
|
||||
current_section_lines: list[str] = []
|
||||
current_title = ""
|
||||
current_level = 0
|
||||
section_index = 0
|
||||
doc_title = ""
|
||||
|
||||
for line in lines:
|
||||
# Check if line is a heading
|
||||
if line.startswith("#"):
|
||||
# Count heading level
|
||||
level = 0
|
||||
for char in line:
|
||||
if char == "#":
|
||||
level += 1
|
||||
else:
|
||||
break
|
||||
|
||||
heading_text = line[level:].strip()
|
||||
|
||||
# Track document title (level 1 heading)
|
||||
if level == 1 and not doc_title:
|
||||
doc_title = heading_text
|
||||
# Don't create a section for just the title - add it to first section
|
||||
current_section_lines.append(line)
|
||||
continue
|
||||
|
||||
# Check if this heading should start a new section
|
||||
if level >= min_heading_level:
|
||||
# Save previous section if it has content
|
||||
if current_section_lines:
|
||||
section_content = "\n".join(current_section_lines).strip()
|
||||
if section_content:
|
||||
# Use doc title for first section if no specific title
|
||||
title = current_title if current_title else doc_title
|
||||
if not title:
|
||||
title = file_path.stem.replace("-", " ").replace(
|
||||
"_", " "
|
||||
)
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
title=title,
|
||||
content=section_content,
|
||||
level=current_level if current_level else 1,
|
||||
index=section_index,
|
||||
)
|
||||
)
|
||||
section_index += 1
|
||||
|
||||
# Start new section
|
||||
current_section_lines = [line]
|
||||
current_title = heading_text
|
||||
current_level = level
|
||||
else:
|
||||
# Lower level heading (e.g., # when splitting on ##)
|
||||
current_section_lines.append(line)
|
||||
else:
|
||||
current_section_lines.append(line)
|
||||
|
||||
# Don't forget the last section
|
||||
if current_section_lines:
|
||||
section_content = "\n".join(current_section_lines).strip()
|
||||
if section_content:
|
||||
title = current_title if current_title else doc_title
|
||||
if not title:
|
||||
title = file_path.stem.replace("-", " ").replace("_", " ")
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
title=title,
|
||||
content=section_content,
|
||||
level=current_level if current_level else 1,
|
||||
index=section_index,
|
||||
)
|
||||
)
|
||||
|
||||
# If no sections were created (no headings found), create one section with all content
|
||||
if not sections and content.strip():
|
||||
title = (
|
||||
doc_title
|
||||
if doc_title
|
||||
else file_path.stem.replace("-", " ").replace("_", " ")
|
||||
)
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
title=title,
|
||||
content=content.strip(),
|
||||
level=1,
|
||||
index=0,
|
||||
)
|
||||
)
|
||||
|
||||
return sections
|
||||
|
||||
def _make_section_content_id(self, doc_path: str, section_index: int) -> str:
|
||||
"""Create a unique content ID for a document section.
|
||||
|
||||
Format: doc_path::section_index
|
||||
Example: 'platform/getting-started.md::0'
|
||||
"""
|
||||
return f"{doc_path}::{section_index}"
|
||||
|
||||
def _parse_section_content_id(self, content_id: str) -> tuple[str, int]:
|
||||
"""Parse a section content ID back into doc_path and section_index.
|
||||
|
||||
Returns: (doc_path, section_index)
|
||||
"""
|
||||
if "::" in content_id:
|
||||
parts = content_id.rsplit("::", 1)
|
||||
return parts[0], int(parts[1])
|
||||
# Legacy format (whole document)
|
||||
return content_id, 0
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch documentation sections without embeddings.
|
||||
|
||||
Chunks each document by markdown headings and creates embeddings for each section.
|
||||
Content IDs use the format: 'path/to/doc.md::section_index'
|
||||
"""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
logger.warning(f"Documentation root not found: {docs_root}")
|
||||
return []
|
||||
|
||||
# Find all .md and .mdx files
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
|
||||
if not all_docs:
|
||||
return []
|
||||
|
||||
# Build list of all sections from all documents
|
||||
all_sections: list[tuple[str, Path, MarkdownSection]] = []
|
||||
for doc_file in all_docs:
|
||||
doc_path = str(doc_file.relative_to(docs_root))
|
||||
sections = self._chunk_markdown_by_headings(doc_file)
|
||||
for section in sections:
|
||||
all_sections.append((doc_path, doc_file, section))
|
||||
|
||||
if not all_sections:
|
||||
return []
|
||||
|
||||
# Generate content IDs for all sections
|
||||
section_content_ids = [
|
||||
self._make_section_content_id(doc_path, section.index)
|
||||
for doc_path, _, section in all_sections
|
||||
]
|
||||
|
||||
# Check which ones have embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(section_content_ids))])
|
||||
existing_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT "contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*section_content_ids,
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
|
||||
# Filter to missing sections
|
||||
missing_sections = [
|
||||
(doc_path, doc_file, section, content_id)
|
||||
for (doc_path, doc_file, section), content_id in zip(
|
||||
all_sections, section_content_ids
|
||||
)
|
||||
if content_id not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem (up to batch_size)
|
||||
items = []
|
||||
for doc_path, doc_file, section, content_id in missing_sections[:batch_size]:
|
||||
try:
|
||||
# Get document title for context
|
||||
doc_title = self._extract_doc_title(doc_file)
|
||||
|
||||
# Build searchable text with context
|
||||
# Include doc title and section title for better search relevance
|
||||
searchable_text = f"{doc_title} - {section.title}\n\n{section.content}"
|
||||
|
||||
items.append(
|
||||
ContentItem(
|
||||
content_id=content_id,
|
||||
content_type=ContentType.DOCUMENTATION,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"doc_title": doc_title,
|
||||
"section_title": section.title,
|
||||
"section_index": section.index,
|
||||
"heading_level": section.level,
|
||||
"path": doc_path,
|
||||
},
|
||||
user_id=None, # Documentation is public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process section {content_id}: {e}")
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
def _get_all_section_content_ids(self, docs_root: Path) -> set[str]:
|
||||
"""Get all current section content IDs from the docs directory.
|
||||
|
||||
Used for stats and cleanup to know what sections should exist.
|
||||
"""
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
content_ids = set()
|
||||
|
||||
for doc_file in all_docs:
|
||||
doc_path = str(doc_file.relative_to(docs_root))
|
||||
sections = self._chunk_markdown_by_headings(doc_file)
|
||||
for section in sections:
|
||||
content_ids.add(self._make_section_content_id(doc_path, section.index))
|
||||
|
||||
return content_ids
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about documentation embedding coverage.
|
||||
|
||||
Counts sections (not documents) since each section gets its own embedding.
|
||||
"""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
# Get all section content IDs
|
||||
all_section_ids = self._get_all_section_content_ids(docs_root)
|
||||
total_sections = len(all_section_ids)
|
||||
|
||||
if total_sections == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
# Count embeddings in database for DOCUMENTATION type
|
||||
embedded_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{schema_prefix}"ContentType"
|
||||
"""
|
||||
)
|
||||
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_sections,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_sections - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
# Content handler registry
|
||||
CONTENT_HANDLERS: dict[ContentType, ContentHandler] = {
|
||||
ContentType.STORE_AGENT: StoreAgentHandler(),
|
||||
ContentType.BLOCK: BlockHandler(),
|
||||
ContentType.DOCUMENTATION: DocumentationHandler(),
|
||||
}
|
||||
@@ -1,215 +0,0 @@
|
||||
"""
|
||||
Integration tests for content handlers using real DB.
|
||||
|
||||
Run with: poetry run pytest backend/api/features/store/content_handlers_integration_test.py -xvs
|
||||
|
||||
These tests use the real database but mock OpenAI calls.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.store.content_handlers import (
|
||||
CONTENT_HANDLERS,
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
)
|
||||
from backend.api.features.store.embeddings import (
|
||||
EMBEDDING_DIM,
|
||||
backfill_all_content_types,
|
||||
ensure_content_embedding,
|
||||
get_embedding_stats,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_real_db():
|
||||
"""Test StoreAgentHandler with real database queries."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Get stats from real DB
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list (may be empty if all have embeddings)
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None
|
||||
assert item.content_type.value == "STORE_AGENT"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_real_db():
|
||||
"""Test BlockHandler with real database queries."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Get stats from real DB
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0 # Should have at least some blocks
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None # Should be block UUID
|
||||
assert item.content_type.value == "BLOCK"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_real_fs():
|
||||
"""Test DocumentationHandler with real filesystem."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Get stats from real filesystem
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None # Should be relative path
|
||||
assert item.content_type.value == "DOCUMENTATION"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_stats_all_types():
|
||||
"""Test get_embedding_stats aggregates all content types."""
|
||||
stats = await get_embedding_stats()
|
||||
|
||||
# Should have structure with by_type and totals
|
||||
assert "by_type" in stats
|
||||
assert "totals" in stats
|
||||
|
||||
# Check each content type is present
|
||||
by_type = stats["by_type"]
|
||||
assert "STORE_AGENT" in by_type
|
||||
assert "BLOCK" in by_type
|
||||
assert "DOCUMENTATION" in by_type
|
||||
|
||||
# Check totals are aggregated
|
||||
totals = stats["totals"]
|
||||
assert totals["total"] >= 0
|
||||
assert totals["with_embeddings"] >= 0
|
||||
assert totals["without_embeddings"] >= 0
|
||||
assert "coverage_percent" in totals
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
async def test_ensure_content_embedding_blocks(mock_generate):
|
||||
"""Test creating embeddings for blocks (mocked OpenAI)."""
|
||||
# Mock OpenAI to return fake embedding
|
||||
mock_generate.return_value = [0.1] * EMBEDDING_DIM
|
||||
|
||||
# Get one block without embedding
|
||||
handler = BlockHandler()
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
if not items:
|
||||
pytest.skip("No blocks without embeddings")
|
||||
|
||||
item = items[0]
|
||||
|
||||
# Try to create embedding (OpenAI mocked)
|
||||
result = await ensure_content_embedding(
|
||||
content_type=item.content_type,
|
||||
content_id=item.content_id,
|
||||
searchable_text=item.searchable_text,
|
||||
metadata=item.metadata,
|
||||
user_id=item.user_id,
|
||||
)
|
||||
|
||||
# Should succeed with mocked OpenAI
|
||||
assert result is True
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
async def test_backfill_all_content_types_dry_run(mock_generate):
|
||||
"""Test backfill_all_content_types processes all handlers in order."""
|
||||
# Mock OpenAI to return fake embedding
|
||||
mock_generate.return_value = [0.1] * EMBEDDING_DIM
|
||||
|
||||
# Run backfill with batch_size=1 to process max 1 per type
|
||||
result = await backfill_all_content_types(batch_size=1)
|
||||
|
||||
# Should have results for all content types
|
||||
assert "by_type" in result
|
||||
assert "totals" in result
|
||||
|
||||
by_type = result["by_type"]
|
||||
assert "BLOCK" in by_type
|
||||
assert "STORE_AGENT" in by_type
|
||||
assert "DOCUMENTATION" in by_type
|
||||
|
||||
# Each type should have correct structure
|
||||
for content_type, type_result in by_type.items():
|
||||
assert "processed" in type_result
|
||||
assert "success" in type_result
|
||||
assert "failed" in type_result
|
||||
|
||||
# Totals should aggregate
|
||||
totals = result["totals"]
|
||||
assert totals["processed"] >= 0
|
||||
assert totals["success"] >= 0
|
||||
assert totals["failed"] >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handler_registry():
|
||||
"""Test all handlers are registered in correct order."""
|
||||
from prisma.enums import ContentType
|
||||
|
||||
# All three types should be registered
|
||||
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
|
||||
assert ContentType.BLOCK in CONTENT_HANDLERS
|
||||
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
|
||||
|
||||
# Check handler types
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
@@ -1,391 +0,0 @@
|
||||
"""
|
||||
E2E tests for content handlers (blocks, store agents, documentation).
|
||||
|
||||
Tests the full flow: discovering content → generating embeddings → storing.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store.content_handlers import (
|
||||
CONTENT_HANDLERS,
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_missing_items(mocker):
|
||||
"""Test StoreAgentHandler fetches approved agents without embeddings."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock database query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "agent-1",
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"subHeading": "Test heading",
|
||||
"categories": ["AI", "Testing"],
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_missing,
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "agent-1"
|
||||
assert items[0].content_type == ContentType.STORE_AGENT
|
||||
assert "Test Agent" in items[0].searchable_text
|
||||
assert "A test agent" in items[0].searchable_text
|
||||
assert items[0].metadata["name"] == "Test Agent"
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_stats(mocker):
|
||||
"""Test StoreAgentHandler returns correct stats."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock approved count query
|
||||
mock_approved = [{"count": 50}]
|
||||
# Mock embedded count query
|
||||
mock_embedded = [{"count": 30}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
side_effect=[mock_approved, mock_embedded],
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 50
|
||||
assert stats["with_embeddings"] == 30
|
||||
assert stats["without_embeddings"] == 20
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_missing_items(mocker):
|
||||
"""Test BlockHandler discovers blocks without embeddings."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks to return test blocks
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Calculator Block"
|
||||
mock_block_instance.description = "Performs calculations"
|
||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||
mock_block_instance.disabled = False
|
||||
mock_block_instance.input_schema.model_json_schema.return_value = {
|
||||
"properties": {"expression": {"description": "Math expression to evaluate"}}
|
||||
}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-uuid-1": mock_block_class}
|
||||
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
mock_existing = []
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_existing,
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "block-uuid-1"
|
||||
assert items[0].content_type == ContentType.BLOCK
|
||||
assert "Calculator Block" in items[0].searchable_text
|
||||
assert "Performs calculations" in items[0].searchable_text
|
||||
assert "MATH" in items[0].searchable_text
|
||||
assert "expression: Math expression" in items[0].searchable_text
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats(mocker):
|
||||
"""Test BlockHandler returns correct stats."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks - each block class returns an instance with disabled=False
|
||||
def make_mock_block_class():
|
||||
mock_class = MagicMock()
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.disabled = False
|
||||
mock_class.return_value = mock_instance
|
||||
return mock_class
|
||||
|
||||
mock_blocks = {
|
||||
"block-1": make_mock_block_class(),
|
||||
"block-2": make_mock_block_class(),
|
||||
"block-3": make_mock_block_class(),
|
||||
}
|
||||
|
||||
# Mock embedded count query (2 blocks have embeddings)
|
||||
mock_embedded = [{"count": 2}]
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["with_embeddings"] == 2
|
||||
assert stats["without_embeddings"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
"""Test DocumentationHandler discovers docs without embeddings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory with test files
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
|
||||
(docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.")
|
||||
(docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.")
|
||||
|
||||
# Mock _get_docs_root to return temp dir
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 2
|
||||
|
||||
# Check guide.md (content_id format: doc_path::section_index)
|
||||
guide_item = next(
|
||||
(item for item in items if item.content_id == "guide.md::0"), None
|
||||
)
|
||||
assert guide_item is not None
|
||||
assert guide_item.content_type == ContentType.DOCUMENTATION
|
||||
assert "Getting Started" in guide_item.searchable_text
|
||||
assert "This is a guide" in guide_item.searchable_text
|
||||
assert guide_item.metadata["doc_title"] == "Getting Started"
|
||||
assert guide_item.user_id is None
|
||||
|
||||
# Check api.mdx (content_id format: doc_path::section_index)
|
||||
api_item = next(
|
||||
(item for item in items if item.content_id == "api.mdx::0"), None
|
||||
)
|
||||
assert api_item is not None
|
||||
assert "API Reference" in api_item.searchable_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_stats(tmp_path, mocker):
|
||||
"""Test DocumentationHandler returns correct stats."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
(docs_root / "doc1.md").write_text("# Doc 1")
|
||||
(docs_root / "doc2.md").write_text("# Doc 2")
|
||||
(docs_root / "doc3.mdx").write_text("# Doc 3")
|
||||
|
||||
# Mock embedded count query (1 doc has embedding)
|
||||
mock_embedded = [{"count": 1}]
|
||||
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["with_embeddings"] == 1
|
||||
assert stats["without_embeddings"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_title_extraction(tmp_path):
|
||||
"""Test DocumentationHandler extracts title from markdown heading."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test with heading
|
||||
doc_with_heading = tmp_path / "with_heading.md"
|
||||
doc_with_heading.write_text("# My Title\n\nContent here")
|
||||
title = handler._extract_doc_title(doc_with_heading)
|
||||
assert title == "My Title"
|
||||
|
||||
# Test without heading
|
||||
doc_without_heading = tmp_path / "no-heading.md"
|
||||
doc_without_heading.write_text("Just content, no heading")
|
||||
title = handler._extract_doc_title(doc_without_heading)
|
||||
assert title == "No Heading" # Uses filename
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
"""Test DocumentationHandler chunks markdown by headings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test document with multiple sections
|
||||
doc_with_sections = tmp_path / "sections.md"
|
||||
doc_with_sections.write_text(
|
||||
"# Document Title\n\n"
|
||||
"Intro paragraph.\n\n"
|
||||
"## Section One\n\n"
|
||||
"Content for section one.\n\n"
|
||||
"## Section Two\n\n"
|
||||
"Content for section two.\n"
|
||||
)
|
||||
sections = handler._chunk_markdown_by_headings(doc_with_sections)
|
||||
|
||||
# Should have 3 sections: intro (with doc title), section one, section two
|
||||
assert len(sections) == 3
|
||||
assert sections[0].title == "Document Title"
|
||||
assert sections[0].index == 0
|
||||
assert "Intro paragraph" in sections[0].content
|
||||
|
||||
assert sections[1].title == "Section One"
|
||||
assert sections[1].index == 1
|
||||
assert "Content for section one" in sections[1].content
|
||||
|
||||
assert sections[2].title == "Section Two"
|
||||
assert sections[2].index == 2
|
||||
assert "Content for section two" in sections[2].content
|
||||
|
||||
# Test document without headings
|
||||
doc_no_sections = tmp_path / "no-sections.md"
|
||||
doc_no_sections.write_text("Just plain content without any headings.")
|
||||
sections = handler._chunk_markdown_by_headings(doc_no_sections)
|
||||
assert len(sections) == 1
|
||||
assert sections[0].index == 0
|
||||
assert "Just plain content" in sections[0].content
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_section_content_ids():
|
||||
"""Test DocumentationHandler creates and parses section content IDs."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test making content ID
|
||||
content_id = handler._make_section_content_id("docs/guide.md", 2)
|
||||
assert content_id == "docs/guide.md::2"
|
||||
|
||||
# Test parsing content ID
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/guide.md::2")
|
||||
assert doc_path == "docs/guide.md"
|
||||
assert section_index == 2
|
||||
|
||||
# Test parsing legacy format (no section index)
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/old-format.md")
|
||||
assert doc_path == "docs/old-format.md"
|
||||
assert section_index == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handlers_registry():
|
||||
"""Test all content types are registered."""
|
||||
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
|
||||
assert ContentType.BLOCK in CONTENT_HANDLERS
|
||||
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
|
||||
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_missing_attributes():
|
||||
"""Test BlockHandler gracefully handles blocks with missing attributes."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock block with minimal attributes
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Minimal Block"
|
||||
mock_block_instance.disabled = False
|
||||
# No description, categories, or schema
|
||||
del mock_block_instance.description
|
||||
del mock_block_instance.categories
|
||||
del mock_block_instance.input_schema
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-minimal": mock_block_class}
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].searchable_text == "Minimal Block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_skips_failed_blocks():
|
||||
"""Test BlockHandler skips blocks that fail to instantiate."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock one good block and one bad block
|
||||
good_block = MagicMock()
|
||||
good_instance = MagicMock()
|
||||
good_instance.name = "Good Block"
|
||||
good_instance.description = "Works fine"
|
||||
good_instance.categories = []
|
||||
good_instance.disabled = False
|
||||
good_block.return_value = good_instance
|
||||
|
||||
bad_block = MagicMock()
|
||||
bad_block.side_effect = Exception("Instantiation failed")
|
||||
|
||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
# Should only get the good block
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "good-block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_missing_docs_directory():
|
||||
"""Test DocumentationHandler handles missing docs directory gracefully."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Mock _get_docs_root to return non-existent path
|
||||
fake_path = Path("/nonexistent/docs")
|
||||
with patch.object(handler, "_get_docs_root", return_value=fake_path):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
assert items == []
|
||||
|
||||
stats = await handler.get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
assert stats["without_embeddings"] == 0
|
||||
@@ -1,965 +0,0 @@
|
||||
"""
|
||||
Unified Content Embeddings Service
|
||||
|
||||
Handles generation and storage of OpenAI embeddings for all content types
|
||||
(store listings, blocks, documentation, library agents) to enable semantic/hybrid search.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
from prisma.enums import ContentType
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.api.features.store.content_handlers import CONTENT_HANDLERS
|
||||
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
|
||||
from backend.util.clients import get_openai_client
|
||||
from backend.util.json import dumps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenAI embedding model configuration
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
# Embedding dimension for the model above
|
||||
# text-embedding-3-small: 1536, text-embedding-3-large: 3072
|
||||
EMBEDDING_DIM = 1536
|
||||
# OpenAI embedding token limit (8,191 with 1 token buffer for safety)
|
||||
EMBEDDING_MAX_TOKENS = 8191
|
||||
|
||||
|
||||
def build_searchable_text(
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Build searchable text from listing version fields.
|
||||
|
||||
Combines relevant fields into a single string for embedding.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Name is important - include it
|
||||
if name:
|
||||
parts.append(name)
|
||||
|
||||
# Sub-heading provides context
|
||||
if sub_heading:
|
||||
parts.append(sub_heading)
|
||||
|
||||
# Description is the main content
|
||||
if description:
|
||||
parts.append(description)
|
||||
|
||||
# Categories help with semantic matching
|
||||
if categories:
|
||||
parts.append(" ".join(categories))
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
async def generate_embedding(text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding for text using OpenAI API.
|
||||
|
||||
Raises exceptions on failure - caller should handle.
|
||||
"""
|
||||
client = get_openai_client()
|
||||
if not client:
|
||||
raise RuntimeError("openai_internal_api_key not set, cannot generate embedding")
|
||||
|
||||
# Truncate text to token limit using tiktoken
|
||||
# Character-based truncation is insufficient because token ratios vary by content type
|
||||
enc = encoding_for_model(EMBEDDING_MODEL)
|
||||
tokens = enc.encode(text)
|
||||
if len(tokens) > EMBEDDING_MAX_TOKENS:
|
||||
tokens = tokens[:EMBEDDING_MAX_TOKENS]
|
||||
truncated_text = enc.decode(tokens)
|
||||
logger.info(
|
||||
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
|
||||
)
|
||||
else:
|
||||
truncated_text = text
|
||||
|
||||
start_time = time.time()
|
||||
response = await client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=truncated_text,
|
||||
)
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
embedding = response.data[0].embedding
|
||||
logger.info(
|
||||
f"Generated embedding: {len(embedding)} dims, "
|
||||
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
|
||||
)
|
||||
return embedding
|
||||
|
||||
|
||||
async def store_embedding(
|
||||
version_id: str,
|
||||
embedding: list[float],
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store embedding in the database.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
DEPRECATED: Use ensure_embedding() instead (includes searchable_text).
|
||||
"""
|
||||
return await store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text="", # Empty for backward compat; ensure_embedding() populates this
|
||||
metadata=None,
|
||||
user_id=None, # Store agents are public
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
|
||||
async def store_content_embedding(
|
||||
content_type: ContentType,
|
||||
content_id: str,
|
||||
embedding: list[float],
|
||||
searchable_text: str,
|
||||
metadata: dict | None = None,
|
||||
user_id: str | None = None,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store embedding in the unified content embeddings table.
|
||||
|
||||
New function for unified content embedding storage.
|
||||
Uses raw SQL since Prisma doesn't natively support pgvector.
|
||||
|
||||
Raises exceptions on failure - caller should handle.
|
||||
"""
|
||||
client = tx if tx else prisma.get_client()
|
||||
|
||||
# Convert embedding to PostgreSQL vector format
|
||||
embedding_str = embedding_to_vector_string(embedding)
|
||||
metadata_json = dumps(metadata or {})
|
||||
|
||||
# Upsert the embedding
|
||||
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
|
||||
# Use unqualified ::vector - pgvector is in search_path on all environments
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
|
||||
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
||||
)
|
||||
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
|
||||
ON CONFLICT ("contentType", "contentId", "userId")
|
||||
DO UPDATE SET
|
||||
"embedding" = $4::vector,
|
||||
"searchableText" = $5,
|
||||
"metadata" = $6::jsonb,
|
||||
"updatedAt" = NOW()
|
||||
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
|
||||
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
|
||||
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
embedding_str,
|
||||
searchable_text,
|
||||
metadata_json,
|
||||
client=client,
|
||||
)
|
||||
|
||||
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
||||
return True
|
||||
|
||||
|
||||
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve embedding record for a listing version.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
Returns dict with storeListingVersionId, embedding, timestamps or None if not found.
|
||||
"""
|
||||
result = await get_content_embedding(
|
||||
ContentType.STORE_AGENT, version_id, user_id=None
|
||||
)
|
||||
if result:
|
||||
# Transform to old format for backward compatibility
|
||||
return {
|
||||
"storeListingVersionId": result["contentId"],
|
||||
"embedding": result["embedding"],
|
||||
"createdAt": result["createdAt"],
|
||||
"updatedAt": result["updatedAt"],
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def get_content_embedding(
|
||||
content_type: ContentType, content_id: str, user_id: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve embedding record for any content type.
|
||||
|
||||
New function for unified content embedding retrieval.
|
||||
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
|
||||
|
||||
Raises exceptions on failure - caller should handle.
|
||||
"""
|
||||
result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
"contentType",
|
||||
"contentId",
|
||||
"userId",
|
||||
"embedding"::text as "embedding",
|
||||
"searchableText",
|
||||
"metadata",
|
||||
"createdAt",
|
||||
"updatedAt"
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
if result and len(result) > 0:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
|
||||
async def ensure_embedding(
|
||||
version_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
force: bool = False,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure an embedding exists for the listing version.
|
||||
|
||||
Creates embedding if missing. Use force=True to regenerate.
|
||||
Backward-compatible wrapper for store listings.
|
||||
|
||||
Args:
|
||||
version_id: The StoreListingVersion ID
|
||||
name: Agent name
|
||||
description: Agent description
|
||||
sub_heading: Agent sub-heading
|
||||
categories: Agent categories
|
||||
force: Force regeneration even if embedding exists
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created
|
||||
|
||||
Raises exceptions on failure - caller should handle.
|
||||
"""
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_embedding(version_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(f"Embedding for version {version_id} already exists")
|
||||
return True
|
||||
|
||||
# Build searchable text for embedding
|
||||
searchable_text = build_searchable_text(name, description, sub_heading, categories)
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
|
||||
# Store the embedding with metadata using new function
|
||||
metadata = {
|
||||
"name": name,
|
||||
"subHeading": sub_heading,
|
||||
"categories": categories,
|
||||
}
|
||||
return await store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata,
|
||||
user_id=None, # Store agents are public
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
|
||||
async def delete_embedding(version_id: str) -> bool:
|
||||
"""
|
||||
Delete embedding for a listing version.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
Note: This is usually handled automatically by CASCADE delete,
|
||||
but provided for manual cleanup if needed.
|
||||
"""
|
||||
return await delete_content_embedding(ContentType.STORE_AGENT, version_id)
|
||||
|
||||
|
||||
async def delete_content_embedding(
|
||||
content_type: ContentType, content_id: str, user_id: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Delete embedding for any content type.
|
||||
|
||||
New function for unified content embedding deletion.
|
||||
Note: This is usually handled automatically by CASCADE delete,
|
||||
but provided for manual cleanup if needed.
|
||||
|
||||
Args:
|
||||
content_type: The type of content (STORE_AGENT, LIBRARY_AGENT, etc.)
|
||||
content_id: The unique identifier for the content
|
||||
user_id: Optional user ID. For public content (STORE_AGENT, BLOCK), pass None.
|
||||
For user-scoped content (LIBRARY_AGENT), pass the user's ID to avoid
|
||||
deleting embeddings belonging to other users.
|
||||
|
||||
Returns:
|
||||
True if deletion succeeded, False otherwise
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType"
|
||||
AND "contentId" = $2
|
||||
AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
client=client,
|
||||
)
|
||||
|
||||
user_str = f" (user: {user_id})" if user_id else ""
|
||||
logger.info(f"Deleted embedding for {content_type}:{content_id}{user_str}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding_stats() -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about embedding coverage for all content types.
|
||||
|
||||
Returns stats per content type and overall totals.
|
||||
"""
|
||||
try:
|
||||
stats_by_type = {}
|
||||
total_items = 0
|
||||
total_with_embeddings = 0
|
||||
total_without_embeddings = 0
|
||||
|
||||
# Aggregate stats from all handlers
|
||||
for content_type, handler in CONTENT_HANDLERS.items():
|
||||
try:
|
||||
stats = await handler.get_stats()
|
||||
stats_by_type[content_type.value] = {
|
||||
"total": stats["total"],
|
||||
"with_embeddings": stats["with_embeddings"],
|
||||
"without_embeddings": stats["without_embeddings"],
|
||||
"coverage_percent": (
|
||||
round(stats["with_embeddings"] / stats["total"] * 100, 1)
|
||||
if stats["total"] > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
total_items += stats["total"]
|
||||
total_with_embeddings += stats["with_embeddings"]
|
||||
total_without_embeddings += stats["without_embeddings"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get stats for {content_type.value}: {e}")
|
||||
stats_by_type[content_type.value] = {
|
||||
"total": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
return {
|
||||
"by_type": stats_by_type,
|
||||
"totals": {
|
||||
"total": total_items,
|
||||
"with_embeddings": total_with_embeddings,
|
||||
"without_embeddings": total_without_embeddings,
|
||||
"coverage_percent": (
|
||||
round(total_with_embeddings / total_items * 100, 1)
|
||||
if total_items > 0
|
||||
else 0
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding stats: {e}")
|
||||
return {
|
||||
"by_type": {},
|
||||
"totals": {
|
||||
"total": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
},
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for approved listings that don't have them.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing usage.
|
||||
This now delegates to backfill_all_content_types() to process all content types.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate per content type
|
||||
|
||||
Returns:
|
||||
Dict with success/failure counts aggregated across all content types
|
||||
"""
|
||||
# Delegate to the new generic backfill system
|
||||
result = await backfill_all_content_types(batch_size)
|
||||
|
||||
# Return in the old format for backward compatibility
|
||||
return result["totals"]
|
||||
|
||||
|
||||
async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for all content types using registered handlers.
|
||||
|
||||
Processes content types in order: BLOCK → STORE_AGENT → DOCUMENTATION.
|
||||
This ensures foundational content (blocks) are searchable first.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate per content type
|
||||
|
||||
Returns:
|
||||
Dict with stats per content type and overall totals
|
||||
"""
|
||||
results_by_type = {}
|
||||
total_processed = 0
|
||||
total_success = 0
|
||||
total_failed = 0
|
||||
all_errors: dict[str, int] = {} # Aggregate errors across all content types
|
||||
|
||||
# Process content types in explicit order
|
||||
processing_order = [
|
||||
ContentType.BLOCK,
|
||||
ContentType.STORE_AGENT,
|
||||
ContentType.DOCUMENTATION,
|
||||
]
|
||||
|
||||
for content_type in processing_order:
|
||||
handler = CONTENT_HANDLERS.get(content_type)
|
||||
if not handler:
|
||||
logger.warning(f"No handler registered for {content_type.value}")
|
||||
continue
|
||||
try:
|
||||
logger.info(f"Processing {content_type.value} content type...")
|
||||
|
||||
# Get missing items from handler
|
||||
missing_items = await handler.get_missing_items(batch_size)
|
||||
|
||||
if not missing_items:
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"message": "No missing embeddings",
|
||||
}
|
||||
continue
|
||||
|
||||
# Process embeddings concurrently for better performance
|
||||
embedding_tasks = [
|
||||
ensure_content_embedding(
|
||||
content_type=item.content_type,
|
||||
content_id=item.content_id,
|
||||
searchable_text=item.searchable_text,
|
||||
metadata=item.metadata,
|
||||
user_id=item.user_id,
|
||||
)
|
||||
for item in missing_items
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*embedding_tasks, return_exceptions=True)
|
||||
|
||||
success = sum(1 for result in results if result is True)
|
||||
failed = len(results) - success
|
||||
|
||||
# Aggregate errors across all content types
|
||||
if failed > 0:
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
error_key = f"{type(result).__name__}: {str(result)}"
|
||||
all_errors[error_key] = all_errors.get(error_key, 0) + 1
|
||||
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": len(missing_items),
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"message": f"Backfilled {success} embeddings, {failed} failed",
|
||||
}
|
||||
|
||||
total_processed += len(missing_items)
|
||||
total_success += success
|
||||
total_failed += failed
|
||||
|
||||
logger.info(
|
||||
f"{content_type.value}: processed {len(missing_items)}, "
|
||||
f"success {success}, failed {failed}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process {content_type.value}: {e}")
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
# Log aggregated errors once at the end
|
||||
if all_errors:
|
||||
error_details = ", ".join(
|
||||
f"{error} ({count}x)" for error, count in all_errors.items()
|
||||
)
|
||||
logger.error(f"Embedding backfill errors: {error_details}")
|
||||
|
||||
return {
|
||||
"by_type": results_by_type,
|
||||
"totals": {
|
||||
"processed": total_processed,
|
||||
"success": total_success,
|
||||
"failed": total_failed,
|
||||
"message": f"Overall: {total_success} succeeded, {total_failed} failed",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def embed_query(query: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding for a search query.
|
||||
|
||||
Same as generate_embedding but with clearer intent.
|
||||
Raises exceptions on failure - caller should handle.
|
||||
"""
|
||||
return await generate_embedding(query)
|
||||
|
||||
|
||||
def embedding_to_vector_string(embedding: list[float]) -> str:
|
||||
"""Convert embedding list to PostgreSQL vector string format."""
|
||||
return "[" + ",".join(str(x) for x in embedding) + "]"
|
||||
|
||||
|
||||
async def ensure_content_embedding(
|
||||
content_type: ContentType,
|
||||
content_id: str,
|
||||
searchable_text: str,
|
||||
metadata: dict | None = None,
|
||||
user_id: str | None = None,
|
||||
force: bool = False,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure an embedding exists for any content type.
|
||||
|
||||
Generic function for creating embeddings for store agents, blocks, docs, etc.
|
||||
|
||||
Args:
|
||||
content_type: ContentType enum value (STORE_AGENT, BLOCK, etc.)
|
||||
content_id: Unique identifier for the content
|
||||
searchable_text: Combined text for embedding generation
|
||||
metadata: Optional metadata to store with embedding
|
||||
force: Force regeneration even if embedding exists
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created
|
||||
|
||||
Raises exceptions on failure - caller should handle.
|
||||
"""
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_content_embedding(content_type, content_id, user_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(f"Embedding for {content_type}:{content_id} already exists")
|
||||
return True
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
|
||||
# Store the embedding
|
||||
return await store_content_embedding(
|
||||
content_type=content_type,
|
||||
content_id=content_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata or {},
|
||||
user_id=user_id,
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
||||
"""
|
||||
Clean up embeddings for content that no longer exists or is no longer valid.
|
||||
|
||||
Compares current content with embeddings in database and removes orphaned records:
|
||||
- STORE_AGENT: Removes embeddings for rejected/deleted store listings
|
||||
- BLOCK: Removes embeddings for blocks no longer registered
|
||||
- DOCUMENTATION: Removes embeddings for deleted doc files
|
||||
|
||||
Returns:
|
||||
Dict with cleanup statistics per content type
|
||||
"""
|
||||
results_by_type = {}
|
||||
total_deleted = 0
|
||||
|
||||
# Cleanup orphaned embeddings for all content types
|
||||
cleanup_types = [
|
||||
ContentType.STORE_AGENT,
|
||||
ContentType.BLOCK,
|
||||
ContentType.DOCUMENTATION,
|
||||
]
|
||||
|
||||
for content_type in cleanup_types:
|
||||
try:
|
||||
handler = CONTENT_HANDLERS.get(content_type)
|
||||
if not handler:
|
||||
logger.warning(f"No handler registered for {content_type}")
|
||||
results_by_type[content_type.value] = {
|
||||
"deleted": 0,
|
||||
"error": "No handler registered",
|
||||
}
|
||||
continue
|
||||
|
||||
# Get all current content IDs from handler
|
||||
if content_type == ContentType.STORE_AGENT:
|
||||
# Get IDs of approved store listing versions from non-deleted listings
|
||||
valid_agents = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT slv.id
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
JOIN {schema_prefix}"StoreListing" sl ON slv."storeListingId" = sl.id
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND sl."isDeleted" = false
|
||||
""",
|
||||
)
|
||||
current_ids = {row["id"] for row in valid_agents}
|
||||
elif content_type == ContentType.BLOCK:
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
current_ids = set(get_blocks().keys())
|
||||
elif content_type == ContentType.DOCUMENTATION:
|
||||
# Use DocumentationHandler to get section-based content IDs
|
||||
from backend.api.features.store.content_handlers import (
|
||||
DocumentationHandler,
|
||||
)
|
||||
|
||||
doc_handler = CONTENT_HANDLERS.get(ContentType.DOCUMENTATION)
|
||||
if isinstance(doc_handler, DocumentationHandler):
|
||||
docs_root = doc_handler._get_docs_root()
|
||||
if docs_root.exists():
|
||||
current_ids = doc_handler._get_all_section_content_ids(
|
||||
docs_root
|
||||
)
|
||||
else:
|
||||
current_ids = set()
|
||||
else:
|
||||
current_ids = set()
|
||||
else:
|
||||
# Skip unknown content types to avoid accidental deletion
|
||||
logger.warning(
|
||||
f"Skipping cleanup for unknown content type: {content_type}"
|
||||
)
|
||||
results_by_type[content_type.value] = {
|
||||
"deleted": 0,
|
||||
"error": "Unknown content type - skipped for safety",
|
||||
}
|
||||
continue
|
||||
|
||||
# Get all embedding IDs from database
|
||||
db_embeddings = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT "contentId"
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType"
|
||||
""",
|
||||
content_type,
|
||||
)
|
||||
|
||||
db_ids = {row["contentId"] for row in db_embeddings}
|
||||
|
||||
# Find orphaned embeddings (in DB but not in current content)
|
||||
orphaned_ids = db_ids - current_ids
|
||||
|
||||
if not orphaned_ids:
|
||||
logger.info(f"{content_type.value}: No orphaned embeddings found")
|
||||
results_by_type[content_type.value] = {
|
||||
"deleted": 0,
|
||||
"message": "No orphaned embeddings",
|
||||
}
|
||||
continue
|
||||
|
||||
# Delete orphaned embeddings in batch for better performance
|
||||
orphaned_list = list(orphaned_ids)
|
||||
try:
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType"
|
||||
AND "contentId" = ANY($2::text[])
|
||||
""",
|
||||
content_type,
|
||||
orphaned_list,
|
||||
)
|
||||
deleted = len(orphaned_list)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to batch delete orphaned embeddings: {e}")
|
||||
deleted = 0
|
||||
|
||||
logger.info(
|
||||
f"{content_type.value}: Deleted {deleted}/{len(orphaned_ids)} orphaned embeddings"
|
||||
)
|
||||
results_by_type[content_type.value] = {
|
||||
"deleted": deleted,
|
||||
"orphaned": len(orphaned_ids),
|
||||
"message": f"Deleted {deleted} orphaned embeddings",
|
||||
}
|
||||
|
||||
total_deleted += deleted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup {content_type.value}: {e}")
|
||||
results_by_type[content_type.value] = {
|
||||
"deleted": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
return {
|
||||
"by_type": results_by_type,
|
||||
"totals": {
|
||||
"deleted": total_deleted,
|
||||
"message": f"Deleted {total_deleted} orphaned embeddings",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def semantic_search(
|
||||
query: str,
|
||||
content_types: list[ContentType] | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int = 20,
|
||||
min_similarity: float = 0.5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Semantic search across content types using embeddings.
|
||||
|
||||
Performs vector similarity search on UnifiedContentEmbedding table.
|
||||
Used directly for blocks/docs/library agents, or as the semantic component
|
||||
within hybrid_search for store agents.
|
||||
|
||||
If embedding generation fails, falls back to lexical search on searchableText.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
content_types: List of ContentType to search. Defaults to [BLOCK, STORE_AGENT, DOCUMENTATION]
|
||||
user_id: Optional user ID for searching private content (library agents)
|
||||
limit: Maximum number of results to return (default: 20)
|
||||
min_similarity: Minimum cosine similarity threshold (0-1, default: 0.5)
|
||||
|
||||
Returns:
|
||||
List of search results with the following structure:
|
||||
[
|
||||
{
|
||||
"content_id": str,
|
||||
"content_type": str, # "BLOCK", "STORE_AGENT", "DOCUMENTATION", or "LIBRARY_AGENT"
|
||||
"searchable_text": str,
|
||||
"metadata": dict,
|
||||
"similarity": float, # Cosine similarity score (0-1)
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
Examples:
|
||||
# Search blocks only
|
||||
results = await semantic_search("calculate", content_types=[ContentType.BLOCK])
|
||||
|
||||
# Search blocks and documentation
|
||||
results = await semantic_search(
|
||||
"how to use API",
|
||||
content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION]
|
||||
)
|
||||
|
||||
# Search all public content (default)
|
||||
results = await semantic_search("AI agent")
|
||||
|
||||
# Search user's library agents
|
||||
results = await semantic_search(
|
||||
"my custom agent",
|
||||
content_types=[ContentType.LIBRARY_AGENT],
|
||||
user_id="user123"
|
||||
)
|
||||
"""
|
||||
# Default to searching all public content types
|
||||
if content_types is None:
|
||||
content_types = [
|
||||
ContentType.BLOCK,
|
||||
ContentType.STORE_AGENT,
|
||||
ContentType.DOCUMENTATION,
|
||||
]
|
||||
|
||||
# Validate inputs
|
||||
if not content_types:
|
||||
return [] # Empty content_types would cause invalid SQL (IN ())
|
||||
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return []
|
||||
|
||||
if limit < 1:
|
||||
limit = 1
|
||||
if limit > 100:
|
||||
limit = 100
|
||||
|
||||
# Generate query embedding
|
||||
try:
|
||||
query_embedding = await embed_query(query)
|
||||
# Semantic search with embeddings
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
|
||||
# Build params in order: limit, then user_id (if provided), then content types
|
||||
params: list[Any] = [limit]
|
||||
user_filter = ""
|
||||
if user_id is not None:
|
||||
user_filter = 'AND "userId" = ${}'.format(len(params) + 1)
|
||||
params.append(user_id)
|
||||
|
||||
# Add content type parameters and build placeholders dynamically
|
||||
content_type_start_idx = len(params) + 1
|
||||
content_type_placeholders = ", ".join(
|
||||
"$" + str(content_type_start_idx + i) + '::{schema_prefix}"ContentType"'
|
||||
for i in range(len(content_types))
|
||||
)
|
||||
params.extend([ct.value for ct in content_types])
|
||||
|
||||
# Build min_similarity param index before appending
|
||||
min_similarity_idx = len(params) + 1
|
||||
params.append(min_similarity)
|
||||
|
||||
# Use unqualified ::vector and <=> operator - pgvector is in search_path on all environments
|
||||
sql = (
|
||||
"""
|
||||
SELECT
|
||||
"contentId" as content_id,
|
||||
"contentType" as content_type,
|
||||
"searchableText" as searchable_text,
|
||||
metadata,
|
||||
1 - (embedding <=> '"""
|
||||
+ embedding_str
|
||||
+ """'::vector) as similarity
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" IN ("""
|
||||
+ content_type_placeholders
|
||||
+ """)
|
||||
"""
|
||||
+ user_filter
|
||||
+ """
|
||||
AND 1 - (embedding <=> '"""
|
||||
+ embedding_str
|
||||
+ """'::vector) >= $"""
|
||||
+ str(min_similarity_idx)
|
||||
+ """
|
||||
ORDER BY similarity DESC
|
||||
LIMIT $1
|
||||
"""
|
||||
)
|
||||
|
||||
results = await query_raw_with_schema(sql, *params)
|
||||
return [
|
||||
{
|
||||
"content_id": row["content_id"],
|
||||
"content_type": row["content_type"],
|
||||
"searchable_text": row["searchable_text"],
|
||||
"metadata": row["metadata"],
|
||||
"similarity": float(row["similarity"]),
|
||||
}
|
||||
for row in results
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"Semantic search failed, falling back to lexical search: {e}")
|
||||
|
||||
# Fallback to lexical search if embeddings unavailable
|
||||
|
||||
params_lexical: list[Any] = [limit]
|
||||
user_filter = ""
|
||||
if user_id is not None:
|
||||
user_filter = 'AND "userId" = ${}'.format(len(params_lexical) + 1)
|
||||
params_lexical.append(user_id)
|
||||
|
||||
# Add content type parameters and build placeholders dynamically
|
||||
content_type_start_idx = len(params_lexical) + 1
|
||||
content_type_placeholders_lexical = ", ".join(
|
||||
"$" + str(content_type_start_idx + i) + '::{schema_prefix}"ContentType"'
|
||||
for i in range(len(content_types))
|
||||
)
|
||||
params_lexical.extend([ct.value for ct in content_types])
|
||||
|
||||
# Build query param index before appending
|
||||
query_param_idx = len(params_lexical) + 1
|
||||
params_lexical.append(f"%{query}%")
|
||||
|
||||
# Use regular string (not f-string) for template to preserve {schema_prefix} placeholders
|
||||
sql_lexical = (
|
||||
"""
|
||||
SELECT
|
||||
"contentId" as content_id,
|
||||
"contentType" as content_type,
|
||||
"searchableText" as searchable_text,
|
||||
metadata,
|
||||
0.0 as similarity
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" IN ("""
|
||||
+ content_type_placeholders_lexical
|
||||
+ """)
|
||||
"""
|
||||
+ user_filter
|
||||
+ """
|
||||
AND "searchableText" ILIKE $"""
|
||||
+ str(query_param_idx)
|
||||
+ """
|
||||
ORDER BY "updatedAt" DESC
|
||||
LIMIT $1
|
||||
"""
|
||||
)
|
||||
|
||||
try:
|
||||
results = await query_raw_with_schema(sql_lexical, *params_lexical)
|
||||
return [
|
||||
{
|
||||
"content_id": row["content_id"],
|
||||
"content_type": row["content_type"],
|
||||
"searchable_text": row["searchable_text"],
|
||||
"metadata": row["metadata"],
|
||||
"similarity": 0.0, # Lexical search doesn't provide similarity
|
||||
}
|
||||
for row in results
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Lexical search failed: {e}")
|
||||
return []
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user