mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-20 12:38:10 -05:00
Compare commits
15 Commits
pwuts/open
...
fix/undefi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fb5c89d881 | ||
|
|
0a48c49902 | ||
|
|
1fc1102eb4 | ||
|
|
bc75d70e7d | ||
|
|
c1a1767034 | ||
|
|
1b56ff13d9 | ||
|
|
f31c160043 | ||
|
|
06550a87eb | ||
|
|
088b9998dc | ||
|
|
05c89fa5c0 | ||
|
|
8cc8295f14 | ||
|
|
e55f05c7a8 | ||
|
|
4a9b13acb6 | ||
|
|
5ff669e999 | ||
|
|
ec03a13e26 |
2249
.claude/skills/vercel-react-best-practices/AGENTS.md
Normal file
2249
.claude/skills/vercel-react-best-practices/AGENTS.md
Normal file
File diff suppressed because it is too large
Load Diff
125
.claude/skills/vercel-react-best-practices/SKILL.md
Normal file
125
.claude/skills/vercel-react-best-practices/SKILL.md
Normal file
@@ -0,0 +1,125 @@
|
||||
---
|
||||
name: vercel-react-best-practices
|
||||
description: React and Next.js performance optimization guidelines from Vercel Engineering. This skill should be used when writing, reviewing, or refactoring React/Next.js code to ensure optimal performance patterns. Triggers on tasks involving React components, Next.js pages, data fetching, bundle optimization, or performance improvements.
|
||||
license: MIT
|
||||
metadata:
|
||||
author: vercel
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Vercel React Best Practices
|
||||
|
||||
Comprehensive performance optimization guide for React and Next.js applications, maintained by Vercel. Contains 45 rules across 8 categories, prioritized by impact to guide automated refactoring and code generation.
|
||||
|
||||
## When to Apply
|
||||
|
||||
Reference these guidelines when:
|
||||
- Writing new React components or Next.js pages
|
||||
- Implementing data fetching (client or server-side)
|
||||
- Reviewing code for performance issues
|
||||
- Refactoring existing React/Next.js code
|
||||
- Optimizing bundle size or load times
|
||||
|
||||
## Rule Categories by Priority
|
||||
|
||||
| Priority | Category | Impact | Prefix |
|
||||
|----------|----------|--------|--------|
|
||||
| 1 | Eliminating Waterfalls | CRITICAL | `async-` |
|
||||
| 2 | Bundle Size Optimization | CRITICAL | `bundle-` |
|
||||
| 3 | Server-Side Performance | HIGH | `server-` |
|
||||
| 4 | Client-Side Data Fetching | MEDIUM-HIGH | `client-` |
|
||||
| 5 | Re-render Optimization | MEDIUM | `rerender-` |
|
||||
| 6 | Rendering Performance | MEDIUM | `rendering-` |
|
||||
| 7 | JavaScript Performance | LOW-MEDIUM | `js-` |
|
||||
| 8 | Advanced Patterns | LOW | `advanced-` |
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### 1. Eliminating Waterfalls (CRITICAL)
|
||||
|
||||
- `async-defer-await` - Move await into branches where actually used
|
||||
- `async-parallel` - Use Promise.all() for independent operations
|
||||
- `async-dependencies` - Use better-all for partial dependencies
|
||||
- `async-api-routes` - Start promises early, await late in API routes
|
||||
- `async-suspense-boundaries` - Use Suspense to stream content
|
||||
|
||||
### 2. Bundle Size Optimization (CRITICAL)
|
||||
|
||||
- `bundle-barrel-imports` - Import directly, avoid barrel files
|
||||
- `bundle-dynamic-imports` - Use next/dynamic for heavy components
|
||||
- `bundle-defer-third-party` - Load analytics/logging after hydration
|
||||
- `bundle-conditional` - Load modules only when feature is activated
|
||||
- `bundle-preload` - Preload on hover/focus for perceived speed
|
||||
|
||||
### 3. Server-Side Performance (HIGH)
|
||||
|
||||
- `server-cache-react` - Use React.cache() for per-request deduplication
|
||||
- `server-cache-lru` - Use LRU cache for cross-request caching
|
||||
- `server-serialization` - Minimize data passed to client components
|
||||
- `server-parallel-fetching` - Restructure components to parallelize fetches
|
||||
- `server-after-nonblocking` - Use after() for non-blocking operations
|
||||
|
||||
### 4. Client-Side Data Fetching (MEDIUM-HIGH)
|
||||
|
||||
- `client-swr-dedup` - Use SWR for automatic request deduplication
|
||||
- `client-event-listeners` - Deduplicate global event listeners
|
||||
|
||||
### 5. Re-render Optimization (MEDIUM)
|
||||
|
||||
- `rerender-defer-reads` - Don't subscribe to state only used in callbacks
|
||||
- `rerender-memo` - Extract expensive work into memoized components
|
||||
- `rerender-dependencies` - Use primitive dependencies in effects
|
||||
- `rerender-derived-state` - Subscribe to derived booleans, not raw values
|
||||
- `rerender-functional-setstate` - Use functional setState for stable callbacks
|
||||
- `rerender-lazy-state-init` - Pass function to useState for expensive values
|
||||
- `rerender-transitions` - Use startTransition for non-urgent updates
|
||||
|
||||
### 6. Rendering Performance (MEDIUM)
|
||||
|
||||
- `rendering-animate-svg-wrapper` - Animate div wrapper, not SVG element
|
||||
- `rendering-content-visibility` - Use content-visibility for long lists
|
||||
- `rendering-hoist-jsx` - Extract static JSX outside components
|
||||
- `rendering-svg-precision` - Reduce SVG coordinate precision
|
||||
- `rendering-hydration-no-flicker` - Use inline script for client-only data
|
||||
- `rendering-activity` - Use Activity component for show/hide
|
||||
- `rendering-conditional-render` - Use ternary, not && for conditionals
|
||||
|
||||
### 7. JavaScript Performance (LOW-MEDIUM)
|
||||
|
||||
- `js-batch-dom-css` - Group CSS changes via classes or cssText
|
||||
- `js-index-maps` - Build Map for repeated lookups
|
||||
- `js-cache-property-access` - Cache object properties in loops
|
||||
- `js-cache-function-results` - Cache function results in module-level Map
|
||||
- `js-cache-storage` - Cache localStorage/sessionStorage reads
|
||||
- `js-combine-iterations` - Combine multiple filter/map into one loop
|
||||
- `js-length-check-first` - Check array length before expensive comparison
|
||||
- `js-early-exit` - Return early from functions
|
||||
- `js-hoist-regexp` - Hoist RegExp creation outside loops
|
||||
- `js-min-max-loop` - Use loop for min/max instead of sort
|
||||
- `js-set-map-lookups` - Use Set/Map for O(1) lookups
|
||||
- `js-tosorted-immutable` - Use toSorted() for immutability
|
||||
|
||||
### 8. Advanced Patterns (LOW)
|
||||
|
||||
- `advanced-event-handler-refs` - Store event handlers in refs
|
||||
- `advanced-use-latest` - useLatest for stable callback refs
|
||||
|
||||
## How to Use
|
||||
|
||||
Read individual rule files for detailed explanations and code examples:
|
||||
|
||||
```
|
||||
rules/async-parallel.md
|
||||
rules/bundle-barrel-imports.md
|
||||
rules/_sections.md
|
||||
```
|
||||
|
||||
Each rule file contains:
|
||||
- Brief explanation of why it matters
|
||||
- Incorrect code example with explanation
|
||||
- Correct code example with explanation
|
||||
- Additional context and references
|
||||
|
||||
## Full Compiled Document
|
||||
|
||||
For the complete guide with all rules expanded: `AGENTS.md`
|
||||
@@ -0,0 +1,55 @@
|
||||
---
|
||||
title: Store Event Handlers in Refs
|
||||
impact: LOW
|
||||
impactDescription: stable subscriptions
|
||||
tags: advanced, hooks, refs, event-handlers, optimization
|
||||
---
|
||||
|
||||
## Store Event Handlers in Refs
|
||||
|
||||
Store callbacks in refs when used in effects that shouldn't re-subscribe on callback changes.
|
||||
|
||||
**Incorrect (re-subscribes on every render):**
|
||||
|
||||
```tsx
|
||||
function useWindowEvent(event: string, handler: () => void) {
|
||||
useEffect(() => {
|
||||
window.addEventListener(event, handler)
|
||||
return () => window.removeEventListener(event, handler)
|
||||
}, [event, handler])
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (stable subscription):**
|
||||
|
||||
```tsx
|
||||
function useWindowEvent(event: string, handler: () => void) {
|
||||
const handlerRef = useRef(handler)
|
||||
useEffect(() => {
|
||||
handlerRef.current = handler
|
||||
}, [handler])
|
||||
|
||||
useEffect(() => {
|
||||
const listener = () => handlerRef.current()
|
||||
window.addEventListener(event, listener)
|
||||
return () => window.removeEventListener(event, listener)
|
||||
}, [event])
|
||||
}
|
||||
```
|
||||
|
||||
**Alternative: use `useEffectEvent` if you're on latest React:**
|
||||
|
||||
```tsx
|
||||
import { useEffectEvent } from 'react'
|
||||
|
||||
function useWindowEvent(event: string, handler: () => void) {
|
||||
const onEvent = useEffectEvent(handler)
|
||||
|
||||
useEffect(() => {
|
||||
window.addEventListener(event, onEvent)
|
||||
return () => window.removeEventListener(event, onEvent)
|
||||
}, [event])
|
||||
}
|
||||
```
|
||||
|
||||
`useEffectEvent` provides a cleaner API for the same pattern: it creates a stable function reference that always calls the latest version of the handler.
|
||||
@@ -0,0 +1,49 @@
|
||||
---
|
||||
title: useLatest for Stable Callback Refs
|
||||
impact: LOW
|
||||
impactDescription: prevents effect re-runs
|
||||
tags: advanced, hooks, useLatest, refs, optimization
|
||||
---
|
||||
|
||||
## useLatest for Stable Callback Refs
|
||||
|
||||
Access latest values in callbacks without adding them to dependency arrays. Prevents effect re-runs while avoiding stale closures.
|
||||
|
||||
**Implementation:**
|
||||
|
||||
```typescript
|
||||
function useLatest<T>(value: T) {
|
||||
const ref = useRef(value)
|
||||
useEffect(() => {
|
||||
ref.current = value
|
||||
}, [value])
|
||||
return ref
|
||||
}
|
||||
```
|
||||
|
||||
**Incorrect (effect re-runs on every callback change):**
|
||||
|
||||
```tsx
|
||||
function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
|
||||
const [query, setQuery] = useState('')
|
||||
|
||||
useEffect(() => {
|
||||
const timeout = setTimeout(() => onSearch(query), 300)
|
||||
return () => clearTimeout(timeout)
|
||||
}, [query, onSearch])
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (stable effect, fresh callback):**
|
||||
|
||||
```tsx
|
||||
function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
|
||||
const [query, setQuery] = useState('')
|
||||
const onSearchRef = useLatest(onSearch)
|
||||
|
||||
useEffect(() => {
|
||||
const timeout = setTimeout(() => onSearchRef.current(query), 300)
|
||||
return () => clearTimeout(timeout)
|
||||
}, [query])
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,38 @@
|
||||
---
|
||||
title: Prevent Waterfall Chains in API Routes
|
||||
impact: CRITICAL
|
||||
impactDescription: 2-10× improvement
|
||||
tags: api-routes, server-actions, waterfalls, parallelization
|
||||
---
|
||||
|
||||
## Prevent Waterfall Chains in API Routes
|
||||
|
||||
In API routes and Server Actions, start independent operations immediately, even if you don't await them yet.
|
||||
|
||||
**Incorrect (config waits for auth, data waits for both):**
|
||||
|
||||
```typescript
|
||||
export async function GET(request: Request) {
|
||||
const session = await auth()
|
||||
const config = await fetchConfig()
|
||||
const data = await fetchData(session.user.id)
|
||||
return Response.json({ data, config })
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (auth and config start immediately):**
|
||||
|
||||
```typescript
|
||||
export async function GET(request: Request) {
|
||||
const sessionPromise = auth()
|
||||
const configPromise = fetchConfig()
|
||||
const session = await sessionPromise
|
||||
const [config, data] = await Promise.all([
|
||||
configPromise,
|
||||
fetchData(session.user.id)
|
||||
])
|
||||
return Response.json({ data, config })
|
||||
}
|
||||
```
|
||||
|
||||
For operations with more complex dependency chains, use `better-all` to automatically maximize parallelism (see Dependency-Based Parallelization).
|
||||
@@ -0,0 +1,80 @@
|
||||
---
|
||||
title: Defer Await Until Needed
|
||||
impact: HIGH
|
||||
impactDescription: avoids blocking unused code paths
|
||||
tags: async, await, conditional, optimization
|
||||
---
|
||||
|
||||
## Defer Await Until Needed
|
||||
|
||||
Move `await` operations into the branches where they're actually used to avoid blocking code paths that don't need them.
|
||||
|
||||
**Incorrect (blocks both branches):**
|
||||
|
||||
```typescript
|
||||
async function handleRequest(userId: string, skipProcessing: boolean) {
|
||||
const userData = await fetchUserData(userId)
|
||||
|
||||
if (skipProcessing) {
|
||||
// Returns immediately but still waited for userData
|
||||
return { skipped: true }
|
||||
}
|
||||
|
||||
// Only this branch uses userData
|
||||
return processUserData(userData)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (only blocks when needed):**
|
||||
|
||||
```typescript
|
||||
async function handleRequest(userId: string, skipProcessing: boolean) {
|
||||
if (skipProcessing) {
|
||||
// Returns immediately without waiting
|
||||
return { skipped: true }
|
||||
}
|
||||
|
||||
// Fetch only when needed
|
||||
const userData = await fetchUserData(userId)
|
||||
return processUserData(userData)
|
||||
}
|
||||
```
|
||||
|
||||
**Another example (early return optimization):**
|
||||
|
||||
```typescript
|
||||
// Incorrect: always fetches permissions
|
||||
async function updateResource(resourceId: string, userId: string) {
|
||||
const permissions = await fetchPermissions(userId)
|
||||
const resource = await getResource(resourceId)
|
||||
|
||||
if (!resource) {
|
||||
return { error: 'Not found' }
|
||||
}
|
||||
|
||||
if (!permissions.canEdit) {
|
||||
return { error: 'Forbidden' }
|
||||
}
|
||||
|
||||
return await updateResourceData(resource, permissions)
|
||||
}
|
||||
|
||||
// Correct: fetches only when needed
|
||||
async function updateResource(resourceId: string, userId: string) {
|
||||
const resource = await getResource(resourceId)
|
||||
|
||||
if (!resource) {
|
||||
return { error: 'Not found' }
|
||||
}
|
||||
|
||||
const permissions = await fetchPermissions(userId)
|
||||
|
||||
if (!permissions.canEdit) {
|
||||
return { error: 'Forbidden' }
|
||||
}
|
||||
|
||||
return await updateResourceData(resource, permissions)
|
||||
}
|
||||
```
|
||||
|
||||
This optimization is especially valuable when the skipped branch is frequently taken, or when the deferred operation is expensive.
|
||||
@@ -0,0 +1,36 @@
|
||||
---
|
||||
title: Dependency-Based Parallelization
|
||||
impact: CRITICAL
|
||||
impactDescription: 2-10× improvement
|
||||
tags: async, parallelization, dependencies, better-all
|
||||
---
|
||||
|
||||
## Dependency-Based Parallelization
|
||||
|
||||
For operations with partial dependencies, use `better-all` to maximize parallelism. It automatically starts each task at the earliest possible moment.
|
||||
|
||||
**Incorrect (profile waits for config unnecessarily):**
|
||||
|
||||
```typescript
|
||||
const [user, config] = await Promise.all([
|
||||
fetchUser(),
|
||||
fetchConfig()
|
||||
])
|
||||
const profile = await fetchProfile(user.id)
|
||||
```
|
||||
|
||||
**Correct (config and profile run in parallel):**
|
||||
|
||||
```typescript
|
||||
import { all } from 'better-all'
|
||||
|
||||
const { user, config, profile } = await all({
|
||||
async user() { return fetchUser() },
|
||||
async config() { return fetchConfig() },
|
||||
async profile() {
|
||||
return fetchProfile((await this.$.user).id)
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
Reference: [https://github.com/shuding/better-all](https://github.com/shuding/better-all)
|
||||
@@ -0,0 +1,28 @@
|
||||
---
|
||||
title: Promise.all() for Independent Operations
|
||||
impact: CRITICAL
|
||||
impactDescription: 2-10× improvement
|
||||
tags: async, parallelization, promises, waterfalls
|
||||
---
|
||||
|
||||
## Promise.all() for Independent Operations
|
||||
|
||||
When async operations have no interdependencies, execute them concurrently using `Promise.all()`.
|
||||
|
||||
**Incorrect (sequential execution, 3 round trips):**
|
||||
|
||||
```typescript
|
||||
const user = await fetchUser()
|
||||
const posts = await fetchPosts()
|
||||
const comments = await fetchComments()
|
||||
```
|
||||
|
||||
**Correct (parallel execution, 1 round trip):**
|
||||
|
||||
```typescript
|
||||
const [user, posts, comments] = await Promise.all([
|
||||
fetchUser(),
|
||||
fetchPosts(),
|
||||
fetchComments()
|
||||
])
|
||||
```
|
||||
@@ -0,0 +1,99 @@
|
||||
---
|
||||
title: Strategic Suspense Boundaries
|
||||
impact: HIGH
|
||||
impactDescription: faster initial paint
|
||||
tags: async, suspense, streaming, layout-shift
|
||||
---
|
||||
|
||||
## Strategic Suspense Boundaries
|
||||
|
||||
Instead of awaiting data in async components before returning JSX, use Suspense boundaries to show the wrapper UI faster while data loads.
|
||||
|
||||
**Incorrect (wrapper blocked by data fetching):**
|
||||
|
||||
```tsx
|
||||
async function Page() {
|
||||
const data = await fetchData() // Blocks entire page
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div>Sidebar</div>
|
||||
<div>Header</div>
|
||||
<div>
|
||||
<DataDisplay data={data} />
|
||||
</div>
|
||||
<div>Footer</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
The entire layout waits for data even though only the middle section needs it.
|
||||
|
||||
**Correct (wrapper shows immediately, data streams in):**
|
||||
|
||||
```tsx
|
||||
function Page() {
|
||||
return (
|
||||
<div>
|
||||
<div>Sidebar</div>
|
||||
<div>Header</div>
|
||||
<div>
|
||||
<Suspense fallback={<Skeleton />}>
|
||||
<DataDisplay />
|
||||
</Suspense>
|
||||
</div>
|
||||
<div>Footer</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
async function DataDisplay() {
|
||||
const data = await fetchData() // Only blocks this component
|
||||
return <div>{data.content}</div>
|
||||
}
|
||||
```
|
||||
|
||||
Sidebar, Header, and Footer render immediately. Only DataDisplay waits for data.
|
||||
|
||||
**Alternative (share promise across components):**
|
||||
|
||||
```tsx
|
||||
function Page() {
|
||||
// Start fetch immediately, but don't await
|
||||
const dataPromise = fetchData()
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div>Sidebar</div>
|
||||
<div>Header</div>
|
||||
<Suspense fallback={<Skeleton />}>
|
||||
<DataDisplay dataPromise={dataPromise} />
|
||||
<DataSummary dataPromise={dataPromise} />
|
||||
</Suspense>
|
||||
<div>Footer</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function DataDisplay({ dataPromise }: { dataPromise: Promise<Data> }) {
|
||||
const data = use(dataPromise) // Unwraps the promise
|
||||
return <div>{data.content}</div>
|
||||
}
|
||||
|
||||
function DataSummary({ dataPromise }: { dataPromise: Promise<Data> }) {
|
||||
const data = use(dataPromise) // Reuses the same promise
|
||||
return <div>{data.summary}</div>
|
||||
}
|
||||
```
|
||||
|
||||
Both components share the same promise, so only one fetch occurs. Layout renders immediately while both components wait together.
|
||||
|
||||
**When NOT to use this pattern:**
|
||||
|
||||
- Critical data needed for layout decisions (affects positioning)
|
||||
- SEO-critical content above the fold
|
||||
- Small, fast queries where suspense overhead isn't worth it
|
||||
- When you want to avoid layout shift (loading → content jump)
|
||||
|
||||
**Trade-off:** Faster initial paint vs potential layout shift. Choose based on your UX priorities.
|
||||
@@ -0,0 +1,59 @@
|
||||
---
|
||||
title: Avoid Barrel File Imports
|
||||
impact: CRITICAL
|
||||
impactDescription: 200-800ms import cost, slow builds
|
||||
tags: bundle, imports, tree-shaking, barrel-files, performance
|
||||
---
|
||||
|
||||
## Avoid Barrel File Imports
|
||||
|
||||
Import directly from source files instead of barrel files to avoid loading thousands of unused modules. **Barrel files** are entry points that re-export multiple modules (e.g., `index.js` that does `export * from './module'`).
|
||||
|
||||
Popular icon and component libraries can have **up to 10,000 re-exports** in their entry file. For many React packages, **it takes 200-800ms just to import them**, affecting both development speed and production cold starts.
|
||||
|
||||
**Why tree-shaking doesn't help:** When a library is marked as external (not bundled), the bundler can't optimize it. If you bundle it to enable tree-shaking, builds become substantially slower analyzing the entire module graph.
|
||||
|
||||
**Incorrect (imports entire library):**
|
||||
|
||||
```tsx
|
||||
import { Check, X, Menu } from 'lucide-react'
|
||||
// Loads 1,583 modules, takes ~2.8s extra in dev
|
||||
// Runtime cost: 200-800ms on every cold start
|
||||
|
||||
import { Button, TextField } from '@mui/material'
|
||||
// Loads 2,225 modules, takes ~4.2s extra in dev
|
||||
```
|
||||
|
||||
**Correct (imports only what you need):**
|
||||
|
||||
```tsx
|
||||
import Check from 'lucide-react/dist/esm/icons/check'
|
||||
import X from 'lucide-react/dist/esm/icons/x'
|
||||
import Menu from 'lucide-react/dist/esm/icons/menu'
|
||||
// Loads only 3 modules (~2KB vs ~1MB)
|
||||
|
||||
import Button from '@mui/material/Button'
|
||||
import TextField from '@mui/material/TextField'
|
||||
// Loads only what you use
|
||||
```
|
||||
|
||||
**Alternative (Next.js 13.5+):**
|
||||
|
||||
```js
|
||||
// next.config.js - use optimizePackageImports
|
||||
module.exports = {
|
||||
experimental: {
|
||||
optimizePackageImports: ['lucide-react', '@mui/material']
|
||||
}
|
||||
}
|
||||
|
||||
// Then you can keep the ergonomic barrel imports:
|
||||
import { Check, X, Menu } from 'lucide-react'
|
||||
// Automatically transformed to direct imports at build time
|
||||
```
|
||||
|
||||
Direct imports provide 15-70% faster dev boot, 28% faster builds, 40% faster cold starts, and significantly faster HMR.
|
||||
|
||||
Libraries commonly affected: `lucide-react`, `@mui/material`, `@mui/icons-material`, `@tabler/icons-react`, `react-icons`, `@headlessui/react`, `@radix-ui/react-*`, `lodash`, `ramda`, `date-fns`, `rxjs`, `react-use`.
|
||||
|
||||
Reference: [How we optimized package imports in Next.js](https://vercel.com/blog/how-we-optimized-package-imports-in-next-js)
|
||||
@@ -0,0 +1,31 @@
|
||||
---
|
||||
title: Conditional Module Loading
|
||||
impact: HIGH
|
||||
impactDescription: loads large data only when needed
|
||||
tags: bundle, conditional-loading, lazy-loading
|
||||
---
|
||||
|
||||
## Conditional Module Loading
|
||||
|
||||
Load large data or modules only when a feature is activated.
|
||||
|
||||
**Example (lazy-load animation frames):**
|
||||
|
||||
```tsx
|
||||
function AnimationPlayer({ enabled }: { enabled: boolean }) {
|
||||
const [frames, setFrames] = useState<Frame[] | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (enabled && !frames && typeof window !== 'undefined') {
|
||||
import('./animation-frames.js')
|
||||
.then(mod => setFrames(mod.frames))
|
||||
.catch(() => setEnabled(false))
|
||||
}
|
||||
}, [enabled, frames])
|
||||
|
||||
if (!frames) return <Skeleton />
|
||||
return <Canvas frames={frames} />
|
||||
}
|
||||
```
|
||||
|
||||
The `typeof window !== 'undefined'` check prevents bundling this module for SSR, optimizing server bundle size and build speed.
|
||||
@@ -0,0 +1,49 @@
|
||||
---
|
||||
title: Defer Non-Critical Third-Party Libraries
|
||||
impact: MEDIUM
|
||||
impactDescription: loads after hydration
|
||||
tags: bundle, third-party, analytics, defer
|
||||
---
|
||||
|
||||
## Defer Non-Critical Third-Party Libraries
|
||||
|
||||
Analytics, logging, and error tracking don't block user interaction. Load them after hydration.
|
||||
|
||||
**Incorrect (blocks initial bundle):**
|
||||
|
||||
```tsx
|
||||
import { Analytics } from '@vercel/analytics/react'
|
||||
|
||||
export default function RootLayout({ children }) {
|
||||
return (
|
||||
<html>
|
||||
<body>
|
||||
{children}
|
||||
<Analytics />
|
||||
</body>
|
||||
</html>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (loads after hydration):**
|
||||
|
||||
```tsx
|
||||
import dynamic from 'next/dynamic'
|
||||
|
||||
const Analytics = dynamic(
|
||||
() => import('@vercel/analytics/react').then(m => m.Analytics),
|
||||
{ ssr: false }
|
||||
)
|
||||
|
||||
export default function RootLayout({ children }) {
|
||||
return (
|
||||
<html>
|
||||
<body>
|
||||
{children}
|
||||
<Analytics />
|
||||
</body>
|
||||
</html>
|
||||
)
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,35 @@
|
||||
---
|
||||
title: Dynamic Imports for Heavy Components
|
||||
impact: CRITICAL
|
||||
impactDescription: directly affects TTI and LCP
|
||||
tags: bundle, dynamic-import, code-splitting, next-dynamic
|
||||
---
|
||||
|
||||
## Dynamic Imports for Heavy Components
|
||||
|
||||
Use `next/dynamic` to lazy-load large components not needed on initial render.
|
||||
|
||||
**Incorrect (Monaco bundles with main chunk ~300KB):**
|
||||
|
||||
```tsx
|
||||
import { MonacoEditor } from './monaco-editor'
|
||||
|
||||
function CodePanel({ code }: { code: string }) {
|
||||
return <MonacoEditor value={code} />
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (Monaco loads on demand):**
|
||||
|
||||
```tsx
|
||||
import dynamic from 'next/dynamic'
|
||||
|
||||
const MonacoEditor = dynamic(
|
||||
() => import('./monaco-editor').then(m => m.MonacoEditor),
|
||||
{ ssr: false }
|
||||
)
|
||||
|
||||
function CodePanel({ code }: { code: string }) {
|
||||
return <MonacoEditor value={code} />
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,50 @@
|
||||
---
|
||||
title: Preload Based on User Intent
|
||||
impact: MEDIUM
|
||||
impactDescription: reduces perceived latency
|
||||
tags: bundle, preload, user-intent, hover
|
||||
---
|
||||
|
||||
## Preload Based on User Intent
|
||||
|
||||
Preload heavy bundles before they're needed to reduce perceived latency.
|
||||
|
||||
**Example (preload on hover/focus):**
|
||||
|
||||
```tsx
|
||||
function EditorButton({ onClick }: { onClick: () => void }) {
|
||||
const preload = () => {
|
||||
if (typeof window !== 'undefined') {
|
||||
void import('./monaco-editor')
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<button
|
||||
onMouseEnter={preload}
|
||||
onFocus={preload}
|
||||
onClick={onClick}
|
||||
>
|
||||
Open Editor
|
||||
</button>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Example (preload when feature flag is enabled):**
|
||||
|
||||
```tsx
|
||||
function FlagsProvider({ children, flags }: Props) {
|
||||
useEffect(() => {
|
||||
if (flags.editorEnabled && typeof window !== 'undefined') {
|
||||
void import('./monaco-editor').then(mod => mod.init())
|
||||
}
|
||||
}, [flags.editorEnabled])
|
||||
|
||||
return <FlagsContext.Provider value={flags}>
|
||||
{children}
|
||||
</FlagsContext.Provider>
|
||||
}
|
||||
```
|
||||
|
||||
The `typeof window !== 'undefined'` check prevents bundling preloaded modules for SSR, optimizing server bundle size and build speed.
|
||||
@@ -0,0 +1,74 @@
|
||||
---
|
||||
title: Deduplicate Global Event Listeners
|
||||
impact: LOW
|
||||
impactDescription: single listener for N components
|
||||
tags: client, swr, event-listeners, subscription
|
||||
---
|
||||
|
||||
## Deduplicate Global Event Listeners
|
||||
|
||||
Use `useSWRSubscription()` to share global event listeners across component instances.
|
||||
|
||||
**Incorrect (N instances = N listeners):**
|
||||
|
||||
```tsx
|
||||
function useKeyboardShortcut(key: string, callback: () => void) {
|
||||
useEffect(() => {
|
||||
const handler = (e: KeyboardEvent) => {
|
||||
if (e.metaKey && e.key === key) {
|
||||
callback()
|
||||
}
|
||||
}
|
||||
window.addEventListener('keydown', handler)
|
||||
return () => window.removeEventListener('keydown', handler)
|
||||
}, [key, callback])
|
||||
}
|
||||
```
|
||||
|
||||
When using the `useKeyboardShortcut` hook multiple times, each instance will register a new listener.
|
||||
|
||||
**Correct (N instances = 1 listener):**
|
||||
|
||||
```tsx
|
||||
import useSWRSubscription from 'swr/subscription'
|
||||
|
||||
// Module-level Map to track callbacks per key
|
||||
const keyCallbacks = new Map<string, Set<() => void>>()
|
||||
|
||||
function useKeyboardShortcut(key: string, callback: () => void) {
|
||||
// Register this callback in the Map
|
||||
useEffect(() => {
|
||||
if (!keyCallbacks.has(key)) {
|
||||
keyCallbacks.set(key, new Set())
|
||||
}
|
||||
keyCallbacks.get(key)!.add(callback)
|
||||
|
||||
return () => {
|
||||
const set = keyCallbacks.get(key)
|
||||
if (set) {
|
||||
set.delete(callback)
|
||||
if (set.size === 0) {
|
||||
keyCallbacks.delete(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}, [key, callback])
|
||||
|
||||
useSWRSubscription('global-keydown', () => {
|
||||
const handler = (e: KeyboardEvent) => {
|
||||
if (e.metaKey && keyCallbacks.has(e.key)) {
|
||||
keyCallbacks.get(e.key)!.forEach(cb => cb())
|
||||
}
|
||||
}
|
||||
window.addEventListener('keydown', handler)
|
||||
return () => window.removeEventListener('keydown', handler)
|
||||
})
|
||||
}
|
||||
|
||||
function Profile() {
|
||||
// Multiple shortcuts will share the same listener
|
||||
useKeyboardShortcut('p', () => { /* ... */ })
|
||||
useKeyboardShortcut('k', () => { /* ... */ })
|
||||
// ...
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,56 @@
|
||||
---
|
||||
title: Use SWR for Automatic Deduplication
|
||||
impact: MEDIUM-HIGH
|
||||
impactDescription: automatic deduplication
|
||||
tags: client, swr, deduplication, data-fetching
|
||||
---
|
||||
|
||||
## Use SWR for Automatic Deduplication
|
||||
|
||||
SWR enables request deduplication, caching, and revalidation across component instances.
|
||||
|
||||
**Incorrect (no deduplication, each instance fetches):**
|
||||
|
||||
```tsx
|
||||
function UserList() {
|
||||
const [users, setUsers] = useState([])
|
||||
useEffect(() => {
|
||||
fetch('/api/users')
|
||||
.then(r => r.json())
|
||||
.then(setUsers)
|
||||
}, [])
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (multiple instances share one request):**
|
||||
|
||||
```tsx
|
||||
import useSWR from 'swr'
|
||||
|
||||
function UserList() {
|
||||
const { data: users } = useSWR('/api/users', fetcher)
|
||||
}
|
||||
```
|
||||
|
||||
**For immutable data:**
|
||||
|
||||
```tsx
|
||||
import { useImmutableSWR } from '@/lib/swr'
|
||||
|
||||
function StaticContent() {
|
||||
const { data } = useImmutableSWR('/api/config', fetcher)
|
||||
}
|
||||
```
|
||||
|
||||
**For mutations:**
|
||||
|
||||
```tsx
|
||||
import { useSWRMutation } from 'swr/mutation'
|
||||
|
||||
function UpdateButton() {
|
||||
const { trigger } = useSWRMutation('/api/user', updateUser)
|
||||
return <button onClick={() => trigger()}>Update</button>
|
||||
}
|
||||
```
|
||||
|
||||
Reference: [https://swr.vercel.app](https://swr.vercel.app)
|
||||
@@ -0,0 +1,82 @@
|
||||
---
|
||||
title: Batch DOM CSS Changes
|
||||
impact: MEDIUM
|
||||
impactDescription: reduces reflows/repaints
|
||||
tags: javascript, dom, css, performance, reflow
|
||||
---
|
||||
|
||||
## Batch DOM CSS Changes
|
||||
|
||||
Avoid changing styles one property at a time. Group multiple CSS changes together via classes or `cssText` to minimize browser reflows.
|
||||
|
||||
**Incorrect (multiple reflows):**
|
||||
|
||||
```typescript
|
||||
function updateElementStyles(element: HTMLElement) {
|
||||
// Each line triggers a reflow
|
||||
element.style.width = '100px'
|
||||
element.style.height = '200px'
|
||||
element.style.backgroundColor = 'blue'
|
||||
element.style.border = '1px solid black'
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (add class - single reflow):**
|
||||
|
||||
```typescript
|
||||
// CSS file
|
||||
.highlighted-box {
|
||||
width: 100px;
|
||||
height: 200px;
|
||||
background-color: blue;
|
||||
border: 1px solid black;
|
||||
}
|
||||
|
||||
// JavaScript
|
||||
function updateElementStyles(element: HTMLElement) {
|
||||
element.classList.add('highlighted-box')
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (change cssText - single reflow):**
|
||||
|
||||
```typescript
|
||||
function updateElementStyles(element: HTMLElement) {
|
||||
element.style.cssText = `
|
||||
width: 100px;
|
||||
height: 200px;
|
||||
background-color: blue;
|
||||
border: 1px solid black;
|
||||
`
|
||||
}
|
||||
```
|
||||
|
||||
**React example:**
|
||||
|
||||
```tsx
|
||||
// Incorrect: changing styles one by one
|
||||
function Box({ isHighlighted }: { isHighlighted: boolean }) {
|
||||
const ref = useRef<HTMLDivElement>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (ref.current && isHighlighted) {
|
||||
ref.current.style.width = '100px'
|
||||
ref.current.style.height = '200px'
|
||||
ref.current.style.backgroundColor = 'blue'
|
||||
}
|
||||
}, [isHighlighted])
|
||||
|
||||
return <div ref={ref}>Content</div>
|
||||
}
|
||||
|
||||
// Correct: toggle class
|
||||
function Box({ isHighlighted }: { isHighlighted: boolean }) {
|
||||
return (
|
||||
<div className={isHighlighted ? 'highlighted-box' : ''}>
|
||||
Content
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
Prefer CSS classes over inline styles when possible. Classes are cached by the browser and provide better separation of concerns.
|
||||
@@ -0,0 +1,80 @@
|
||||
---
|
||||
title: Cache Repeated Function Calls
|
||||
impact: MEDIUM
|
||||
impactDescription: avoid redundant computation
|
||||
tags: javascript, cache, memoization, performance
|
||||
---
|
||||
|
||||
## Cache Repeated Function Calls
|
||||
|
||||
Use a module-level Map to cache function results when the same function is called repeatedly with the same inputs during render.
|
||||
|
||||
**Incorrect (redundant computation):**
|
||||
|
||||
```typescript
|
||||
function ProjectList({ projects }: { projects: Project[] }) {
|
||||
return (
|
||||
<div>
|
||||
{projects.map(project => {
|
||||
// slugify() called 100+ times for same project names
|
||||
const slug = slugify(project.name)
|
||||
|
||||
return <ProjectCard key={project.id} slug={slug} />
|
||||
})}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (cached results):**
|
||||
|
||||
```typescript
|
||||
// Module-level cache
|
||||
const slugifyCache = new Map<string, string>()
|
||||
|
||||
function cachedSlugify(text: string): string {
|
||||
if (slugifyCache.has(text)) {
|
||||
return slugifyCache.get(text)!
|
||||
}
|
||||
const result = slugify(text)
|
||||
slugifyCache.set(text, result)
|
||||
return result
|
||||
}
|
||||
|
||||
function ProjectList({ projects }: { projects: Project[] }) {
|
||||
return (
|
||||
<div>
|
||||
{projects.map(project => {
|
||||
// Computed only once per unique project name
|
||||
const slug = cachedSlugify(project.name)
|
||||
|
||||
return <ProjectCard key={project.id} slug={slug} />
|
||||
})}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Simpler pattern for single-value functions:**
|
||||
|
||||
```typescript
|
||||
let isLoggedInCache: boolean | null = null
|
||||
|
||||
function isLoggedIn(): boolean {
|
||||
if (isLoggedInCache !== null) {
|
||||
return isLoggedInCache
|
||||
}
|
||||
|
||||
isLoggedInCache = document.cookie.includes('auth=')
|
||||
return isLoggedInCache
|
||||
}
|
||||
|
||||
// Clear cache when auth changes
|
||||
function onAuthChange() {
|
||||
isLoggedInCache = null
|
||||
}
|
||||
```
|
||||
|
||||
Use a Map (not a hook) so it works everywhere: utilities, event handlers, not just React components.
|
||||
|
||||
Reference: [How we made the Vercel Dashboard twice as fast](https://vercel.com/blog/how-we-made-the-vercel-dashboard-twice-as-fast)
|
||||
@@ -0,0 +1,28 @@
|
||||
---
|
||||
title: Cache Property Access in Loops
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: reduces lookups
|
||||
tags: javascript, loops, optimization, caching
|
||||
---
|
||||
|
||||
## Cache Property Access in Loops
|
||||
|
||||
Cache object property lookups in hot paths.
|
||||
|
||||
**Incorrect (3 lookups × N iterations):**
|
||||
|
||||
```typescript
|
||||
for (let i = 0; i < arr.length; i++) {
|
||||
process(obj.config.settings.value)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (1 lookup total):**
|
||||
|
||||
```typescript
|
||||
const value = obj.config.settings.value
|
||||
const len = arr.length
|
||||
for (let i = 0; i < len; i++) {
|
||||
process(value)
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,70 @@
|
||||
---
|
||||
title: Cache Storage API Calls
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: reduces expensive I/O
|
||||
tags: javascript, localStorage, storage, caching, performance
|
||||
---
|
||||
|
||||
## Cache Storage API Calls
|
||||
|
||||
`localStorage`, `sessionStorage`, and `document.cookie` are synchronous and expensive. Cache reads in memory.
|
||||
|
||||
**Incorrect (reads storage on every call):**
|
||||
|
||||
```typescript
|
||||
function getTheme() {
|
||||
return localStorage.getItem('theme') ?? 'light'
|
||||
}
|
||||
// Called 10 times = 10 storage reads
|
||||
```
|
||||
|
||||
**Correct (Map cache):**
|
||||
|
||||
```typescript
|
||||
const storageCache = new Map<string, string | null>()
|
||||
|
||||
function getLocalStorage(key: string) {
|
||||
if (!storageCache.has(key)) {
|
||||
storageCache.set(key, localStorage.getItem(key))
|
||||
}
|
||||
return storageCache.get(key)
|
||||
}
|
||||
|
||||
function setLocalStorage(key: string, value: string) {
|
||||
localStorage.setItem(key, value)
|
||||
storageCache.set(key, value) // keep cache in sync
|
||||
}
|
||||
```
|
||||
|
||||
Use a Map (not a hook) so it works everywhere: utilities, event handlers, not just React components.
|
||||
|
||||
**Cookie caching:**
|
||||
|
||||
```typescript
|
||||
let cookieCache: Record<string, string> | null = null
|
||||
|
||||
function getCookie(name: string) {
|
||||
if (!cookieCache) {
|
||||
cookieCache = Object.fromEntries(
|
||||
document.cookie.split('; ').map(c => c.split('='))
|
||||
)
|
||||
}
|
||||
return cookieCache[name]
|
||||
}
|
||||
```
|
||||
|
||||
**Important (invalidate on external changes):**
|
||||
|
||||
If storage can change externally (another tab, server-set cookies), invalidate cache:
|
||||
|
||||
```typescript
|
||||
window.addEventListener('storage', (e) => {
|
||||
if (e.key) storageCache.delete(e.key)
|
||||
})
|
||||
|
||||
document.addEventListener('visibilitychange', () => {
|
||||
if (document.visibilityState === 'visible') {
|
||||
storageCache.clear()
|
||||
}
|
||||
})
|
||||
```
|
||||
@@ -0,0 +1,32 @@
|
||||
---
|
||||
title: Combine Multiple Array Iterations
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: reduces iterations
|
||||
tags: javascript, arrays, loops, performance
|
||||
---
|
||||
|
||||
## Combine Multiple Array Iterations
|
||||
|
||||
Multiple `.filter()` or `.map()` calls iterate the array multiple times. Combine into one loop.
|
||||
|
||||
**Incorrect (3 iterations):**
|
||||
|
||||
```typescript
|
||||
const admins = users.filter(u => u.isAdmin)
|
||||
const testers = users.filter(u => u.isTester)
|
||||
const inactive = users.filter(u => !u.isActive)
|
||||
```
|
||||
|
||||
**Correct (1 iteration):**
|
||||
|
||||
```typescript
|
||||
const admins: User[] = []
|
||||
const testers: User[] = []
|
||||
const inactive: User[] = []
|
||||
|
||||
for (const user of users) {
|
||||
if (user.isAdmin) admins.push(user)
|
||||
if (user.isTester) testers.push(user)
|
||||
if (!user.isActive) inactive.push(user)
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,50 @@
|
||||
---
|
||||
title: Early Return from Functions
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: avoids unnecessary computation
|
||||
tags: javascript, functions, optimization, early-return
|
||||
---
|
||||
|
||||
## Early Return from Functions
|
||||
|
||||
Return early when result is determined to skip unnecessary processing.
|
||||
|
||||
**Incorrect (processes all items even after finding answer):**
|
||||
|
||||
```typescript
|
||||
function validateUsers(users: User[]) {
|
||||
let hasError = false
|
||||
let errorMessage = ''
|
||||
|
||||
for (const user of users) {
|
||||
if (!user.email) {
|
||||
hasError = true
|
||||
errorMessage = 'Email required'
|
||||
}
|
||||
if (!user.name) {
|
||||
hasError = true
|
||||
errorMessage = 'Name required'
|
||||
}
|
||||
// Continues checking all users even after error found
|
||||
}
|
||||
|
||||
return hasError ? { valid: false, error: errorMessage } : { valid: true }
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (returns immediately on first error):**
|
||||
|
||||
```typescript
|
||||
function validateUsers(users: User[]) {
|
||||
for (const user of users) {
|
||||
if (!user.email) {
|
||||
return { valid: false, error: 'Email required' }
|
||||
}
|
||||
if (!user.name) {
|
||||
return { valid: false, error: 'Name required' }
|
||||
}
|
||||
}
|
||||
|
||||
return { valid: true }
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,45 @@
|
||||
---
|
||||
title: Hoist RegExp Creation
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: avoids recreation
|
||||
tags: javascript, regexp, optimization, memoization
|
||||
---
|
||||
|
||||
## Hoist RegExp Creation
|
||||
|
||||
Don't create RegExp inside render. Hoist to module scope or memoize with `useMemo()`.
|
||||
|
||||
**Incorrect (new RegExp every render):**
|
||||
|
||||
```tsx
|
||||
function Highlighter({ text, query }: Props) {
|
||||
const regex = new RegExp(`(${query})`, 'gi')
|
||||
const parts = text.split(regex)
|
||||
return <>{parts.map((part, i) => ...)}</>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (memoize or hoist):**
|
||||
|
||||
```tsx
|
||||
const EMAIL_REGEX = /^[^\s@]+@[^\s@]+\.[^\s@]+$/
|
||||
|
||||
function Highlighter({ text, query }: Props) {
|
||||
const regex = useMemo(
|
||||
() => new RegExp(`(${escapeRegex(query)})`, 'gi'),
|
||||
[query]
|
||||
)
|
||||
const parts = text.split(regex)
|
||||
return <>{parts.map((part, i) => ...)}</>
|
||||
}
|
||||
```
|
||||
|
||||
**Warning (global regex has mutable state):**
|
||||
|
||||
Global regex (`/g`) has mutable `lastIndex` state:
|
||||
|
||||
```typescript
|
||||
const regex = /foo/g
|
||||
regex.test('foo') // true, lastIndex = 3
|
||||
regex.test('foo') // false, lastIndex = 0
|
||||
```
|
||||
@@ -0,0 +1,37 @@
|
||||
---
|
||||
title: Build Index Maps for Repeated Lookups
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: 1M ops to 2K ops
|
||||
tags: javascript, map, indexing, optimization, performance
|
||||
---
|
||||
|
||||
## Build Index Maps for Repeated Lookups
|
||||
|
||||
Multiple `.find()` calls by the same key should use a Map.
|
||||
|
||||
**Incorrect (O(n) per lookup):**
|
||||
|
||||
```typescript
|
||||
function processOrders(orders: Order[], users: User[]) {
|
||||
return orders.map(order => ({
|
||||
...order,
|
||||
user: users.find(u => u.id === order.userId)
|
||||
}))
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (O(1) per lookup):**
|
||||
|
||||
```typescript
|
||||
function processOrders(orders: Order[], users: User[]) {
|
||||
const userById = new Map(users.map(u => [u.id, u]))
|
||||
|
||||
return orders.map(order => ({
|
||||
...order,
|
||||
user: userById.get(order.userId)
|
||||
}))
|
||||
}
|
||||
```
|
||||
|
||||
Build map once (O(n)), then all lookups are O(1).
|
||||
For 1000 orders × 1000 users: 1M ops → 2K ops.
|
||||
@@ -0,0 +1,49 @@
|
||||
---
|
||||
title: Early Length Check for Array Comparisons
|
||||
impact: MEDIUM-HIGH
|
||||
impactDescription: avoids expensive operations when lengths differ
|
||||
tags: javascript, arrays, performance, optimization, comparison
|
||||
---
|
||||
|
||||
## Early Length Check for Array Comparisons
|
||||
|
||||
When comparing arrays with expensive operations (sorting, deep equality, serialization), check lengths first. If lengths differ, the arrays cannot be equal.
|
||||
|
||||
In real-world applications, this optimization is especially valuable when the comparison runs in hot paths (event handlers, render loops).
|
||||
|
||||
**Incorrect (always runs expensive comparison):**
|
||||
|
||||
```typescript
|
||||
function hasChanges(current: string[], original: string[]) {
|
||||
// Always sorts and joins, even when lengths differ
|
||||
return current.sort().join() !== original.sort().join()
|
||||
}
|
||||
```
|
||||
|
||||
Two O(n log n) sorts run even when `current.length` is 5 and `original.length` is 100. There is also overhead of joining the arrays and comparing the strings.
|
||||
|
||||
**Correct (O(1) length check first):**
|
||||
|
||||
```typescript
|
||||
function hasChanges(current: string[], original: string[]) {
|
||||
// Early return if lengths differ
|
||||
if (current.length !== original.length) {
|
||||
return true
|
||||
}
|
||||
// Only sort/join when lengths match
|
||||
const currentSorted = current.toSorted()
|
||||
const originalSorted = original.toSorted()
|
||||
for (let i = 0; i < currentSorted.length; i++) {
|
||||
if (currentSorted[i] !== originalSorted[i]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
```
|
||||
|
||||
This new approach is more efficient because:
|
||||
- It avoids the overhead of sorting and joining the arrays when lengths differ
|
||||
- It avoids consuming memory for the joined strings (especially important for large arrays)
|
||||
- It avoids mutating the original arrays
|
||||
- It returns early when a difference is found
|
||||
@@ -0,0 +1,82 @@
|
||||
---
|
||||
title: Use Loop for Min/Max Instead of Sort
|
||||
impact: LOW
|
||||
impactDescription: O(n) instead of O(n log n)
|
||||
tags: javascript, arrays, performance, sorting, algorithms
|
||||
---
|
||||
|
||||
## Use Loop for Min/Max Instead of Sort
|
||||
|
||||
Finding the smallest or largest element only requires a single pass through the array. Sorting is wasteful and slower.
|
||||
|
||||
**Incorrect (O(n log n) - sort to find latest):**
|
||||
|
||||
```typescript
|
||||
interface Project {
|
||||
id: string
|
||||
name: string
|
||||
updatedAt: number
|
||||
}
|
||||
|
||||
function getLatestProject(projects: Project[]) {
|
||||
const sorted = [...projects].sort((a, b) => b.updatedAt - a.updatedAt)
|
||||
return sorted[0]
|
||||
}
|
||||
```
|
||||
|
||||
Sorts the entire array just to find the maximum value.
|
||||
|
||||
**Incorrect (O(n log n) - sort for oldest and newest):**
|
||||
|
||||
```typescript
|
||||
function getOldestAndNewest(projects: Project[]) {
|
||||
const sorted = [...projects].sort((a, b) => a.updatedAt - b.updatedAt)
|
||||
return { oldest: sorted[0], newest: sorted[sorted.length - 1] }
|
||||
}
|
||||
```
|
||||
|
||||
Still sorts unnecessarily when only min/max are needed.
|
||||
|
||||
**Correct (O(n) - single loop):**
|
||||
|
||||
```typescript
|
||||
function getLatestProject(projects: Project[]) {
|
||||
if (projects.length === 0) return null
|
||||
|
||||
let latest = projects[0]
|
||||
|
||||
for (let i = 1; i < projects.length; i++) {
|
||||
if (projects[i].updatedAt > latest.updatedAt) {
|
||||
latest = projects[i]
|
||||
}
|
||||
}
|
||||
|
||||
return latest
|
||||
}
|
||||
|
||||
function getOldestAndNewest(projects: Project[]) {
|
||||
if (projects.length === 0) return { oldest: null, newest: null }
|
||||
|
||||
let oldest = projects[0]
|
||||
let newest = projects[0]
|
||||
|
||||
for (let i = 1; i < projects.length; i++) {
|
||||
if (projects[i].updatedAt < oldest.updatedAt) oldest = projects[i]
|
||||
if (projects[i].updatedAt > newest.updatedAt) newest = projects[i]
|
||||
}
|
||||
|
||||
return { oldest, newest }
|
||||
}
|
||||
```
|
||||
|
||||
Single pass through the array, no copying, no sorting.
|
||||
|
||||
**Alternative (Math.min/Math.max for small arrays):**
|
||||
|
||||
```typescript
|
||||
const numbers = [5, 2, 8, 1, 9]
|
||||
const min = Math.min(...numbers)
|
||||
const max = Math.max(...numbers)
|
||||
```
|
||||
|
||||
This works for small arrays but can be slower for very large arrays due to spread operator limitations. Use the loop approach for reliability.
|
||||
@@ -0,0 +1,24 @@
|
||||
---
|
||||
title: Use Set/Map for O(1) Lookups
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: O(n) to O(1)
|
||||
tags: javascript, set, map, data-structures, performance
|
||||
---
|
||||
|
||||
## Use Set/Map for O(1) Lookups
|
||||
|
||||
Convert arrays to Set/Map for repeated membership checks.
|
||||
|
||||
**Incorrect (O(n) per check):**
|
||||
|
||||
```typescript
|
||||
const allowedIds = ['a', 'b', 'c', ...]
|
||||
items.filter(item => allowedIds.includes(item.id))
|
||||
```
|
||||
|
||||
**Correct (O(1) per check):**
|
||||
|
||||
```typescript
|
||||
const allowedIds = new Set(['a', 'b', 'c', ...])
|
||||
items.filter(item => allowedIds.has(item.id))
|
||||
```
|
||||
@@ -0,0 +1,57 @@
|
||||
---
|
||||
title: Use toSorted() Instead of sort() for Immutability
|
||||
impact: MEDIUM-HIGH
|
||||
impactDescription: prevents mutation bugs in React state
|
||||
tags: javascript, arrays, immutability, react, state, mutation
|
||||
---
|
||||
|
||||
## Use toSorted() Instead of sort() for Immutability
|
||||
|
||||
`.sort()` mutates the array in place, which can cause bugs with React state and props. Use `.toSorted()` to create a new sorted array without mutation.
|
||||
|
||||
**Incorrect (mutates original array):**
|
||||
|
||||
```typescript
|
||||
function UserList({ users }: { users: User[] }) {
|
||||
// Mutates the users prop array!
|
||||
const sorted = useMemo(
|
||||
() => users.sort((a, b) => a.name.localeCompare(b.name)),
|
||||
[users]
|
||||
)
|
||||
return <div>{sorted.map(renderUser)}</div>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (creates new array):**
|
||||
|
||||
```typescript
|
||||
function UserList({ users }: { users: User[] }) {
|
||||
// Creates new sorted array, original unchanged
|
||||
const sorted = useMemo(
|
||||
() => users.toSorted((a, b) => a.name.localeCompare(b.name)),
|
||||
[users]
|
||||
)
|
||||
return <div>{sorted.map(renderUser)}</div>
|
||||
}
|
||||
```
|
||||
|
||||
**Why this matters in React:**
|
||||
|
||||
1. Props/state mutations break React's immutability model - React expects props and state to be treated as read-only
|
||||
2. Causes stale closure bugs - Mutating arrays inside closures (callbacks, effects) can lead to unexpected behavior
|
||||
|
||||
**Browser support (fallback for older browsers):**
|
||||
|
||||
`.toSorted()` is available in all modern browsers (Chrome 110+, Safari 16+, Firefox 115+, Node.js 20+). For older environments, use spread operator:
|
||||
|
||||
```typescript
|
||||
// Fallback for older browsers
|
||||
const sorted = [...items].sort((a, b) => a.value - b.value)
|
||||
```
|
||||
|
||||
**Other immutable array methods:**
|
||||
|
||||
- `.toSorted()` - immutable sort
|
||||
- `.toReversed()` - immutable reverse
|
||||
- `.toSpliced()` - immutable splice
|
||||
- `.with()` - immutable element replacement
|
||||
@@ -0,0 +1,26 @@
|
||||
---
|
||||
title: Use Activity Component for Show/Hide
|
||||
impact: MEDIUM
|
||||
impactDescription: preserves state/DOM
|
||||
tags: rendering, activity, visibility, state-preservation
|
||||
---
|
||||
|
||||
## Use Activity Component for Show/Hide
|
||||
|
||||
Use React's `<Activity>` to preserve state/DOM for expensive components that frequently toggle visibility.
|
||||
|
||||
**Usage:**
|
||||
|
||||
```tsx
|
||||
import { Activity } from 'react'
|
||||
|
||||
function Dropdown({ isOpen }: Props) {
|
||||
return (
|
||||
<Activity mode={isOpen ? 'visible' : 'hidden'}>
|
||||
<ExpensiveMenu />
|
||||
</Activity>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
Avoids expensive re-renders and state loss.
|
||||
@@ -0,0 +1,47 @@
|
||||
---
|
||||
title: Animate SVG Wrapper Instead of SVG Element
|
||||
impact: LOW
|
||||
impactDescription: enables hardware acceleration
|
||||
tags: rendering, svg, css, animation, performance
|
||||
---
|
||||
|
||||
## Animate SVG Wrapper Instead of SVG Element
|
||||
|
||||
Many browsers don't have hardware acceleration for CSS3 animations on SVG elements. Wrap SVG in a `<div>` and animate the wrapper instead.
|
||||
|
||||
**Incorrect (animating SVG directly - no hardware acceleration):**
|
||||
|
||||
```tsx
|
||||
function LoadingSpinner() {
|
||||
return (
|
||||
<svg
|
||||
className="animate-spin"
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<circle cx="12" cy="12" r="10" stroke="currentColor" />
|
||||
</svg>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (animating wrapper div - hardware accelerated):**
|
||||
|
||||
```tsx
|
||||
function LoadingSpinner() {
|
||||
return (
|
||||
<div className="animate-spin">
|
||||
<svg
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<circle cx="12" cy="12" r="10" stroke="currentColor" />
|
||||
</svg>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
This applies to all CSS transforms and transitions (`transform`, `opacity`, `translate`, `scale`, `rotate`). The wrapper div allows browsers to use GPU acceleration for smoother animations.
|
||||
@@ -0,0 +1,40 @@
|
||||
---
|
||||
title: Use Explicit Conditional Rendering
|
||||
impact: LOW
|
||||
impactDescription: prevents rendering 0 or NaN
|
||||
tags: rendering, conditional, jsx, falsy-values
|
||||
---
|
||||
|
||||
## Use Explicit Conditional Rendering
|
||||
|
||||
Use explicit ternary operators (`? :`) instead of `&&` for conditional rendering when the condition can be `0`, `NaN`, or other falsy values that render.
|
||||
|
||||
**Incorrect (renders "0" when count is 0):**
|
||||
|
||||
```tsx
|
||||
function Badge({ count }: { count: number }) {
|
||||
return (
|
||||
<div>
|
||||
{count && <span className="badge">{count}</span>}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// When count = 0, renders: <div>0</div>
|
||||
// When count = 5, renders: <div><span class="badge">5</span></div>
|
||||
```
|
||||
|
||||
**Correct (renders nothing when count is 0):**
|
||||
|
||||
```tsx
|
||||
function Badge({ count }: { count: number }) {
|
||||
return (
|
||||
<div>
|
||||
{count > 0 ? <span className="badge">{count}</span> : null}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// When count = 0, renders: <div></div>
|
||||
// When count = 5, renders: <div><span class="badge">5</span></div>
|
||||
```
|
||||
@@ -0,0 +1,38 @@
|
||||
---
|
||||
title: CSS content-visibility for Long Lists
|
||||
impact: HIGH
|
||||
impactDescription: faster initial render
|
||||
tags: rendering, css, content-visibility, long-lists
|
||||
---
|
||||
|
||||
## CSS content-visibility for Long Lists
|
||||
|
||||
Apply `content-visibility: auto` to defer off-screen rendering.
|
||||
|
||||
**CSS:**
|
||||
|
||||
```css
|
||||
.message-item {
|
||||
content-visibility: auto;
|
||||
contain-intrinsic-size: 0 80px;
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
```tsx
|
||||
function MessageList({ messages }: { messages: Message[] }) {
|
||||
return (
|
||||
<div className="overflow-y-auto h-screen">
|
||||
{messages.map(msg => (
|
||||
<div key={msg.id} className="message-item">
|
||||
<Avatar user={msg.author} />
|
||||
<div>{msg.content}</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
For 1000 messages, browser skips layout/paint for ~990 off-screen items (10× faster initial render).
|
||||
@@ -0,0 +1,46 @@
|
||||
---
|
||||
title: Hoist Static JSX Elements
|
||||
impact: LOW
|
||||
impactDescription: avoids re-creation
|
||||
tags: rendering, jsx, static, optimization
|
||||
---
|
||||
|
||||
## Hoist Static JSX Elements
|
||||
|
||||
Extract static JSX outside components to avoid re-creation.
|
||||
|
||||
**Incorrect (recreates element every render):**
|
||||
|
||||
```tsx
|
||||
function LoadingSkeleton() {
|
||||
return <div className="animate-pulse h-20 bg-gray-200" />
|
||||
}
|
||||
|
||||
function Container() {
|
||||
return (
|
||||
<div>
|
||||
{loading && <LoadingSkeleton />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (reuses same element):**
|
||||
|
||||
```tsx
|
||||
const loadingSkeleton = (
|
||||
<div className="animate-pulse h-20 bg-gray-200" />
|
||||
)
|
||||
|
||||
function Container() {
|
||||
return (
|
||||
<div>
|
||||
{loading && loadingSkeleton}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
This is especially helpful for large and static SVG nodes, which can be expensive to recreate on every render.
|
||||
|
||||
**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, the compiler automatically hoists static JSX elements and optimizes component re-renders, making manual hoisting unnecessary.
|
||||
@@ -0,0 +1,82 @@
|
||||
---
|
||||
title: Prevent Hydration Mismatch Without Flickering
|
||||
impact: MEDIUM
|
||||
impactDescription: avoids visual flicker and hydration errors
|
||||
tags: rendering, ssr, hydration, localStorage, flicker
|
||||
---
|
||||
|
||||
## Prevent Hydration Mismatch Without Flickering
|
||||
|
||||
When rendering content that depends on client-side storage (localStorage, cookies), avoid both SSR breakage and post-hydration flickering by injecting a synchronous script that updates the DOM before React hydrates.
|
||||
|
||||
**Incorrect (breaks SSR):**
|
||||
|
||||
```tsx
|
||||
function ThemeWrapper({ children }: { children: ReactNode }) {
|
||||
// localStorage is not available on server - throws error
|
||||
const theme = localStorage.getItem('theme') || 'light'
|
||||
|
||||
return (
|
||||
<div className={theme}>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
Server-side rendering will fail because `localStorage` is undefined.
|
||||
|
||||
**Incorrect (visual flickering):**
|
||||
|
||||
```tsx
|
||||
function ThemeWrapper({ children }: { children: ReactNode }) {
|
||||
const [theme, setTheme] = useState('light')
|
||||
|
||||
useEffect(() => {
|
||||
// Runs after hydration - causes visible flash
|
||||
const stored = localStorage.getItem('theme')
|
||||
if (stored) {
|
||||
setTheme(stored)
|
||||
}
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div className={theme}>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
Component first renders with default value (`light`), then updates after hydration, causing a visible flash of incorrect content.
|
||||
|
||||
**Correct (no flicker, no hydration mismatch):**
|
||||
|
||||
```tsx
|
||||
function ThemeWrapper({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<>
|
||||
<div id="theme-wrapper">
|
||||
{children}
|
||||
</div>
|
||||
<script
|
||||
dangerouslySetInnerHTML={{
|
||||
__html: `
|
||||
(function() {
|
||||
try {
|
||||
var theme = localStorage.getItem('theme') || 'light';
|
||||
var el = document.getElementById('theme-wrapper');
|
||||
if (el) el.className = theme;
|
||||
} catch (e) {}
|
||||
})();
|
||||
`,
|
||||
}}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
The inline script executes synchronously before showing the element, ensuring the DOM already has the correct value. No flickering, no hydration mismatch.
|
||||
|
||||
This pattern is especially useful for theme toggles, user preferences, authentication states, and any client-only data that should render immediately without flashing default values.
|
||||
@@ -0,0 +1,28 @@
|
||||
---
|
||||
title: Optimize SVG Precision
|
||||
impact: LOW
|
||||
impactDescription: reduces file size
|
||||
tags: rendering, svg, optimization, svgo
|
||||
---
|
||||
|
||||
## Optimize SVG Precision
|
||||
|
||||
Reduce SVG coordinate precision to decrease file size. The optimal precision depends on the viewBox size, but in general reducing precision should be considered.
|
||||
|
||||
**Incorrect (excessive precision):**
|
||||
|
||||
```svg
|
||||
<path d="M 10.293847 20.847362 L 30.938472 40.192837" />
|
||||
```
|
||||
|
||||
**Correct (1 decimal place):**
|
||||
|
||||
```svg
|
||||
<path d="M 10.3 20.8 L 30.9 40.2" />
|
||||
```
|
||||
|
||||
**Automate with SVGO:**
|
||||
|
||||
```bash
|
||||
npx svgo --precision=1 --multipass icon.svg
|
||||
```
|
||||
@@ -0,0 +1,39 @@
|
||||
---
|
||||
title: Defer State Reads to Usage Point
|
||||
impact: MEDIUM
|
||||
impactDescription: avoids unnecessary subscriptions
|
||||
tags: rerender, searchParams, localStorage, optimization
|
||||
---
|
||||
|
||||
## Defer State Reads to Usage Point
|
||||
|
||||
Don't subscribe to dynamic state (searchParams, localStorage) if you only read it inside callbacks.
|
||||
|
||||
**Incorrect (subscribes to all searchParams changes):**
|
||||
|
||||
```tsx
|
||||
function ShareButton({ chatId }: { chatId: string }) {
|
||||
const searchParams = useSearchParams()
|
||||
|
||||
const handleShare = () => {
|
||||
const ref = searchParams.get('ref')
|
||||
shareChat(chatId, { ref })
|
||||
}
|
||||
|
||||
return <button onClick={handleShare}>Share</button>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (reads on demand, no subscription):**
|
||||
|
||||
```tsx
|
||||
function ShareButton({ chatId }: { chatId: string }) {
|
||||
const handleShare = () => {
|
||||
const params = new URLSearchParams(window.location.search)
|
||||
const ref = params.get('ref')
|
||||
shareChat(chatId, { ref })
|
||||
}
|
||||
|
||||
return <button onClick={handleShare}>Share</button>
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,45 @@
|
||||
---
|
||||
title: Narrow Effect Dependencies
|
||||
impact: LOW
|
||||
impactDescription: minimizes effect re-runs
|
||||
tags: rerender, useEffect, dependencies, optimization
|
||||
---
|
||||
|
||||
## Narrow Effect Dependencies
|
||||
|
||||
Specify primitive dependencies instead of objects to minimize effect re-runs.
|
||||
|
||||
**Incorrect (re-runs on any user field change):**
|
||||
|
||||
```tsx
|
||||
useEffect(() => {
|
||||
console.log(user.id)
|
||||
}, [user])
|
||||
```
|
||||
|
||||
**Correct (re-runs only when id changes):**
|
||||
|
||||
```tsx
|
||||
useEffect(() => {
|
||||
console.log(user.id)
|
||||
}, [user.id])
|
||||
```
|
||||
|
||||
**For derived state, compute outside effect:**
|
||||
|
||||
```tsx
|
||||
// Incorrect: runs on width=767, 766, 765...
|
||||
useEffect(() => {
|
||||
if (width < 768) {
|
||||
enableMobileMode()
|
||||
}
|
||||
}, [width])
|
||||
|
||||
// Correct: runs only on boolean transition
|
||||
const isMobile = width < 768
|
||||
useEffect(() => {
|
||||
if (isMobile) {
|
||||
enableMobileMode()
|
||||
}
|
||||
}, [isMobile])
|
||||
```
|
||||
@@ -0,0 +1,29 @@
|
||||
---
|
||||
title: Subscribe to Derived State
|
||||
impact: MEDIUM
|
||||
impactDescription: reduces re-render frequency
|
||||
tags: rerender, derived-state, media-query, optimization
|
||||
---
|
||||
|
||||
## Subscribe to Derived State
|
||||
|
||||
Subscribe to derived boolean state instead of continuous values to reduce re-render frequency.
|
||||
|
||||
**Incorrect (re-renders on every pixel change):**
|
||||
|
||||
```tsx
|
||||
function Sidebar() {
|
||||
const width = useWindowWidth() // updates continuously
|
||||
const isMobile = width < 768
|
||||
return <nav className={isMobile ? 'mobile' : 'desktop'}>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (re-renders only when boolean changes):**
|
||||
|
||||
```tsx
|
||||
function Sidebar() {
|
||||
const isMobile = useMediaQuery('(max-width: 767px)')
|
||||
return <nav className={isMobile ? 'mobile' : 'desktop'}>
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,74 @@
|
||||
---
|
||||
title: Use Functional setState Updates
|
||||
impact: MEDIUM
|
||||
impactDescription: prevents stale closures and unnecessary callback recreations
|
||||
tags: react, hooks, useState, useCallback, callbacks, closures
|
||||
---
|
||||
|
||||
## Use Functional setState Updates
|
||||
|
||||
When updating state based on the current state value, use the functional update form of setState instead of directly referencing the state variable. This prevents stale closures, eliminates unnecessary dependencies, and creates stable callback references.
|
||||
|
||||
**Incorrect (requires state as dependency):**
|
||||
|
||||
```tsx
|
||||
function TodoList() {
|
||||
const [items, setItems] = useState(initialItems)
|
||||
|
||||
// Callback must depend on items, recreated on every items change
|
||||
const addItems = useCallback((newItems: Item[]) => {
|
||||
setItems([...items, ...newItems])
|
||||
}, [items]) // ❌ items dependency causes recreations
|
||||
|
||||
// Risk of stale closure if dependency is forgotten
|
||||
const removeItem = useCallback((id: string) => {
|
||||
setItems(items.filter(item => item.id !== id))
|
||||
}, []) // ❌ Missing items dependency - will use stale items!
|
||||
|
||||
return <ItemsEditor items={items} onAdd={addItems} onRemove={removeItem} />
|
||||
}
|
||||
```
|
||||
|
||||
The first callback is recreated every time `items` changes, which can cause child components to re-render unnecessarily. The second callback has a stale closure bug—it will always reference the initial `items` value.
|
||||
|
||||
**Correct (stable callbacks, no stale closures):**
|
||||
|
||||
```tsx
|
||||
function TodoList() {
|
||||
const [items, setItems] = useState(initialItems)
|
||||
|
||||
// Stable callback, never recreated
|
||||
const addItems = useCallback((newItems: Item[]) => {
|
||||
setItems(curr => [...curr, ...newItems])
|
||||
}, []) // ✅ No dependencies needed
|
||||
|
||||
// Always uses latest state, no stale closure risk
|
||||
const removeItem = useCallback((id: string) => {
|
||||
setItems(curr => curr.filter(item => item.id !== id))
|
||||
}, []) // ✅ Safe and stable
|
||||
|
||||
return <ItemsEditor items={items} onAdd={addItems} onRemove={removeItem} />
|
||||
}
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
|
||||
1. **Stable callback references** - Callbacks don't need to be recreated when state changes
|
||||
2. **No stale closures** - Always operates on the latest state value
|
||||
3. **Fewer dependencies** - Simplifies dependency arrays and reduces memory leaks
|
||||
4. **Prevents bugs** - Eliminates the most common source of React closure bugs
|
||||
|
||||
**When to use functional updates:**
|
||||
|
||||
- Any setState that depends on the current state value
|
||||
- Inside useCallback/useMemo when state is needed
|
||||
- Event handlers that reference state
|
||||
- Async operations that update state
|
||||
|
||||
**When direct updates are fine:**
|
||||
|
||||
- Setting state to a static value: `setCount(0)`
|
||||
- Setting state from props/arguments only: `setName(newName)`
|
||||
- State doesn't depend on previous value
|
||||
|
||||
**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, the compiler can automatically optimize some cases, but functional updates are still recommended for correctness and to prevent stale closure bugs.
|
||||
@@ -0,0 +1,58 @@
|
||||
---
|
||||
title: Use Lazy State Initialization
|
||||
impact: MEDIUM
|
||||
impactDescription: wasted computation on every render
|
||||
tags: react, hooks, useState, performance, initialization
|
||||
---
|
||||
|
||||
## Use Lazy State Initialization
|
||||
|
||||
Pass a function to `useState` for expensive initial values. Without the function form, the initializer runs on every render even though the value is only used once.
|
||||
|
||||
**Incorrect (runs on every render):**
|
||||
|
||||
```tsx
|
||||
function FilteredList({ items }: { items: Item[] }) {
|
||||
// buildSearchIndex() runs on EVERY render, even after initialization
|
||||
const [searchIndex, setSearchIndex] = useState(buildSearchIndex(items))
|
||||
const [query, setQuery] = useState('')
|
||||
|
||||
// When query changes, buildSearchIndex runs again unnecessarily
|
||||
return <SearchResults index={searchIndex} query={query} />
|
||||
}
|
||||
|
||||
function UserProfile() {
|
||||
// JSON.parse runs on every render
|
||||
const [settings, setSettings] = useState(
|
||||
JSON.parse(localStorage.getItem('settings') || '{}')
|
||||
)
|
||||
|
||||
return <SettingsForm settings={settings} onChange={setSettings} />
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (runs only once):**
|
||||
|
||||
```tsx
|
||||
function FilteredList({ items }: { items: Item[] }) {
|
||||
// buildSearchIndex() runs ONLY on initial render
|
||||
const [searchIndex, setSearchIndex] = useState(() => buildSearchIndex(items))
|
||||
const [query, setQuery] = useState('')
|
||||
|
||||
return <SearchResults index={searchIndex} query={query} />
|
||||
}
|
||||
|
||||
function UserProfile() {
|
||||
// JSON.parse runs only on initial render
|
||||
const [settings, setSettings] = useState(() => {
|
||||
const stored = localStorage.getItem('settings')
|
||||
return stored ? JSON.parse(stored) : {}
|
||||
})
|
||||
|
||||
return <SettingsForm settings={settings} onChange={setSettings} />
|
||||
}
|
||||
```
|
||||
|
||||
Use lazy initialization when computing initial values from localStorage/sessionStorage, building data structures (indexes, maps), reading from the DOM, or performing heavy transformations.
|
||||
|
||||
For simple primitives (`useState(0)`), direct references (`useState(props.value)`), or cheap literals (`useState({})`), the function form is unnecessary.
|
||||
@@ -0,0 +1,44 @@
|
||||
---
|
||||
title: Extract to Memoized Components
|
||||
impact: MEDIUM
|
||||
impactDescription: enables early returns
|
||||
tags: rerender, memo, useMemo, optimization
|
||||
---
|
||||
|
||||
## Extract to Memoized Components
|
||||
|
||||
Extract expensive work into memoized components to enable early returns before computation.
|
||||
|
||||
**Incorrect (computes avatar even when loading):**
|
||||
|
||||
```tsx
|
||||
function Profile({ user, loading }: Props) {
|
||||
const avatar = useMemo(() => {
|
||||
const id = computeAvatarId(user)
|
||||
return <Avatar id={id} />
|
||||
}, [user])
|
||||
|
||||
if (loading) return <Skeleton />
|
||||
return <div>{avatar}</div>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (skips computation when loading):**
|
||||
|
||||
```tsx
|
||||
const UserAvatar = memo(function UserAvatar({ user }: { user: User }) {
|
||||
const id = useMemo(() => computeAvatarId(user), [user])
|
||||
return <Avatar id={id} />
|
||||
})
|
||||
|
||||
function Profile({ user, loading }: Props) {
|
||||
if (loading) return <Skeleton />
|
||||
return (
|
||||
<div>
|
||||
<UserAvatar user={user} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, manual memoization with `memo()` and `useMemo()` is not necessary. The compiler automatically optimizes re-renders.
|
||||
@@ -0,0 +1,40 @@
|
||||
---
|
||||
title: Use Transitions for Non-Urgent Updates
|
||||
impact: MEDIUM
|
||||
impactDescription: maintains UI responsiveness
|
||||
tags: rerender, transitions, startTransition, performance
|
||||
---
|
||||
|
||||
## Use Transitions for Non-Urgent Updates
|
||||
|
||||
Mark frequent, non-urgent state updates as transitions to maintain UI responsiveness.
|
||||
|
||||
**Incorrect (blocks UI on every scroll):**
|
||||
|
||||
```tsx
|
||||
function ScrollTracker() {
|
||||
const [scrollY, setScrollY] = useState(0)
|
||||
useEffect(() => {
|
||||
const handler = () => setScrollY(window.scrollY)
|
||||
window.addEventListener('scroll', handler, { passive: true })
|
||||
return () => window.removeEventListener('scroll', handler)
|
||||
}, [])
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (non-blocking updates):**
|
||||
|
||||
```tsx
|
||||
import { startTransition } from 'react'
|
||||
|
||||
function ScrollTracker() {
|
||||
const [scrollY, setScrollY] = useState(0)
|
||||
useEffect(() => {
|
||||
const handler = () => {
|
||||
startTransition(() => setScrollY(window.scrollY))
|
||||
}
|
||||
window.addEventListener('scroll', handler, { passive: true })
|
||||
return () => window.removeEventListener('scroll', handler)
|
||||
}, [])
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,73 @@
|
||||
---
|
||||
title: Use after() for Non-Blocking Operations
|
||||
impact: MEDIUM
|
||||
impactDescription: faster response times
|
||||
tags: server, async, logging, analytics, side-effects
|
||||
---
|
||||
|
||||
## Use after() for Non-Blocking Operations
|
||||
|
||||
Use Next.js's `after()` to schedule work that should execute after a response is sent. This prevents logging, analytics, and other side effects from blocking the response.
|
||||
|
||||
**Incorrect (blocks response):**
|
||||
|
||||
```tsx
|
||||
import { logUserAction } from '@/app/utils'
|
||||
|
||||
export async function POST(request: Request) {
|
||||
// Perform mutation
|
||||
await updateDatabase(request)
|
||||
|
||||
// Logging blocks the response
|
||||
const userAgent = request.headers.get('user-agent') || 'unknown'
|
||||
await logUserAction({ userAgent })
|
||||
|
||||
return new Response(JSON.stringify({ status: 'success' }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (non-blocking):**
|
||||
|
||||
```tsx
|
||||
import { after } from 'next/server'
|
||||
import { headers, cookies } from 'next/headers'
|
||||
import { logUserAction } from '@/app/utils'
|
||||
|
||||
export async function POST(request: Request) {
|
||||
// Perform mutation
|
||||
await updateDatabase(request)
|
||||
|
||||
// Log after response is sent
|
||||
after(async () => {
|
||||
const userAgent = (await headers()).get('user-agent') || 'unknown'
|
||||
const sessionCookie = (await cookies()).get('session-id')?.value || 'anonymous'
|
||||
|
||||
logUserAction({ sessionCookie, userAgent })
|
||||
})
|
||||
|
||||
return new Response(JSON.stringify({ status: 'success' }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
The response is sent immediately while logging happens in the background.
|
||||
|
||||
**Common use cases:**
|
||||
|
||||
- Analytics tracking
|
||||
- Audit logging
|
||||
- Sending notifications
|
||||
- Cache invalidation
|
||||
- Cleanup tasks
|
||||
|
||||
**Important notes:**
|
||||
|
||||
- `after()` runs even if the response fails or redirects
|
||||
- Works in Server Actions, Route Handlers, and Server Components
|
||||
|
||||
Reference: [https://nextjs.org/docs/app/api-reference/functions/after](https://nextjs.org/docs/app/api-reference/functions/after)
|
||||
@@ -0,0 +1,41 @@
|
||||
---
|
||||
title: Cross-Request LRU Caching
|
||||
impact: HIGH
|
||||
impactDescription: caches across requests
|
||||
tags: server, cache, lru, cross-request
|
||||
---
|
||||
|
||||
## Cross-Request LRU Caching
|
||||
|
||||
`React.cache()` only works within one request. For data shared across sequential requests (user clicks button A then button B), use an LRU cache.
|
||||
|
||||
**Implementation:**
|
||||
|
||||
```typescript
|
||||
import { LRUCache } from 'lru-cache'
|
||||
|
||||
const cache = new LRUCache<string, any>({
|
||||
max: 1000,
|
||||
ttl: 5 * 60 * 1000 // 5 minutes
|
||||
})
|
||||
|
||||
export async function getUser(id: string) {
|
||||
const cached = cache.get(id)
|
||||
if (cached) return cached
|
||||
|
||||
const user = await db.user.findUnique({ where: { id } })
|
||||
cache.set(id, user)
|
||||
return user
|
||||
}
|
||||
|
||||
// Request 1: DB query, result cached
|
||||
// Request 2: cache hit, no DB query
|
||||
```
|
||||
|
||||
Use when sequential user actions hit multiple endpoints needing the same data within seconds.
|
||||
|
||||
**With Vercel's [Fluid Compute](https://vercel.com/docs/fluid-compute):** LRU caching is especially effective because multiple concurrent requests can share the same function instance and cache. This means the cache persists across requests without needing external storage like Redis.
|
||||
|
||||
**In traditional serverless:** Each invocation runs in isolation, so consider Redis for cross-process caching.
|
||||
|
||||
Reference: [https://github.com/isaacs/node-lru-cache](https://github.com/isaacs/node-lru-cache)
|
||||
@@ -0,0 +1,26 @@
|
||||
---
|
||||
title: Per-Request Deduplication with React.cache()
|
||||
impact: MEDIUM
|
||||
impactDescription: deduplicates within request
|
||||
tags: server, cache, react-cache, deduplication
|
||||
---
|
||||
|
||||
## Per-Request Deduplication with React.cache()
|
||||
|
||||
Use `React.cache()` for server-side request deduplication. Authentication and database queries benefit most.
|
||||
|
||||
**Usage:**
|
||||
|
||||
```typescript
|
||||
import { cache } from 'react'
|
||||
|
||||
export const getCurrentUser = cache(async () => {
|
||||
const session = await auth()
|
||||
if (!session?.user?.id) return null
|
||||
return await db.user.findUnique({
|
||||
where: { id: session.user.id }
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
Within a single request, multiple calls to `getCurrentUser()` execute the query only once.
|
||||
@@ -0,0 +1,79 @@
|
||||
---
|
||||
title: Parallel Data Fetching with Component Composition
|
||||
impact: CRITICAL
|
||||
impactDescription: eliminates server-side waterfalls
|
||||
tags: server, rsc, parallel-fetching, composition
|
||||
---
|
||||
|
||||
## Parallel Data Fetching with Component Composition
|
||||
|
||||
React Server Components execute sequentially within a tree. Restructure with composition to parallelize data fetching.
|
||||
|
||||
**Incorrect (Sidebar waits for Page's fetch to complete):**
|
||||
|
||||
```tsx
|
||||
export default async function Page() {
|
||||
const header = await fetchHeader()
|
||||
return (
|
||||
<div>
|
||||
<div>{header}</div>
|
||||
<Sidebar />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
async function Sidebar() {
|
||||
const items = await fetchSidebarItems()
|
||||
return <nav>{items.map(renderItem)}</nav>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (both fetch simultaneously):**
|
||||
|
||||
```tsx
|
||||
async function Header() {
|
||||
const data = await fetchHeader()
|
||||
return <div>{data}</div>
|
||||
}
|
||||
|
||||
async function Sidebar() {
|
||||
const items = await fetchSidebarItems()
|
||||
return <nav>{items.map(renderItem)}</nav>
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<div>
|
||||
<Header />
|
||||
<Sidebar />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Alternative with children prop:**
|
||||
|
||||
```tsx
|
||||
async function Layout({ children }: { children: ReactNode }) {
|
||||
const header = await fetchHeader()
|
||||
return (
|
||||
<div>
|
||||
<div>{header}</div>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
async function Sidebar() {
|
||||
const items = await fetchSidebarItems()
|
||||
return <nav>{items.map(renderItem)}</nav>
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<Layout>
|
||||
<Sidebar />
|
||||
</Layout>
|
||||
)
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,38 @@
|
||||
---
|
||||
title: Minimize Serialization at RSC Boundaries
|
||||
impact: HIGH
|
||||
impactDescription: reduces data transfer size
|
||||
tags: server, rsc, serialization, props
|
||||
---
|
||||
|
||||
## Minimize Serialization at RSC Boundaries
|
||||
|
||||
The React Server/Client boundary serializes all object properties into strings and embeds them in the HTML response and subsequent RSC requests. This serialized data directly impacts page weight and load time, so **size matters a lot**. Only pass fields that the client actually uses.
|
||||
|
||||
**Incorrect (serializes all 50 fields):**
|
||||
|
||||
```tsx
|
||||
async function Page() {
|
||||
const user = await fetchUser() // 50 fields
|
||||
return <Profile user={user} />
|
||||
}
|
||||
|
||||
'use client'
|
||||
function Profile({ user }: { user: User }) {
|
||||
return <div>{user.name}</div> // uses 1 field
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (serializes only 1 field):**
|
||||
|
||||
```tsx
|
||||
async function Page() {
|
||||
const user = await fetchUser()
|
||||
return <Profile name={user.name} />
|
||||
}
|
||||
|
||||
'use client'
|
||||
function Profile({ name }: { name: string }) {
|
||||
return <div>{name}</div>
|
||||
}
|
||||
```
|
||||
@@ -93,5 +93,5 @@ jobs:
|
||||
|
||||
Error logs:
|
||||
${{ toJSON(fromJSON(steps.failure_details.outputs.result).errorLogs) }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
claude_args: "--allowedTools 'Edit,MultiEdit,Write,Read,Glob,Grep,LS,Bash(git:*),Bash(bun:*),Bash(npm:*),Bash(npx:*),Bash(gh:*)'"
|
||||
|
||||
4
.github/workflows/claude-dependabot.yml
vendored
4
.github/workflows/claude-dependabot.yml
vendored
@@ -7,7 +7,7 @@
|
||||
# - Provide actionable recommendations for the development team
|
||||
#
|
||||
# Triggered on: Dependabot PRs (opened, synchronize)
|
||||
# Requirements: ANTHROPIC_API_KEY secret must be configured
|
||||
# Requirements: CLAUDE_CODE_OAUTH_TOKEN secret must be configured
|
||||
|
||||
name: Claude Dependabot PR Review
|
||||
|
||||
@@ -308,7 +308,7 @@ jobs:
|
||||
id: claude_review
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||
prompt: |
|
||||
|
||||
2
.github/workflows/claude.yml
vendored
2
.github/workflows/claude.yml
vendored
@@ -323,7 +323,7 @@ jobs:
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*), Bash(gh pr edit:*)"
|
||||
--model opus
|
||||
|
||||
78
.github/workflows/docs-block-sync.yml
vendored
Normal file
78
.github/workflows/docs-block-sync.yml
vendored
Normal file
@@ -0,0 +1,78 @@
|
||||
name: Block Documentation Sync Check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- "autogpt_platform/backend/backend/blocks/**"
|
||||
- "docs/integrations/**"
|
||||
- "autogpt_platform/backend/scripts/generate_block_docs.py"
|
||||
- ".github/workflows/docs-block-sync.yml"
|
||||
pull_request:
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- "autogpt_platform/backend/backend/blocks/**"
|
||||
- "docs/integrations/**"
|
||||
- "autogpt_platform/backend/scripts/generate_block_docs.py"
|
||||
- ".github/workflows/docs-block-sync.yml"
|
||||
|
||||
jobs:
|
||||
check-docs-sync:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
restore-keys: |
|
||||
poetry-${{ runner.os }}-
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry install --only main
|
||||
poetry run prisma generate
|
||||
|
||||
- name: Check block documentation is in sync
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
echo "Checking if block documentation is in sync with code..."
|
||||
poetry run python scripts/generate_block_docs.py --check
|
||||
|
||||
- name: Show diff if out of sync
|
||||
if: failure()
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
echo "::error::Block documentation is out of sync with code!"
|
||||
echo ""
|
||||
echo "To fix this, run the following command locally:"
|
||||
echo " cd autogpt_platform/backend && poetry run python scripts/generate_block_docs.py"
|
||||
echo ""
|
||||
echo "Then commit the updated documentation files."
|
||||
echo ""
|
||||
echo "Regenerating docs to show diff..."
|
||||
poetry run python scripts/generate_block_docs.py
|
||||
echo ""
|
||||
echo "Changes detected:"
|
||||
git diff ../../docs/integrations/ || true
|
||||
95
.github/workflows/docs-claude-review.yml
vendored
Normal file
95
.github/workflows/docs-claude-review.yml
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
name: Claude Block Docs Review
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
paths:
|
||||
- "docs/integrations/**"
|
||||
- "autogpt_platform/backend/backend/blocks/**"
|
||||
|
||||
jobs:
|
||||
claude-review:
|
||||
# Only run for PRs from members/collaborators
|
||||
if: |
|
||||
github.event.pull_request.author_association == 'OWNER' ||
|
||||
github.event.pull_request.author_association == 'MEMBER' ||
|
||||
github.event.pull_request.author_association == 'COLLABORATOR'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
restore-keys: |
|
||||
poetry-${{ runner.os }}-
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry install --only main
|
||||
poetry run prisma generate
|
||||
|
||||
- name: Run Claude Code Review
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
claude_args: |
|
||||
--allowedTools "Read,Glob,Grep,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*)"
|
||||
prompt: |
|
||||
You are reviewing a PR that modifies block documentation or block code for AutoGPT.
|
||||
|
||||
## Your Task
|
||||
Review the changes in this PR and provide constructive feedback. Focus on:
|
||||
|
||||
1. **Documentation Accuracy**: For any block code changes, verify that:
|
||||
- Input/output tables in docs match the actual block schemas
|
||||
- Description text accurately reflects what the block does
|
||||
- Any new blocks have corresponding documentation
|
||||
|
||||
2. **Manual Content Quality**: Check manual sections (marked with `<!-- MANUAL: -->` markers):
|
||||
- "How it works" sections should have clear technical explanations
|
||||
- "Possible use case" sections should have practical, real-world examples
|
||||
- Content should be helpful for users trying to understand the blocks
|
||||
|
||||
3. **Template Compliance**: Ensure docs follow the standard template:
|
||||
- What it is (brief intro)
|
||||
- What it does (description)
|
||||
- How it works (technical explanation)
|
||||
- Inputs table
|
||||
- Outputs table
|
||||
- Possible use case
|
||||
|
||||
4. **Cross-references**: Check that links and anchors are correct
|
||||
|
||||
## Review Process
|
||||
1. First, get the PR diff to see what changed: `gh pr diff ${{ github.event.pull_request.number }}`
|
||||
2. Read any modified block files to understand the implementation
|
||||
3. Read corresponding documentation files to verify accuracy
|
||||
4. Provide your feedback as a PR comment
|
||||
|
||||
Be constructive and specific. If everything looks good, say so!
|
||||
If there are issues, explain what's wrong and suggest how to fix it.
|
||||
194
.github/workflows/docs-enhance.yml
vendored
Normal file
194
.github/workflows/docs-enhance.yml
vendored
Normal file
@@ -0,0 +1,194 @@
|
||||
name: Enhance Block Documentation
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
block_pattern:
|
||||
description: 'Block file pattern to enhance (e.g., "google/*.md" or "*" for all blocks)'
|
||||
required: true
|
||||
default: '*'
|
||||
type: string
|
||||
dry_run:
|
||||
description: 'Dry run mode - show proposed changes without committing'
|
||||
type: boolean
|
||||
default: true
|
||||
max_blocks:
|
||||
description: 'Maximum number of blocks to process (0 for unlimited)'
|
||||
type: number
|
||||
default: 10
|
||||
|
||||
jobs:
|
||||
enhance-docs:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
restore-keys: |
|
||||
poetry-${{ runner.os }}-
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry install --only main
|
||||
poetry run prisma generate
|
||||
|
||||
- name: Run Claude Enhancement
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||
claude_args: |
|
||||
--allowedTools "Read,Edit,Glob,Grep,Write,Bash(git:*),Bash(gh:*),Bash(find:*),Bash(ls:*)"
|
||||
prompt: |
|
||||
You are enhancing block documentation for AutoGPT. Your task is to improve the MANUAL sections
|
||||
of block documentation files by reading the actual block implementations and writing helpful content.
|
||||
|
||||
## Configuration
|
||||
- Block pattern: ${{ inputs.block_pattern }}
|
||||
- Dry run: ${{ inputs.dry_run }}
|
||||
- Max blocks to process: ${{ inputs.max_blocks }}
|
||||
|
||||
## Your Task
|
||||
|
||||
1. **Find Documentation Files**
|
||||
Find block documentation files matching the pattern in `docs/integrations/`
|
||||
Pattern: ${{ inputs.block_pattern }}
|
||||
|
||||
Use: `find docs/integrations -name "*.md" -type f`
|
||||
|
||||
2. **For Each Documentation File** (up to ${{ inputs.max_blocks }} files):
|
||||
|
||||
a. Read the documentation file
|
||||
|
||||
b. Identify which block(s) it documents (look for the block class name)
|
||||
|
||||
c. Find and read the corresponding block implementation in `autogpt_platform/backend/backend/blocks/`
|
||||
|
||||
d. Improve the MANUAL sections:
|
||||
|
||||
**"How it works" section** (within `<!-- MANUAL: how_it_works -->` markers):
|
||||
- Explain the technical flow of the block
|
||||
- Describe what APIs or services it connects to
|
||||
- Note any important configuration or prerequisites
|
||||
- Keep it concise but informative (2-4 paragraphs)
|
||||
|
||||
**"Possible use case" section** (within `<!-- MANUAL: use_case -->` markers):
|
||||
- Provide 2-3 practical, real-world examples
|
||||
- Make them specific and actionable
|
||||
- Show how this block could be used in an automation workflow
|
||||
|
||||
3. **Important Rules**
|
||||
- ONLY modify content within `<!-- MANUAL: -->` and `<!-- END MANUAL -->` markers
|
||||
- Do NOT modify auto-generated sections (inputs/outputs tables, descriptions)
|
||||
- Keep content accurate based on the actual block implementation
|
||||
- Write for users who may not be technical experts
|
||||
|
||||
4. **Output**
|
||||
${{ inputs.dry_run == true && 'DRY RUN MODE: Show proposed changes for each file but do NOT actually edit the files. Describe what you would change.' || 'LIVE MODE: Actually edit the files to improve the documentation.' }}
|
||||
|
||||
## Example Improvements
|
||||
|
||||
**Before (How it works):**
|
||||
```
|
||||
_Add technical explanation here._
|
||||
```
|
||||
|
||||
**After (How it works):**
|
||||
```
|
||||
This block connects to the GitHub API to retrieve issue information. When executed,
|
||||
it authenticates using your GitHub credentials and fetches issue details including
|
||||
title, body, labels, and assignees.
|
||||
|
||||
The block requires a valid GitHub OAuth connection with repository access permissions.
|
||||
It supports both public and private repositories you have access to.
|
||||
```
|
||||
|
||||
**Before (Possible use case):**
|
||||
```
|
||||
_Add practical use case examples here._
|
||||
```
|
||||
|
||||
**After (Possible use case):**
|
||||
```
|
||||
**Customer Support Automation**: Monitor a GitHub repository for new issues with
|
||||
the "bug" label, then automatically create a ticket in your support system and
|
||||
notify the on-call engineer via Slack.
|
||||
|
||||
**Release Notes Generation**: When a new release is published, gather all closed
|
||||
issues since the last release and generate a summary for your changelog.
|
||||
```
|
||||
|
||||
Begin by finding and listing the documentation files to process.
|
||||
|
||||
- name: Create PR with enhanced documentation
|
||||
if: ${{ inputs.dry_run == false }}
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
# Check if there are changes
|
||||
if git diff --quiet docs/integrations/; then
|
||||
echo "No changes to commit"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Configure git
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
# Create branch and commit
|
||||
BRANCH_NAME="docs/enhance-blocks-$(date +%Y%m%d-%H%M%S)"
|
||||
git checkout -b "$BRANCH_NAME"
|
||||
git add docs/integrations/
|
||||
git commit -m "docs: enhance block documentation with LLM-generated content
|
||||
|
||||
Pattern: ${{ inputs.block_pattern }}
|
||||
Max blocks: ${{ inputs.max_blocks }}
|
||||
|
||||
🤖 Generated with [Claude Code](https://claude.com/claude-code)
|
||||
|
||||
Co-Authored-By: Claude <noreply@anthropic.com>"
|
||||
|
||||
# Push and create PR
|
||||
git push -u origin "$BRANCH_NAME"
|
||||
gh pr create \
|
||||
--title "docs: LLM-enhanced block documentation" \
|
||||
--body "## Summary
|
||||
This PR contains LLM-enhanced documentation for block files matching pattern: \`${{ inputs.block_pattern }}\`
|
||||
|
||||
The following manual sections were improved:
|
||||
- **How it works**: Technical explanations based on block implementations
|
||||
- **Possible use case**: Practical, real-world examples
|
||||
|
||||
## Review Checklist
|
||||
- [ ] Content is accurate based on block implementations
|
||||
- [ ] Examples are practical and helpful
|
||||
- [ ] No auto-generated sections were modified
|
||||
|
||||
---
|
||||
🤖 Generated with [Claude Code](https://claude.com/claude-code)" \
|
||||
--base dev
|
||||
@@ -1,57 +1,21 @@
|
||||
"""
|
||||
External API Application
|
||||
|
||||
This module defines the main FastAPI application for the external API,
|
||||
which mounts the v1 and v2 sub-applications.
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
from backend.api.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
|
||||
from .v1.app import v1_app
|
||||
from .v2.app import v2_app
|
||||
|
||||
DESCRIPTION = """
|
||||
The external API provides programmatic access to the AutoGPT Platform for building
|
||||
integrations, automations, and custom applications.
|
||||
|
||||
### API Versions
|
||||
|
||||
| Version | End of Life | Path | Documentation |
|
||||
|---------------------|-------------|------------------------|---------------|
|
||||
| **v2** | | `/external-api/v2/...` | [v2 docs](v2/docs) |
|
||||
| **v1** (deprecated) | 2025-05-01 | `/external-api/v1/...` | [v1 docs](v1/docs) |
|
||||
|
||||
**Recommendation**: New integrations should use v2.
|
||||
|
||||
For authentication details and usage examples, see the
|
||||
[API Integration Guide](https://docs.agpt.co/platform/integrating/api-guide/).
|
||||
"""
|
||||
from .v1.routes import v1_router
|
||||
|
||||
external_api = FastAPI(
|
||||
title="AutoGPT Platform API",
|
||||
summary="External API for AutoGPT Platform integrations",
|
||||
description=DESCRIPTION,
|
||||
version="2.0.0",
|
||||
title="AutoGPT External API",
|
||||
description="External API for AutoGPT integrations",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
version="1.0",
|
||||
)
|
||||
|
||||
external_api.add_middleware(SecurityHeadersMiddleware)
|
||||
external_api.include_router(v1_router, prefix="/v1")
|
||||
|
||||
@external_api.get("/", include_in_schema=False)
|
||||
async def root_redirect() -> RedirectResponse:
|
||||
"""Redirect root to API documentation."""
|
||||
return RedirectResponse(url="/docs")
|
||||
|
||||
|
||||
# Mount versioned sub-applications
|
||||
# Each sub-app has its own /docs page at /v1/docs and /v2/docs
|
||||
external_api.mount("/v1", v1_app)
|
||||
external_api.mount("/v2", v2_app)
|
||||
|
||||
# Add Prometheus instrumentation to the main app
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
external_api,
|
||||
service_name="external-api",
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
"""
|
||||
V1 External API Application
|
||||
|
||||
This module defines the FastAPI application for the v1 external API.
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
from .routes import v1_router
|
||||
|
||||
DESCRIPTION = """
|
||||
The v1 API provides access to core AutoGPT functionality for external integrations.
|
||||
|
||||
For authentication details and usage examples, see the
|
||||
[API Integration Guide](https://docs.agpt.co/platform/integrating/api-guide/).
|
||||
"""
|
||||
|
||||
v1_app = FastAPI(
|
||||
title="AutoGPT Platform API",
|
||||
summary="External API for AutoGPT Platform integrations (v1)",
|
||||
description=DESCRIPTION,
|
||||
version="1.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
openapi_tags=[
|
||||
{"name": "user", "description": "User information"},
|
||||
{"name": "blocks", "description": "Block operations"},
|
||||
{"name": "graphs", "description": "Graph execution"},
|
||||
{"name": "store", "description": "Marketplace agents and creators"},
|
||||
{"name": "integrations", "description": "OAuth credential management"},
|
||||
{"name": "tools", "description": "AI assistant tools"},
|
||||
],
|
||||
)
|
||||
|
||||
v1_app.add_middleware(SecurityHeadersMiddleware)
|
||||
v1_app.include_router(v1_router)
|
||||
@@ -1,9 +0,0 @@
|
||||
"""
|
||||
V2 External API
|
||||
|
||||
This module provides the v2 external API for programmatic access to the AutoGPT Platform.
|
||||
"""
|
||||
|
||||
from .routes import v2_router
|
||||
|
||||
__all__ = ["v2_router"]
|
||||
@@ -1,82 +0,0 @@
|
||||
"""
|
||||
V2 External API Application
|
||||
|
||||
This module defines the FastAPI application for the v2 external API.
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
from .routes import v2_router
|
||||
|
||||
DESCRIPTION = """
|
||||
The v2 API provides comprehensive access to the AutoGPT Platform for building
|
||||
integrations, automations, and custom applications.
|
||||
|
||||
### Key Improvements over v1
|
||||
|
||||
- **Consistent naming**: Uses `graph_id`/`graph_version` consistently
|
||||
- **Better pagination**: All list endpoints support pagination
|
||||
- **Comprehensive coverage**: Access to library, runs, schedules, credits, and more
|
||||
- **Human-in-the-loop**: Review and approve agent decisions via the API
|
||||
|
||||
For authentication details and usage examples, see the
|
||||
[API Integration Guide](https://docs.agpt.co/platform/integrating/api-guide/).
|
||||
|
||||
### Pagination
|
||||
|
||||
List endpoints return paginated responses. Use `page` and `page_size` query
|
||||
parameters to navigate results. Maximum page size is 100 items.
|
||||
"""
|
||||
|
||||
v2_app = FastAPI(
|
||||
title="AutoGPT Platform External API",
|
||||
summary="External API for AutoGPT Platform integrations (v2)",
|
||||
description=DESCRIPTION,
|
||||
version="2.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
openapi_tags=[
|
||||
{
|
||||
"name": "graphs",
|
||||
"description": "Create, update, and manage agent graphs",
|
||||
},
|
||||
{
|
||||
"name": "schedules",
|
||||
"description": "Manage scheduled graph executions",
|
||||
},
|
||||
{
|
||||
"name": "blocks",
|
||||
"description": "Discover available building blocks",
|
||||
},
|
||||
{
|
||||
"name": "marketplace",
|
||||
"description": "Browse agents and creators, manage submissions",
|
||||
},
|
||||
{
|
||||
"name": "library",
|
||||
"description": "Access your agent library and execute agents",
|
||||
},
|
||||
{
|
||||
"name": "runs",
|
||||
"description": "Monitor execution runs and human-in-the-loop reviews",
|
||||
},
|
||||
{
|
||||
"name": "credits",
|
||||
"description": "Check balance and view transaction history",
|
||||
},
|
||||
{
|
||||
"name": "integrations",
|
||||
"description": "Manage OAuth credentials for external services",
|
||||
},
|
||||
{
|
||||
"name": "files",
|
||||
"description": "Upload files for agent input",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
v2_app.add_middleware(SecurityHeadersMiddleware)
|
||||
v2_app.include_router(v2_router)
|
||||
@@ -1,140 +0,0 @@
|
||||
"""
|
||||
V2 External API - Blocks Endpoints
|
||||
|
||||
Provides read-only access to available building blocks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Response, Security
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.block import get_blocks
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import dumps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
blocks_router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
"""Cost information for a block."""
|
||||
|
||||
cost_type: str = Field(description="Type of cost (e.g., 'per_call', 'per_token')")
|
||||
cost_filter: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Conditions for this cost"
|
||||
)
|
||||
cost_amount: int = Field(description="Cost amount in credits")
|
||||
|
||||
|
||||
class Block(BaseModel):
|
||||
"""A building block that can be used in graphs."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
categories: list[str] = Field(default_factory=list)
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
costs: list[BlockCost] = Field(default_factory=list)
|
||||
disabled: bool = Field(default=False)
|
||||
|
||||
|
||||
class BlocksListResponse(BaseModel):
|
||||
"""Response for listing blocks."""
|
||||
|
||||
blocks: list[Block]
|
||||
total_count: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Internal Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _compute_blocks_sync() -> str:
|
||||
"""
|
||||
Synchronous function to compute blocks data.
|
||||
This does the heavy lifting: instantiate 226+ blocks, compute costs, serialize.
|
||||
"""
|
||||
from backend.data.credit import get_block_cost
|
||||
|
||||
block_classes = get_blocks()
|
||||
result = []
|
||||
|
||||
for block_class in block_classes.values():
|
||||
block_instance = block_class()
|
||||
if not block_instance.disabled:
|
||||
costs = get_block_cost(block_instance)
|
||||
# Convert BlockCost BaseModel objects to dictionaries
|
||||
costs_dict = [
|
||||
cost.model_dump() if isinstance(cost, BaseModel) else cost
|
||||
for cost in costs
|
||||
]
|
||||
result.append({**block_instance.to_dict(), "costs": costs_dict})
|
||||
|
||||
return dumps(result)
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
async def _get_cached_blocks() -> str:
|
||||
"""
|
||||
Async cached function with thundering herd protection.
|
||||
On cache miss: runs heavy work in thread pool
|
||||
On cache hit: returns cached string immediately
|
||||
"""
|
||||
return await run_in_threadpool(_compute_blocks_sync)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@blocks_router.get(
|
||||
path="",
|
||||
summary="List available blocks",
|
||||
responses={
|
||||
200: {
|
||||
"description": "List of available building blocks",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"items": {"additionalProperties": True, "type": "object"},
|
||||
"type": "array",
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def list_blocks(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_BLOCK)
|
||||
),
|
||||
) -> Response:
|
||||
"""
|
||||
List all available building blocks that can be used in graphs.
|
||||
|
||||
Each block represents a specific capability (e.g., HTTP request, text processing,
|
||||
AI completion, etc.) that can be connected in a graph to create an agent.
|
||||
|
||||
The response includes input/output schemas for each block, as well as
|
||||
cost information for blocks that consume credits.
|
||||
"""
|
||||
content = await _get_cached_blocks()
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="application/json",
|
||||
)
|
||||
@@ -1,36 +0,0 @@
|
||||
"""
|
||||
Common utilities for V2 External API
|
||||
"""
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Constants for pagination
|
||||
MAX_PAGE_SIZE = 100
|
||||
DEFAULT_PAGE_SIZE = 20
|
||||
|
||||
|
||||
class PaginationParams(BaseModel):
|
||||
"""Common pagination parameters."""
|
||||
|
||||
page: int = Field(default=1, ge=1, description="Page number (1-indexed)")
|
||||
page_size: int = Field(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
ge=1,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Number of items per page (max {MAX_PAGE_SIZE})",
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel):
|
||||
"""Generic paginated response wrapper."""
|
||||
|
||||
items: list
|
||||
total_count: int = Field(description="Total number of items across all pages")
|
||||
page: int = Field(description="Current page number (1-indexed)")
|
||||
page_size: int = Field(description="Number of items per page")
|
||||
total_pages: int = Field(description="Total number of pages")
|
||||
@@ -1,141 +0,0 @@
|
||||
"""
|
||||
V2 External API - Credits Endpoints
|
||||
|
||||
Provides access to credit balance and transaction history.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Query, Security
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.credit import get_user_credit_model
|
||||
|
||||
from .common import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
credits_router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CreditBalance(BaseModel):
|
||||
"""User's credit balance."""
|
||||
|
||||
balance: int = Field(description="Current credit balance")
|
||||
|
||||
|
||||
class CreditTransaction(BaseModel):
|
||||
"""A credit transaction."""
|
||||
|
||||
transaction_key: str
|
||||
amount: int = Field(description="Transaction amount (positive or negative)")
|
||||
type: str = Field(description="One of: TOP_UP, USAGE, GRANT, REFUND")
|
||||
transaction_time: datetime
|
||||
running_balance: Optional[int] = Field(
|
||||
default=None, description="Balance after this transaction"
|
||||
)
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class CreditTransactionsResponse(BaseModel):
|
||||
"""Response for listing credit transactions."""
|
||||
|
||||
transactions: list[CreditTransaction]
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@credits_router.get(
|
||||
path="",
|
||||
summary="Get credit balance",
|
||||
response_model=CreditBalance,
|
||||
)
|
||||
async def get_balance(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_CREDITS)
|
||||
),
|
||||
) -> CreditBalance:
|
||||
"""
|
||||
Get the current credit balance for the authenticated user.
|
||||
"""
|
||||
user_credit_model = await get_user_credit_model(auth.user_id)
|
||||
balance = await user_credit_model.get_credits(auth.user_id)
|
||||
|
||||
return CreditBalance(balance=balance)
|
||||
|
||||
|
||||
@credits_router.get(
|
||||
path="/transactions",
|
||||
summary="Get transaction history",
|
||||
response_model=CreditTransactionsResponse,
|
||||
)
|
||||
async def get_transactions(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_CREDITS)
|
||||
),
|
||||
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = Query(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
ge=1,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
transaction_type: Optional[str] = Query(
|
||||
default=None,
|
||||
description="Filter by transaction type (TOP_UP, USAGE, GRANT, REFUND)",
|
||||
),
|
||||
) -> CreditTransactionsResponse:
|
||||
"""
|
||||
Get credit transaction history for the authenticated user.
|
||||
|
||||
Returns transactions sorted by most recent first.
|
||||
"""
|
||||
user_credit_model = await get_user_credit_model(auth.user_id)
|
||||
|
||||
history = await user_credit_model.get_transaction_history(
|
||||
user_id=auth.user_id,
|
||||
transaction_count_limit=page_size,
|
||||
transaction_type=transaction_type,
|
||||
)
|
||||
|
||||
transactions = [
|
||||
CreditTransaction(
|
||||
transaction_key=t.transaction_key,
|
||||
amount=t.amount,
|
||||
type=t.transaction_type.value,
|
||||
transaction_time=t.transaction_time,
|
||||
running_balance=t.running_balance,
|
||||
description=t.description,
|
||||
)
|
||||
for t in history.transactions
|
||||
]
|
||||
|
||||
# Note: The current credit module doesn't support true pagination,
|
||||
# so we're returning what we have
|
||||
total_count = len(transactions)
|
||||
total_pages = 1 # Without true pagination support
|
||||
|
||||
return CreditTransactionsResponse(
|
||||
transactions=transactions,
|
||||
total_count=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
@@ -1,132 +0,0 @@
|
||||
"""
|
||||
V2 External API - Files Endpoints
|
||||
|
||||
Provides file upload functionality for agent inputs.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, Query, Security, UploadFile
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
files_router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class UploadFileResponse(BaseModel):
|
||||
"""Response after uploading a file."""
|
||||
|
||||
file_uri: str = Field(description="URI to reference the uploaded file in agents")
|
||||
file_name: str
|
||||
size: int = Field(description="File size in bytes")
|
||||
content_type: str
|
||||
expires_in_hours: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
|
||||
"""Create standardized file size error response."""
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File size ({size_bytes} bytes) exceeds the maximum allowed size of {max_size_mb}MB",
|
||||
)
|
||||
|
||||
|
||||
@files_router.post(
|
||||
path="/upload",
|
||||
summary="Upload a file",
|
||||
response_model=UploadFileResponse,
|
||||
)
|
||||
async def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.UPLOAD_FILES)
|
||||
),
|
||||
provider: str = Query(
|
||||
default="gcs", description="Storage provider (gcs, s3, azure)"
|
||||
),
|
||||
expiration_hours: int = Query(
|
||||
default=24, ge=1, le=48, description="Hours until file expires (1-48)"
|
||||
),
|
||||
) -> UploadFileResponse:
|
||||
"""
|
||||
Upload a file to cloud storage for use with agents.
|
||||
|
||||
The returned `file_uri` can be used as input to agents that accept file inputs
|
||||
(e.g., FileStoreBlock, AgentFileInputBlock).
|
||||
|
||||
Files are automatically scanned for viruses before storage.
|
||||
"""
|
||||
# Check file size limit
|
||||
max_size_mb = settings.config.upload_file_size_limit_mb
|
||||
max_size_bytes = max_size_mb * 1024 * 1024
|
||||
|
||||
# Try to get file size from headers first
|
||||
if hasattr(file, "size") and file.size is not None and file.size > max_size_bytes:
|
||||
raise _create_file_size_error(file.size, max_size_mb)
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
content_size = len(content)
|
||||
|
||||
# Double-check file size after reading
|
||||
if content_size > max_size_bytes:
|
||||
raise _create_file_size_error(content_size, max_size_mb)
|
||||
|
||||
# Extract file info
|
||||
file_name = file.filename or "uploaded_file"
|
||||
content_type = file.content_type or "application/octet-stream"
|
||||
|
||||
# Virus scan the content
|
||||
await scan_content_safe(content, filename=file_name)
|
||||
|
||||
# Check if cloud storage is configured
|
||||
cloud_storage = await get_cloud_storage_handler()
|
||||
if not cloud_storage.config.gcs_bucket_name:
|
||||
# Fallback to base64 data URI when GCS is not configured
|
||||
base64_content = base64.b64encode(content).decode("utf-8")
|
||||
data_uri = f"data:{content_type};base64,{base64_content}"
|
||||
|
||||
return UploadFileResponse(
|
||||
file_uri=data_uri,
|
||||
file_name=file_name,
|
||||
size=content_size,
|
||||
content_type=content_type,
|
||||
expires_in_hours=expiration_hours,
|
||||
)
|
||||
|
||||
# Store in cloud storage
|
||||
storage_path = await cloud_storage.store_file(
|
||||
content=content,
|
||||
filename=file_name,
|
||||
provider=provider,
|
||||
expiration_hours=expiration_hours,
|
||||
user_id=auth.user_id,
|
||||
)
|
||||
|
||||
return UploadFileResponse(
|
||||
file_uri=storage_path,
|
||||
file_name=file_name,
|
||||
size=content_size,
|
||||
content_type=content_type,
|
||||
expires_in_hours=expiration_hours,
|
||||
)
|
||||
@@ -1,445 +0,0 @@
|
||||
"""
|
||||
V2 External API - Graphs Endpoints
|
||||
|
||||
Provides endpoints for managing agent graphs (CRUD operations).
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Security
|
||||
from prisma.enums import APIKeyPermission
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
on_graph_activate,
|
||||
on_graph_deactivate,
|
||||
)
|
||||
|
||||
from .common import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
from .models import (
|
||||
CreateGraphRequest,
|
||||
DeleteGraphResponse,
|
||||
GraphDetails,
|
||||
GraphLink,
|
||||
GraphMeta,
|
||||
GraphNode,
|
||||
GraphSettings,
|
||||
GraphsListResponse,
|
||||
SetActiveVersionRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
graphs_router = APIRouter()
|
||||
|
||||
|
||||
def _convert_graph_meta(graph: graph_db.GraphMeta) -> GraphMeta:
|
||||
"""Convert internal GraphMeta to v2 API model."""
|
||||
return GraphMeta(
|
||||
id=graph.id,
|
||||
version=graph.version,
|
||||
is_active=graph.is_active,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
created_at=graph.created_at,
|
||||
input_schema=graph.input_schema,
|
||||
output_schema=graph.output_schema,
|
||||
)
|
||||
|
||||
|
||||
def _convert_graph_details(graph: graph_db.GraphModel) -> GraphDetails:
|
||||
"""Convert internal GraphModel to v2 API GraphDetails model."""
|
||||
return GraphDetails(
|
||||
id=graph.id,
|
||||
version=graph.version,
|
||||
is_active=graph.is_active,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
created_at=graph.created_at,
|
||||
input_schema=graph.input_schema,
|
||||
output_schema=graph.output_schema,
|
||||
nodes=[
|
||||
GraphNode(
|
||||
id=node.id,
|
||||
block_id=node.block_id,
|
||||
input_default=node.input_default,
|
||||
metadata=node.metadata,
|
||||
)
|
||||
for node in graph.nodes
|
||||
],
|
||||
links=[
|
||||
GraphLink(
|
||||
id=link.id,
|
||||
source_id=link.source_id,
|
||||
sink_id=link.sink_id,
|
||||
source_name=link.source_name,
|
||||
sink_name=link.sink_name,
|
||||
is_static=link.is_static,
|
||||
)
|
||||
for link in graph.links
|
||||
],
|
||||
credentials_input_schema=graph.credentials_input_schema,
|
||||
)
|
||||
|
||||
|
||||
@graphs_router.get(
|
||||
path="",
|
||||
summary="List user's graphs",
|
||||
response_model=GraphsListResponse,
|
||||
)
|
||||
async def list_graphs(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_GRAPH)
|
||||
),
|
||||
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = Query(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
ge=1,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
) -> GraphsListResponse:
|
||||
"""
|
||||
List all graphs owned by the authenticated user.
|
||||
|
||||
Returns a paginated list of graph metadata (not full graph details).
|
||||
"""
|
||||
graphs, pagination_info = await graph_db.list_graphs_paginated(
|
||||
user_id=auth.user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
filter_by="active",
|
||||
)
|
||||
return GraphsListResponse(
|
||||
graphs=[_convert_graph_meta(g) for g in graphs],
|
||||
total_count=pagination_info.total_items,
|
||||
page=pagination_info.current_page,
|
||||
page_size=pagination_info.page_size,
|
||||
total_pages=pagination_info.total_pages,
|
||||
)
|
||||
|
||||
|
||||
@graphs_router.post(
|
||||
path="",
|
||||
summary="Create a new graph",
|
||||
response_model=GraphDetails,
|
||||
)
|
||||
async def create_graph(
|
||||
create_graph_request: CreateGraphRequest,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.WRITE_GRAPH)
|
||||
),
|
||||
) -> GraphDetails:
|
||||
"""
|
||||
Create a new agent graph.
|
||||
|
||||
The graph will be validated and assigned a new ID. It will automatically
|
||||
be added to the user's library.
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from backend.api.features.library import db as library_db
|
||||
|
||||
# Convert v2 API Graph model to internal Graph model
|
||||
internal_graph = graph_db.Graph(
|
||||
id=create_graph_request.graph.id or "",
|
||||
version=create_graph_request.graph.version,
|
||||
is_active=create_graph_request.graph.is_active,
|
||||
name=create_graph_request.graph.name,
|
||||
description=create_graph_request.graph.description,
|
||||
nodes=[
|
||||
graph_db.Node(
|
||||
id=node.id,
|
||||
block_id=node.block_id,
|
||||
input_default=node.input_default,
|
||||
metadata=node.metadata,
|
||||
)
|
||||
for node in create_graph_request.graph.nodes
|
||||
],
|
||||
links=[
|
||||
graph_db.Link(
|
||||
id=link.id,
|
||||
source_id=link.source_id,
|
||||
sink_id=link.sink_id,
|
||||
source_name=link.source_name,
|
||||
sink_name=link.sink_name,
|
||||
is_static=link.is_static,
|
||||
)
|
||||
for link in create_graph_request.graph.links
|
||||
],
|
||||
)
|
||||
|
||||
graph = graph_db.make_graph_model(internal_graph, auth.user_id)
|
||||
graph.reassign_ids(user_id=auth.user_id, reassign_graph_id=True)
|
||||
graph.validate_graph(for_run=False)
|
||||
|
||||
await graph_db.create_graph(graph, user_id=auth.user_id)
|
||||
await library_db.create_library_agent(graph, user_id=auth.user_id)
|
||||
activated_graph = await on_graph_activate(graph, user_id=auth.user_id)
|
||||
|
||||
return _convert_graph_details(activated_graph)
|
||||
|
||||
|
||||
@graphs_router.get(
|
||||
path="/{graph_id}",
|
||||
summary="Get graph details",
|
||||
response_model=GraphDetails,
|
||||
)
|
||||
async def get_graph(
|
||||
graph_id: str,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_GRAPH)
|
||||
),
|
||||
version: int | None = Query(
|
||||
default=None,
|
||||
description="Specific version to retrieve (default: active version)",
|
||||
),
|
||||
) -> GraphDetails:
|
||||
"""
|
||||
Get detailed information about a specific graph.
|
||||
|
||||
By default returns the active version. Use the `version` query parameter
|
||||
to retrieve a specific version.
|
||||
"""
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id,
|
||||
version,
|
||||
user_id=auth.user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return _convert_graph_details(graph)
|
||||
|
||||
|
||||
@graphs_router.put(
|
||||
path="/{graph_id}",
|
||||
summary="Update graph (creates new version)",
|
||||
response_model=GraphDetails,
|
||||
)
|
||||
async def update_graph(
|
||||
graph_id: str,
|
||||
graph_request: CreateGraphRequest,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.WRITE_GRAPH)
|
||||
),
|
||||
) -> GraphDetails:
|
||||
"""
|
||||
Update a graph by creating a new version.
|
||||
|
||||
This does not modify existing versions - it creates a new version with
|
||||
the provided content. The new version becomes the active version.
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from backend.api.features.library import db as library_db
|
||||
|
||||
graph_data = graph_request.graph
|
||||
if graph_data.id and graph_data.id != graph_id:
|
||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||
|
||||
existing_versions = await graph_db.get_graph_all_versions(
|
||||
graph_id, user_id=auth.user_id
|
||||
)
|
||||
if not existing_versions:
|
||||
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||
|
||||
latest_version_number = max(g.version for g in existing_versions)
|
||||
|
||||
# Convert v2 API Graph model to internal Graph model
|
||||
internal_graph = graph_db.Graph(
|
||||
id=graph_id,
|
||||
version=latest_version_number + 1,
|
||||
is_active=graph_data.is_active,
|
||||
name=graph_data.name,
|
||||
description=graph_data.description,
|
||||
nodes=[
|
||||
graph_db.Node(
|
||||
id=node.id,
|
||||
block_id=node.block_id,
|
||||
input_default=node.input_default,
|
||||
metadata=node.metadata,
|
||||
)
|
||||
for node in graph_data.nodes
|
||||
],
|
||||
links=[
|
||||
graph_db.Link(
|
||||
id=link.id,
|
||||
source_id=link.source_id,
|
||||
sink_id=link.sink_id,
|
||||
source_name=link.source_name,
|
||||
sink_name=link.sink_name,
|
||||
is_static=link.is_static,
|
||||
)
|
||||
for link in graph_data.links
|
||||
],
|
||||
)
|
||||
|
||||
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
||||
graph = graph_db.make_graph_model(internal_graph, auth.user_id)
|
||||
graph.reassign_ids(user_id=auth.user_id, reassign_graph_id=False)
|
||||
graph.validate_graph(for_run=False)
|
||||
|
||||
new_graph_version = await graph_db.create_graph(graph, user_id=auth.user_id)
|
||||
|
||||
if new_graph_version.is_active:
|
||||
await library_db.update_agent_version_in_library(
|
||||
auth.user_id, new_graph_version.id, new_graph_version.version
|
||||
)
|
||||
new_graph_version = await on_graph_activate(
|
||||
new_graph_version, user_id=auth.user_id
|
||||
)
|
||||
await graph_db.set_graph_active_version(
|
||||
graph_id=graph_id, version=new_graph_version.version, user_id=auth.user_id
|
||||
)
|
||||
if current_active_version:
|
||||
await on_graph_deactivate(current_active_version, user_id=auth.user_id)
|
||||
|
||||
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
||||
graph_id,
|
||||
new_graph_version.version,
|
||||
user_id=auth.user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
assert new_graph_version_with_subgraphs
|
||||
return _convert_graph_details(new_graph_version_with_subgraphs)
|
||||
|
||||
|
||||
@graphs_router.delete(
|
||||
path="/{graph_id}",
|
||||
summary="Delete graph permanently",
|
||||
response_model=DeleteGraphResponse,
|
||||
)
|
||||
async def delete_graph(
|
||||
graph_id: str,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.WRITE_GRAPH)
|
||||
),
|
||||
) -> DeleteGraphResponse:
|
||||
"""
|
||||
Permanently delete a graph and all its versions.
|
||||
|
||||
This action cannot be undone. All associated executions will remain
|
||||
but will reference a deleted graph.
|
||||
"""
|
||||
if active_version := await graph_db.get_graph(
|
||||
graph_id=graph_id, version=None, user_id=auth.user_id
|
||||
):
|
||||
await on_graph_deactivate(active_version, user_id=auth.user_id)
|
||||
|
||||
version_count = await graph_db.delete_graph(graph_id, user_id=auth.user_id)
|
||||
return DeleteGraphResponse(version_count=version_count)
|
||||
|
||||
|
||||
@graphs_router.get(
|
||||
path="/{graph_id}/versions",
|
||||
summary="List all graph versions",
|
||||
response_model=list[GraphDetails],
|
||||
)
|
||||
async def list_graph_versions(
|
||||
graph_id: str,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_GRAPH)
|
||||
),
|
||||
) -> list[GraphDetails]:
|
||||
"""
|
||||
Get all versions of a specific graph.
|
||||
|
||||
Returns a list of all versions, with the active version marked.
|
||||
"""
|
||||
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=auth.user_id)
|
||||
if not graphs:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return [_convert_graph_details(g) for g in graphs]
|
||||
|
||||
|
||||
@graphs_router.put(
|
||||
path="/{graph_id}/versions/active",
|
||||
summary="Set active graph version",
|
||||
)
|
||||
async def set_active_version(
|
||||
graph_id: str,
|
||||
request_body: SetActiveVersionRequest,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.WRITE_GRAPH)
|
||||
),
|
||||
) -> None:
|
||||
"""
|
||||
Set which version of a graph is the active version.
|
||||
|
||||
The active version is used when executing the graph without specifying
|
||||
a version number.
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from backend.api.features.library import db as library_db
|
||||
|
||||
new_active_version = request_body.active_graph_version
|
||||
new_active_graph = await graph_db.get_graph(
|
||||
graph_id, new_active_version, user_id=auth.user_id
|
||||
)
|
||||
if not new_active_graph:
|
||||
raise HTTPException(404, f"Graph #{graph_id} v{new_active_version} not found")
|
||||
|
||||
current_active_graph = await graph_db.get_graph(
|
||||
graph_id=graph_id,
|
||||
version=None,
|
||||
user_id=auth.user_id,
|
||||
)
|
||||
|
||||
await on_graph_activate(new_active_graph, user_id=auth.user_id)
|
||||
await graph_db.set_graph_active_version(
|
||||
graph_id=graph_id,
|
||||
version=new_active_version,
|
||||
user_id=auth.user_id,
|
||||
)
|
||||
|
||||
await library_db.update_agent_version_in_library(
|
||||
auth.user_id, new_active_graph.id, new_active_graph.version
|
||||
)
|
||||
|
||||
if current_active_graph and current_active_graph.version != new_active_version:
|
||||
await on_graph_deactivate(current_active_graph, user_id=auth.user_id)
|
||||
|
||||
|
||||
@graphs_router.patch(
|
||||
path="/{graph_id}/settings",
|
||||
summary="Update graph settings",
|
||||
response_model=GraphSettings,
|
||||
)
|
||||
async def update_graph_settings(
|
||||
graph_id: str,
|
||||
settings: GraphSettings,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.WRITE_GRAPH)
|
||||
),
|
||||
) -> GraphSettings:
|
||||
"""
|
||||
Update settings for a graph.
|
||||
|
||||
Currently supports:
|
||||
- human_in_the_loop_safe_mode: Enable/disable safe mode for human-in-the-loop blocks
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data.graph import GraphSettings as InternalGraphSettings
|
||||
|
||||
library_agent = await library_db.get_library_agent_by_graph_id(
|
||||
graph_id=graph_id, user_id=auth.user_id
|
||||
)
|
||||
if not library_agent:
|
||||
raise HTTPException(404, f"Graph #{graph_id} not found in user's library")
|
||||
|
||||
# Convert to internal model
|
||||
internal_settings = InternalGraphSettings(
|
||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode
|
||||
)
|
||||
|
||||
updated_agent = await library_db.update_library_agent_settings(
|
||||
user_id=auth.user_id,
|
||||
agent_id=library_agent.id,
|
||||
settings=internal_settings,
|
||||
)
|
||||
|
||||
return GraphSettings(
|
||||
human_in_the_loop_safe_mode=updated_agent.settings.human_in_the_loop_safe_mode
|
||||
)
|
||||
@@ -1,271 +0,0 @@
|
||||
"""
|
||||
V2 External API - Integrations Endpoints
|
||||
|
||||
Provides access to user's integration credentials.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Path, Security
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.model import Credentials, OAuth2Credentials
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
integrations_router = APIRouter()
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class Credential(BaseModel):
|
||||
"""A user's credential for an integration."""
|
||||
|
||||
id: str
|
||||
provider: str = Field(description="Integration provider name")
|
||||
title: Optional[str] = Field(
|
||||
default=None, description="User-assigned title for this credential"
|
||||
)
|
||||
scopes: list[str] = Field(default_factory=list, description="Granted scopes")
|
||||
|
||||
|
||||
class CredentialsListResponse(BaseModel):
|
||||
"""Response for listing credentials."""
|
||||
|
||||
credentials: list[Credential]
|
||||
|
||||
|
||||
class CredentialRequirement(BaseModel):
|
||||
"""A credential requirement for a graph or agent."""
|
||||
|
||||
provider: str = Field(description="Required provider name")
|
||||
required_scopes: list[str] = Field(
|
||||
default_factory=list, description="Required scopes"
|
||||
)
|
||||
matching_credentials: list[Credential] = Field(
|
||||
default_factory=list,
|
||||
description="User's credentials that match this requirement",
|
||||
)
|
||||
|
||||
|
||||
class CredentialRequirementsResponse(BaseModel):
|
||||
"""Response for listing credential requirements."""
|
||||
|
||||
requirements: list[CredentialRequirement]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Conversion Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _convert_credential(cred: Credentials) -> Credential:
|
||||
"""Convert internal credential to v2 API model."""
|
||||
scopes: list[str] = []
|
||||
if isinstance(cred, OAuth2Credentials):
|
||||
scopes = cred.scopes or []
|
||||
|
||||
return Credential(
|
||||
id=cred.id,
|
||||
provider=cred.provider,
|
||||
title=cred.title,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@integrations_router.get(
|
||||
path="/credentials",
|
||||
summary="List all credentials",
|
||||
response_model=CredentialsListResponse,
|
||||
)
|
||||
async def list_credentials(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> CredentialsListResponse:
|
||||
"""
|
||||
List all integration credentials for the authenticated user.
|
||||
|
||||
This returns all OAuth credentials the user has connected, across
|
||||
all integration providers.
|
||||
"""
|
||||
credentials = await creds_manager.store.get_all_creds(auth.user_id)
|
||||
|
||||
return CredentialsListResponse(
|
||||
credentials=[_convert_credential(c) for c in credentials]
|
||||
)
|
||||
|
||||
|
||||
@integrations_router.get(
|
||||
path="/credentials/{provider}",
|
||||
summary="List credentials by provider",
|
||||
response_model=CredentialsListResponse,
|
||||
)
|
||||
async def list_credentials_by_provider(
|
||||
provider: str = Path(description="Provider name (e.g., 'github', 'google')"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> CredentialsListResponse:
|
||||
"""
|
||||
List integration credentials for a specific provider.
|
||||
"""
|
||||
all_credentials = await creds_manager.store.get_all_creds(auth.user_id)
|
||||
|
||||
# Filter by provider
|
||||
filtered = [c for c in all_credentials if c.provider.lower() == provider.lower()]
|
||||
|
||||
return CredentialsListResponse(
|
||||
credentials=[_convert_credential(c) for c in filtered]
|
||||
)
|
||||
|
||||
|
||||
@integrations_router.get(
|
||||
path="/graphs/{graph_id}/credentials",
|
||||
summary="List credentials matching graph requirements",
|
||||
response_model=CredentialRequirementsResponse,
|
||||
)
|
||||
async def list_graph_credential_requirements(
|
||||
graph_id: str = Path(description="Graph ID"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> CredentialRequirementsResponse:
|
||||
"""
|
||||
List credential requirements for a graph and matching user credentials.
|
||||
|
||||
This helps identify which credentials the user needs to provide
|
||||
when executing a graph.
|
||||
"""
|
||||
# Get the graph
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=graph_id,
|
||||
version=None, # Active version
|
||||
user_id=auth.user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found")
|
||||
|
||||
# Get the credentials input schema which contains provider requirements
|
||||
creds_schema = graph.credentials_input_schema
|
||||
all_credentials = await creds_manager.store.get_all_creds(auth.user_id)
|
||||
|
||||
requirements = []
|
||||
for field_name, field_schema in creds_schema.get("properties", {}).items():
|
||||
# Extract provider from schema
|
||||
# The schema structure varies, but typically has provider info
|
||||
providers = []
|
||||
if "anyOf" in field_schema:
|
||||
for option in field_schema["anyOf"]:
|
||||
if "provider" in option:
|
||||
providers.append(option["provider"])
|
||||
elif "provider" in field_schema:
|
||||
providers.append(field_schema["provider"])
|
||||
|
||||
for provider in providers:
|
||||
# Find matching credentials
|
||||
matching = [
|
||||
_convert_credential(c)
|
||||
for c in all_credentials
|
||||
if c.provider.lower() == provider.lower()
|
||||
]
|
||||
|
||||
requirements.append(
|
||||
CredentialRequirement(
|
||||
provider=provider,
|
||||
required_scopes=[], # Would need to extract from schema
|
||||
matching_credentials=matching,
|
||||
)
|
||||
)
|
||||
|
||||
return CredentialRequirementsResponse(requirements=requirements)
|
||||
|
||||
|
||||
@integrations_router.get(
|
||||
path="/library/{agent_id}/credentials",
|
||||
summary="List credentials matching library agent requirements",
|
||||
response_model=CredentialRequirementsResponse,
|
||||
)
|
||||
async def list_library_agent_credential_requirements(
|
||||
agent_id: str = Path(description="Library agent ID"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||
),
|
||||
) -> CredentialRequirementsResponse:
|
||||
"""
|
||||
List credential requirements for a library agent and matching user credentials.
|
||||
|
||||
This helps identify which credentials the user needs to provide
|
||||
when executing an agent from their library.
|
||||
"""
|
||||
# Get the library agent
|
||||
try:
|
||||
library_agent = await library_db.get_library_agent(
|
||||
id=agent_id,
|
||||
user_id=auth.user_id,
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail=f"Agent #{agent_id} not found")
|
||||
|
||||
# Get the underlying graph
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=library_agent.graph_id,
|
||||
version=library_agent.graph_version,
|
||||
user_id=auth.user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Graph for agent #{agent_id} not found",
|
||||
)
|
||||
|
||||
# Get the credentials input schema
|
||||
creds_schema = graph.credentials_input_schema
|
||||
all_credentials = await creds_manager.store.get_all_creds(auth.user_id)
|
||||
|
||||
requirements = []
|
||||
for field_name, field_schema in creds_schema.get("properties", {}).items():
|
||||
# Extract provider from schema
|
||||
providers = []
|
||||
if "anyOf" in field_schema:
|
||||
for option in field_schema["anyOf"]:
|
||||
if "provider" in option:
|
||||
providers.append(option["provider"])
|
||||
elif "provider" in field_schema:
|
||||
providers.append(field_schema["provider"])
|
||||
|
||||
for provider in providers:
|
||||
# Find matching credentials
|
||||
matching = [
|
||||
_convert_credential(c)
|
||||
for c in all_credentials
|
||||
if c.provider.lower() == provider.lower()
|
||||
]
|
||||
|
||||
requirements.append(
|
||||
CredentialRequirement(
|
||||
provider=provider,
|
||||
required_scopes=[],
|
||||
matching_credentials=matching,
|
||||
)
|
||||
)
|
||||
|
||||
return CredentialRequirementsResponse(requirements=requirements)
|
||||
@@ -1,247 +0,0 @@
|
||||
"""
|
||||
V2 External API - Library Endpoints
|
||||
|
||||
Provides access to the user's agent library and agent execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Path, Query, Security
|
||||
from prisma.enums import APIKeyPermission
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
from .common import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
from .models import (
|
||||
ExecuteAgentRequest,
|
||||
LibraryAgent,
|
||||
LibraryAgentsResponse,
|
||||
Run,
|
||||
RunsListResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
library_router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Conversion Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _convert_library_agent(agent: library_model.LibraryAgent) -> LibraryAgent:
|
||||
"""Convert internal LibraryAgent to v2 API model."""
|
||||
return LibraryAgent(
|
||||
id=agent.id,
|
||||
graph_id=agent.graph_id,
|
||||
graph_version=agent.graph_version,
|
||||
name=agent.name,
|
||||
description=agent.description,
|
||||
is_favorite=agent.is_favorite,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
is_latest_version=agent.is_latest_version,
|
||||
image_url=agent.image_url,
|
||||
creator_name=agent.creator_name,
|
||||
input_schema=agent.input_schema,
|
||||
output_schema=agent.output_schema,
|
||||
created_at=agent.created_at,
|
||||
updated_at=agent.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _convert_execution_to_run(exec: execution_db.GraphExecutionMeta) -> Run:
|
||||
"""Convert internal execution to v2 API Run model."""
|
||||
return Run(
|
||||
id=exec.id,
|
||||
graph_id=exec.graph_id,
|
||||
graph_version=exec.graph_version,
|
||||
status=exec.status.value,
|
||||
started_at=exec.started_at,
|
||||
ended_at=exec.ended_at,
|
||||
inputs=exec.inputs,
|
||||
cost=exec.stats.cost if exec.stats else 0,
|
||||
duration=exec.stats.duration if exec.stats else 0,
|
||||
node_count=exec.stats.node_exec_count if exec.stats else 0,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@library_router.get(
|
||||
path="/agents",
|
||||
summary="List library agents",
|
||||
response_model=LibraryAgentsResponse,
|
||||
)
|
||||
async def list_library_agents(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_LIBRARY)
|
||||
),
|
||||
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = Query(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
ge=1,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
) -> LibraryAgentsResponse:
|
||||
"""
|
||||
List agents in the user's library.
|
||||
|
||||
The library contains agents the user has created or added from the marketplace.
|
||||
"""
|
||||
result = await library_db.list_library_agents(
|
||||
user_id=auth.user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return LibraryAgentsResponse(
|
||||
agents=[_convert_library_agent(a) for a in result.agents],
|
||||
total_count=result.pagination.total_items,
|
||||
page=result.pagination.current_page,
|
||||
page_size=result.pagination.page_size,
|
||||
total_pages=result.pagination.total_pages,
|
||||
)
|
||||
|
||||
|
||||
@library_router.get(
|
||||
path="/agents/favorites",
|
||||
summary="List favorite agents",
|
||||
response_model=LibraryAgentsResponse,
|
||||
)
|
||||
async def list_favorite_agents(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_LIBRARY)
|
||||
),
|
||||
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = Query(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
ge=1,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
) -> LibraryAgentsResponse:
|
||||
"""
|
||||
List favorite agents in the user's library.
|
||||
"""
|
||||
result = await library_db.list_favorite_library_agents(
|
||||
user_id=auth.user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return LibraryAgentsResponse(
|
||||
agents=[_convert_library_agent(a) for a in result.agents],
|
||||
total_count=result.pagination.total_items,
|
||||
page=result.pagination.current_page,
|
||||
page_size=result.pagination.page_size,
|
||||
total_pages=result.pagination.total_pages,
|
||||
)
|
||||
|
||||
|
||||
@library_router.post(
|
||||
path="/agents/{agent_id}/runs",
|
||||
summary="Execute an agent",
|
||||
response_model=Run,
|
||||
)
|
||||
async def execute_agent(
|
||||
request: ExecuteAgentRequest,
|
||||
agent_id: str = Path(description="Library agent ID"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.RUN_AGENT)
|
||||
),
|
||||
) -> Run:
|
||||
"""
|
||||
Execute an agent from the library.
|
||||
|
||||
This creates a new run with the provided inputs. The run executes
|
||||
asynchronously and you can poll the run status using GET /runs/{run_id}.
|
||||
"""
|
||||
# Check credit balance
|
||||
user_credit_model = await get_user_credit_model(auth.user_id)
|
||||
current_balance = await user_credit_model.get_credits(auth.user_id)
|
||||
if current_balance <= 0:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail="Insufficient balance to execute the agent. Please top up your account.",
|
||||
)
|
||||
|
||||
# Get the library agent to find the graph ID and version
|
||||
try:
|
||||
library_agent = await library_db.get_library_agent(
|
||||
id=agent_id,
|
||||
user_id=auth.user_id,
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail=f"Agent #{agent_id} not found")
|
||||
|
||||
try:
|
||||
result = await execution_utils.add_graph_execution(
|
||||
graph_id=library_agent.graph_id,
|
||||
user_id=auth.user_id,
|
||||
inputs=request.inputs,
|
||||
graph_version=library_agent.graph_version,
|
||||
graph_credentials_inputs=request.credentials_inputs,
|
||||
)
|
||||
|
||||
return _convert_execution_to_run(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute agent: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@library_router.get(
|
||||
path="/agents/{agent_id}/runs",
|
||||
summary="List runs for an agent",
|
||||
response_model=RunsListResponse,
|
||||
)
|
||||
async def list_agent_runs(
|
||||
agent_id: str = Path(description="Library agent ID"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_LIBRARY)
|
||||
),
|
||||
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = Query(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
ge=1,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
) -> RunsListResponse:
|
||||
"""
|
||||
List execution runs for a specific agent.
|
||||
"""
|
||||
# Get the library agent to find the graph ID
|
||||
try:
|
||||
library_agent = await library_db.get_library_agent(
|
||||
id=agent_id,
|
||||
user_id=auth.user_id,
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail=f"Agent #{agent_id} not found")
|
||||
|
||||
result = await execution_db.get_graph_executions_paginated(
|
||||
graph_id=library_agent.graph_id,
|
||||
user_id=auth.user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return RunsListResponse(
|
||||
runs=[_convert_execution_to_run(e) for e in result.executions],
|
||||
total_count=result.pagination.total_items,
|
||||
page=result.pagination.current_page,
|
||||
page_size=result.pagination.page_size,
|
||||
total_pages=result.pagination.total_pages,
|
||||
)
|
||||
@@ -1,510 +0,0 @@
|
||||
"""
|
||||
V2 External API - Marketplace Endpoints
|
||||
|
||||
Provides access to the agent marketplace (store).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import urllib.parse
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Path, Query, Security
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.store import cache as store_cache
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.api.features.store import model as store_model
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
|
||||
from .common import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
marketplace_router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class MarketplaceAgent(BaseModel):
|
||||
"""An agent available in the marketplace."""
|
||||
|
||||
slug: str
|
||||
name: str
|
||||
description: str
|
||||
sub_heading: str
|
||||
creator: str
|
||||
creator_avatar: str
|
||||
runs: int = Field(default=0, description="Number of times this agent has been run")
|
||||
rating: float = Field(default=0.0, description="Average rating")
|
||||
image_url: str = Field(default="")
|
||||
|
||||
|
||||
class MarketplaceAgentDetails(BaseModel):
|
||||
"""Detailed information about a marketplace agent."""
|
||||
|
||||
store_listing_version_id: str
|
||||
slug: str
|
||||
name: str
|
||||
description: str
|
||||
sub_heading: str
|
||||
instructions: Optional[str] = None
|
||||
creator: str
|
||||
creator_avatar: str
|
||||
categories: list[str] = Field(default_factory=list)
|
||||
runs: int = Field(default=0)
|
||||
rating: float = Field(default=0.0)
|
||||
image_urls: list[str] = Field(default_factory=list)
|
||||
video_url: str = Field(default="")
|
||||
versions: list[str] = Field(default_factory=list, description="Available versions")
|
||||
agent_graph_versions: list[str] = Field(default_factory=list)
|
||||
agent_graph_id: str
|
||||
last_updated: datetime
|
||||
|
||||
|
||||
class MarketplaceAgentsResponse(BaseModel):
|
||||
"""Response for listing marketplace agents."""
|
||||
|
||||
agents: list[MarketplaceAgent]
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
|
||||
class MarketplaceCreator(BaseModel):
|
||||
"""A creator on the marketplace."""
|
||||
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
avatar_url: str
|
||||
num_agents: int
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
is_featured: bool = False
|
||||
|
||||
|
||||
class MarketplaceCreatorDetails(BaseModel):
|
||||
"""Detailed information about a marketplace creator."""
|
||||
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
avatar_url: str
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
top_categories: list[str] = Field(default_factory=list)
|
||||
links: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class MarketplaceCreatorsResponse(BaseModel):
|
||||
"""Response for listing marketplace creators."""
|
||||
|
||||
creators: list[MarketplaceCreator]
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
|
||||
class MarketplaceSubmission(BaseModel):
|
||||
"""A marketplace submission."""
|
||||
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
name: str
|
||||
sub_heading: str
|
||||
slug: str
|
||||
description: str
|
||||
instructions: Optional[str] = None
|
||||
image_urls: list[str] = Field(default_factory=list)
|
||||
date_submitted: datetime
|
||||
status: str = Field(description="One of: DRAFT, PENDING, APPROVED, REJECTED")
|
||||
runs: int = Field(default=0)
|
||||
rating: float = Field(default=0.0)
|
||||
store_listing_version_id: Optional[str] = None
|
||||
version: Optional[int] = None
|
||||
review_comments: Optional[str] = None
|
||||
reviewed_at: Optional[datetime] = None
|
||||
video_url: Optional[str] = None
|
||||
categories: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SubmissionsListResponse(BaseModel):
|
||||
"""Response for listing submissions."""
|
||||
|
||||
submissions: list[MarketplaceSubmission]
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
|
||||
class CreateSubmissionRequest(BaseModel):
|
||||
"""Request to create a marketplace submission."""
|
||||
|
||||
graph_id: str = Field(description="ID of the graph to submit")
|
||||
graph_version: int = Field(description="Version of the graph to submit")
|
||||
name: str = Field(description="Display name for the agent")
|
||||
slug: str = Field(description="URL-friendly identifier")
|
||||
description: str = Field(description="Full description")
|
||||
sub_heading: str = Field(description="Short tagline")
|
||||
image_urls: list[str] = Field(default_factory=list)
|
||||
video_url: Optional[str] = None
|
||||
categories: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Conversion Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _convert_store_agent(agent: store_model.StoreAgent) -> MarketplaceAgent:
|
||||
"""Convert internal StoreAgent to v2 API model."""
|
||||
return MarketplaceAgent(
|
||||
slug=agent.slug,
|
||||
name=agent.agent_name,
|
||||
description=agent.description,
|
||||
sub_heading=agent.sub_heading,
|
||||
creator=agent.creator,
|
||||
creator_avatar=agent.creator_avatar,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
image_url=agent.agent_image,
|
||||
)
|
||||
|
||||
|
||||
def _convert_store_agent_details(
|
||||
agent: store_model.StoreAgentDetails,
|
||||
) -> MarketplaceAgentDetails:
|
||||
"""Convert internal StoreAgentDetails to v2 API model."""
|
||||
return MarketplaceAgentDetails(
|
||||
store_listing_version_id=agent.store_listing_version_id,
|
||||
slug=agent.slug,
|
||||
name=agent.agent_name,
|
||||
description=agent.description,
|
||||
sub_heading=agent.sub_heading,
|
||||
instructions=agent.instructions,
|
||||
creator=agent.creator,
|
||||
creator_avatar=agent.creator_avatar,
|
||||
categories=agent.categories,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
image_urls=agent.agent_image,
|
||||
video_url=agent.agent_video,
|
||||
versions=agent.versions,
|
||||
agent_graph_versions=agent.agentGraphVersions,
|
||||
agent_graph_id=agent.agentGraphId,
|
||||
last_updated=agent.last_updated,
|
||||
)
|
||||
|
||||
|
||||
def _convert_creator(creator: store_model.Creator) -> MarketplaceCreator:
|
||||
"""Convert internal Creator to v2 API model."""
|
||||
return MarketplaceCreator(
|
||||
name=creator.name,
|
||||
username=creator.username,
|
||||
description=creator.description,
|
||||
avatar_url=creator.avatar_url,
|
||||
num_agents=creator.num_agents,
|
||||
agent_rating=creator.agent_rating,
|
||||
agent_runs=creator.agent_runs,
|
||||
is_featured=creator.is_featured,
|
||||
)
|
||||
|
||||
|
||||
def _convert_creator_details(
|
||||
creator: store_model.CreatorDetails,
|
||||
) -> MarketplaceCreatorDetails:
|
||||
"""Convert internal CreatorDetails to v2 API model."""
|
||||
return MarketplaceCreatorDetails(
|
||||
name=creator.name,
|
||||
username=creator.username,
|
||||
description=creator.description,
|
||||
avatar_url=creator.avatar_url,
|
||||
agent_rating=creator.agent_rating,
|
||||
agent_runs=creator.agent_runs,
|
||||
top_categories=creator.top_categories,
|
||||
links=creator.links,
|
||||
)
|
||||
|
||||
|
||||
def _convert_submission(sub: store_model.StoreSubmission) -> MarketplaceSubmission:
|
||||
"""Convert internal StoreSubmission to v2 API model."""
|
||||
return MarketplaceSubmission(
|
||||
graph_id=sub.agent_id,
|
||||
graph_version=sub.agent_version,
|
||||
name=sub.name,
|
||||
sub_heading=sub.sub_heading,
|
||||
slug=sub.slug,
|
||||
description=sub.description,
|
||||
instructions=sub.instructions,
|
||||
image_urls=sub.image_urls,
|
||||
date_submitted=sub.date_submitted,
|
||||
status=sub.status.value,
|
||||
runs=sub.runs,
|
||||
rating=sub.rating,
|
||||
store_listing_version_id=sub.store_listing_version_id,
|
||||
version=sub.version,
|
||||
review_comments=sub.review_comments,
|
||||
reviewed_at=sub.reviewed_at,
|
||||
video_url=sub.video_url,
|
||||
categories=sub.categories,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints - Read (authenticated)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@marketplace_router.get(
|
||||
path="/agents",
|
||||
summary="List marketplace agents",
|
||||
response_model=MarketplaceAgentsResponse,
|
||||
)
|
||||
async def list_agents(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_STORE)
|
||||
),
|
||||
featured: bool = Query(default=False, description="Filter to featured agents only"),
|
||||
creator: Optional[str] = Query(
|
||||
default=None, description="Filter by creator username"
|
||||
),
|
||||
sorted_by: Optional[Literal["rating", "runs", "name", "updated_at"]] = Query(
|
||||
default=None, description="Sort field"
|
||||
),
|
||||
search_query: Optional[str] = Query(default=None, description="Search query"),
|
||||
category: Optional[str] = Query(default=None, description="Filter by category"),
|
||||
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = Query(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
ge=1,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
) -> MarketplaceAgentsResponse:
|
||||
"""
|
||||
List agents available in the marketplace.
|
||||
|
||||
Supports filtering by featured status, creator, category, and search query.
|
||||
Results can be sorted by rating, runs, name, or update time.
|
||||
"""
|
||||
result = await store_cache._get_cached_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return MarketplaceAgentsResponse(
|
||||
agents=[_convert_store_agent(a) for a in result.agents],
|
||||
total_count=result.pagination.total_items,
|
||||
page=result.pagination.current_page,
|
||||
page_size=result.pagination.page_size,
|
||||
total_pages=result.pagination.total_pages,
|
||||
)
|
||||
|
||||
|
||||
@marketplace_router.get(
|
||||
path="/agents/{username}/{agent_name}",
|
||||
summary="Get agent details",
|
||||
response_model=MarketplaceAgentDetails,
|
||||
)
|
||||
async def get_agent_details(
|
||||
username: str = Path(description="Creator username"),
|
||||
agent_name: str = Path(description="Agent slug/name"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_STORE)
|
||||
),
|
||||
) -> MarketplaceAgentDetails:
|
||||
"""
|
||||
Get detailed information about a specific marketplace agent.
|
||||
"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
|
||||
agent = await store_cache._get_cached_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
|
||||
return _convert_store_agent_details(agent)
|
||||
|
||||
|
||||
@marketplace_router.get(
|
||||
path="/creators",
|
||||
summary="List marketplace creators",
|
||||
response_model=MarketplaceCreatorsResponse,
|
||||
)
|
||||
async def list_creators(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_STORE)
|
||||
),
|
||||
featured: bool = Query(
|
||||
default=False, description="Filter to featured creators only"
|
||||
),
|
||||
search_query: Optional[str] = Query(default=None, description="Search query"),
|
||||
sorted_by: Optional[Literal["agent_rating", "agent_runs", "num_agents"]] = Query(
|
||||
default=None, description="Sort field"
|
||||
),
|
||||
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = Query(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
ge=1,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
) -> MarketplaceCreatorsResponse:
|
||||
"""
|
||||
List creators on the marketplace.
|
||||
|
||||
Supports filtering by featured status and search query.
|
||||
Results can be sorted by rating, runs, or number of agents.
|
||||
"""
|
||||
result = await store_cache._get_cached_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return MarketplaceCreatorsResponse(
|
||||
creators=[_convert_creator(c) for c in result.creators],
|
||||
total_count=result.pagination.total_items,
|
||||
page=result.pagination.current_page,
|
||||
page_size=result.pagination.page_size,
|
||||
total_pages=result.pagination.total_pages,
|
||||
)
|
||||
|
||||
|
||||
@marketplace_router.get(
|
||||
path="/creators/{username}",
|
||||
summary="Get creator details",
|
||||
response_model=MarketplaceCreatorDetails,
|
||||
)
|
||||
async def get_creator_details(
|
||||
username: str = Path(description="Creator username"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_STORE)
|
||||
),
|
||||
) -> MarketplaceCreatorDetails:
|
||||
"""
|
||||
Get detailed information about a specific marketplace creator.
|
||||
"""
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
|
||||
creator = await store_cache._get_cached_creator_details(username=username)
|
||||
|
||||
return _convert_creator_details(creator)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints - Submissions (CRUD)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@marketplace_router.get(
|
||||
path="/submissions",
|
||||
summary="List my submissions",
|
||||
response_model=SubmissionsListResponse,
|
||||
)
|
||||
async def list_submissions(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_STORE)
|
||||
),
|
||||
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = Query(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
ge=1,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
) -> SubmissionsListResponse:
|
||||
"""
|
||||
List your marketplace submissions.
|
||||
|
||||
Returns all submissions you've created, including drafts, pending,
|
||||
approved, and rejected submissions.
|
||||
"""
|
||||
result = await store_db.get_store_submissions(
|
||||
user_id=auth.user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return SubmissionsListResponse(
|
||||
submissions=[_convert_submission(s) for s in result.submissions],
|
||||
total_count=result.pagination.total_items,
|
||||
page=result.pagination.current_page,
|
||||
page_size=result.pagination.page_size,
|
||||
total_pages=result.pagination.total_pages,
|
||||
)
|
||||
|
||||
|
||||
@marketplace_router.post(
|
||||
path="/submissions",
|
||||
summary="Create a submission",
|
||||
response_model=MarketplaceSubmission,
|
||||
)
|
||||
async def create_submission(
|
||||
request: CreateSubmissionRequest,
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.WRITE_STORE)
|
||||
),
|
||||
) -> MarketplaceSubmission:
|
||||
"""
|
||||
Create a new marketplace submission.
|
||||
|
||||
This submits an agent for review to be published in the marketplace.
|
||||
The submission will be in PENDING status until reviewed by the team.
|
||||
"""
|
||||
submission = await store_db.create_store_submission(
|
||||
user_id=auth.user_id,
|
||||
agent_id=request.graph_id,
|
||||
agent_version=request.graph_version,
|
||||
slug=request.slug,
|
||||
name=request.name,
|
||||
sub_heading=request.sub_heading,
|
||||
description=request.description,
|
||||
image_urls=request.image_urls,
|
||||
video_url=request.video_url,
|
||||
categories=request.categories,
|
||||
)
|
||||
|
||||
return _convert_submission(submission)
|
||||
|
||||
|
||||
@marketplace_router.delete(
|
||||
path="/submissions/{submission_id}",
|
||||
summary="Delete a submission",
|
||||
)
|
||||
async def delete_submission(
|
||||
submission_id: str = Path(description="Submission ID"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.WRITE_STORE)
|
||||
),
|
||||
) -> None:
|
||||
"""
|
||||
Delete a marketplace submission.
|
||||
|
||||
Only submissions in DRAFT status can be deleted.
|
||||
"""
|
||||
success = await store_db.delete_store_submission(
|
||||
user_id=auth.user_id,
|
||||
submission_id=submission_id,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Submission #{submission_id} not found"
|
||||
)
|
||||
@@ -1,552 +0,0 @@
|
||||
"""
|
||||
V2 External API - Request and Response Models
|
||||
|
||||
This module defines all request and response models for the v2 external API.
|
||||
All models are self-contained and specific to the external API contract.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ============================================================================
|
||||
# Common/Shared Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel):
|
||||
"""Base class for paginated responses."""
|
||||
|
||||
total_count: int = Field(description="Total number of items across all pages")
|
||||
page: int = Field(description="Current page number (1-indexed)")
|
||||
page_size: int = Field(description="Number of items per page")
|
||||
total_pages: int = Field(description="Total number of pages")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Graph Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class GraphLink(BaseModel):
|
||||
"""A link between two nodes in a graph."""
|
||||
|
||||
id: str
|
||||
source_id: str = Field(description="ID of the source node")
|
||||
sink_id: str = Field(description="ID of the target node")
|
||||
source_name: str = Field(description="Output pin name on source node")
|
||||
sink_name: str = Field(description="Input pin name on target node")
|
||||
is_static: bool = Field(
|
||||
default=False, description="Whether this link provides static data"
|
||||
)
|
||||
|
||||
|
||||
class GraphNode(BaseModel):
|
||||
"""A node in an agent graph."""
|
||||
|
||||
id: str
|
||||
block_id: str = Field(description="ID of the block type")
|
||||
input_default: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Default input values"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Node metadata (e.g., position)"
|
||||
)
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
"""Graph definition for creating or updating an agent."""
|
||||
|
||||
id: Optional[str] = Field(default=None, description="Graph ID (assigned by server)")
|
||||
version: int = Field(default=1, description="Graph version")
|
||||
is_active: bool = Field(default=True, description="Whether this version is active")
|
||||
name: str = Field(description="Graph name")
|
||||
description: str = Field(default="", description="Graph description")
|
||||
nodes: list[GraphNode] = Field(default_factory=list, description="List of nodes")
|
||||
links: list[GraphLink] = Field(
|
||||
default_factory=list, description="Links between nodes"
|
||||
)
|
||||
|
||||
|
||||
class GraphMeta(BaseModel):
|
||||
"""Graph metadata (summary information)."""
|
||||
|
||||
id: str
|
||||
version: int
|
||||
is_active: bool
|
||||
name: str
|
||||
description: str
|
||||
created_at: datetime
|
||||
input_schema: dict[str, Any] = Field(description="Input schema for the graph")
|
||||
output_schema: dict[str, Any] = Field(description="Output schema for the graph")
|
||||
|
||||
|
||||
class GraphDetails(GraphMeta):
|
||||
"""Full graph details including nodes and links."""
|
||||
|
||||
nodes: list[GraphNode]
|
||||
links: list[GraphLink]
|
||||
credentials_input_schema: dict[str, Any] = Field(
|
||||
description="Schema for required credentials"
|
||||
)
|
||||
|
||||
|
||||
class GraphSettings(BaseModel):
|
||||
"""Settings for a graph."""
|
||||
|
||||
human_in_the_loop_safe_mode: Optional[bool] = Field(
|
||||
default=None, description="Enable safe mode for human-in-the-loop blocks"
|
||||
)
|
||||
|
||||
|
||||
class CreateGraphRequest(BaseModel):
|
||||
"""Request to create a new graph."""
|
||||
|
||||
graph: Graph = Field(description="The graph definition")
|
||||
|
||||
|
||||
class SetActiveVersionRequest(BaseModel):
|
||||
"""Request to set the active graph version."""
|
||||
|
||||
active_graph_version: int = Field(description="Version number to set as active")
|
||||
|
||||
|
||||
class GraphsListResponse(PaginatedResponse):
|
||||
"""Response for listing graphs."""
|
||||
|
||||
graphs: list[GraphMeta]
|
||||
|
||||
|
||||
class DeleteGraphResponse(BaseModel):
|
||||
"""Response for deleting a graph."""
|
||||
|
||||
version_count: int = Field(description="Number of versions deleted")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Schedule Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class Schedule(BaseModel):
|
||||
"""An execution schedule for a graph."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
cron: str = Field(description="Cron expression for the schedule")
|
||||
input_data: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Input data for scheduled executions"
|
||||
)
|
||||
is_enabled: bool = Field(default=True, description="Whether schedule is enabled")
|
||||
next_run_time: Optional[datetime] = Field(
|
||||
default=None, description="Next scheduled run time"
|
||||
)
|
||||
|
||||
|
||||
class CreateScheduleRequest(BaseModel):
|
||||
"""Request to create a schedule."""
|
||||
|
||||
name: str = Field(description="Display name for the schedule")
|
||||
cron: str = Field(description="Cron expression (e.g., '0 9 * * *' for 9am daily)")
|
||||
input_data: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Input data for scheduled executions"
|
||||
)
|
||||
credentials_inputs: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Credentials for the schedule"
|
||||
)
|
||||
graph_version: Optional[int] = Field(
|
||||
default=None, description="Graph version (default: active version)"
|
||||
)
|
||||
timezone: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Timezone for schedule (e.g., 'America/New_York')",
|
||||
)
|
||||
|
||||
|
||||
class SchedulesListResponse(PaginatedResponse):
|
||||
"""Response for listing schedules."""
|
||||
|
||||
schedules: list[Schedule]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Block Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
"""Cost information for a block."""
|
||||
|
||||
cost_type: str = Field(description="Type of cost (e.g., 'per_call', 'per_token')")
|
||||
cost_filter: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Conditions for this cost"
|
||||
)
|
||||
cost_amount: int = Field(description="Cost amount in credits")
|
||||
|
||||
|
||||
class Block(BaseModel):
|
||||
"""A building block that can be used in graphs."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
categories: list[str] = Field(default_factory=list)
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
costs: list[BlockCost] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BlocksListResponse(BaseModel):
|
||||
"""Response for listing blocks."""
|
||||
|
||||
blocks: list[Block]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Marketplace Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class MarketplaceAgent(BaseModel):
|
||||
"""An agent available in the marketplace."""
|
||||
|
||||
slug: str
|
||||
agent_name: str
|
||||
agent_image: str
|
||||
creator: str
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
runs: int = Field(default=0, description="Number of times this agent has been run")
|
||||
rating: float = Field(default=0.0, description="Average rating")
|
||||
|
||||
|
||||
class MarketplaceAgentDetails(BaseModel):
|
||||
"""Detailed information about a marketplace agent."""
|
||||
|
||||
store_listing_version_id: str
|
||||
slug: str
|
||||
agent_name: str
|
||||
agent_video: str
|
||||
agent_output_demo: str
|
||||
agent_image: list[str]
|
||||
creator: str
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
instructions: Optional[str] = None
|
||||
categories: list[str]
|
||||
runs: int
|
||||
rating: float
|
||||
versions: list[str]
|
||||
agent_graph_versions: list[str]
|
||||
agent_graph_id: str
|
||||
last_updated: datetime
|
||||
recommended_schedule_cron: Optional[str] = None
|
||||
|
||||
|
||||
class MarketplaceCreator(BaseModel):
|
||||
"""A creator on the marketplace."""
|
||||
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
avatar_url: str
|
||||
num_agents: int
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
is_featured: bool = False
|
||||
|
||||
|
||||
class MarketplaceAgentsResponse(PaginatedResponse):
|
||||
"""Response for listing marketplace agents."""
|
||||
|
||||
agents: list[MarketplaceAgent]
|
||||
|
||||
|
||||
class MarketplaceCreatorsResponse(PaginatedResponse):
|
||||
"""Response for listing marketplace creators."""
|
||||
|
||||
creators: list[MarketplaceCreator]
|
||||
|
||||
|
||||
# Submission models
|
||||
class MarketplaceSubmission(BaseModel):
|
||||
"""A marketplace submission."""
|
||||
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
name: str
|
||||
sub_heading: str
|
||||
slug: str
|
||||
description: str
|
||||
instructions: Optional[str] = None
|
||||
image_urls: list[str] = Field(default_factory=list)
|
||||
date_submitted: datetime
|
||||
status: str = Field(description="One of: DRAFT, PENDING, APPROVED, REJECTED")
|
||||
runs: int
|
||||
rating: float
|
||||
store_listing_version_id: Optional[str] = None
|
||||
version: Optional[int] = None
|
||||
|
||||
# Review fields
|
||||
review_comments: Optional[str] = None
|
||||
reviewed_at: Optional[datetime] = None
|
||||
|
||||
# Additional optional fields
|
||||
video_url: Optional[str] = None
|
||||
categories: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CreateSubmissionRequest(BaseModel):
|
||||
"""Request to create a marketplace submission."""
|
||||
|
||||
agent_id: str = Field(description="ID of the graph to submit")
|
||||
agent_version: int = Field(description="Version of the graph to submit")
|
||||
name: str = Field(description="Display name for the agent")
|
||||
slug: str = Field(description="URL-friendly identifier")
|
||||
description: str = Field(description="Full description")
|
||||
sub_heading: str = Field(description="Short tagline")
|
||||
image_urls: list[str] = Field(default_factory=list)
|
||||
video_url: Optional[str] = None
|
||||
categories: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UpdateSubmissionRequest(BaseModel):
|
||||
"""Request to update a marketplace submission."""
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
sub_heading: Optional[str] = None
|
||||
image_urls: Optional[list[str]] = None
|
||||
video_url: Optional[str] = None
|
||||
categories: Optional[list[str]] = None
|
||||
|
||||
|
||||
class SubmissionsListResponse(PaginatedResponse):
|
||||
"""Response for listing submissions."""
|
||||
|
||||
submissions: list[MarketplaceSubmission]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Library Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class LibraryAgent(BaseModel):
|
||||
"""An agent in the user's library."""
|
||||
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
name: str
|
||||
description: str
|
||||
is_favorite: bool = False
|
||||
can_access_graph: bool = False
|
||||
is_latest_version: bool = False
|
||||
image_url: Optional[str] = None
|
||||
creator_name: str
|
||||
input_schema: dict[str, Any] = Field(description="Input schema for the agent")
|
||||
output_schema: dict[str, Any] = Field(description="Output schema for the agent")
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class LibraryAgentsResponse(PaginatedResponse):
|
||||
"""Response for listing library agents."""
|
||||
|
||||
agents: list[LibraryAgent]
|
||||
|
||||
|
||||
class ExecuteAgentRequest(BaseModel):
|
||||
"""Request to execute an agent."""
|
||||
|
||||
inputs: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Input values for the agent"
|
||||
)
|
||||
credentials_inputs: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Credentials for the agent"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Run Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class Run(BaseModel):
|
||||
"""An execution run."""
|
||||
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
status: str = Field(
|
||||
description="One of: INCOMPLETE, QUEUED, RUNNING, COMPLETED, TERMINATED, FAILED, REVIEW"
|
||||
)
|
||||
started_at: datetime
|
||||
ended_at: Optional[datetime] = None
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
cost: int = Field(default=0, description="Cost in credits")
|
||||
duration: float = Field(default=0, description="Duration in seconds")
|
||||
node_count: int = Field(default=0, description="Number of nodes executed")
|
||||
|
||||
|
||||
class RunDetails(Run):
|
||||
"""Detailed information about a run including node executions."""
|
||||
|
||||
outputs: Optional[dict[str, list[Any]]] = None
|
||||
node_executions: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Individual node execution results"
|
||||
)
|
||||
|
||||
|
||||
class RunsListResponse(PaginatedResponse):
|
||||
"""Response for listing runs."""
|
||||
|
||||
runs: list[Run]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Run Review Models (Human-in-the-loop)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class PendingReview(BaseModel):
|
||||
"""A pending human-in-the-loop review."""
|
||||
|
||||
id: str # node_exec_id
|
||||
run_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
payload: Any = Field(description="Data to be reviewed")
|
||||
instructions: Optional[str] = Field(
|
||||
default=None, description="Instructions for the reviewer"
|
||||
)
|
||||
editable: bool = Field(
|
||||
default=True, description="Whether the reviewer can edit the data"
|
||||
)
|
||||
status: str = Field(description="One of: WAITING, APPROVED, REJECTED")
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class PendingReviewsResponse(PaginatedResponse):
|
||||
"""Response for listing pending reviews."""
|
||||
|
||||
reviews: list[PendingReview]
|
||||
|
||||
|
||||
class ReviewDecision(BaseModel):
|
||||
"""Decision for a single review item."""
|
||||
|
||||
node_exec_id: str = Field(description="Node execution ID (review ID)")
|
||||
approved: bool = Field(description="Whether to approve the data")
|
||||
edited_payload: Optional[Any] = Field(
|
||||
default=None, description="Modified payload data (if editing)"
|
||||
)
|
||||
message: Optional[str] = Field(
|
||||
default=None, description="Optional message from reviewer", max_length=2000
|
||||
)
|
||||
|
||||
|
||||
class SubmitReviewsRequest(BaseModel):
|
||||
"""Request to submit review responses for all pending reviews of an execution."""
|
||||
|
||||
reviews: list[ReviewDecision] = Field(
|
||||
description="All review decisions for the execution"
|
||||
)
|
||||
|
||||
|
||||
class SubmitReviewsResponse(BaseModel):
|
||||
"""Response after submitting reviews."""
|
||||
|
||||
run_id: str
|
||||
approved_count: int = Field(description="Number of reviews approved")
|
||||
rejected_count: int = Field(description="Number of reviews rejected")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Credit Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CreditBalance(BaseModel):
|
||||
"""User's credit balance."""
|
||||
|
||||
balance: int = Field(description="Current credit balance")
|
||||
|
||||
|
||||
class CreditTransaction(BaseModel):
|
||||
"""A credit transaction."""
|
||||
|
||||
transaction_key: str
|
||||
amount: int
|
||||
transaction_type: str = Field(description="Transaction type")
|
||||
transaction_time: datetime
|
||||
running_balance: int
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class CreditTransactionsResponse(PaginatedResponse):
|
||||
"""Response for listing credit transactions."""
|
||||
|
||||
transactions: list[CreditTransaction]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class Credential(BaseModel):
|
||||
"""A user's credential for an integration."""
|
||||
|
||||
id: str
|
||||
provider: str = Field(description="Integration provider name")
|
||||
title: Optional[str] = Field(
|
||||
default=None, description="User-assigned title for this credential"
|
||||
)
|
||||
scopes: list[str] = Field(default_factory=list, description="Granted scopes")
|
||||
|
||||
|
||||
class CredentialsListResponse(BaseModel):
|
||||
"""Response for listing credentials."""
|
||||
|
||||
credentials: list[Credential]
|
||||
|
||||
|
||||
class CredentialRequirement(BaseModel):
|
||||
"""A credential requirement for a graph or agent."""
|
||||
|
||||
provider: str = Field(description="Required provider name")
|
||||
required_scopes: list[str] = Field(
|
||||
default_factory=list, description="Required scopes"
|
||||
)
|
||||
matching_credentials: list[Credential] = Field(
|
||||
default_factory=list,
|
||||
description="User's credentials that match this requirement",
|
||||
)
|
||||
|
||||
|
||||
class CredentialRequirementsResponse(BaseModel):
|
||||
"""Response for listing credential requirements."""
|
||||
|
||||
requirements: list[CredentialRequirement]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# File Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class UploadFileResponse(BaseModel):
|
||||
"""Response after uploading a file."""
|
||||
|
||||
file_uri: str = Field(description="URI to reference the uploaded file")
|
||||
file_name: str
|
||||
size: int = Field(description="File size in bytes")
|
||||
content_type: str
|
||||
expires_in_hours: int
|
||||
@@ -1,35 +0,0 @@
|
||||
"""
|
||||
V2 External API Routes
|
||||
|
||||
This module defines the main v2 router that aggregates all v2 API endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .blocks import blocks_router
|
||||
from .credits import credits_router
|
||||
from .files import files_router
|
||||
from .graphs import graphs_router
|
||||
from .integrations import integrations_router
|
||||
from .library import library_router
|
||||
from .marketplace import marketplace_router
|
||||
from .runs import runs_router
|
||||
from .schedules import graph_schedules_router, schedules_router
|
||||
|
||||
v2_router = APIRouter()
|
||||
|
||||
# Include all sub-routers
|
||||
v2_router.include_router(graphs_router, prefix="/graphs", tags=["graphs"])
|
||||
v2_router.include_router(graph_schedules_router, prefix="/graphs", tags=["schedules"])
|
||||
v2_router.include_router(schedules_router, prefix="/schedules", tags=["schedules"])
|
||||
v2_router.include_router(blocks_router, prefix="/blocks", tags=["blocks"])
|
||||
v2_router.include_router(
|
||||
marketplace_router, prefix="/marketplace", tags=["marketplace"]
|
||||
)
|
||||
v2_router.include_router(library_router, prefix="/library", tags=["library"])
|
||||
v2_router.include_router(runs_router, prefix="/runs", tags=["runs"])
|
||||
v2_router.include_router(credits_router, prefix="/credits", tags=["credits"])
|
||||
v2_router.include_router(
|
||||
integrations_router, prefix="/integrations", tags=["integrations"]
|
||||
)
|
||||
v2_router.include_router(files_router, prefix="/files", tags=["files"])
|
||||
@@ -1,451 +0,0 @@
|
||||
"""
|
||||
V2 External API - Runs Endpoints
|
||||
|
||||
Provides access to execution runs and human-in-the-loop reviews.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Path, Query, Security
|
||||
from prisma.enums import APIKeyPermission, ReviewStatus
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.executions.review.model import (
|
||||
PendingHumanReviewModel,
|
||||
SafeJsonData,
|
||||
)
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import human_review as review_db
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
from .common import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
runs_router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class Run(BaseModel):
|
||||
"""An execution run."""
|
||||
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
status: str = Field(
|
||||
description="One of: INCOMPLETE, QUEUED, RUNNING, COMPLETED, TERMINATED, FAILED, REVIEW"
|
||||
)
|
||||
started_at: datetime
|
||||
ended_at: Optional[datetime] = None
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
cost: int = Field(default=0, description="Cost in credits")
|
||||
duration: float = Field(default=0, description="Duration in seconds")
|
||||
node_count: int = Field(default=0, description="Number of nodes executed")
|
||||
|
||||
|
||||
class RunDetails(Run):
|
||||
"""Detailed information about a run including outputs and node executions."""
|
||||
|
||||
outputs: Optional[dict[str, list[Any]]] = None
|
||||
node_executions: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Individual node execution results"
|
||||
)
|
||||
|
||||
|
||||
class RunsListResponse(BaseModel):
|
||||
"""Response for listing runs."""
|
||||
|
||||
runs: list[Run]
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
|
||||
class PendingReview(BaseModel):
|
||||
"""A pending human-in-the-loop review."""
|
||||
|
||||
id: str # node_exec_id
|
||||
run_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
payload: SafeJsonData = Field(description="Data to be reviewed")
|
||||
instructions: Optional[str] = Field(
|
||||
default=None, description="Instructions for the reviewer"
|
||||
)
|
||||
editable: bool = Field(
|
||||
default=True, description="Whether the reviewer can edit the data"
|
||||
)
|
||||
status: str = Field(description="One of: WAITING, APPROVED, REJECTED")
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class PendingReviewsResponse(BaseModel):
|
||||
"""Response for listing pending reviews."""
|
||||
|
||||
reviews: list[PendingReview]
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
|
||||
class ReviewDecision(BaseModel):
|
||||
"""Decision for a single review item."""
|
||||
|
||||
node_exec_id: str = Field(description="Node execution ID (review ID)")
|
||||
approved: bool = Field(description="Whether to approve the data")
|
||||
edited_payload: Optional[SafeJsonData] = Field(
|
||||
default=None, description="Modified payload data (if editing)"
|
||||
)
|
||||
message: Optional[str] = Field(
|
||||
default=None, description="Optional message from reviewer", max_length=2000
|
||||
)
|
||||
|
||||
|
||||
class SubmitReviewsRequest(BaseModel):
|
||||
"""Request to submit review responses for all pending reviews of an execution."""
|
||||
|
||||
reviews: list[ReviewDecision] = Field(
|
||||
description="All review decisions for the execution"
|
||||
)
|
||||
|
||||
|
||||
class SubmitReviewsResponse(BaseModel):
|
||||
"""Response after submitting reviews."""
|
||||
|
||||
run_id: str
|
||||
approved_count: int = Field(description="Number of reviews approved")
|
||||
rejected_count: int = Field(description="Number of reviews rejected")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Conversion Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _convert_execution_to_run(exec: execution_db.GraphExecutionMeta) -> Run:
|
||||
"""Convert internal execution to v2 API Run model."""
|
||||
return Run(
|
||||
id=exec.id,
|
||||
graph_id=exec.graph_id,
|
||||
graph_version=exec.graph_version,
|
||||
status=exec.status.value,
|
||||
started_at=exec.started_at,
|
||||
ended_at=exec.ended_at,
|
||||
inputs=exec.inputs,
|
||||
cost=exec.stats.cost if exec.stats else 0,
|
||||
duration=exec.stats.duration if exec.stats else 0,
|
||||
node_count=exec.stats.node_exec_count if exec.stats else 0,
|
||||
)
|
||||
|
||||
|
||||
def _convert_execution_to_run_details(
|
||||
exec: execution_db.GraphExecutionWithNodes,
|
||||
) -> RunDetails:
|
||||
"""Convert internal execution with nodes to v2 API RunDetails model."""
|
||||
return RunDetails(
|
||||
id=exec.id,
|
||||
graph_id=exec.graph_id,
|
||||
graph_version=exec.graph_version,
|
||||
status=exec.status.value,
|
||||
started_at=exec.started_at,
|
||||
ended_at=exec.ended_at,
|
||||
inputs=exec.inputs,
|
||||
outputs=exec.outputs,
|
||||
cost=exec.stats.cost if exec.stats else 0,
|
||||
duration=exec.stats.duration if exec.stats else 0,
|
||||
node_count=exec.stats.node_exec_count if exec.stats else 0,
|
||||
node_executions=[
|
||||
{
|
||||
"node_id": node.node_id,
|
||||
"status": node.status.value,
|
||||
"input_data": node.input_data,
|
||||
"output_data": node.output_data,
|
||||
"started_at": node.start_time,
|
||||
"ended_at": node.end_time,
|
||||
}
|
||||
for node in exec.node_executions
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _convert_pending_review(review: PendingHumanReviewModel) -> PendingReview:
|
||||
"""Convert internal PendingHumanReviewModel to v2 API PendingReview model."""
|
||||
return PendingReview(
|
||||
id=review.node_exec_id,
|
||||
run_id=review.graph_exec_id,
|
||||
graph_id=review.graph_id,
|
||||
graph_version=review.graph_version,
|
||||
payload=review.payload,
|
||||
instructions=review.instructions,
|
||||
editable=review.editable,
|
||||
status=review.status.value,
|
||||
created_at=review.created_at,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints - Runs
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@runs_router.get(
|
||||
path="",
|
||||
summary="List all runs",
|
||||
response_model=RunsListResponse,
|
||||
)
|
||||
async def list_runs(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_RUN)
|
||||
),
|
||||
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = Query(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
ge=1,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
) -> RunsListResponse:
|
||||
"""
|
||||
List all execution runs for the authenticated user.
|
||||
|
||||
Returns runs across all agents, sorted by most recent first.
|
||||
"""
|
||||
result = await execution_db.get_graph_executions_paginated(
|
||||
user_id=auth.user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return RunsListResponse(
|
||||
runs=[_convert_execution_to_run(e) for e in result.executions],
|
||||
total_count=result.pagination.total_items,
|
||||
page=result.pagination.current_page,
|
||||
page_size=result.pagination.page_size,
|
||||
total_pages=result.pagination.total_pages,
|
||||
)
|
||||
|
||||
|
||||
@runs_router.get(
|
||||
path="/{run_id}",
|
||||
summary="Get run details",
|
||||
response_model=RunDetails,
|
||||
)
|
||||
async def get_run(
|
||||
run_id: str = Path(description="Run ID"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_RUN)
|
||||
),
|
||||
) -> RunDetails:
|
||||
"""
|
||||
Get detailed information about a specific run.
|
||||
|
||||
Includes outputs and individual node execution results.
|
||||
"""
|
||||
result = await execution_db.get_graph_execution(
|
||||
user_id=auth.user_id,
|
||||
execution_id=run_id,
|
||||
include_node_executions=True,
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail=f"Run #{run_id} not found")
|
||||
|
||||
return _convert_execution_to_run_details(result)
|
||||
|
||||
|
||||
@runs_router.post(
|
||||
path="/{run_id}/stop",
|
||||
summary="Stop a run",
|
||||
)
|
||||
async def stop_run(
|
||||
run_id: str = Path(description="Run ID"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.WRITE_RUN)
|
||||
),
|
||||
) -> Run:
|
||||
"""
|
||||
Stop a running execution.
|
||||
|
||||
Only runs in QUEUED or RUNNING status can be stopped.
|
||||
"""
|
||||
# Verify the run exists and belongs to the user
|
||||
exec = await execution_db.get_graph_execution(
|
||||
user_id=auth.user_id,
|
||||
execution_id=run_id,
|
||||
)
|
||||
if not exec:
|
||||
raise HTTPException(status_code=404, detail=f"Run #{run_id} not found")
|
||||
|
||||
# Stop the execution
|
||||
await execution_utils.stop_graph_execution(
|
||||
graph_exec_id=run_id,
|
||||
user_id=auth.user_id,
|
||||
)
|
||||
|
||||
# Fetch updated execution
|
||||
updated_exec = await execution_db.get_graph_execution(
|
||||
user_id=auth.user_id,
|
||||
execution_id=run_id,
|
||||
)
|
||||
|
||||
if not updated_exec:
|
||||
raise HTTPException(status_code=404, detail=f"Run #{run_id} not found")
|
||||
|
||||
return _convert_execution_to_run(updated_exec)
|
||||
|
||||
|
||||
@runs_router.delete(
|
||||
path="/{run_id}",
|
||||
summary="Delete a run",
|
||||
)
|
||||
async def delete_run(
|
||||
run_id: str = Path(description="Run ID"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.WRITE_RUN)
|
||||
),
|
||||
) -> None:
|
||||
"""
|
||||
Delete an execution run.
|
||||
|
||||
This marks the run as deleted. The data may still be retained for
|
||||
some time for recovery purposes.
|
||||
"""
|
||||
await execution_db.delete_graph_execution(
|
||||
graph_exec_id=run_id,
|
||||
user_id=auth.user_id,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints - Reviews (Human-in-the-loop)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@runs_router.get(
|
||||
path="/reviews",
|
||||
summary="List all pending reviews",
|
||||
response_model=PendingReviewsResponse,
|
||||
)
|
||||
async def list_pending_reviews(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_RUN_REVIEW)
|
||||
),
|
||||
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = Query(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
ge=1,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
) -> PendingReviewsResponse:
|
||||
"""
|
||||
List all pending human-in-the-loop reviews.
|
||||
|
||||
These are blocks that require human approval or input before the
|
||||
agent can continue execution.
|
||||
"""
|
||||
reviews = await review_db.get_pending_reviews_for_user(
|
||||
user_id=auth.user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
# Note: get_pending_reviews_for_user returns list directly, not a paginated result
|
||||
# We compute pagination info based on results
|
||||
total_count = len(reviews)
|
||||
total_pages = max(1, (total_count + page_size - 1) // page_size)
|
||||
|
||||
return PendingReviewsResponse(
|
||||
reviews=[_convert_pending_review(r) for r in reviews],
|
||||
total_count=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
@runs_router.get(
|
||||
path="/{run_id}/reviews",
|
||||
summary="List reviews for a run",
|
||||
response_model=list[PendingReview],
|
||||
)
|
||||
async def list_run_reviews(
|
||||
run_id: str = Path(description="Run ID"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_RUN_REVIEW)
|
||||
),
|
||||
) -> list[PendingReview]:
|
||||
"""
|
||||
List all human-in-the-loop reviews for a specific run.
|
||||
"""
|
||||
reviews = await review_db.get_pending_reviews_for_execution(
|
||||
graph_exec_id=run_id,
|
||||
user_id=auth.user_id,
|
||||
)
|
||||
|
||||
return [_convert_pending_review(r) for r in reviews]
|
||||
|
||||
|
||||
@runs_router.post(
|
||||
path="/{run_id}/reviews",
|
||||
summary="Submit review responses for a run",
|
||||
response_model=SubmitReviewsResponse,
|
||||
)
|
||||
async def submit_reviews(
|
||||
request: SubmitReviewsRequest,
|
||||
run_id: str = Path(description="Run ID"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.WRITE_RUN_REVIEW)
|
||||
),
|
||||
) -> SubmitReviewsResponse:
|
||||
"""
|
||||
Submit responses to all pending human-in-the-loop reviews for a run.
|
||||
|
||||
All pending reviews for the execution must be included. Approving
|
||||
a review will allow the agent to continue; rejecting will terminate
|
||||
execution at that point.
|
||||
"""
|
||||
# Build review decisions dict for process_all_reviews_for_execution
|
||||
review_decisions: dict[
|
||||
str, tuple[ReviewStatus, SafeJsonData | None, str | None]
|
||||
] = {}
|
||||
|
||||
for decision in request.reviews:
|
||||
status = ReviewStatus.APPROVED if decision.approved else ReviewStatus.REJECTED
|
||||
review_decisions[decision.node_exec_id] = (
|
||||
status,
|
||||
decision.edited_payload,
|
||||
decision.message,
|
||||
)
|
||||
|
||||
try:
|
||||
results = await review_db.process_all_reviews_for_execution(
|
||||
user_id=auth.user_id,
|
||||
review_decisions=review_decisions,
|
||||
)
|
||||
|
||||
approved_count = sum(
|
||||
1 for r in results.values() if r.status == ReviewStatus.APPROVED
|
||||
)
|
||||
rejected_count = sum(
|
||||
1 for r in results.values() if r.status == ReviewStatus.REJECTED
|
||||
)
|
||||
|
||||
return SubmitReviewsResponse(
|
||||
run_id=run_id,
|
||||
approved_count=approved_count,
|
||||
rejected_count=rejected_count,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -1,250 +0,0 @@
|
||||
"""
|
||||
V2 External API - Schedules Endpoints
|
||||
|
||||
Provides endpoints for managing execution schedules.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Path, Query, Security
|
||||
from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor import scheduler
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.timezone_utils import get_user_timezone_or_utc
|
||||
|
||||
from .common import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
schedules_router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Models
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class Schedule(BaseModel):
|
||||
"""An execution schedule for a graph."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
cron: str = Field(description="Cron expression for the schedule")
|
||||
input_data: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Input data for scheduled executions"
|
||||
)
|
||||
next_run_time: Optional[datetime] = Field(
|
||||
default=None, description="Next scheduled run time"
|
||||
)
|
||||
is_enabled: bool = Field(default=True, description="Whether schedule is enabled")
|
||||
|
||||
|
||||
class SchedulesListResponse(BaseModel):
|
||||
"""Response for listing schedules."""
|
||||
|
||||
schedules: list[Schedule]
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
|
||||
class CreateScheduleRequest(BaseModel):
|
||||
"""Request to create a schedule."""
|
||||
|
||||
name: str = Field(description="Display name for the schedule")
|
||||
cron: str = Field(description="Cron expression (e.g., '0 9 * * *' for 9am daily)")
|
||||
input_data: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Input data for scheduled executions"
|
||||
)
|
||||
credentials_inputs: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Credentials for the schedule"
|
||||
)
|
||||
graph_version: Optional[int] = Field(
|
||||
default=None, description="Graph version (default: active version)"
|
||||
)
|
||||
timezone: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Timezone for schedule (e.g., 'America/New_York'). "
|
||||
"Defaults to user's timezone."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _convert_schedule(job: scheduler.GraphExecutionJobInfo) -> Schedule:
|
||||
"""Convert internal schedule job info to v2 API model."""
|
||||
# Parse the ISO format string to datetime
|
||||
next_run = datetime.fromisoformat(job.next_run_time) if job.next_run_time else None
|
||||
|
||||
return Schedule(
|
||||
id=job.id,
|
||||
name=job.name or "",
|
||||
graph_id=job.graph_id,
|
||||
graph_version=job.graph_version,
|
||||
cron=job.cron,
|
||||
input_data=job.input_data,
|
||||
next_run_time=next_run,
|
||||
is_enabled=True, # All returned schedules are enabled
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@schedules_router.get(
|
||||
path="",
|
||||
summary="List all user schedules",
|
||||
response_model=SchedulesListResponse,
|
||||
)
|
||||
async def list_all_schedules(
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_SCHEDULE)
|
||||
),
|
||||
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = Query(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
ge=1,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
) -> SchedulesListResponse:
|
||||
"""
|
||||
List all schedules for the authenticated user across all graphs.
|
||||
"""
|
||||
schedules = await get_scheduler_client().get_execution_schedules(
|
||||
user_id=auth.user_id
|
||||
)
|
||||
converted = [_convert_schedule(s) for s in schedules]
|
||||
|
||||
# Manual pagination (scheduler doesn't support pagination natively)
|
||||
total_count = len(converted)
|
||||
total_pages = (total_count + page_size - 1) // page_size if total_count > 0 else 1
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
paginated = converted[start:end]
|
||||
|
||||
return SchedulesListResponse(
|
||||
schedules=paginated,
|
||||
total_count=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
@schedules_router.delete(
|
||||
path="/{schedule_id}",
|
||||
summary="Delete a schedule",
|
||||
)
|
||||
async def delete_schedule(
|
||||
schedule_id: str = Path(description="Schedule ID to delete"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.WRITE_SCHEDULE)
|
||||
),
|
||||
) -> None:
|
||||
"""
|
||||
Delete an execution schedule.
|
||||
"""
|
||||
try:
|
||||
await get_scheduler_client().delete_schedule(
|
||||
schedule_id=schedule_id,
|
||||
user_id=auth.user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
if "not found" in str(e).lower():
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Schedule #{schedule_id} not found"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Graph-specific Schedule Endpoints (nested under /graphs)
|
||||
# These are included in the graphs router via include_router
|
||||
# ============================================================================
|
||||
|
||||
graph_schedules_router = APIRouter()
|
||||
|
||||
|
||||
@graph_schedules_router.get(
|
||||
path="/{graph_id}/schedules",
|
||||
summary="List schedules for a graph",
|
||||
response_model=list[Schedule],
|
||||
)
|
||||
async def list_graph_schedules(
|
||||
graph_id: str = Path(description="Graph ID"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.READ_SCHEDULE)
|
||||
),
|
||||
) -> list[Schedule]:
|
||||
"""
|
||||
List all schedules for a specific graph.
|
||||
"""
|
||||
schedules = await get_scheduler_client().get_execution_schedules(
|
||||
user_id=auth.user_id,
|
||||
graph_id=graph_id,
|
||||
)
|
||||
return [_convert_schedule(s) for s in schedules]
|
||||
|
||||
|
||||
@graph_schedules_router.post(
|
||||
path="/{graph_id}/schedules",
|
||||
summary="Create a schedule for a graph",
|
||||
response_model=Schedule,
|
||||
)
|
||||
async def create_graph_schedule(
|
||||
request: CreateScheduleRequest,
|
||||
graph_id: str = Path(description="Graph ID"),
|
||||
auth: APIAuthorizationInfo = Security(
|
||||
require_permission(APIKeyPermission.WRITE_SCHEDULE)
|
||||
),
|
||||
) -> Schedule:
|
||||
"""
|
||||
Create a new execution schedule for a graph.
|
||||
|
||||
The schedule will execute the graph at times matching the cron expression,
|
||||
using the provided input data.
|
||||
"""
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=graph_id,
|
||||
version=request.graph_version,
|
||||
user_id=auth.user_id,
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Graph #{graph_id} v{request.graph_version} not found.",
|
||||
)
|
||||
|
||||
# Determine timezone
|
||||
if request.timezone:
|
||||
user_timezone = request.timezone
|
||||
else:
|
||||
user = await get_user_by_id(auth.user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
result = await get_scheduler_client().add_execution_schedule(
|
||||
user_id=auth.user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
name=request.name,
|
||||
cron=request.cron,
|
||||
input_data=request.input_data,
|
||||
input_credentials=request.credentials_inputs,
|
||||
user_timezone=user_timezone,
|
||||
)
|
||||
|
||||
return _convert_schedule(result)
|
||||
@@ -28,6 +28,7 @@ from backend.executor.manager import get_db_async_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class ExecutionAnalyticsRequest(BaseModel):
|
||||
@@ -63,6 +64,8 @@ class ExecutionAnalyticsResult(BaseModel):
|
||||
score: Optional[float]
|
||||
status: str # "success", "failed", "skipped"
|
||||
error_message: Optional[str] = None
|
||||
started_at: Optional[datetime] = None
|
||||
ended_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class ExecutionAnalyticsResponse(BaseModel):
|
||||
@@ -224,11 +227,6 @@ async def generate_execution_analytics(
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate model configuration
|
||||
settings = Settings()
|
||||
if not settings.secrets.openai_internal_api_key:
|
||||
raise HTTPException(status_code=500, detail="OpenAI API key not configured")
|
||||
|
||||
# Get database client
|
||||
db_client = get_db_async_client()
|
||||
|
||||
@@ -320,6 +318,8 @@ async def generate_execution_analytics(
|
||||
),
|
||||
status="skipped",
|
||||
error_message=None, # Not an error - just already processed
|
||||
started_at=execution.started_at,
|
||||
ended_at=execution.ended_at,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -349,6 +349,9 @@ async def _process_batch(
|
||||
) -> list[ExecutionAnalyticsResult]:
|
||||
"""Process a batch of executions concurrently."""
|
||||
|
||||
if not settings.secrets.openai_internal_api_key:
|
||||
raise HTTPException(status_code=500, detail="OpenAI API key not configured")
|
||||
|
||||
async def process_single_execution(execution) -> ExecutionAnalyticsResult:
|
||||
try:
|
||||
# Generate activity status and score using the specified model
|
||||
@@ -387,6 +390,8 @@ async def _process_batch(
|
||||
score=None,
|
||||
status="skipped",
|
||||
error_message="Activity generation returned None",
|
||||
started_at=execution.started_at,
|
||||
ended_at=execution.ended_at,
|
||||
)
|
||||
|
||||
# Update the execution stats
|
||||
@@ -416,6 +421,8 @@ async def _process_batch(
|
||||
summary_text=activity_response["activity_status"],
|
||||
score=activity_response["correctness_score"],
|
||||
status="success",
|
||||
started_at=execution.started_at,
|
||||
ended_at=execution.ended_at,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -429,6 +436,8 @@ async def _process_batch(
|
||||
score=None,
|
||||
status="failed",
|
||||
error_message=str(e),
|
||||
started_at=execution.started_at,
|
||||
ended_at=execution.ended_at,
|
||||
)
|
||||
|
||||
# Process all executions in the batch concurrently
|
||||
|
||||
@@ -4,14 +4,9 @@ from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from langfuse import Langfuse
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
APIStatusError,
|
||||
AsyncOpenAI,
|
||||
RateLimitError,
|
||||
)
|
||||
from langfuse import get_client, propagate_attributes
|
||||
from langfuse.openai import openai # type: ignore
|
||||
from openai import APIConnectionError, APIError, APIStatusError, RateLimitError
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||
|
||||
from backend.data.understanding import (
|
||||
@@ -21,7 +16,6 @@ from backend.data.understanding import (
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
@@ -50,10 +44,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
config = ChatConfig()
|
||||
settings = Settings()
|
||||
client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
# Langfuse client (lazy initialization)
|
||||
_langfuse_client: Langfuse | None = None
|
||||
|
||||
langfuse = get_client()
|
||||
|
||||
|
||||
class LangfuseNotConfiguredError(Exception):
|
||||
@@ -69,65 +63,6 @@ def _is_langfuse_configured() -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _get_langfuse_client() -> Langfuse:
|
||||
"""Get or create the Langfuse client for prompt management and tracing."""
|
||||
global _langfuse_client
|
||||
if _langfuse_client is None:
|
||||
if not _is_langfuse_configured():
|
||||
raise LangfuseNotConfiguredError(
|
||||
"Langfuse is not configured. The chat feature requires Langfuse for prompt management. "
|
||||
"Please set the LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables."
|
||||
)
|
||||
_langfuse_client = Langfuse(
|
||||
public_key=settings.secrets.langfuse_public_key,
|
||||
secret_key=settings.secrets.langfuse_secret_key,
|
||||
host=settings.secrets.langfuse_host or "https://cloud.langfuse.com",
|
||||
)
|
||||
return _langfuse_client
|
||||
|
||||
|
||||
def _get_environment() -> str:
|
||||
"""Get the current environment name for Langfuse tagging."""
|
||||
return settings.config.app_env.value
|
||||
|
||||
|
||||
def _get_langfuse_prompt() -> str:
|
||||
"""Fetch the latest production prompt from Langfuse.
|
||||
|
||||
Returns:
|
||||
The compiled prompt text from Langfuse.
|
||||
|
||||
Raises:
|
||||
Exception: If Langfuse is unavailable or prompt fetch fails.
|
||||
"""
|
||||
try:
|
||||
langfuse = _get_langfuse_client()
|
||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||
prompt = langfuse.get_prompt(config.langfuse_prompt_name, cache_ttl_seconds=0)
|
||||
compiled = prompt.compile()
|
||||
logger.info(
|
||||
f"Fetched prompt '{config.langfuse_prompt_name}' from Langfuse "
|
||||
f"(version: {prompt.version})"
|
||||
)
|
||||
return compiled
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch prompt from Langfuse: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def _is_first_session(user_id: str) -> bool:
|
||||
"""Check if this is the user's first chat session.
|
||||
|
||||
Returns True if the user has 1 or fewer sessions (meaning this is their first).
|
||||
"""
|
||||
try:
|
||||
session_count = await chat_db.get_user_session_count(user_id)
|
||||
return session_count <= 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check session count for user {user_id}: {e}")
|
||||
return False # Default to non-onboarding if we can't check
|
||||
|
||||
|
||||
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||
"""Build the full system prompt including business understanding if available.
|
||||
|
||||
@@ -139,8 +74,6 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||
Tuple of (compiled prompt string, Langfuse prompt object for tracing)
|
||||
"""
|
||||
|
||||
langfuse = _get_langfuse_client()
|
||||
|
||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||
prompt = langfuse.get_prompt(config.langfuse_prompt_name, cache_ttl_seconds=0)
|
||||
|
||||
@@ -158,7 +91,7 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||
|
||||
compiled = prompt.compile(users_information=context)
|
||||
return compiled, prompt
|
||||
return compiled, understanding
|
||||
|
||||
|
||||
async def _generate_session_title(message: str) -> str | None:
|
||||
@@ -217,6 +150,7 @@ async def assign_user_to_session(
|
||||
async def stream_chat_completion(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
tool_call_response: str | None = None,
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
retry_count: int = 0,
|
||||
@@ -256,11 +190,6 @@ async def stream_chat_completion(
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Langfuse observations will be created after session is loaded (need messages for input)
|
||||
# Initialize to None so finally block can safely check and end them
|
||||
trace = None
|
||||
generation = None
|
||||
|
||||
# Only fetch from Redis if session not provided (initial call)
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
@@ -299,9 +228,6 @@ async def stream_chat_completion(
|
||||
f"new message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
if len(session.messages) > config.max_context_messages:
|
||||
raise ValueError(f"Max messages exceeded: {config.max_context_messages}")
|
||||
|
||||
logger.info(
|
||||
f"Upserting session: {session.session_id} with user id {session.user_id}, "
|
||||
f"message_count={len(session.messages)}"
|
||||
@@ -339,297 +265,259 @@ async def stream_chat_completion(
|
||||
asyncio.create_task(_update_title())
|
||||
|
||||
# Build system prompt with business understanding
|
||||
system_prompt, langfuse_prompt = await _build_system_prompt(user_id)
|
||||
|
||||
# Build input messages including system prompt for complete Langfuse logging
|
||||
trace_input_messages = [{"role": "system", "content": system_prompt}] + [
|
||||
m.model_dump() for m in session.messages
|
||||
]
|
||||
system_prompt, understanding = await _build_system_prompt(user_id)
|
||||
|
||||
# Create Langfuse trace for this LLM call (each call gets its own trace, grouped by session_id)
|
||||
# Using v3 SDK: start_observation creates a root span, update_trace sets trace-level attributes
|
||||
try:
|
||||
langfuse = _get_langfuse_client()
|
||||
env = _get_environment()
|
||||
trace = langfuse.start_observation(
|
||||
name="chat_completion",
|
||||
input={"messages": trace_input_messages},
|
||||
metadata={
|
||||
"environment": env,
|
||||
"model": config.model,
|
||||
"message_count": len(session.messages),
|
||||
"prompt_name": langfuse_prompt.name if langfuse_prompt else None,
|
||||
"prompt_version": langfuse_prompt.version if langfuse_prompt else None,
|
||||
},
|
||||
)
|
||||
# Set trace-level attributes (session_id, user_id, tags)
|
||||
trace.update_trace(
|
||||
input = message
|
||||
if not message and tool_call_response:
|
||||
input = tool_call_response
|
||||
|
||||
langfuse = get_client()
|
||||
with langfuse.start_as_current_observation(
|
||||
as_type="span",
|
||||
name="user-copilot-request",
|
||||
input=input,
|
||||
) as span:
|
||||
with propagate_attributes(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tags=[env, "copilot"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create Langfuse trace: {e}")
|
||||
tags=["copilot"],
|
||||
metadata={
|
||||
"users_information": format_understanding_for_prompt(understanding)[
|
||||
:200
|
||||
] # langfuse only accepts upto to 200 chars
|
||||
},
|
||||
):
|
||||
|
||||
# Initialize variables that will be used in finally block (must be defined before try)
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
)
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
# Wrap main logic in try/finally to ensure Langfuse observations are always ended
|
||||
try:
|
||||
has_yielded_end = False
|
||||
has_yielded_error = False
|
||||
has_done_tool_call = False
|
||||
has_received_text = False
|
||||
text_streaming_ended = False
|
||||
tool_response_messages: list[ChatMessage] = []
|
||||
should_retry = False
|
||||
|
||||
# Generate unique IDs for AI SDK protocol
|
||||
import uuid as uuid_module
|
||||
|
||||
message_id = str(uuid_module.uuid4())
|
||||
text_block_id = str(uuid_module.uuid4())
|
||||
|
||||
# Yield message start
|
||||
yield StreamStart(messageId=message_id)
|
||||
|
||||
# Create Langfuse generation for each LLM call, linked to the prompt
|
||||
# Using v3 SDK: start_observation with as_type="generation"
|
||||
generation = (
|
||||
trace.start_observation(
|
||||
as_type="generation",
|
||||
name="llm_call",
|
||||
model=config.model,
|
||||
input={"messages": trace_input_messages},
|
||||
prompt=langfuse_prompt,
|
||||
# Initialize variables that will be used in finally block (must be defined before try)
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
)
|
||||
if trace
|
||||
else None
|
||||
)
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
try:
|
||||
async for chunk in _stream_chat_chunks(
|
||||
session=session,
|
||||
tools=tools,
|
||||
system_prompt=system_prompt,
|
||||
text_block_id=text_block_id,
|
||||
):
|
||||
# Wrap main logic in try/finally to ensure Langfuse observations are always ended
|
||||
has_yielded_end = False
|
||||
has_yielded_error = False
|
||||
has_done_tool_call = False
|
||||
has_received_text = False
|
||||
text_streaming_ended = False
|
||||
tool_response_messages: list[ChatMessage] = []
|
||||
should_retry = False
|
||||
|
||||
if isinstance(chunk, StreamTextStart):
|
||||
# Emit text-start before first text delta
|
||||
if not has_received_text:
|
||||
# Generate unique IDs for AI SDK protocol
|
||||
import uuid as uuid_module
|
||||
|
||||
message_id = str(uuid_module.uuid4())
|
||||
text_block_id = str(uuid_module.uuid4())
|
||||
|
||||
# Yield message start
|
||||
yield StreamStart(messageId=message_id)
|
||||
|
||||
try:
|
||||
async for chunk in _stream_chat_chunks(
|
||||
session=session,
|
||||
tools=tools,
|
||||
system_prompt=system_prompt,
|
||||
text_block_id=text_block_id,
|
||||
):
|
||||
|
||||
if isinstance(chunk, StreamTextStart):
|
||||
# Emit text-start before first text delta
|
||||
if not has_received_text:
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamTextDelta):
|
||||
delta = chunk.delta or ""
|
||||
assert assistant_response.content is not None
|
||||
assistant_response.content += delta
|
||||
has_received_text = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamTextDelta):
|
||||
delta = chunk.delta or ""
|
||||
assert assistant_response.content is not None
|
||||
assistant_response.content += delta
|
||||
has_received_text = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamTextEnd):
|
||||
# Emit text-end after text completes
|
||||
if has_received_text and not text_streaming_ended:
|
||||
text_streaming_ended = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamToolInputStart):
|
||||
# Emit text-end before first tool call, but only if we've received text
|
||||
if has_received_text and not text_streaming_ended:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_streaming_ended = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamToolInputAvailable):
|
||||
# Accumulate tool calls in OpenAI format
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": chunk.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": chunk.toolName,
|
||||
"arguments": orjson.dumps(chunk.input).decode("utf-8"),
|
||||
},
|
||||
}
|
||||
)
|
||||
elif isinstance(chunk, StreamToolOutputAvailable):
|
||||
result_content = (
|
||||
chunk.output
|
||||
if isinstance(chunk.output, str)
|
||||
else orjson.dumps(chunk.output).decode("utf-8")
|
||||
)
|
||||
tool_response_messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=result_content,
|
||||
tool_call_id=chunk.toolCallId,
|
||||
)
|
||||
)
|
||||
has_done_tool_call = True
|
||||
# Track if any tool execution failed
|
||||
if not chunk.success:
|
||||
logger.warning(
|
||||
f"Tool {chunk.toolName} (ID: {chunk.toolCallId}) execution failed"
|
||||
)
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamFinish):
|
||||
if not has_done_tool_call:
|
||||
# Emit text-end before finish if we received text but haven't closed it
|
||||
elif isinstance(chunk, StreamTextEnd):
|
||||
# Emit text-end after text completes
|
||||
if has_received_text and not text_streaming_ended:
|
||||
text_streaming_ended = True
|
||||
if assistant_response.content:
|
||||
logger.warn(
|
||||
f"StreamTextEnd: Attempting to set output {assistant_response.content}"
|
||||
)
|
||||
span.update_trace(output=assistant_response.content)
|
||||
span.update(output=assistant_response.content)
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamToolInputStart):
|
||||
# Emit text-end before first tool call, but only if we've received text
|
||||
if has_received_text and not text_streaming_ended:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_streaming_ended = True
|
||||
has_yielded_end = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamError):
|
||||
has_yielded_error = True
|
||||
elif isinstance(chunk, StreamUsage):
|
||||
session.usage.append(
|
||||
Usage(
|
||||
prompt_tokens=chunk.promptTokens,
|
||||
completion_tokens=chunk.completionTokens,
|
||||
total_tokens=chunk.totalTokens,
|
||||
elif isinstance(chunk, StreamToolInputAvailable):
|
||||
# Accumulate tool calls in OpenAI format
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": chunk.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": chunk.toolName,
|
||||
"arguments": orjson.dumps(chunk.input).decode(
|
||||
"utf-8"
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
elif isinstance(chunk, StreamToolOutputAvailable):
|
||||
result_content = (
|
||||
chunk.output
|
||||
if isinstance(chunk.output, str)
|
||||
else orjson.dumps(chunk.output).decode("utf-8")
|
||||
)
|
||||
tool_response_messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=result_content,
|
||||
tool_call_id=chunk.toolCallId,
|
||||
)
|
||||
)
|
||||
has_done_tool_call = True
|
||||
# Track if any tool execution failed
|
||||
if not chunk.success:
|
||||
logger.warning(
|
||||
f"Tool {chunk.toolName} (ID: {chunk.toolCallId}) execution failed"
|
||||
)
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamFinish):
|
||||
if not has_done_tool_call:
|
||||
# Emit text-end before finish if we received text but haven't closed it
|
||||
if has_received_text and not text_streaming_ended:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
text_streaming_ended = True
|
||||
has_yielded_end = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamError):
|
||||
has_yielded_error = True
|
||||
elif isinstance(chunk, StreamUsage):
|
||||
session.usage.append(
|
||||
Usage(
|
||||
prompt_tokens=chunk.promptTokens,
|
||||
completion_tokens=chunk.completionTokens,
|
||||
total_tokens=chunk.totalTokens,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Unknown chunk type: {type(chunk)}", exc_info=True
|
||||
)
|
||||
if assistant_response.content:
|
||||
langfuse.update_current_trace(output=assistant_response.content)
|
||||
langfuse.update_current_span(output=assistant_response.content)
|
||||
elif tool_response_messages:
|
||||
langfuse.update_current_trace(output=str(tool_response_messages))
|
||||
langfuse.update_current_span(output=str(tool_response_messages))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during stream: {e!s}", exc_info=True)
|
||||
|
||||
# Check if this is a retryable error (JSON parsing, incomplete tool calls, etc.)
|
||||
is_retryable = isinstance(
|
||||
e, (orjson.JSONDecodeError, KeyError, TypeError)
|
||||
)
|
||||
|
||||
if is_retryable and retry_count < config.max_retries:
|
||||
logger.info(
|
||||
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
||||
)
|
||||
should_retry = True
|
||||
else:
|
||||
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during stream: {e!s}", exc_info=True)
|
||||
# Non-retryable error or max retries exceeded
|
||||
# Save any partial progress before reporting error
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
|
||||
# Check if this is a retryable error (JSON parsing, incomplete tool calls, etc.)
|
||||
is_retryable = isinstance(e, (orjson.JSONDecodeError, KeyError, TypeError))
|
||||
# Add assistant message if it has content or tool calls
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_save.append(assistant_response)
|
||||
|
||||
if is_retryable and retry_count < config.max_retries:
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
|
||||
session.messages.extend(messages_to_save)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
if not has_yielded_error:
|
||||
error_message = str(e)
|
||||
if not is_retryable:
|
||||
error_message = f"Non-retryable error: {error_message}"
|
||||
elif retry_count >= config.max_retries:
|
||||
error_message = f"Max retries ({config.max_retries}) exceeded: {error_message}"
|
||||
|
||||
error_response = StreamError(errorText=error_message)
|
||||
yield error_response
|
||||
if not has_yielded_end:
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Handle retry outside of exception handler to avoid nesting
|
||||
if should_retry and retry_count < config.max_retries:
|
||||
logger.info(
|
||||
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
||||
f"Retrying stream_chat_completion for session {session_id}, attempt {retry_count + 1}"
|
||||
)
|
||||
should_retry = True
|
||||
else:
|
||||
# Non-retryable error or max retries exceeded
|
||||
# Save any partial progress before reporting error
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
async for chunk in stream_chat_completion(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
retry_count=retry_count + 1,
|
||||
session=session,
|
||||
context=context,
|
||||
):
|
||||
yield chunk
|
||||
return # Exit after retry to avoid double-saving in finally block
|
||||
|
||||
# Add assistant message if it has content or tool calls
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_save.append(assistant_response)
|
||||
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
|
||||
session.messages.extend(messages_to_save)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
if not has_yielded_error:
|
||||
error_message = str(e)
|
||||
if not is_retryable:
|
||||
error_message = f"Non-retryable error: {error_message}"
|
||||
elif retry_count >= config.max_retries:
|
||||
error_message = f"Max retries ({config.max_retries}) exceeded: {error_message}"
|
||||
|
||||
error_response = StreamError(errorText=error_message)
|
||||
yield error_response
|
||||
if not has_yielded_end:
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Handle retry outside of exception handler to avoid nesting
|
||||
if should_retry and retry_count < config.max_retries:
|
||||
# Normal completion path - save session and handle tool call continuation
|
||||
logger.info(
|
||||
f"Retrying stream_chat_completion for session {session_id}, attempt {retry_count + 1}"
|
||||
)
|
||||
async for chunk in stream_chat_completion(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
retry_count=retry_count + 1,
|
||||
session=session,
|
||||
context=context,
|
||||
):
|
||||
yield chunk
|
||||
return # Exit after retry to avoid double-saving in finally block
|
||||
|
||||
# Normal completion path - save session and handle tool call continuation
|
||||
logger.info(
|
||||
f"Normal completion path: session={session.session_id}, "
|
||||
f"current message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
# Build the messages list in the correct order
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
|
||||
# Add assistant message with tool_calls if any
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
logger.info(
|
||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||
)
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_save.append(assistant_response)
|
||||
logger.info(
|
||||
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
|
||||
f"Normal completion path: session={session.session_id}, "
|
||||
f"current message_count={len(session.messages)}"
|
||||
)
|
||||
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
logger.info(
|
||||
f"Saving {len(tool_response_messages)} tool response messages, "
|
||||
f"total_to_save={len(messages_to_save)}"
|
||||
)
|
||||
# Build the messages list in the correct order
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
|
||||
session.messages.extend(messages_to_save)
|
||||
logger.info(
|
||||
f"Extended session messages, new message_count={len(session.messages)}"
|
||||
)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
# If we did a tool call, stream the chat completion again to get the next response
|
||||
if has_done_tool_call:
|
||||
logger.info(
|
||||
"Tool call executed, streaming chat completion again to get assistant response"
|
||||
)
|
||||
async for chunk in stream_chat_completion(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
session=session, # Pass session object to avoid Redis refetch
|
||||
context=context,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
finally:
|
||||
# Always end Langfuse observations to prevent resource leaks
|
||||
# Guard against None and catch errors to avoid masking original exceptions
|
||||
if generation is not None:
|
||||
try:
|
||||
latest_usage = session.usage[-1] if session.usage else None
|
||||
generation.update(
|
||||
model=config.model,
|
||||
output={
|
||||
"content": assistant_response.content,
|
||||
"tool_calls": accumulated_tool_calls or None,
|
||||
},
|
||||
usage_details=(
|
||||
{
|
||||
"input": latest_usage.prompt_tokens,
|
||||
"output": latest_usage.completion_tokens,
|
||||
"total": latest_usage.total_tokens,
|
||||
}
|
||||
if latest_usage
|
||||
else None
|
||||
),
|
||||
# Add assistant message with tool_calls if any
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
logger.info(
|
||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||
)
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_save.append(assistant_response)
|
||||
logger.info(
|
||||
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
|
||||
)
|
||||
generation.end()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to end Langfuse generation: {e}")
|
||||
|
||||
if trace is not None:
|
||||
try:
|
||||
if accumulated_tool_calls:
|
||||
trace.update_trace(output={"tool_calls": accumulated_tool_calls})
|
||||
else:
|
||||
trace.update_trace(output={"response": assistant_response.content})
|
||||
trace.end()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to end Langfuse trace: {e}")
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
logger.info(
|
||||
f"Saving {len(tool_response_messages)} tool response messages, "
|
||||
f"total_to_save={len(messages_to_save)}"
|
||||
)
|
||||
|
||||
session.messages.extend(messages_to_save)
|
||||
logger.info(
|
||||
f"Extended session messages, new message_count={len(session.messages)}"
|
||||
)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
# If we did a tool call, stream the chat completion again to get the next response
|
||||
if has_done_tool_call:
|
||||
logger.info(
|
||||
"Tool call executed, streaming chat completion again to get assistant response"
|
||||
)
|
||||
async for chunk in stream_chat_completion(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
session=session, # Pass session object to avoid Redis refetch
|
||||
context=context,
|
||||
tool_call_response=str(tool_response_messages),
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
# Retry configuration for OpenAI API calls
|
||||
@@ -903,5 +791,4 @@ async def _yield_tool_call(
|
||||
session=session,
|
||||
)
|
||||
|
||||
logger.info(f"Yielding Tool execution response: {tool_execution_response}")
|
||||
yield tool_execution_response
|
||||
|
||||
@@ -7,9 +7,15 @@ from backend.api.features.chat.model import ChatSession
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
from .get_doc_page import GetDocPageTool
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .search_docs import SearchDocsTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
@@ -17,10 +23,16 @@ if TYPE_CHECKING:
|
||||
# Single source of truth for all tools
|
||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"add_understanding": AddUnderstandingTool(),
|
||||
"create_agent": CreateAgentTool(),
|
||||
"edit_agent": EditAgentTool(),
|
||||
"find_agent": FindAgentTool(),
|
||||
"find_block": FindBlockTool(),
|
||||
"find_library_agent": FindLibraryAgentTool(),
|
||||
"run_agent": RunAgentTool(),
|
||||
"agent_output": AgentOutputTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
"search_docs": SearchDocsTool(),
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
}
|
||||
|
||||
# Export individual tool instances for backwards compatibility
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
@@ -59,6 +61,7 @@ and automations for the user's specific needs."""
|
||||
"""Requires authentication to store user-specific data."""
|
||||
return True
|
||||
|
||||
@observe(as_type="tool", name="add_understanding")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Agent generator package - Creates agents from natural language."""
|
||||
|
||||
from .core import (
|
||||
apply_agent_patch,
|
||||
decompose_goal,
|
||||
generate_agent,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .fixer import apply_all_fixes
|
||||
from .utils import get_blocks_info
|
||||
from .validator import validate_agent
|
||||
|
||||
__all__ = [
|
||||
# Core functions
|
||||
"decompose_goal",
|
||||
"generate_agent",
|
||||
"generate_agent_patch",
|
||||
"apply_agent_patch",
|
||||
"save_agent_to_library",
|
||||
"get_agent_as_json",
|
||||
# Fixer
|
||||
"apply_all_fixes",
|
||||
# Validator
|
||||
"validate_agent",
|
||||
# Utils
|
||||
"get_blocks_info",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
"""OpenRouter client configuration for agent generation."""
|
||||
|
||||
import os
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Configuration - use OPEN_ROUTER_API_KEY for consistency with chat/config.py
|
||||
OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY")
|
||||
AGENT_GENERATOR_MODEL = os.getenv("AGENT_GENERATOR_MODEL", "anthropic/claude-opus-4.5")
|
||||
|
||||
# OpenRouter client (OpenAI-compatible API)
|
||||
_client: AsyncOpenAI | None = None
|
||||
|
||||
|
||||
def get_client() -> AsyncOpenAI:
|
||||
"""Get or create the OpenRouter client."""
|
||||
global _client
|
||||
if _client is None:
|
||||
if not OPENROUTER_API_KEY:
|
||||
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
||||
_client = AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=OPENROUTER_API_KEY,
|
||||
)
|
||||
return _client
|
||||
@@ -0,0 +1,390 @@
|
||||
"""Core agent generation functions."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
|
||||
from .client import AGENT_GENERATOR_MODEL, get_client
|
||||
from .prompts import DECOMPOSITION_PROMPT, GENERATION_PROMPT, PATCH_PROMPT
|
||||
from .utils import get_block_summaries, parse_json_from_llm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
||||
"""Break down a goal into steps or return clarifying questions.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
|
||||
Returns:
|
||||
Dict with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
Or None on error
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = DECOMPOSITION_PROMPT.format(block_summaries=get_block_summaries())
|
||||
|
||||
full_description = description
|
||||
if context:
|
||||
full_description = f"{description}\n\nAdditional context:\n{context}"
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": full_description},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for decomposition")
|
||||
return None
|
||||
|
||||
result = parse_json_from_llm(content)
|
||||
|
||||
if result is None:
|
||||
logger.error(f"Failed to parse decomposition response: {content[:200]}")
|
||||
return None
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error decomposing goal: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
|
||||
Returns:
|
||||
Agent JSON dict or None on error
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = GENERATION_PROMPT.format(block_summaries=get_block_summaries())
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": json.dumps(instructions, indent=2)},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for agent generation")
|
||||
return None
|
||||
|
||||
result = parse_json_from_llm(content)
|
||||
|
||||
if result is None:
|
||||
logger.error(f"Failed to parse agent JSON: {content[:200]}")
|
||||
return None
|
||||
|
||||
# Ensure required fields
|
||||
if "id" not in result:
|
||||
result["id"] = str(uuid.uuid4())
|
||||
if "version" not in result:
|
||||
result["version"] = 1
|
||||
if "is_active" not in result:
|
||||
result["is_active"] = True
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating agent: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
"""Convert agent JSON dict to Graph model.
|
||||
|
||||
Args:
|
||||
agent_json: Agent JSON with nodes and links
|
||||
|
||||
Returns:
|
||||
Graph ready for saving
|
||||
"""
|
||||
nodes = []
|
||||
for n in agent_json.get("nodes", []):
|
||||
node = Node(
|
||||
id=n.get("id", str(uuid.uuid4())),
|
||||
block_id=n["block_id"],
|
||||
input_default=n.get("input_default", {}),
|
||||
metadata=n.get("metadata", {}),
|
||||
)
|
||||
nodes.append(node)
|
||||
|
||||
links = []
|
||||
for link_data in agent_json.get("links", []):
|
||||
link = Link(
|
||||
id=link_data.get("id", str(uuid.uuid4())),
|
||||
source_id=link_data["source_id"],
|
||||
sink_id=link_data["sink_id"],
|
||||
source_name=link_data["source_name"],
|
||||
sink_name=link_data["sink_name"],
|
||||
is_static=link_data.get("is_static", False),
|
||||
)
|
||||
links.append(link)
|
||||
|
||||
return Graph(
|
||||
id=agent_json.get("id", str(uuid.uuid4())),
|
||||
version=agent_json.get("version", 1),
|
||||
is_active=agent_json.get("is_active", True),
|
||||
name=agent_json.get("name", "Generated Agent"),
|
||||
description=agent_json.get("description", ""),
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
|
||||
|
||||
def _reassign_node_ids(graph: Graph) -> None:
|
||||
"""Reassign all node and link IDs to new UUIDs.
|
||||
|
||||
This is needed when creating a new version to avoid unique constraint violations.
|
||||
"""
|
||||
# Create mapping from old node IDs to new UUIDs
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||
|
||||
# Reassign node IDs
|
||||
for node in graph.nodes:
|
||||
node.id = id_map[node.id]
|
||||
|
||||
# Update link references to use new node IDs
|
||||
for link in graph.links:
|
||||
link.id = str(uuid.uuid4()) # Also give links new IDs
|
||||
if link.source_id in id_map:
|
||||
link.source_id = id_map[link.source_id]
|
||||
if link.sink_id in id_map:
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
|
||||
async def save_agent_to_library(
|
||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||
) -> tuple[Graph, Any]:
|
||||
"""Save agent to database and user's library.
|
||||
|
||||
Args:
|
||||
agent_json: Agent JSON dict
|
||||
user_id: User ID
|
||||
is_update: Whether this is an update to an existing agent
|
||||
|
||||
Returns:
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
"""
|
||||
from backend.data.graph import get_graph_all_versions
|
||||
|
||||
graph = json_to_graph(agent_json)
|
||||
|
||||
if is_update:
|
||||
# For updates, keep the same graph ID but increment version
|
||||
# and reassign node/link IDs to avoid conflicts
|
||||
if graph.id:
|
||||
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
||||
if existing_versions:
|
||||
latest_version = max(v.version for v in existing_versions)
|
||||
graph.version = latest_version + 1
|
||||
# Reassign node IDs (but keep graph ID the same)
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
||||
else:
|
||||
# For new agents, always generate a fresh UUID to avoid collisions
|
||||
graph.id = str(uuid.uuid4())
|
||||
graph.version = 1
|
||||
# Reassign all node IDs as well
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Creating new agent with ID {graph.id}")
|
||||
|
||||
# Save to database
|
||||
created_graph = await create_graph(graph, user_id)
|
||||
|
||||
# Add to user's library (or update existing library agent)
|
||||
library_agents = await library_db.create_library_agent(
|
||||
graph=created_graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
)
|
||||
|
||||
return created_graph, library_agents[0]
|
||||
|
||||
|
||||
async def get_agent_as_json(
|
||||
graph_id: str, user_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch an agent and convert to JSON format for editing.
|
||||
|
||||
Args:
|
||||
graph_id: Graph ID or library agent ID
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Agent as JSON dict or None if not found
|
||||
"""
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
# Try to get the graph (version=None gets the active version)
|
||||
graph = await get_graph(graph_id, version=None, user_id=user_id)
|
||||
if not graph:
|
||||
return None
|
||||
|
||||
# Convert to JSON format
|
||||
nodes = []
|
||||
for node in graph.nodes:
|
||||
nodes.append(
|
||||
{
|
||||
"id": node.id,
|
||||
"block_id": node.block_id,
|
||||
"input_default": node.input_default,
|
||||
"metadata": node.metadata,
|
||||
}
|
||||
)
|
||||
|
||||
links = []
|
||||
for node in graph.nodes:
|
||||
for link in node.output_links:
|
||||
links.append(
|
||||
{
|
||||
"id": link.id,
|
||||
"source_id": link.source_id,
|
||||
"sink_id": link.sink_id,
|
||||
"source_name": link.source_name,
|
||||
"sink_name": link.sink_name,
|
||||
"is_static": link.is_static,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"id": graph.id,
|
||||
"name": graph.name,
|
||||
"description": graph.description,
|
||||
"version": graph.version,
|
||||
"is_active": graph.is_active,
|
||||
"nodes": nodes,
|
||||
"links": links,
|
||||
}
|
||||
|
||||
|
||||
async def generate_agent_patch(
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate a patch to update an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
|
||||
Returns:
|
||||
Patch dict or clarifying questions, or None on error
|
||||
"""
|
||||
client = get_client()
|
||||
prompt = PATCH_PROMPT.format(
|
||||
current_agent=json.dumps(current_agent, indent=2),
|
||||
block_summaries=get_block_summaries(),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=AGENT_GENERATOR_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": update_request},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error("LLM returned empty content for patch generation")
|
||||
return None
|
||||
|
||||
return parse_json_from_llm(content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating patch: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def apply_agent_patch(
|
||||
current_agent: dict[str, Any], patch: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Apply a patch to an existing agent.
|
||||
|
||||
Args:
|
||||
current_agent: Current agent JSON
|
||||
patch: Patch dict with operations
|
||||
|
||||
Returns:
|
||||
Updated agent JSON
|
||||
"""
|
||||
agent = copy.deepcopy(current_agent)
|
||||
patches = patch.get("patches", [])
|
||||
|
||||
for p in patches:
|
||||
patch_type = p.get("type")
|
||||
|
||||
if patch_type == "modify":
|
||||
node_id = p.get("node_id")
|
||||
changes = p.get("changes", {})
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
if node["id"] == node_id:
|
||||
_deep_update(node, changes)
|
||||
logger.debug(f"Modified node {node_id}")
|
||||
break
|
||||
|
||||
elif patch_type == "add":
|
||||
new_nodes = p.get("new_nodes", [])
|
||||
new_links = p.get("new_links", [])
|
||||
|
||||
agent["nodes"] = agent.get("nodes", []) + new_nodes
|
||||
agent["links"] = agent.get("links", []) + new_links
|
||||
logger.debug(f"Added {len(new_nodes)} nodes, {len(new_links)} links")
|
||||
|
||||
elif patch_type == "remove":
|
||||
node_ids_to_remove = set(p.get("node_ids", []))
|
||||
link_ids_to_remove = set(p.get("link_ids", []))
|
||||
|
||||
# Remove nodes
|
||||
agent["nodes"] = [
|
||||
n for n in agent.get("nodes", []) if n["id"] not in node_ids_to_remove
|
||||
]
|
||||
|
||||
# Remove links (both explicit and those referencing removed nodes)
|
||||
agent["links"] = [
|
||||
link
|
||||
for link in agent.get("links", [])
|
||||
if link["id"] not in link_ids_to_remove
|
||||
and link["source_id"] not in node_ids_to_remove
|
||||
and link["sink_id"] not in node_ids_to_remove
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Removed {len(node_ids_to_remove)} nodes, {len(link_ids_to_remove)} links"
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def _deep_update(target: dict, source: dict) -> None:
|
||||
"""Recursively update a dict with another dict."""
|
||||
for key, value in source.items():
|
||||
if key in target and isinstance(target[key], dict) and isinstance(value, dict):
|
||||
_deep_update(target[key], value)
|
||||
else:
|
||||
target[key] = value
|
||||
@@ -0,0 +1,606 @@
|
||||
"""Agent fixer - Fixes common LLM generation errors."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from .utils import (
|
||||
ADDTODICTIONARY_BLOCK_ID,
|
||||
ADDTOLIST_BLOCK_ID,
|
||||
CODE_EXECUTION_BLOCK_ID,
|
||||
CONDITION_BLOCK_ID,
|
||||
CREATEDICT_BLOCK_ID,
|
||||
CREATELIST_BLOCK_ID,
|
||||
DATA_SAMPLING_BLOCK_ID,
|
||||
DOUBLE_CURLY_BRACES_BLOCK_IDS,
|
||||
GET_CURRENT_DATE_BLOCK_ID,
|
||||
STORE_VALUE_BLOCK_ID,
|
||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
||||
get_blocks_info,
|
||||
is_valid_uuid,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fix_agent_ids(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix invalid UUIDs in agent and link IDs."""
|
||||
# Fix agent ID
|
||||
if not is_valid_uuid(agent.get("id", "")):
|
||||
agent["id"] = str(uuid.uuid4())
|
||||
logger.debug(f"Fixed agent ID: {agent['id']}")
|
||||
|
||||
# Fix node IDs
|
||||
id_mapping = {} # Old ID -> New ID
|
||||
for node in agent.get("nodes", []):
|
||||
if not is_valid_uuid(node.get("id", "")):
|
||||
old_id = node.get("id", "")
|
||||
new_id = str(uuid.uuid4())
|
||||
id_mapping[old_id] = new_id
|
||||
node["id"] = new_id
|
||||
logger.debug(f"Fixed node ID: {old_id} -> {new_id}")
|
||||
|
||||
# Fix link IDs and update references
|
||||
for link in agent.get("links", []):
|
||||
if not is_valid_uuid(link.get("id", "")):
|
||||
link["id"] = str(uuid.uuid4())
|
||||
logger.debug(f"Fixed link ID: {link['id']}")
|
||||
|
||||
# Update source/sink IDs if they were remapped
|
||||
if link.get("source_id") in id_mapping:
|
||||
link["source_id"] = id_mapping[link["source_id"]]
|
||||
if link.get("sink_id") in id_mapping:
|
||||
link["sink_id"] = id_mapping[link["sink_id"]]
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_double_curly_braces(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix single curly braces to double in template blocks."""
|
||||
for node in agent.get("nodes", []):
|
||||
if node.get("block_id") not in DOUBLE_CURLY_BRACES_BLOCK_IDS:
|
||||
continue
|
||||
|
||||
input_data = node.get("input_default", {})
|
||||
for key in ("prompt", "format"):
|
||||
if key in input_data and isinstance(input_data[key], str):
|
||||
original = input_data[key]
|
||||
# Fix simple variable references: {var} -> {{var}}
|
||||
fixed = re.sub(
|
||||
r"(?<!\{)\{([a-zA-Z_][a-zA-Z0-9_]*)\}(?!\})",
|
||||
r"{{\1}}",
|
||||
original,
|
||||
)
|
||||
if fixed != original:
|
||||
input_data[key] = fixed
|
||||
logger.debug(f"Fixed curly braces in {key}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_storevalue_before_condition(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Add StoreValueBlock before ConditionBlock if needed for value2."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
|
||||
# Find all ConditionBlock nodes
|
||||
condition_node_ids = {
|
||||
node["id"] for node in nodes if node.get("block_id") == CONDITION_BLOCK_ID
|
||||
}
|
||||
|
||||
if not condition_node_ids:
|
||||
return agent
|
||||
|
||||
new_nodes = []
|
||||
new_links = []
|
||||
processed_conditions = set()
|
||||
|
||||
for link in links:
|
||||
sink_id = link.get("sink_id")
|
||||
sink_name = link.get("sink_name")
|
||||
|
||||
# Check if this link goes to a ConditionBlock's value2
|
||||
if sink_id in condition_node_ids and sink_name == "value2":
|
||||
source_node = next(
|
||||
(n for n in nodes if n["id"] == link.get("source_id")), None
|
||||
)
|
||||
|
||||
# Skip if source is already a StoreValueBlock
|
||||
if source_node and source_node.get("block_id") == STORE_VALUE_BLOCK_ID:
|
||||
continue
|
||||
|
||||
# Skip if we already processed this condition
|
||||
if sink_id in processed_conditions:
|
||||
continue
|
||||
|
||||
processed_conditions.add(sink_id)
|
||||
|
||||
# Create StoreValueBlock
|
||||
store_node_id = str(uuid.uuid4())
|
||||
store_node = {
|
||||
"id": store_node_id,
|
||||
"block_id": STORE_VALUE_BLOCK_ID,
|
||||
"input_default": {"data": None},
|
||||
"metadata": {"position": {"x": 0, "y": -100}},
|
||||
}
|
||||
new_nodes.append(store_node)
|
||||
|
||||
# Create link: original source -> StoreValueBlock
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": link["source_id"],
|
||||
"source_name": link["source_name"],
|
||||
"sink_id": store_node_id,
|
||||
"sink_name": "input",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Update original link: StoreValueBlock -> ConditionBlock
|
||||
link["source_id"] = store_node_id
|
||||
link["source_name"] = "output"
|
||||
|
||||
logger.debug(f"Added StoreValueBlock before ConditionBlock {sink_id}")
|
||||
|
||||
if new_nodes:
|
||||
agent["nodes"] = nodes + new_nodes
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_addtolist_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix AddToList blocks by adding prerequisite empty AddToList block.
|
||||
|
||||
When an AddToList block is found:
|
||||
1. Checks if there's a CreateListBlock before it
|
||||
2. Removes CreateListBlock if linked directly to AddToList
|
||||
3. Adds an empty AddToList block before the original
|
||||
4. Ensures the original has a self-referencing link
|
||||
"""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
new_nodes = []
|
||||
original_addtolist_ids = set()
|
||||
nodes_to_remove = set()
|
||||
links_to_remove = []
|
||||
|
||||
# First pass: identify CreateListBlock nodes to remove
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
||||
|
||||
if (
|
||||
source_node
|
||||
and sink_node
|
||||
and source_node.get("block_id") == CREATELIST_BLOCK_ID
|
||||
and sink_node.get("block_id") == ADDTOLIST_BLOCK_ID
|
||||
):
|
||||
nodes_to_remove.add(source_node.get("id"))
|
||||
links_to_remove.append(link)
|
||||
logger.debug(f"Removing CreateListBlock {source_node.get('id')}")
|
||||
|
||||
# Second pass: process AddToList blocks
|
||||
filtered_nodes = []
|
||||
for node in nodes:
|
||||
if node.get("id") in nodes_to_remove:
|
||||
continue
|
||||
|
||||
if node.get("block_id") == ADDTOLIST_BLOCK_ID:
|
||||
original_addtolist_ids.add(node.get("id"))
|
||||
node_id = node.get("id")
|
||||
pos = node.get("metadata", {}).get("position", {"x": 0, "y": 0})
|
||||
|
||||
# Check if already has prerequisite
|
||||
has_prereq = any(
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "list"
|
||||
and link.get("source_name") == "updated_list"
|
||||
for link in links
|
||||
)
|
||||
|
||||
if not has_prereq:
|
||||
# Remove links to "list" input (except self-reference)
|
||||
for link in links:
|
||||
if (
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "list"
|
||||
and link.get("source_id") != node_id
|
||||
and link not in links_to_remove
|
||||
):
|
||||
links_to_remove.append(link)
|
||||
|
||||
# Create prerequisite AddToList block
|
||||
prereq_id = str(uuid.uuid4())
|
||||
prereq_node = {
|
||||
"id": prereq_id,
|
||||
"block_id": ADDTOLIST_BLOCK_ID,
|
||||
"input_default": {"list": [], "entry": None, "entries": []},
|
||||
"metadata": {
|
||||
"position": {"x": pos.get("x", 0) - 800, "y": pos.get("y", 0)}
|
||||
},
|
||||
}
|
||||
new_nodes.append(prereq_node)
|
||||
|
||||
# Link prerequisite to original
|
||||
links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": prereq_id,
|
||||
"source_name": "updated_list",
|
||||
"sink_id": node_id,
|
||||
"sink_name": "list",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
logger.debug(f"Added prerequisite AddToList block for {node_id}")
|
||||
|
||||
filtered_nodes.append(node)
|
||||
|
||||
# Remove marked links
|
||||
filtered_links = [link for link in links if link not in links_to_remove]
|
||||
|
||||
# Add self-referencing links for original AddToList blocks
|
||||
for node in filtered_nodes + new_nodes:
|
||||
if (
|
||||
node.get("block_id") == ADDTOLIST_BLOCK_ID
|
||||
and node.get("id") in original_addtolist_ids
|
||||
):
|
||||
node_id = node.get("id")
|
||||
has_self_ref = any(
|
||||
link["source_id"] == node_id
|
||||
and link["sink_id"] == node_id
|
||||
and link["source_name"] == "updated_list"
|
||||
and link["sink_name"] == "list"
|
||||
for link in filtered_links
|
||||
)
|
||||
if not has_self_ref:
|
||||
filtered_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": node_id,
|
||||
"source_name": "updated_list",
|
||||
"sink_id": node_id,
|
||||
"sink_name": "list",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
logger.debug(f"Added self-reference for AddToList {node_id}")
|
||||
|
||||
agent["nodes"] = filtered_nodes + new_nodes
|
||||
agent["links"] = filtered_links
|
||||
return agent
|
||||
|
||||
|
||||
def fix_addtodictionary_blocks(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix AddToDictionary blocks by removing empty CreateDictionary nodes."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
nodes_to_remove = set()
|
||||
links_to_remove = []
|
||||
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
sink_node = next((n for n in nodes if n.get("id") == link.get("sink_id")), None)
|
||||
|
||||
if (
|
||||
source_node
|
||||
and sink_node
|
||||
and source_node.get("block_id") == CREATEDICT_BLOCK_ID
|
||||
and sink_node.get("block_id") == ADDTODICTIONARY_BLOCK_ID
|
||||
):
|
||||
nodes_to_remove.add(source_node.get("id"))
|
||||
links_to_remove.append(link)
|
||||
logger.debug(f"Removing CreateDictionary {source_node.get('id')}")
|
||||
|
||||
agent["nodes"] = [n for n in nodes if n.get("id") not in nodes_to_remove]
|
||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
||||
return agent
|
||||
|
||||
|
||||
def fix_code_execution_output(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix CodeExecutionBlock output: change 'response' to 'stdout_logs'."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
|
||||
for link in links:
|
||||
source_node = next(
|
||||
(n for n in nodes if n.get("id") == link.get("source_id")), None
|
||||
)
|
||||
if (
|
||||
source_node
|
||||
and source_node.get("block_id") == CODE_EXECUTION_BLOCK_ID
|
||||
and link.get("source_name") == "response"
|
||||
):
|
||||
link["source_name"] = "stdout_logs"
|
||||
logger.debug("Fixed CodeExecutionBlock output: response -> stdout_logs")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_data_sampling_sample_size(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix DataSamplingBlock by setting sample_size to 1 as default."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
links_to_remove = []
|
||||
|
||||
for node in nodes:
|
||||
if node.get("block_id") == DATA_SAMPLING_BLOCK_ID:
|
||||
node_id = node.get("id")
|
||||
input_default = node.get("input_default", {})
|
||||
|
||||
# Remove links to sample_size
|
||||
for link in links:
|
||||
if (
|
||||
link.get("sink_id") == node_id
|
||||
and link.get("sink_name") == "sample_size"
|
||||
):
|
||||
links_to_remove.append(link)
|
||||
|
||||
# Set default
|
||||
input_default["sample_size"] = 1
|
||||
node["input_default"] = input_default
|
||||
logger.debug(f"Fixed DataSamplingBlock {node_id} sample_size to 1")
|
||||
|
||||
if links_to_remove:
|
||||
agent["links"] = [link for link in links if link not in links_to_remove]
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_node_x_coordinates(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix node x-coordinates to ensure 800+ unit spacing between linked nodes."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
node_lookup = {n.get("id"): n for n in nodes}
|
||||
|
||||
for link in links:
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
source_node = node_lookup.get(source_id)
|
||||
sink_node = node_lookup.get(sink_id)
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
|
||||
source_pos = source_node.get("metadata", {}).get("position", {})
|
||||
sink_pos = sink_node.get("metadata", {}).get("position", {})
|
||||
|
||||
source_x = source_pos.get("x", 0)
|
||||
sink_x = sink_pos.get("x", 0)
|
||||
|
||||
if abs(sink_x - source_x) < 800:
|
||||
new_x = source_x + 800
|
||||
if "metadata" not in sink_node:
|
||||
sink_node["metadata"] = {}
|
||||
if "position" not in sink_node["metadata"]:
|
||||
sink_node["metadata"]["position"] = {}
|
||||
sink_node["metadata"]["position"]["x"] = new_x
|
||||
logger.debug(f"Fixed node {sink_id} x: {sink_x} -> {new_x}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_getcurrentdate_offset(agent: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix GetCurrentDateBlock offset to ensure it's positive."""
|
||||
for node in agent.get("nodes", []):
|
||||
if node.get("block_id") == GET_CURRENT_DATE_BLOCK_ID:
|
||||
input_default = node.get("input_default", {})
|
||||
if "offset" in input_default:
|
||||
offset = input_default["offset"]
|
||||
if isinstance(offset, (int, float)) and offset < 0:
|
||||
input_default["offset"] = abs(offset)
|
||||
logger.debug(f"Fixed offset: {offset} -> {abs(offset)}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_ai_model_parameter(
|
||||
agent: dict[str, Any],
|
||||
blocks_info: list[dict[str, Any]],
|
||||
default_model: str = "gpt-4o",
|
||||
) -> dict[str, Any]:
|
||||
"""Add default model parameter to AI blocks if missing."""
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
block = block_map.get(block_id)
|
||||
|
||||
if not block:
|
||||
continue
|
||||
|
||||
# Check if block has AI category
|
||||
categories = block.get("categories", [])
|
||||
is_ai_block = any(
|
||||
cat.get("category") == "AI" for cat in categories if isinstance(cat, dict)
|
||||
)
|
||||
|
||||
if is_ai_block:
|
||||
input_default = node.get("input_default", {})
|
||||
if "model" not in input_default:
|
||||
input_default["model"] = default_model
|
||||
node["input_default"] = input_default
|
||||
logger.debug(
|
||||
f"Added model '{default_model}' to AI block {node.get('id')}"
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_link_static_properties(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Fix is_static property based on source block's staticOutput."""
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
if not source_block:
|
||||
continue
|
||||
|
||||
static_output = source_block.get("staticOutput", False)
|
||||
if link.get("is_static") != static_output:
|
||||
link["is_static"] = static_output
|
||||
logger.debug(f"Fixed link {link.get('id')} is_static to {static_output}")
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def fix_data_type_mismatch(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Fix data type mismatches by inserting UniversalTypeConverterBlock."""
|
||||
nodes = agent.get("nodes", [])
|
||||
links = agent.get("links", [])
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in nodes}
|
||||
|
||||
def get_property_type(schema: dict, name: str) -> str | None:
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema:
|
||||
return parent_schema["properties"].get(child, {}).get("type")
|
||||
return None
|
||||
return schema.get(name, {}).get("type")
|
||||
|
||||
def are_types_compatible(src: str, sink: str) -> bool:
|
||||
if {src, sink} <= {"integer", "number"}:
|
||||
return True
|
||||
return src == sink
|
||||
|
||||
type_mapping = {
|
||||
"string": "string",
|
||||
"text": "string",
|
||||
"integer": "number",
|
||||
"number": "number",
|
||||
"float": "number",
|
||||
"boolean": "boolean",
|
||||
"bool": "boolean",
|
||||
"array": "list",
|
||||
"list": "list",
|
||||
"object": "dictionary",
|
||||
"dict": "dictionary",
|
||||
"dictionary": "dictionary",
|
||||
}
|
||||
|
||||
new_links = []
|
||||
nodes_to_add = []
|
||||
|
||||
for link in links:
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
|
||||
if not source_node or not sink_node:
|
||||
new_links.append(link)
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
sink_block = block_map.get(sink_node.get("block_id"))
|
||||
|
||||
if not source_block or not sink_block:
|
||||
new_links.append(link)
|
||||
continue
|
||||
|
||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
||||
|
||||
source_type = get_property_type(source_outputs, link.get("source_name", ""))
|
||||
sink_type = get_property_type(sink_inputs, link.get("sink_name", ""))
|
||||
|
||||
if (
|
||||
source_type
|
||||
and sink_type
|
||||
and not are_types_compatible(source_type, sink_type)
|
||||
):
|
||||
# Insert type converter
|
||||
converter_id = str(uuid.uuid4())
|
||||
target_type = type_mapping.get(sink_type, sink_type)
|
||||
|
||||
converter_node = {
|
||||
"id": converter_id,
|
||||
"block_id": UNIVERSAL_TYPE_CONVERTER_BLOCK_ID,
|
||||
"input_default": {"type": target_type},
|
||||
"metadata": {"position": {"x": 0, "y": 100}},
|
||||
}
|
||||
nodes_to_add.append(converter_node)
|
||||
|
||||
# source -> converter
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": link["source_id"],
|
||||
"source_name": link["source_name"],
|
||||
"sink_id": converter_id,
|
||||
"sink_name": "value",
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
# converter -> sink
|
||||
new_links.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"source_id": converter_id,
|
||||
"source_name": "value",
|
||||
"sink_id": link["sink_id"],
|
||||
"sink_name": link["sink_name"],
|
||||
"is_static": False,
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Inserted type converter: {source_type} -> {target_type}")
|
||||
else:
|
||||
new_links.append(link)
|
||||
|
||||
if nodes_to_add:
|
||||
agent["nodes"] = nodes + nodes_to_add
|
||||
agent["links"] = new_links
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def apply_all_fixes(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Apply all fixes to an agent JSON.
|
||||
|
||||
Args:
|
||||
agent: Agent JSON dict
|
||||
blocks_info: Optional list of block info dicts for advanced fixes
|
||||
|
||||
Returns:
|
||||
Fixed agent JSON
|
||||
"""
|
||||
# Basic fixes (no block info needed)
|
||||
agent = fix_agent_ids(agent)
|
||||
agent = fix_double_curly_braces(agent)
|
||||
agent = fix_storevalue_before_condition(agent)
|
||||
agent = fix_addtolist_blocks(agent)
|
||||
agent = fix_addtodictionary_blocks(agent)
|
||||
agent = fix_code_execution_output(agent)
|
||||
agent = fix_data_sampling_sample_size(agent)
|
||||
agent = fix_node_x_coordinates(agent)
|
||||
agent = fix_getcurrentdate_offset(agent)
|
||||
|
||||
# Advanced fixes (require block info)
|
||||
if blocks_info is None:
|
||||
blocks_info = get_blocks_info()
|
||||
|
||||
agent = fix_ai_model_parameter(agent, blocks_info)
|
||||
agent = fix_link_static_properties(agent, blocks_info)
|
||||
agent = fix_data_type_mismatch(agent, blocks_info)
|
||||
|
||||
return agent
|
||||
@@ -0,0 +1,225 @@
|
||||
"""Prompt templates for agent generation."""
|
||||
|
||||
DECOMPOSITION_PROMPT = """
|
||||
You are an expert AutoGPT Workflow Decomposer. Your task is to analyze a user's high-level goal and break it down into a clear, step-by-step plan using the available blocks.
|
||||
|
||||
Each step should represent a distinct, automatable action suitable for execution by an AI automation system.
|
||||
|
||||
---
|
||||
|
||||
FIRST: Analyze the user's goal and determine:
|
||||
1) Design-time configuration (fixed settings that won't change per run)
|
||||
2) Runtime inputs (values the agent's end-user will provide each time it runs)
|
||||
|
||||
For anything that can vary per run (email addresses, names, dates, search terms, etc.):
|
||||
- DO NOT ask for the actual value
|
||||
- Instead, define it as an Agent Input with a clear name, type, and description
|
||||
|
||||
Only ask clarifying questions about design-time config that affects how you build the workflow:
|
||||
- Which external service to use (e.g., "Gmail vs Outlook", "Notion vs Google Docs")
|
||||
- Required formats or structures (e.g., "CSV, JSON, or PDF output?")
|
||||
- Business rules that must be hard-coded
|
||||
|
||||
IMPORTANT CLARIFICATIONS POLICY:
|
||||
- Ask no more than five essential questions
|
||||
- Do not ask for concrete values that can be provided at runtime as Agent Inputs
|
||||
- Do not ask for API keys or credentials; the platform handles those directly
|
||||
- If there is enough information to infer reasonable defaults, prefer to propose defaults
|
||||
|
||||
---
|
||||
|
||||
GUIDELINES:
|
||||
1. List each step as a numbered item
|
||||
2. Describe the action clearly and specify inputs/outputs
|
||||
3. Ensure steps are in logical, sequential order
|
||||
4. Mention block names naturally (e.g., "Use GetWeatherByLocationBlock to...")
|
||||
5. Help the user reach their goal efficiently
|
||||
|
||||
---
|
||||
|
||||
RULES:
|
||||
1. OUTPUT FORMAT: Only output either clarifying questions OR step-by-step instructions, not both
|
||||
2. USE ONLY THE BLOCKS PROVIDED
|
||||
3. ALL required_input fields must be provided
|
||||
4. Data types of linked properties must match
|
||||
5. Write expert-level prompts for AI-related blocks
|
||||
|
||||
---
|
||||
|
||||
CRITICAL BLOCK RESTRICTIONS:
|
||||
1. AddToListBlock: Outputs updated list EVERY addition, not after all additions
|
||||
2. SendEmailBlock: Draft the email for user review; set SMTP config based on email type
|
||||
3. ConditionBlock: value2 is reference, value1 is contrast
|
||||
4. CodeExecutionBlock: DO NOT USE - use AI blocks instead
|
||||
5. ReadCsvBlock: Only use the 'rows' output, not 'row'
|
||||
|
||||
---
|
||||
|
||||
OUTPUT FORMAT:
|
||||
|
||||
If more information is needed:
|
||||
```json
|
||||
{{
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
{{
|
||||
"question": "Which email provider should be used? (Gmail, Outlook, custom SMTP)",
|
||||
"keyword": "email_provider",
|
||||
"example": "Gmail"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
If ready to proceed:
|
||||
```json
|
||||
{{
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{{
|
||||
"step_number": 1,
|
||||
"block_name": "AgentShortTextInputBlock",
|
||||
"description": "Get the URL of the content to analyze.",
|
||||
"inputs": [{{"name": "name", "value": "URL"}}],
|
||||
"outputs": [{{"name": "result", "description": "The URL entered by user"}}]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
AVAILABLE BLOCKS:
|
||||
{block_summaries}
|
||||
"""
|
||||
|
||||
GENERATION_PROMPT = """
|
||||
You are an expert AI workflow builder. Generate a valid agent JSON from the given instructions.
|
||||
|
||||
---
|
||||
|
||||
NODES:
|
||||
Each node must include:
|
||||
- `id`: Unique UUID v4 (e.g. `a8f5b1e2-c3d4-4e5f-8a9b-0c1d2e3f4a5b`)
|
||||
- `block_id`: The block identifier (must match an Allowed Block)
|
||||
- `input_default`: Dict of inputs (can be empty if no static inputs needed)
|
||||
- `metadata`: Must contain:
|
||||
- `position`: {{"x": number, "y": number}} - adjacent nodes should differ by 800+ in X
|
||||
- `customized_name`: Clear name describing this block's purpose in the workflow
|
||||
|
||||
---
|
||||
|
||||
LINKS:
|
||||
Each link connects a source node's output to a sink node's input:
|
||||
- `id`: MUST be UUID v4 (NOT "link-1", "link-2", etc.)
|
||||
- `source_id`: ID of the source node
|
||||
- `source_name`: Output field name from the source block
|
||||
- `sink_id`: ID of the sink node
|
||||
- `sink_name`: Input field name on the sink block
|
||||
- `is_static`: true only if source block has static_output: true
|
||||
|
||||
CRITICAL: All IDs must be valid UUID v4 format!
|
||||
|
||||
---
|
||||
|
||||
AGENT (GRAPH):
|
||||
Wrap nodes and links in:
|
||||
- `id`: UUID of the agent
|
||||
- `name`: Short, generic name (avoid specific company names, URLs)
|
||||
- `description`: Short, generic description
|
||||
- `nodes`: List of all nodes
|
||||
- `links`: List of all links
|
||||
- `version`: 1
|
||||
- `is_active`: true
|
||||
|
||||
---
|
||||
|
||||
TIPS:
|
||||
- All required_input fields must be provided via input_default or a valid link
|
||||
- Ensure consistent source_id and sink_id references
|
||||
- Avoid dangling links
|
||||
- Input/output pins must match block schemas
|
||||
- Do not invent unknown block_ids
|
||||
|
||||
---
|
||||
|
||||
ALLOWED BLOCKS:
|
||||
{block_summaries}
|
||||
|
||||
---
|
||||
|
||||
Generate the complete agent JSON. Output ONLY valid JSON, no explanation.
|
||||
"""
|
||||
|
||||
PATCH_PROMPT = """
|
||||
You are an expert at modifying AutoGPT agent workflows. Given the current agent and a modification request, generate a JSON patch to update the agent.
|
||||
|
||||
CURRENT AGENT:
|
||||
{current_agent}
|
||||
|
||||
AVAILABLE BLOCKS:
|
||||
{block_summaries}
|
||||
|
||||
---
|
||||
|
||||
PATCH FORMAT:
|
||||
Return a JSON object with the following structure:
|
||||
|
||||
```json
|
||||
{{
|
||||
"type": "patch",
|
||||
"intent": "Brief description of what the patch does",
|
||||
"patches": [
|
||||
{{
|
||||
"type": "modify",
|
||||
"node_id": "uuid-of-node-to-modify",
|
||||
"changes": {{
|
||||
"input_default": {{"field": "new_value"}},
|
||||
"metadata": {{"customized_name": "New Name"}}
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"type": "add",
|
||||
"new_nodes": [
|
||||
{{
|
||||
"id": "new-uuid",
|
||||
"block_id": "block-uuid",
|
||||
"input_default": {{}},
|
||||
"metadata": {{"position": {{"x": 0, "y": 0}}, "customized_name": "Name"}}
|
||||
}}
|
||||
],
|
||||
"new_links": [
|
||||
{{
|
||||
"id": "link-uuid",
|
||||
"source_id": "source-node-id",
|
||||
"source_name": "output_field",
|
||||
"sink_id": "sink-node-id",
|
||||
"sink_name": "input_field"
|
||||
}}
|
||||
]
|
||||
}},
|
||||
{{
|
||||
"type": "remove",
|
||||
"node_ids": ["uuid-of-node-to-remove"],
|
||||
"link_ids": ["uuid-of-link-to-remove"]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
If you need more information, return:
|
||||
```json
|
||||
{{
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
{{
|
||||
"question": "What specific change do you want?",
|
||||
"keyword": "change_type",
|
||||
"example": "Add error handling"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
Generate the minimal patch needed. Output ONLY valid JSON.
|
||||
"""
|
||||
@@ -0,0 +1,213 @@
|
||||
"""Utilities for agent generation."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
# UUID validation regex
|
||||
UUID_REGEX = re.compile(
|
||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$"
|
||||
)
|
||||
|
||||
# Block IDs for various fixes
|
||||
STORE_VALUE_BLOCK_ID = "1ff065e9-88e8-4358-9d82-8dc91f622ba9"
|
||||
CONDITION_BLOCK_ID = "715696a0-e1da-45c8-b209-c2fa9c3b0be6"
|
||||
ADDTOLIST_BLOCK_ID = "aeb08fc1-2fc1-4141-bc8e-f758f183a822"
|
||||
ADDTODICTIONARY_BLOCK_ID = "31d1064e-7446-4693-a7d4-65e5ca1180d1"
|
||||
CREATELIST_BLOCK_ID = "a912d5c7-6e00-4542-b2a9-8034136930e4"
|
||||
CREATEDICT_BLOCK_ID = "b924ddf4-de4f-4b56-9a85-358930dcbc91"
|
||||
CODE_EXECUTION_BLOCK_ID = "0b02b072-abe7-11ef-8372-fb5d162dd712"
|
||||
DATA_SAMPLING_BLOCK_ID = "4a448883-71fa-49cf-91cf-70d793bd7d87"
|
||||
UNIVERSAL_TYPE_CONVERTER_BLOCK_ID = "95d1b990-ce13-4d88-9737-ba5c2070c97b"
|
||||
GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
|
||||
|
||||
DOUBLE_CURLY_BRACES_BLOCK_IDS = [
|
||||
"44f6c8ad-d75c-4ae1-8209-aad1c0326928", # FillTextTemplateBlock
|
||||
"6ab085e2-20b3-4055-bc3e-08036e01eca6",
|
||||
"90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
|
||||
"363ae599-353e-4804-937e-b2ee3cef3da4", # AgentOutputBlock
|
||||
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
"db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
|
||||
"3a7c4b8d-6e2f-4a5d-b9c1-f8d23c5a9b0e",
|
||||
"ed1ae7a0-b770-4089-b520-1f0005fad19a",
|
||||
"a892b8d9-3e4e-4e9c-9c1e-75f8efcf1bfa",
|
||||
"b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1",
|
||||
"716a67b3-6760-42e7-86dc-18645c6e00fc",
|
||||
"530cf046-2ce0-4854-ae2c-659db17c7a46",
|
||||
"ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
"1f292d4a-41a4-4977-9684-7c8d560b9f91", # LLM blocks
|
||||
"32a87eab-381e-4dd4-bdb8-4c47151be35a",
|
||||
]
|
||||
|
||||
|
||||
def is_valid_uuid(value: str) -> bool:
|
||||
"""Check if a string is a valid UUID v4."""
|
||||
return isinstance(value, str) and UUID_REGEX.match(value) is not None
|
||||
|
||||
|
||||
def _compact_schema(schema: dict) -> dict[str, str]:
|
||||
"""Extract compact type info from a JSON schema properties dict.
|
||||
|
||||
Returns a dict of {field_name: type_string} for essential info only.
|
||||
"""
|
||||
props = schema.get("properties", {})
|
||||
result = {}
|
||||
|
||||
for name, prop in props.items():
|
||||
# Skip internal/complex fields
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
# Get type string
|
||||
type_str = prop.get("type", "any")
|
||||
|
||||
# Handle anyOf/oneOf (optional types)
|
||||
if "anyOf" in prop:
|
||||
types = [t.get("type", "?") for t in prop["anyOf"] if t.get("type")]
|
||||
type_str = "|".join(types) if types else "any"
|
||||
elif "allOf" in prop:
|
||||
type_str = "object"
|
||||
|
||||
# Add array item type if present
|
||||
if type_str == "array" and "items" in prop:
|
||||
items = prop["items"]
|
||||
if isinstance(items, dict):
|
||||
item_type = items.get("type", "any")
|
||||
type_str = f"array[{item_type}]"
|
||||
|
||||
result[name] = type_str
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_block_summaries(include_schemas: bool = True) -> str:
|
||||
"""Generate compact block summaries for prompts.
|
||||
|
||||
Args:
|
||||
include_schemas: Whether to include input/output type info
|
||||
|
||||
Returns:
|
||||
Formatted string of block summaries (compact format)
|
||||
"""
|
||||
blocks = get_blocks()
|
||||
summaries = []
|
||||
|
||||
for block_id, block_cls in blocks.items():
|
||||
block = block_cls()
|
||||
name = block.name
|
||||
desc = getattr(block, "description", "") or ""
|
||||
|
||||
# Truncate description
|
||||
if len(desc) > 150:
|
||||
desc = desc[:147] + "..."
|
||||
|
||||
if not include_schemas:
|
||||
summaries.append(f"- {name} (id: {block_id}): {desc}")
|
||||
else:
|
||||
# Compact format with type info only
|
||||
inputs = {}
|
||||
outputs = {}
|
||||
required = []
|
||||
|
||||
if hasattr(block, "input_schema"):
|
||||
try:
|
||||
schema = block.input_schema.jsonschema()
|
||||
inputs = _compact_schema(schema)
|
||||
required = schema.get("required", [])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if hasattr(block, "output_schema"):
|
||||
try:
|
||||
schema = block.output_schema.jsonschema()
|
||||
outputs = _compact_schema(schema)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build compact line format
|
||||
# Format: NAME (id): desc | in: {field:type, ...} [required] | out: {field:type}
|
||||
in_str = ", ".join(f"{k}:{v}" for k, v in inputs.items())
|
||||
out_str = ", ".join(f"{k}:{v}" for k, v in outputs.items())
|
||||
req_str = f" req=[{','.join(required)}]" if required else ""
|
||||
|
||||
static = " [static]" if getattr(block, "static_output", False) else ""
|
||||
|
||||
line = f"- {name} (id: {block_id}): {desc}"
|
||||
if in_str:
|
||||
line += f"\n in: {{{in_str}}}{req_str}"
|
||||
if out_str:
|
||||
line += f"\n out: {{{out_str}}}{static}"
|
||||
|
||||
summaries.append(line)
|
||||
|
||||
return "\n".join(summaries)
|
||||
|
||||
|
||||
def get_blocks_info() -> list[dict[str, Any]]:
|
||||
"""Get block information with schemas for validation and fixing."""
|
||||
blocks = get_blocks()
|
||||
blocks_info = []
|
||||
for block_id, block_cls in blocks.items():
|
||||
block = block_cls()
|
||||
blocks_info.append(
|
||||
{
|
||||
"id": block_id,
|
||||
"name": block.name,
|
||||
"description": getattr(block, "description", ""),
|
||||
"categories": getattr(block, "categories", []),
|
||||
"staticOutput": getattr(block, "static_output", False),
|
||||
"inputSchema": (
|
||||
block.input_schema.jsonschema()
|
||||
if hasattr(block, "input_schema")
|
||||
else {}
|
||||
),
|
||||
"outputSchema": (
|
||||
block.output_schema.jsonschema()
|
||||
if hasattr(block, "output_schema")
|
||||
else {}
|
||||
),
|
||||
}
|
||||
)
|
||||
return blocks_info
|
||||
|
||||
|
||||
def parse_json_from_llm(text: str) -> dict[str, Any] | None:
|
||||
"""Extract JSON from LLM response (handles markdown code blocks)."""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# Try fenced code block
|
||||
match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group(1).strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try raw text
|
||||
try:
|
||||
return json.loads(text.strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding {...} span
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start != -1 and end > start:
|
||||
try:
|
||||
return json.loads(text[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding [...] span
|
||||
start = text.find("[")
|
||||
end = text.rfind("]")
|
||||
if start != -1 and end > start:
|
||||
try:
|
||||
return json.loads(text[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,279 @@
|
||||
"""Agent validator - Validates agent structure and connections."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from .utils import get_blocks_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentValidator:
|
||||
"""Validator for AutoGPT agents with detailed error reporting."""
|
||||
|
||||
def __init__(self):
|
||||
self.errors: list[str] = []
|
||||
|
||||
def add_error(self, error: str) -> None:
|
||||
"""Add an error message."""
|
||||
self.errors.append(error)
|
||||
|
||||
def validate_block_existence(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate all block IDs exist in the blocks library."""
|
||||
valid = True
|
||||
valid_block_ids = {b.get("id") for b in blocks_info if b.get("id")}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
node_id = node.get("id")
|
||||
|
||||
if not block_id:
|
||||
self.add_error(f"Node '{node_id}' is missing 'block_id' field.")
|
||||
valid = False
|
||||
continue
|
||||
|
||||
if block_id not in valid_block_ids:
|
||||
self.add_error(
|
||||
f"Node '{node_id}' references block_id '{block_id}' which does not exist."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_link_node_references(self, agent: dict[str, Any]) -> bool:
|
||||
"""Validate all node IDs referenced in links exist."""
|
||||
valid = True
|
||||
valid_node_ids = {n.get("id") for n in agent.get("nodes", []) if n.get("id")}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
link_id = link.get("id", "Unknown")
|
||||
source_id = link.get("source_id")
|
||||
sink_id = link.get("sink_id")
|
||||
|
||||
if not source_id:
|
||||
self.add_error(f"Link '{link_id}' is missing 'source_id'.")
|
||||
valid = False
|
||||
elif source_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references non-existent source_id '{source_id}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
if not sink_id:
|
||||
self.add_error(f"Link '{link_id}' is missing 'sink_id'.")
|
||||
valid = False
|
||||
elif sink_id not in valid_node_ids:
|
||||
self.add_error(
|
||||
f"Link '{link_id}' references non-existent sink_id '{sink_id}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_required_inputs(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate required inputs are provided."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
block_id = node.get("block_id")
|
||||
block = block_map.get(block_id)
|
||||
|
||||
if not block:
|
||||
continue
|
||||
|
||||
required_inputs = block.get("inputSchema", {}).get("required", [])
|
||||
input_defaults = node.get("input_default", {})
|
||||
node_id = node.get("id")
|
||||
|
||||
# Get linked inputs
|
||||
linked_inputs = {
|
||||
link["sink_name"]
|
||||
for link in agent.get("links", [])
|
||||
if link.get("sink_id") == node_id
|
||||
}
|
||||
|
||||
for req_input in required_inputs:
|
||||
if (
|
||||
req_input not in input_defaults
|
||||
and req_input not in linked_inputs
|
||||
and req_input != "credentials"
|
||||
):
|
||||
block_name = block.get("name", "Unknown Block")
|
||||
self.add_error(
|
||||
f"Node '{node_id}' ({block_name}) is missing required input '{req_input}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_data_type_compatibility(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate linked data types are compatible."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
def get_type(schema: dict, name: str) -> str | None:
|
||||
if "_#_" in name:
|
||||
parent, child = name.split("_#_", 1)
|
||||
parent_schema = schema.get(parent, {})
|
||||
if "properties" in parent_schema:
|
||||
return parent_schema["properties"].get(child, {}).get("type")
|
||||
return None
|
||||
return schema.get(name, {}).get("type")
|
||||
|
||||
def are_compatible(src: str, sink: str) -> bool:
|
||||
if {src, sink} <= {"integer", "number"}:
|
||||
return True
|
||||
return src == sink
|
||||
|
||||
for link in agent.get("links", []):
|
||||
source_node = node_lookup.get(link.get("source_id"))
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
|
||||
if not source_node or not sink_node:
|
||||
continue
|
||||
|
||||
source_block = block_map.get(source_node.get("block_id"))
|
||||
sink_block = block_map.get(sink_node.get("block_id"))
|
||||
|
||||
if not source_block or not sink_block:
|
||||
continue
|
||||
|
||||
source_outputs = source_block.get("outputSchema", {}).get("properties", {})
|
||||
sink_inputs = sink_block.get("inputSchema", {}).get("properties", {})
|
||||
|
||||
source_type = get_type(source_outputs, link.get("source_name", ""))
|
||||
sink_type = get_type(sink_inputs, link.get("sink_name", ""))
|
||||
|
||||
if source_type and sink_type and not are_compatible(source_type, sink_type):
|
||||
self.add_error(
|
||||
f"Type mismatch: {source_block.get('name')} output '{link['source_name']}' "
|
||||
f"({source_type}) -> {sink_block.get('name')} input '{link['sink_name']}' ({sink_type})."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_nested_sink_links(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]]
|
||||
) -> bool:
|
||||
"""Validate nested sink links (with _#_ notation)."""
|
||||
valid = True
|
||||
block_map = {b.get("id"): b for b in blocks_info}
|
||||
node_lookup = {n.get("id"): n for n in agent.get("nodes", [])}
|
||||
|
||||
for link in agent.get("links", []):
|
||||
sink_name = link.get("sink_name", "")
|
||||
|
||||
if "_#_" in sink_name:
|
||||
parent, child = sink_name.split("_#_", 1)
|
||||
|
||||
sink_node = node_lookup.get(link.get("sink_id"))
|
||||
if not sink_node:
|
||||
continue
|
||||
|
||||
block = block_map.get(sink_node.get("block_id"))
|
||||
if not block:
|
||||
continue
|
||||
|
||||
input_props = block.get("inputSchema", {}).get("properties", {})
|
||||
parent_schema = input_props.get(parent)
|
||||
|
||||
if not parent_schema:
|
||||
self.add_error(
|
||||
f"Invalid nested link '{sink_name}': parent '{parent}' not found."
|
||||
)
|
||||
valid = False
|
||||
continue
|
||||
|
||||
if not parent_schema.get("additionalProperties"):
|
||||
if not (
|
||||
isinstance(parent_schema, dict)
|
||||
and "properties" in parent_schema
|
||||
and child in parent_schema.get("properties", {})
|
||||
):
|
||||
self.add_error(
|
||||
f"Invalid nested link '{sink_name}': child '{child}' not found in '{parent}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate_prompt_spaces(self, agent: dict[str, Any]) -> bool:
|
||||
"""Validate prompts don't have spaces in template variables."""
|
||||
valid = True
|
||||
|
||||
for node in agent.get("nodes", []):
|
||||
input_default = node.get("input_default", {})
|
||||
prompt = input_default.get("prompt", "")
|
||||
|
||||
if not isinstance(prompt, str):
|
||||
continue
|
||||
|
||||
# Find {{...}} with spaces
|
||||
matches = re.finditer(r"\{\{([^}]+)\}\}", prompt)
|
||||
for match in matches:
|
||||
content = match.group(1)
|
||||
if " " in content:
|
||||
self.add_error(
|
||||
f"Node '{node.get('id')}' has spaces in template variable: "
|
||||
f"'{{{{{content}}}}}' should be '{{{{{content.replace(' ', '_')}}}}}'."
|
||||
)
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
def validate(
|
||||
self, agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Run all validations.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
self.errors = []
|
||||
|
||||
if blocks_info is None:
|
||||
blocks_info = get_blocks_info()
|
||||
|
||||
checks = [
|
||||
self.validate_block_existence(agent, blocks_info),
|
||||
self.validate_link_node_references(agent),
|
||||
self.validate_required_inputs(agent, blocks_info),
|
||||
self.validate_data_type_compatibility(agent, blocks_info),
|
||||
self.validate_nested_sink_links(agent, blocks_info),
|
||||
self.validate_prompt_spaces(agent),
|
||||
]
|
||||
|
||||
all_passed = all(checks)
|
||||
|
||||
if all_passed:
|
||||
logger.info("Agent validation successful")
|
||||
return True, None
|
||||
|
||||
error_message = "Agent validation failed:\n"
|
||||
for i, error in enumerate(self.errors, 1):
|
||||
error_message += f"{i}. {error}\n"
|
||||
|
||||
logger.warning(f"Agent validation failed with {len(self.errors)} errors")
|
||||
return False, error_message
|
||||
|
||||
|
||||
def validate_agent(
|
||||
agent: dict[str, Any], blocks_info: list[dict[str, Any]] | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Convenience function to validate an agent.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
validator = AgentValidator()
|
||||
return validator.validate(agent, blocks_info)
|
||||
@@ -5,6 +5,7 @@ import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
@@ -103,7 +104,7 @@ class AgentOutputTool(BaseTool):
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "agent_output"
|
||||
return "view_agent_output"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -328,6 +329,7 @@ class AgentOutputTool(BaseTool):
|
||||
total_executions=len(available_executions) if available_executions else 1,
|
||||
)
|
||||
|
||||
@observe(as_type="tool", name="view_agent_output")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
|
||||
@@ -0,0 +1,282 @@
|
||||
"""CreateAgentTool - Creates agents from natural language descriptions."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
apply_all_fixes,
|
||||
decompose_goal,
|
||||
generate_agent,
|
||||
get_blocks_info,
|
||||
save_agent_to_library,
|
||||
validate_agent,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum retries for agent generation with validation feedback
|
||||
MAX_GENERATION_RETRIES = 2
|
||||
|
||||
|
||||
class CreateAgentTool(BaseTool):
|
||||
"""Tool for creating agents from natural language descriptions."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "create_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new agent workflow from a natural language description. "
|
||||
"First generates a preview, then saves to library if save=true."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Natural language description of what the agent should do. "
|
||||
"Be specific about inputs, outputs, and the workflow steps."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions. "
|
||||
"Include any preferences or constraints mentioned by the user."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the agent to the user's library. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["description"],
|
||||
}
|
||||
|
||||
@observe(as_type="tool", name="create_agent")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the create_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Decompose the description into steps (may return clarifying questions)
|
||||
2. Generate agent JSON from the steps
|
||||
3. Apply fixes to correct common LLM errors
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
description = kwargs.get("description", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
message="Please provide a description of what the agent should do.",
|
||||
error="Missing description parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 1: Decompose goal into steps
|
||||
try:
|
||||
decomposition_result = await decompose_goal(description, context)
|
||||
except ValueError as e:
|
||||
# Handle missing API key or configuration errors
|
||||
return ErrorResponse(
|
||||
message=f"Agent generation is not configured: {str(e)}",
|
||||
error="configuration_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to analyze the goal. Please try rephrasing.",
|
||||
error="Decomposition failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if decomposition_result.get("type") == "clarifying_questions":
|
||||
questions = decomposition_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to create this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check for unachievable/vague goals
|
||||
if decomposition_result.get("type") == "unachievable_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get("reason", "")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"This goal cannot be accomplished with the available blocks. "
|
||||
f"{reason} "
|
||||
f"Suggestion: {suggested}"
|
||||
),
|
||||
error="unachievable_goal",
|
||||
details={"suggested_goal": suggested, "reason": reason},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result.get("type") == "vague_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The goal is too vague to create a specific workflow. "
|
||||
f"Suggestion: {suggested}"
|
||||
),
|
||||
error="vague_goal",
|
||||
details={"suggested_goal": suggested},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 2: Generate agent JSON with retry on validation failure
|
||||
blocks_info = get_blocks_info()
|
||||
agent_json = None
|
||||
validation_errors = None
|
||||
|
||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
||||
# Generate agent (include validation errors from previous attempt)
|
||||
if attempt == 0:
|
||||
agent_json = await generate_agent(decomposition_result)
|
||||
else:
|
||||
# Retry with validation error feedback
|
||||
logger.info(
|
||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
||||
)
|
||||
retry_instructions = {
|
||||
**decomposition_result,
|
||||
"previous_errors": validation_errors,
|
||||
"retry_instructions": (
|
||||
"The previous generation had validation errors. "
|
||||
"Please fix these issues in the new generation:\n"
|
||||
f"{validation_errors}"
|
||||
),
|
||||
}
|
||||
agent_json = await generate_agent(retry_instructions)
|
||||
|
||||
if agent_json is None:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate the agent. Please try again.",
|
||||
error="Generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# Step 3: Apply fixes to correct common errors
|
||||
agent_json = apply_all_fixes(agent_json, blocks_info)
|
||||
|
||||
# Step 4: Validate the agent
|
||||
is_valid, validation_errors = validate_agent(agent_json, blocks_info)
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"Agent generated successfully on attempt {attempt + 1}")
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
||||
)
|
||||
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
# Return error with validation details
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Generated agent has validation errors after {MAX_GENERATION_RETRIES + 1} attempts. "
|
||||
f"Please try rephrasing your request or simplify the workflow."
|
||||
),
|
||||
error="validation_failed",
|
||||
details={"validation_errors": validation_errors},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
# Step 4: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
|
||||
f"Review it and call create_agent with save=true to save it to your library."
|
||||
),
|
||||
agent_json=agent_json,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to library
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,297 @@
|
||||
"""EditAgentTool - Edits existing agents using natural language."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
apply_agent_patch,
|
||||
apply_all_fixes,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_blocks_info,
|
||||
save_agent_to_library,
|
||||
validate_agent,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum retries for patch generation with validation feedback
|
||||
MAX_GENERATION_RETRIES = 2
|
||||
|
||||
|
||||
class EditAgentTool(BaseTool):
|
||||
"""Tool for editing existing agents using natural language."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "edit_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit an existing agent from the user's library using natural language. "
|
||||
"Generates a patch to update the agent while preserving unchanged parts."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The ID of the agent to edit. "
|
||||
"Can be a graph ID or library agent ID."
|
||||
),
|
||||
},
|
||||
"changes": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Natural language description of what changes to make. "
|
||||
"Be specific about what to add, remove, or modify."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the changes. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "changes"],
|
||||
}
|
||||
|
||||
@observe(as_type="tool", name="edit_agent")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the edit_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Fetch the current agent
|
||||
2. Generate a patch based on the requested changes
|
||||
3. Apply the patch to create an updated agent
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
changes = kwargs.get("changes", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
error="Missing agent_id parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not changes:
|
||||
return ErrorResponse(
|
||||
message="Please describe what changes you want to make.",
|
||||
error="Missing changes parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 1: Fetch current agent
|
||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||
|
||||
if current_agent is None:
|
||||
return ErrorResponse(
|
||||
message=f"Could not find agent with ID '{agent_id}' in your library.",
|
||||
error="agent_not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build the update request with context
|
||||
update_request = changes
|
||||
if context:
|
||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||
|
||||
# Step 2: Generate patch with retry on validation failure
|
||||
blocks_info = get_blocks_info()
|
||||
updated_agent = None
|
||||
validation_errors = None
|
||||
intent = "Applied requested changes"
|
||||
|
||||
for attempt in range(MAX_GENERATION_RETRIES + 1):
|
||||
# Generate patch (include validation errors from previous attempt)
|
||||
try:
|
||||
if attempt == 0:
|
||||
patch_result = await generate_agent_patch(
|
||||
update_request, current_agent
|
||||
)
|
||||
else:
|
||||
# Retry with validation error feedback
|
||||
logger.info(
|
||||
f"Retry {attempt}/{MAX_GENERATION_RETRIES} with validation feedback"
|
||||
)
|
||||
retry_request = (
|
||||
f"{update_request}\n\n"
|
||||
f"IMPORTANT: The previous edit had validation errors. "
|
||||
f"Please fix these issues:\n{validation_errors}"
|
||||
)
|
||||
patch_result = await generate_agent_patch(
|
||||
retry_request, current_agent
|
||||
)
|
||||
except ValueError as e:
|
||||
# Handle missing API key or configuration errors
|
||||
return ErrorResponse(
|
||||
message=f"Agent generation is not configured: {str(e)}",
|
||||
error="configuration_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if patch_result is None:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate changes. Please try rephrasing.",
|
||||
error="Patch generation failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if patch_result.get("type") == "clarifying_questions":
|
||||
questions = patch_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information about the changes. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 3: Apply patch and fixes
|
||||
try:
|
||||
updated_agent = apply_agent_patch(current_agent, patch_result)
|
||||
updated_agent = apply_all_fixes(updated_agent, blocks_info)
|
||||
except Exception as e:
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to apply changes: {str(e)}",
|
||||
error="patch_apply_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
validation_errors = str(e)
|
||||
continue
|
||||
|
||||
# Step 4: Validate the updated agent
|
||||
is_valid, validation_errors = validate_agent(updated_agent, blocks_info)
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"Agent edited successfully on attempt {attempt + 1}")
|
||||
intent = patch_result.get("intent", "Applied requested changes")
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"Validation failed on attempt {attempt + 1}: {validation_errors}"
|
||||
)
|
||||
|
||||
if attempt == MAX_GENERATION_RETRIES:
|
||||
# Return error with validation details
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Updated agent has validation errors after "
|
||||
f"{MAX_GENERATION_RETRIES + 1} attempts. "
|
||||
f"Please try rephrasing your request or simplify the changes."
|
||||
),
|
||||
error="validation_failed",
|
||||
details={"validation_errors": validation_errors},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# At this point, updated_agent is guaranteed to be set (we return on all failure paths)
|
||||
assert updated_agent is not None
|
||||
|
||||
agent_name = updated_agent.get("name", "Updated Agent")
|
||||
agent_description = updated_agent.get("description", "")
|
||||
node_count = len(updated_agent.get("nodes", []))
|
||||
link_count = len(updated_agent.get("links", []))
|
||||
|
||||
# Step 5: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've updated the agent. Changes: {intent}. "
|
||||
f"The agent now has {node_count} blocks. "
|
||||
f"Review it and call edit_agent with save=true to save the changes."
|
||||
),
|
||||
agent_json=updated_agent,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to library (creates a new version)
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
updated_agent, user_id, is_update=True
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=(
|
||||
f"Updated agent '{created_graph.name}' has been saved to your library! "
|
||||
f"Changes: {intent}"
|
||||
),
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to save the updated agent: {str(e)}",
|
||||
error="save_failed",
|
||||
details={"exception": str(e)},
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
@@ -35,6 +37,7 @@ class FindAgentTool(BaseTool):
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@observe(as_type="tool", name="find_agent")
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
|
||||
@@ -0,0 +1,194 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||
from backend.api.features.chat.tools.models import (
|
||||
BlockInfoSummary,
|
||||
BlockInputFieldInfo,
|
||||
BlockListResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.data.block import get_block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindBlockTool(BaseTool):
|
||||
"""Tool for searching available blocks."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for available blocks by name or description. "
|
||||
"Blocks are reusable components that perform specific tasks like "
|
||||
"sending emails, making API calls, processing text, etc. "
|
||||
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
||||
"The response includes each block's id, required_inputs, and input_schema."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find blocks by name or description. "
|
||||
"Use keywords like 'email', 'http', 'text', 'ai', etc."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@observe(as_type="tool", name="find_block")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for blocks matching the query.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
BlockListResponse: List of matching blocks
|
||||
NoResultsResponse: No blocks found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Search for blocks using hybrid search
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{query}'",
|
||||
suggestions=[
|
||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||
"Check spelling of technical terms",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Enrich results with full block information
|
||||
blocks: list[BlockInfoSummary] = []
|
||||
for result in results:
|
||||
block_id = result["content_id"]
|
||||
block = get_block(block_id)
|
||||
|
||||
if block:
|
||||
# Get input/output schemas
|
||||
input_schema = {}
|
||||
output_schema = {}
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get categories from block instance
|
||||
categories = []
|
||||
if hasattr(block, "categories") and block.categories:
|
||||
categories = [cat.value for cat in block.categories]
|
||||
|
||||
# Extract required inputs for easier use
|
||||
required_inputs: list[BlockInputFieldInfo] = []
|
||||
if input_schema:
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = set(input_schema.get("required", []))
|
||||
# Get credential field names to exclude from required inputs
|
||||
credentials_fields = set(
|
||||
block.input_schema.get_credentials_fields().keys()
|
||||
)
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields - they're handled separately
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
required_inputs.append(
|
||||
BlockInputFieldInfo(
|
||||
name=field_name,
|
||||
type=field_schema.get("type", "string"),
|
||||
description=field_schema.get("description", ""),
|
||||
required=field_name in required_fields,
|
||||
default=field_schema.get("default"),
|
||||
)
|
||||
)
|
||||
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.description or "",
|
||||
categories=categories,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
required_inputs=required_inputs,
|
||||
)
|
||||
)
|
||||
|
||||
if not blocks:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{query}'",
|
||||
suggestions=[
|
||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block(s) matching '{query}'. "
|
||||
"To execute a block, use run_block with the block's 'id' field "
|
||||
"and provide 'input_data' matching the block's input_schema."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching blocks: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search blocks",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
@@ -41,6 +43,7 @@ class FindLibraryAgentTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@observe(as_type="tool", name="find_library_agent")
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
"""GetDocPageTool - Fetch full content of a documentation page."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
DocPageResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Base URL for documentation (can be configured)
|
||||
DOCS_BASE_URL = "https://docs.agpt.co"
|
||||
|
||||
|
||||
class GetDocPageTool(BaseTool):
|
||||
"""Tool for fetching full content of a documentation page."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_doc_page"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Get the full content of a documentation page by its path. "
|
||||
"Use this after search_docs to read the complete content of a relevant page."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The path to the documentation file, as returned by search_docs. "
|
||||
"Example: 'platform/block-sdk-guide.md'"
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False # Documentation is public
|
||||
|
||||
def _get_docs_root(self) -> Path:
|
||||
"""Get the documentation root directory."""
|
||||
this_file = Path(__file__)
|
||||
project_root = this_file.parent.parent.parent.parent.parent.parent.parent.parent
|
||||
return project_root / "docs"
|
||||
|
||||
def _extract_title(self, content: str, fallback: str) -> str:
|
||||
"""Extract title from markdown content."""
|
||||
lines = content.split("\n")
|
||||
for line in lines:
|
||||
if line.startswith("# "):
|
||||
return line[2:].strip()
|
||||
return fallback
|
||||
|
||||
def _make_doc_url(self, path: str) -> str:
|
||||
"""Create a URL for a documentation page."""
|
||||
url_path = path.rsplit(".", 1)[0] if "." in path else path
|
||||
return f"{DOCS_BASE_URL}/{url_path}"
|
||||
|
||||
@observe(as_type="tool", name="get_doc_page")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Fetch full content of a documentation page.
|
||||
|
||||
Args:
|
||||
user_id: User ID (not required for docs)
|
||||
session: Chat session
|
||||
path: Path to the documentation file
|
||||
|
||||
Returns:
|
||||
DocPageResponse: Full document content
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
path = kwargs.get("path", "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide a documentation path.",
|
||||
error="Missing path parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Sanitize path to prevent directory traversal
|
||||
if ".." in path or path.startswith("/"):
|
||||
return ErrorResponse(
|
||||
message="Invalid documentation path.",
|
||||
error="invalid_path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
docs_root = self._get_docs_root()
|
||||
full_path = docs_root / path
|
||||
|
||||
if not full_path.exists():
|
||||
return ErrorResponse(
|
||||
message=f"Documentation page not found: {path}",
|
||||
error="not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Ensure the path is within docs root
|
||||
try:
|
||||
full_path.resolve().relative_to(docs_root.resolve())
|
||||
except ValueError:
|
||||
return ErrorResponse(
|
||||
message="Invalid documentation path.",
|
||||
error="invalid_path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
content = full_path.read_text(encoding="utf-8")
|
||||
title = self._extract_title(content, path)
|
||||
|
||||
return DocPageResponse(
|
||||
message=f"Retrieved documentation page: {title}",
|
||||
title=title,
|
||||
path=path,
|
||||
content=content,
|
||||
doc_url=self._make_doc_url(path),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read documentation page {path}: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read documentation page: {str(e)}",
|
||||
error="read_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -21,6 +21,13 @@ class ResponseType(str, Enum):
|
||||
NO_RESULTS = "no_results"
|
||||
AGENT_OUTPUT = "agent_output"
|
||||
UNDERSTANDING_UPDATED = "understanding_updated"
|
||||
AGENT_PREVIEW = "agent_preview"
|
||||
AGENT_SAVED = "agent_saved"
|
||||
CLARIFICATION_NEEDED = "clarification_needed"
|
||||
BLOCK_LIST = "block_list"
|
||||
BLOCK_OUTPUT = "block_output"
|
||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||
DOC_PAGE = "doc_page"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -209,3 +216,121 @@ class UnderstandingUpdatedResponse(ToolResponseBase):
|
||||
type: ResponseType = ResponseType.UNDERSTANDING_UPDATED
|
||||
updated_fields: list[str] = Field(default_factory=list)
|
||||
current_understanding: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
# Agent generation models
|
||||
class ClarifyingQuestion(BaseModel):
|
||||
"""A question that needs user clarification."""
|
||||
|
||||
question: str
|
||||
keyword: str
|
||||
example: str | None = None
|
||||
|
||||
|
||||
class AgentPreviewResponse(ToolResponseBase):
|
||||
"""Response for previewing a generated agent before saving."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_PREVIEW
|
||||
agent_json: dict[str, Any]
|
||||
agent_name: str
|
||||
description: str
|
||||
node_count: int
|
||||
link_count: int = 0
|
||||
|
||||
|
||||
class AgentSavedResponse(ToolResponseBase):
|
||||
"""Response when an agent is saved to the library."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_SAVED
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
library_agent_id: str
|
||||
library_agent_link: str
|
||||
agent_page_link: str # Link to the agent builder/editor page
|
||||
|
||||
|
||||
class ClarificationNeededResponse(ToolResponseBase):
|
||||
"""Response when the LLM needs more information from the user."""
|
||||
|
||||
type: ResponseType = ResponseType.CLARIFICATION_NEEDED
|
||||
questions: list[ClarifyingQuestion] = Field(default_factory=list)
|
||||
|
||||
|
||||
# Documentation search models
|
||||
class DocSearchResult(BaseModel):
|
||||
"""A single documentation search result."""
|
||||
|
||||
title: str
|
||||
path: str
|
||||
section: str
|
||||
snippet: str # Short excerpt for UI display
|
||||
score: float
|
||||
doc_url: str | None = None
|
||||
|
||||
|
||||
class DocSearchResultsResponse(ToolResponseBase):
|
||||
"""Response for search_docs tool."""
|
||||
|
||||
type: ResponseType = ResponseType.DOC_SEARCH_RESULTS
|
||||
results: list[DocSearchResult]
|
||||
count: int
|
||||
query: str
|
||||
|
||||
|
||||
class DocPageResponse(ToolResponseBase):
|
||||
"""Response for get_doc_page tool."""
|
||||
|
||||
type: ResponseType = ResponseType.DOC_PAGE
|
||||
title: str
|
||||
path: str
|
||||
content: str # Full document content
|
||||
doc_url: str | None = None
|
||||
|
||||
|
||||
# Block models
|
||||
class BlockInputFieldInfo(BaseModel):
|
||||
"""Information about a block input field."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
description: str = ""
|
||||
required: bool = False
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
class BlockInfoSummary(BaseModel):
|
||||
"""Summary of a block for search results."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
categories: list[str]
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
||||
default_factory=list,
|
||||
description="List of required input fields for this block",
|
||||
)
|
||||
|
||||
|
||||
class BlockListResponse(ToolResponseBase):
|
||||
"""Response for find_block tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BLOCK_LIST
|
||||
blocks: list[BlockInfoSummary]
|
||||
count: int
|
||||
query: str
|
||||
usage_hint: str = Field(
|
||||
default="To execute a block, call run_block with block_id set to the block's "
|
||||
"'id' field and input_data containing the required fields from input_schema."
|
||||
)
|
||||
|
||||
|
||||
class BlockOutputResponse(ToolResponseBase):
|
||||
"""Response for run_block tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BLOCK_OUTPUT
|
||||
block_id: str
|
||||
block_name: str
|
||||
outputs: dict[str, list[Any]]
|
||||
success: bool = True
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.api.features.chat.config import ChatConfig
|
||||
@@ -154,6 +155,7 @@ class RunAgentTool(BaseTool):
|
||||
"""All operations require authentication."""
|
||||
return True
|
||||
|
||||
@observe(as_type="tool", name="run_agent")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
|
||||
@@ -0,0 +1,300 @@
|
||||
"""Tool for executing blocks directly."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunBlockTool(BaseTool):
|
||||
"""Tool for executing a block and returning its outputs."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "run_block"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a specific block with the provided input data. "
|
||||
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
||||
"do NOT guess or make up block IDs. "
|
||||
"Use the 'id' from find_block results and provide input_data "
|
||||
"matching the block's required_inputs."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"block_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The block's 'id' field from find_block results. "
|
||||
"NEVER guess this - always get it from find_block first."
|
||||
),
|
||||
},
|
||||
"input_data": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Input values for the block. Use the 'required_inputs' field "
|
||||
"from find_block to see what fields are needed."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["block_id", "input_data"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _check_block_credentials(
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
missing_credentials: list[CredentialsMetaInput] = []
|
||||
|
||||
# Get credential field info from block's input schema
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
if not credentials_fields_info:
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
# Get user's available credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
# field_info.provider is a frozenset of acceptable providers
|
||||
# field_info.supported_types is a frozenset of acceptable types
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in field_info.provider
|
||||
and cred.type in field_info.supported_types
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if matching_cred:
|
||||
matched_credentials[field_name] = CredentialsMetaInput(
|
||||
id=matching_cred.id,
|
||||
provider=matching_cred.provider, # type: ignore
|
||||
type=matching_cred.type,
|
||||
title=matching_cred.title,
|
||||
)
|
||||
else:
|
||||
# Create a placeholder for the missing credential
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing_credentials.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
provider=provider, # type: ignore
|
||||
type=cred_type, # type: ignore
|
||||
title=field_name.replace("_", " ").title(),
|
||||
)
|
||||
)
|
||||
|
||||
return matched_credentials, missing_credentials
|
||||
|
||||
@observe(as_type="tool", name="run_block")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with the given input data.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
block_id: Block UUID to execute
|
||||
input_data: Input values for the block
|
||||
|
||||
Returns:
|
||||
BlockOutputResponse: Block execution outputs
|
||||
SetupRequirementsResponse: Missing credentials
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
block_id = kwargs.get("block_id", "").strip()
|
||||
input_data = kwargs.get("input_data", {})
|
||||
session_id = session.session_id
|
||||
|
||||
if not block_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a block_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not isinstance(input_data, dict):
|
||||
return ErrorResponse(
|
||||
message="input_data must be an object",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Get the block
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
# Check credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||
user_id, block
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
# Return setup requirements response with missing credentials
|
||||
missing_creds_dict = {c.id: c.model_dump() for c in missing_credentials}
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' requires credentials that are not configured. "
|
||||
"Please set up the required credentials before running this block."
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=block_id,
|
||||
agent_name=block.name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_creds_dict,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": [c.model_dump() for c in missing_credentials],
|
||||
"inputs": self._get_inputs_list(block),
|
||||
"execution_modes": ["immediate"],
|
||||
},
|
||||
),
|
||||
graph_id=None,
|
||||
graph_version=None,
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch actual credentials and prepare kwargs for block execution
|
||||
# Create execution context with defaults (blocks may require it)
|
||||
exec_kwargs: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"execution_context": ExecutionContext(),
|
||||
}
|
||||
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
# Inject metadata into input_data (for validation)
|
||||
if field_name not in input_data:
|
||||
input_data[field_name] = cred_meta.model_dump()
|
||||
|
||||
# Fetch actual credentials and pass as kwargs (for execution)
|
||||
actual_credentials = await creds_manager.get(
|
||||
user_id, cred_meta.id, lock=False
|
||||
)
|
||||
if actual_credentials:
|
||||
exec_kwargs[field_name] = actual_credentials
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to retrieve credentials for {field_name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except BlockError as e:
|
||||
logger.warning(f"Block execution failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Block execution failed: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute block: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||
"""Extract non-credential inputs from block schema."""
|
||||
inputs_list = []
|
||||
schema = block.input_schema.jsonschema()
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = set(schema.get("required", []))
|
||||
|
||||
# Get credential field names to exclude
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
inputs_list.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"title": field_schema.get("title", field_name),
|
||||
"type": field_schema.get("type", "string"),
|
||||
"description": field_schema.get("description", ""),
|
||||
"required": field_name in required_fields,
|
||||
}
|
||||
)
|
||||
|
||||
return inputs_list
|
||||
@@ -0,0 +1,210 @@
|
||||
"""SearchDocsTool - Search documentation using hybrid search."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langfuse import observe
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
DocSearchResult,
|
||||
DocSearchResultsResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Base URL for documentation (can be configured)
|
||||
DOCS_BASE_URL = "https://docs.agpt.co"
|
||||
|
||||
# Maximum number of results to return
|
||||
MAX_RESULTS = 5
|
||||
|
||||
# Snippet length for preview
|
||||
SNIPPET_LENGTH = 200
|
||||
|
||||
|
||||
class SearchDocsTool(BaseTool):
|
||||
"""Tool for searching AutoGPT platform documentation."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "search_docs"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search the AutoGPT platform documentation for information about "
|
||||
"how to use the platform, build agents, configure blocks, and more. "
|
||||
"Returns relevant documentation sections. Use get_doc_page to read full content."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find relevant documentation. "
|
||||
"Use natural language to describe what you're looking for."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False # Documentation is public
|
||||
|
||||
def _create_snippet(self, content: str, max_length: int = SNIPPET_LENGTH) -> str:
|
||||
"""Create a short snippet from content for preview."""
|
||||
# Remove markdown formatting for cleaner snippet
|
||||
clean_content = content.replace("#", "").replace("*", "").replace("`", "")
|
||||
# Remove extra whitespace
|
||||
clean_content = " ".join(clean_content.split())
|
||||
|
||||
if len(clean_content) <= max_length:
|
||||
return clean_content
|
||||
|
||||
# Truncate at word boundary
|
||||
truncated = clean_content[:max_length]
|
||||
last_space = truncated.rfind(" ")
|
||||
if last_space > max_length // 2:
|
||||
truncated = truncated[:last_space]
|
||||
|
||||
return truncated + "..."
|
||||
|
||||
def _make_doc_url(self, path: str) -> str:
|
||||
"""Create a URL for a documentation page."""
|
||||
# Remove file extension for URL
|
||||
url_path = path.rsplit(".", 1)[0] if "." in path else path
|
||||
return f"{DOCS_BASE_URL}/{url_path}"
|
||||
|
||||
@observe(as_type="tool", name="search_docs")
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search documentation and return relevant sections.
|
||||
|
||||
Args:
|
||||
user_id: User ID (not required for docs)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
DocSearchResultsResponse: List of matching documentation sections
|
||||
NoResultsResponse: No results found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query.",
|
||||
error="Missing query parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Search using hybrid search for DOCUMENTATION content type only
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.DOCUMENTATION],
|
||||
page=1,
|
||||
page_size=MAX_RESULTS * 2, # Fetch extra for deduplication
|
||||
min_score=0.1, # Lower threshold for docs
|
||||
)
|
||||
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No documentation found for '{query}'.",
|
||||
suggestions=[
|
||||
"Try different keywords",
|
||||
"Use more general terms",
|
||||
"Check for typos in your query",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Deduplicate by document path (keep highest scoring section per doc)
|
||||
seen_docs: dict[str, dict[str, Any]] = {}
|
||||
for result in results:
|
||||
metadata = result.get("metadata", {})
|
||||
doc_path = metadata.get("path", "")
|
||||
|
||||
if not doc_path:
|
||||
continue
|
||||
|
||||
# Keep the highest scoring result for each document
|
||||
if doc_path not in seen_docs:
|
||||
seen_docs[doc_path] = result
|
||||
elif result.get("combined_score", 0) > seen_docs[doc_path].get(
|
||||
"combined_score", 0
|
||||
):
|
||||
seen_docs[doc_path] = result
|
||||
|
||||
# Sort by score and take top MAX_RESULTS
|
||||
deduplicated = sorted(
|
||||
seen_docs.values(),
|
||||
key=lambda x: x.get("combined_score", 0),
|
||||
reverse=True,
|
||||
)[:MAX_RESULTS]
|
||||
|
||||
if not deduplicated:
|
||||
return NoResultsResponse(
|
||||
message=f"No documentation found for '{query}'.",
|
||||
suggestions=[
|
||||
"Try different keywords",
|
||||
"Use more general terms",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build response
|
||||
doc_results: list[DocSearchResult] = []
|
||||
for result in deduplicated:
|
||||
metadata = result.get("metadata", {})
|
||||
doc_path = metadata.get("path", "")
|
||||
doc_title = metadata.get("doc_title", "")
|
||||
section_title = metadata.get("section_title", "")
|
||||
searchable_text = result.get("searchable_text", "")
|
||||
score = result.get("combined_score", 0)
|
||||
|
||||
doc_results.append(
|
||||
DocSearchResult(
|
||||
title=doc_title or section_title or doc_path,
|
||||
path=doc_path,
|
||||
section=section_title,
|
||||
snippet=self._create_snippet(searchable_text),
|
||||
score=round(score, 3),
|
||||
doc_url=self._make_doc_url(doc_path),
|
||||
)
|
||||
)
|
||||
|
||||
return DocSearchResultsResponse(
|
||||
message=f"Found {len(doc_results)} relevant documentation sections.",
|
||||
results=doc_results,
|
||||
count=len(doc_results),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Documentation search failed: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Failed to search documentation: {str(e)}",
|
||||
error="search_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -275,8 +275,22 @@ class BlockHandler(ContentHandler):
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarkdownSection:
|
||||
"""Represents a section of a markdown document."""
|
||||
|
||||
title: str # Section heading text
|
||||
content: str # Section content (including the heading line)
|
||||
level: int # Heading level (1 for #, 2 for ##, etc.)
|
||||
index: int # Section index within the document
|
||||
|
||||
|
||||
class DocumentationHandler(ContentHandler):
|
||||
"""Handler for documentation files (.md/.mdx)."""
|
||||
"""Handler for documentation files (.md/.mdx).
|
||||
|
||||
Chunks documents by markdown headings to create multiple embeddings per file.
|
||||
Each section (## heading) becomes a separate embedding for better retrieval.
|
||||
"""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
@@ -297,35 +311,162 @@ class DocumentationHandler(ContentHandler):
|
||||
docs_root = project_root / "docs"
|
||||
return docs_root
|
||||
|
||||
def _extract_title_and_content(self, file_path: Path) -> tuple[str, str]:
|
||||
"""Extract title and content from markdown file."""
|
||||
def _extract_doc_title(self, file_path: Path) -> str:
|
||||
"""Extract the document title from a markdown file."""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
lines = content.split("\n")
|
||||
|
||||
# Try to extract title from first # heading
|
||||
lines = content.split("\n")
|
||||
title = ""
|
||||
body_lines = []
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("# ") and not title:
|
||||
title = line[2:].strip()
|
||||
else:
|
||||
body_lines.append(line)
|
||||
if line.startswith("# "):
|
||||
return line[2:].strip()
|
||||
|
||||
# If no title found, use filename
|
||||
if not title:
|
||||
title = file_path.stem.replace("-", " ").replace("_", " ").title()
|
||||
return file_path.stem.replace("-", " ").replace("_", " ").title()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read title from {file_path}: {e}")
|
||||
return file_path.stem.replace("-", " ").replace("_", " ").title()
|
||||
|
||||
body = "\n".join(body_lines)
|
||||
def _chunk_markdown_by_headings(
|
||||
self, file_path: Path, min_heading_level: int = 2
|
||||
) -> list[MarkdownSection]:
|
||||
"""
|
||||
Split a markdown file into sections based on headings.
|
||||
|
||||
return title, body
|
||||
Args:
|
||||
file_path: Path to the markdown file
|
||||
min_heading_level: Minimum heading level to split on (default: 2 for ##)
|
||||
|
||||
Returns:
|
||||
List of MarkdownSection objects, one per section.
|
||||
If no headings found, returns a single section with all content.
|
||||
"""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {file_path}: {e}")
|
||||
return file_path.stem, ""
|
||||
return []
|
||||
|
||||
lines = content.split("\n")
|
||||
sections: list[MarkdownSection] = []
|
||||
current_section_lines: list[str] = []
|
||||
current_title = ""
|
||||
current_level = 0
|
||||
section_index = 0
|
||||
doc_title = ""
|
||||
|
||||
for line in lines:
|
||||
# Check if line is a heading
|
||||
if line.startswith("#"):
|
||||
# Count heading level
|
||||
level = 0
|
||||
for char in line:
|
||||
if char == "#":
|
||||
level += 1
|
||||
else:
|
||||
break
|
||||
|
||||
heading_text = line[level:].strip()
|
||||
|
||||
# Track document title (level 1 heading)
|
||||
if level == 1 and not doc_title:
|
||||
doc_title = heading_text
|
||||
# Don't create a section for just the title - add it to first section
|
||||
current_section_lines.append(line)
|
||||
continue
|
||||
|
||||
# Check if this heading should start a new section
|
||||
if level >= min_heading_level:
|
||||
# Save previous section if it has content
|
||||
if current_section_lines:
|
||||
section_content = "\n".join(current_section_lines).strip()
|
||||
if section_content:
|
||||
# Use doc title for first section if no specific title
|
||||
title = current_title if current_title else doc_title
|
||||
if not title:
|
||||
title = file_path.stem.replace("-", " ").replace(
|
||||
"_", " "
|
||||
)
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
title=title,
|
||||
content=section_content,
|
||||
level=current_level if current_level else 1,
|
||||
index=section_index,
|
||||
)
|
||||
)
|
||||
section_index += 1
|
||||
|
||||
# Start new section
|
||||
current_section_lines = [line]
|
||||
current_title = heading_text
|
||||
current_level = level
|
||||
else:
|
||||
# Lower level heading (e.g., # when splitting on ##)
|
||||
current_section_lines.append(line)
|
||||
else:
|
||||
current_section_lines.append(line)
|
||||
|
||||
# Don't forget the last section
|
||||
if current_section_lines:
|
||||
section_content = "\n".join(current_section_lines).strip()
|
||||
if section_content:
|
||||
title = current_title if current_title else doc_title
|
||||
if not title:
|
||||
title = file_path.stem.replace("-", " ").replace("_", " ")
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
title=title,
|
||||
content=section_content,
|
||||
level=current_level if current_level else 1,
|
||||
index=section_index,
|
||||
)
|
||||
)
|
||||
|
||||
# If no sections were created (no headings found), create one section with all content
|
||||
if not sections and content.strip():
|
||||
title = (
|
||||
doc_title
|
||||
if doc_title
|
||||
else file_path.stem.replace("-", " ").replace("_", " ")
|
||||
)
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
title=title,
|
||||
content=content.strip(),
|
||||
level=1,
|
||||
index=0,
|
||||
)
|
||||
)
|
||||
|
||||
return sections
|
||||
|
||||
def _make_section_content_id(self, doc_path: str, section_index: int) -> str:
|
||||
"""Create a unique content ID for a document section.
|
||||
|
||||
Format: doc_path::section_index
|
||||
Example: 'platform/getting-started.md::0'
|
||||
"""
|
||||
return f"{doc_path}::{section_index}"
|
||||
|
||||
def _parse_section_content_id(self, content_id: str) -> tuple[str, int]:
|
||||
"""Parse a section content ID back into doc_path and section_index.
|
||||
|
||||
Returns: (doc_path, section_index)
|
||||
"""
|
||||
if "::" in content_id:
|
||||
parts = content_id.rsplit("::", 1)
|
||||
return parts[0], int(parts[1])
|
||||
# Legacy format (whole document)
|
||||
return content_id, 0
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch documentation files without embeddings."""
|
||||
"""Fetch documentation sections without embeddings.
|
||||
|
||||
Chunks each document by markdown headings and creates embeddings for each section.
|
||||
Content IDs use the format: 'path/to/doc.md::section_index'
|
||||
"""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
@@ -335,14 +476,28 @@ class DocumentationHandler(ContentHandler):
|
||||
# Find all .md and .mdx files
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
|
||||
# Get relative paths for content IDs
|
||||
doc_paths = [str(doc.relative_to(docs_root)) for doc in all_docs]
|
||||
|
||||
if not doc_paths:
|
||||
if not all_docs:
|
||||
return []
|
||||
|
||||
# Build list of all sections from all documents
|
||||
all_sections: list[tuple[str, Path, MarkdownSection]] = []
|
||||
for doc_file in all_docs:
|
||||
doc_path = str(doc_file.relative_to(docs_root))
|
||||
sections = self._chunk_markdown_by_headings(doc_file)
|
||||
for section in sections:
|
||||
all_sections.append((doc_path, doc_file, section))
|
||||
|
||||
if not all_sections:
|
||||
return []
|
||||
|
||||
# Generate content IDs for all sections
|
||||
section_content_ids = [
|
||||
self._make_section_content_id(doc_path, section.index)
|
||||
for doc_path, _, section in all_sections
|
||||
]
|
||||
|
||||
# Check which ones have embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(doc_paths))])
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(section_content_ids))])
|
||||
existing_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT "contentId"
|
||||
@@ -350,76 +505,100 @@ class DocumentationHandler(ContentHandler):
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*doc_paths,
|
||||
*section_content_ids,
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
missing_docs = [
|
||||
(doc_path, doc_file)
|
||||
for doc_path, doc_file in zip(doc_paths, all_docs)
|
||||
if doc_path not in existing_ids
|
||||
|
||||
# Filter to missing sections
|
||||
missing_sections = [
|
||||
(doc_path, doc_file, section, content_id)
|
||||
for (doc_path, doc_file, section), content_id in zip(
|
||||
all_sections, section_content_ids
|
||||
)
|
||||
if content_id not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem
|
||||
# Convert to ContentItem (up to batch_size)
|
||||
items = []
|
||||
for doc_path, doc_file in missing_docs[:batch_size]:
|
||||
for doc_path, doc_file, section, content_id in missing_sections[:batch_size]:
|
||||
try:
|
||||
title, content = self._extract_title_and_content(doc_file)
|
||||
# Get document title for context
|
||||
doc_title = self._extract_doc_title(doc_file)
|
||||
|
||||
# Build searchable text
|
||||
searchable_text = f"{title} {content}"
|
||||
# Build searchable text with context
|
||||
# Include doc title and section title for better search relevance
|
||||
searchable_text = f"{doc_title} - {section.title}\n\n{section.content}"
|
||||
|
||||
items.append(
|
||||
ContentItem(
|
||||
content_id=doc_path,
|
||||
content_id=content_id,
|
||||
content_type=ContentType.DOCUMENTATION,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"title": title,
|
||||
"doc_title": doc_title,
|
||||
"section_title": section.title,
|
||||
"section_index": section.index,
|
||||
"heading_level": section.level,
|
||||
"path": doc_path,
|
||||
},
|
||||
user_id=None, # Documentation is public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process doc {doc_path}: {e}")
|
||||
logger.warning(f"Failed to process section {content_id}: {e}")
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
def _get_all_section_content_ids(self, docs_root: Path) -> set[str]:
|
||||
"""Get all current section content IDs from the docs directory.
|
||||
|
||||
Used for stats and cleanup to know what sections should exist.
|
||||
"""
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
content_ids = set()
|
||||
|
||||
for doc_file in all_docs:
|
||||
doc_path = str(doc_file.relative_to(docs_root))
|
||||
sections = self._chunk_markdown_by_headings(doc_file)
|
||||
for section in sections:
|
||||
content_ids.add(self._make_section_content_id(doc_path, section.index))
|
||||
|
||||
return content_ids
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about documentation embedding coverage."""
|
||||
"""Get statistics about documentation embedding coverage.
|
||||
|
||||
Counts sections (not documents) since each section gets its own embedding.
|
||||
"""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
# Count all .md and .mdx files
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
total_docs = len(all_docs)
|
||||
# Get all section content IDs
|
||||
all_section_ids = self._get_all_section_content_ids(docs_root)
|
||||
total_sections = len(all_section_ids)
|
||||
|
||||
if total_docs == 0:
|
||||
if total_sections == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
doc_paths = [str(doc.relative_to(docs_root)) for doc in all_docs]
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(doc_paths))])
|
||||
|
||||
# Count embeddings in database for DOCUMENTATION type
|
||||
embedded_result = await query_raw_with_schema(
|
||||
f"""
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*doc_paths,
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{schema_prefix}"ContentType"
|
||||
"""
|
||||
)
|
||||
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_docs,
|
||||
"total": total_sections,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_docs - with_embeddings,
|
||||
"without_embeddings": total_sections - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -164,20 +164,20 @@ async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
|
||||
assert len(items) == 2
|
||||
|
||||
# Check guide.md
|
||||
# Check guide.md (content_id format: doc_path::section_index)
|
||||
guide_item = next(
|
||||
(item for item in items if item.content_id == "guide.md"), None
|
||||
(item for item in items if item.content_id == "guide.md::0"), None
|
||||
)
|
||||
assert guide_item is not None
|
||||
assert guide_item.content_type == ContentType.DOCUMENTATION
|
||||
assert "Getting Started" in guide_item.searchable_text
|
||||
assert "This is a guide" in guide_item.searchable_text
|
||||
assert guide_item.metadata["title"] == "Getting Started"
|
||||
assert guide_item.metadata["doc_title"] == "Getting Started"
|
||||
assert guide_item.user_id is None
|
||||
|
||||
# Check api.mdx
|
||||
# Check api.mdx (content_id format: doc_path::section_index)
|
||||
api_item = next(
|
||||
(item for item in items if item.content_id == "api.mdx"), None
|
||||
(item for item in items if item.content_id == "api.mdx::0"), None
|
||||
)
|
||||
assert api_item is not None
|
||||
assert "API Reference" in api_item.searchable_text
|
||||
@@ -218,17 +218,74 @@ async def test_documentation_handler_title_extraction(tmp_path):
|
||||
# Test with heading
|
||||
doc_with_heading = tmp_path / "with_heading.md"
|
||||
doc_with_heading.write_text("# My Title\n\nContent here")
|
||||
title, content = handler._extract_title_and_content(doc_with_heading)
|
||||
title = handler._extract_doc_title(doc_with_heading)
|
||||
assert title == "My Title"
|
||||
assert "# My Title" not in content
|
||||
assert "Content here" in content
|
||||
|
||||
# Test without heading
|
||||
doc_without_heading = tmp_path / "no-heading.md"
|
||||
doc_without_heading.write_text("Just content, no heading")
|
||||
title, content = handler._extract_title_and_content(doc_without_heading)
|
||||
title = handler._extract_doc_title(doc_without_heading)
|
||||
assert title == "No Heading" # Uses filename
|
||||
assert "Just content" in content
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
"""Test DocumentationHandler chunks markdown by headings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test document with multiple sections
|
||||
doc_with_sections = tmp_path / "sections.md"
|
||||
doc_with_sections.write_text(
|
||||
"# Document Title\n\n"
|
||||
"Intro paragraph.\n\n"
|
||||
"## Section One\n\n"
|
||||
"Content for section one.\n\n"
|
||||
"## Section Two\n\n"
|
||||
"Content for section two.\n"
|
||||
)
|
||||
sections = handler._chunk_markdown_by_headings(doc_with_sections)
|
||||
|
||||
# Should have 3 sections: intro (with doc title), section one, section two
|
||||
assert len(sections) == 3
|
||||
assert sections[0].title == "Document Title"
|
||||
assert sections[0].index == 0
|
||||
assert "Intro paragraph" in sections[0].content
|
||||
|
||||
assert sections[1].title == "Section One"
|
||||
assert sections[1].index == 1
|
||||
assert "Content for section one" in sections[1].content
|
||||
|
||||
assert sections[2].title == "Section Two"
|
||||
assert sections[2].index == 2
|
||||
assert "Content for section two" in sections[2].content
|
||||
|
||||
# Test document without headings
|
||||
doc_no_sections = tmp_path / "no-sections.md"
|
||||
doc_no_sections.write_text("Just plain content without any headings.")
|
||||
sections = handler._chunk_markdown_by_headings(doc_no_sections)
|
||||
assert len(sections) == 1
|
||||
assert sections[0].index == 0
|
||||
assert "Just plain content" in sections[0].content
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_section_content_ids():
|
||||
"""Test DocumentationHandler creates and parses section content IDs."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test making content ID
|
||||
content_id = handler._make_section_content_id("docs/guide.md", 2)
|
||||
assert content_id == "docs/guide.md::2"
|
||||
|
||||
# Test parsing content ID
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/guide.md::2")
|
||||
assert doc_path == "docs/guide.md"
|
||||
assert section_index == 2
|
||||
|
||||
# Test parsing legacy format (no section index)
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/old-format.md")
|
||||
assert doc_path == "docs/old-format.md"
|
||||
assert section_index == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
@@ -154,15 +154,16 @@ async def store_content_embedding(
|
||||
|
||||
# Upsert the embedding
|
||||
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
|
||||
# Use {pgvector_schema}.vector for explicit pgvector type qualification
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
|
||||
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
||||
)
|
||||
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
|
||||
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::{pgvector_schema}.vector, $5, $6::jsonb, NOW(), NOW())
|
||||
ON CONFLICT ("contentType", "contentId", "userId")
|
||||
DO UPDATE SET
|
||||
"embedding" = $4::vector,
|
||||
"embedding" = $4::{pgvector_schema}.vector,
|
||||
"searchableText" = $5,
|
||||
"metadata" = $6::jsonb,
|
||||
"updatedAt" = NOW()
|
||||
@@ -177,7 +178,6 @@ async def store_content_embedding(
|
||||
searchable_text,
|
||||
metadata_json,
|
||||
client=client,
|
||||
set_public_search_path=True,
|
||||
)
|
||||
|
||||
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
||||
@@ -236,7 +236,6 @@ async def get_content_embedding(
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
set_public_search_path=True,
|
||||
)
|
||||
|
||||
if result and len(result) > 0:
|
||||
@@ -683,20 +682,20 @@ async def cleanup_orphaned_embeddings() -> dict[str, Any]:
|
||||
|
||||
current_ids = set(get_blocks().keys())
|
||||
elif content_type == ContentType.DOCUMENTATION:
|
||||
from pathlib import Path
|
||||
|
||||
# embeddings.py is at: backend/backend/api/features/store/embeddings.py
|
||||
# Need to go up to project root then into docs/
|
||||
this_file = Path(__file__)
|
||||
project_root = (
|
||||
this_file.parent.parent.parent.parent.parent.parent.parent
|
||||
# Use DocumentationHandler to get section-based content IDs
|
||||
from backend.api.features.store.content_handlers import (
|
||||
DocumentationHandler,
|
||||
)
|
||||
docs_root = project_root / "docs"
|
||||
if docs_root.exists():
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(
|
||||
docs_root.rglob("*.mdx")
|
||||
)
|
||||
current_ids = {str(doc.relative_to(docs_root)) for doc in all_docs}
|
||||
|
||||
doc_handler = CONTENT_HANDLERS.get(ContentType.DOCUMENTATION)
|
||||
if isinstance(doc_handler, DocumentationHandler):
|
||||
docs_root = doc_handler._get_docs_root()
|
||||
if docs_root.exists():
|
||||
current_ids = doc_handler._get_all_section_content_ids(
|
||||
docs_root
|
||||
)
|
||||
else:
|
||||
current_ids = set()
|
||||
else:
|
||||
current_ids = set()
|
||||
else:
|
||||
@@ -871,31 +870,46 @@ async def semantic_search(
|
||||
# Add content type parameters and build placeholders dynamically
|
||||
content_type_start_idx = len(params) + 1
|
||||
content_type_placeholders = ", ".join(
|
||||
f'${content_type_start_idx + i}::{{{{schema_prefix}}}}"ContentType"'
|
||||
"$" + str(content_type_start_idx + i) + '::{schema_prefix}"ContentType"'
|
||||
for i in range(len(content_types))
|
||||
)
|
||||
params.extend([ct.value for ct in content_types])
|
||||
|
||||
sql = f"""
|
||||
# Build min_similarity param index before appending
|
||||
min_similarity_idx = len(params) + 1
|
||||
params.append(min_similarity)
|
||||
|
||||
# Use regular string (not f-string) for template to preserve {schema_prefix} and {schema} placeholders
|
||||
# Use OPERATOR({pgvector_schema}.<=>) for explicit operator schema qualification
|
||||
sql = (
|
||||
"""
|
||||
SELECT
|
||||
"contentId" as content_id,
|
||||
"contentType" as content_type,
|
||||
"searchableText" as searchable_text,
|
||||
metadata,
|
||||
1 - (embedding <=> '{embedding_str}'::vector) as similarity
|
||||
FROM {{{{schema_prefix}}}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" IN ({content_type_placeholders})
|
||||
{user_filter}
|
||||
AND 1 - (embedding <=> '{embedding_str}'::vector) >= ${len(params) + 1}
|
||||
1 - (embedding OPERATOR({pgvector_schema}.<=>) '"""
|
||||
+ embedding_str
|
||||
+ """'::{pgvector_schema}.vector) as similarity
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" IN ("""
|
||||
+ content_type_placeholders
|
||||
+ """)
|
||||
"""
|
||||
+ user_filter
|
||||
+ """
|
||||
AND 1 - (embedding OPERATOR({pgvector_schema}.<=>) '"""
|
||||
+ embedding_str
|
||||
+ """'::{pgvector_schema}.vector) >= $"""
|
||||
+ str(min_similarity_idx)
|
||||
+ """
|
||||
ORDER BY similarity DESC
|
||||
LIMIT $1
|
||||
"""
|
||||
params.append(min_similarity)
|
||||
)
|
||||
|
||||
try:
|
||||
results = await query_raw_with_schema(
|
||||
sql, *params, set_public_search_path=True
|
||||
)
|
||||
results = await query_raw_with_schema(sql, *params)
|
||||
return [
|
||||
{
|
||||
"content_id": row["content_id"],
|
||||
@@ -922,31 +936,41 @@ async def semantic_search(
|
||||
# Add content type parameters and build placeholders dynamically
|
||||
content_type_start_idx = len(params_lexical) + 1
|
||||
content_type_placeholders_lexical = ", ".join(
|
||||
f'${content_type_start_idx + i}::{{{{schema_prefix}}}}"ContentType"'
|
||||
"$" + str(content_type_start_idx + i) + '::{schema_prefix}"ContentType"'
|
||||
for i in range(len(content_types))
|
||||
)
|
||||
params_lexical.extend([ct.value for ct in content_types])
|
||||
|
||||
sql_lexical = f"""
|
||||
# Build query param index before appending
|
||||
query_param_idx = len(params_lexical) + 1
|
||||
params_lexical.append(f"%{query}%")
|
||||
|
||||
# Use regular string (not f-string) for template to preserve {schema_prefix} placeholders
|
||||
sql_lexical = (
|
||||
"""
|
||||
SELECT
|
||||
"contentId" as content_id,
|
||||
"contentType" as content_type,
|
||||
"searchableText" as searchable_text,
|
||||
metadata,
|
||||
0.0 as similarity
|
||||
FROM {{{{schema_prefix}}}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" IN ({content_type_placeholders_lexical})
|
||||
{user_filter}
|
||||
AND "searchableText" ILIKE ${len(params_lexical) + 1}
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" IN ("""
|
||||
+ content_type_placeholders_lexical
|
||||
+ """)
|
||||
"""
|
||||
+ user_filter
|
||||
+ """
|
||||
AND "searchableText" ILIKE $"""
|
||||
+ str(query_param_idx)
|
||||
+ """
|
||||
ORDER BY "updatedAt" DESC
|
||||
LIMIT $1
|
||||
"""
|
||||
params_lexical.append(f"%{query}%")
|
||||
)
|
||||
|
||||
try:
|
||||
results = await query_raw_with_schema(
|
||||
sql_lexical, *params_lexical, set_public_search_path=True
|
||||
)
|
||||
results = await query_raw_with_schema(sql_lexical, *params_lexical)
|
||||
return [
|
||||
{
|
||||
"content_id": row["content_id"],
|
||||
|
||||
@@ -155,18 +155,14 @@ async def test_store_embedding_success(mocker):
|
||||
)
|
||||
|
||||
assert result is True
|
||||
# execute_raw is called twice: once for SET search_path, once for INSERT
|
||||
assert mock_client.execute_raw.call_count == 2
|
||||
# execute_raw is called once for INSERT (no separate SET search_path needed)
|
||||
assert mock_client.execute_raw.call_count == 1
|
||||
|
||||
# First call: SET search_path
|
||||
first_call_args = mock_client.execute_raw.call_args_list[0][0]
|
||||
assert "SET search_path" in first_call_args[0]
|
||||
|
||||
# Second call: INSERT query with the actual data
|
||||
second_call_args = mock_client.execute_raw.call_args_list[1][0]
|
||||
assert "test-version-id" in second_call_args
|
||||
assert "[0.1,0.2,0.3]" in second_call_args
|
||||
assert None in second_call_args # userId should be None for store agents
|
||||
# Verify the INSERT query with the actual data
|
||||
call_args = mock_client.execute_raw.call_args_list[0][0]
|
||||
assert "test-version-id" in call_args
|
||||
assert "[0.1,0.2,0.3]" in call_args
|
||||
assert None in call_args # userId should be None for store agents
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
@@ -3,13 +3,16 @@ Unified Hybrid Search
|
||||
|
||||
Combines semantic (embedding) search with lexical (tsvector) search
|
||||
for improved relevance across all content types (agents, blocks, docs).
|
||||
Includes BM25 reranking for improved lexical relevance.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from prisma.enums import ContentType
|
||||
from rank_bm25 import BM25Okapi # type: ignore[import-untyped]
|
||||
|
||||
from backend.api.features.store.embeddings import (
|
||||
EMBEDDING_DIM,
|
||||
@@ -21,6 +24,84 @@ from backend.data.db import query_raw_with_schema
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# BM25 Reranking
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for BM25 - lowercase and split on non-alphanumeric."""
|
||||
if not text:
|
||||
return []
|
||||
# Lowercase and split on non-alphanumeric characters
|
||||
tokens = re.findall(r"\b\w+\b", text.lower())
|
||||
return tokens
|
||||
|
||||
|
||||
def bm25_rerank(
|
||||
query: str,
|
||||
results: list[dict[str, Any]],
|
||||
text_field: str = "searchable_text",
|
||||
bm25_weight: float = 0.3,
|
||||
original_score_field: str = "combined_score",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Rerank search results using BM25.
|
||||
|
||||
Combines the original combined_score with BM25 score for improved
|
||||
lexical relevance, especially for exact term matches.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
results: List of result dicts with text_field and original_score_field
|
||||
text_field: Field name containing the text to score
|
||||
bm25_weight: Weight for BM25 score (0-1). Original score gets (1 - bm25_weight)
|
||||
original_score_field: Field name containing the original score
|
||||
|
||||
Returns:
|
||||
Results list sorted by combined score (BM25 + original)
|
||||
"""
|
||||
if not results or not query:
|
||||
return results
|
||||
|
||||
# Extract texts and tokenize
|
||||
corpus = [tokenize(r.get(text_field, "") or "") for r in results]
|
||||
|
||||
# Handle edge case where all documents are empty
|
||||
if all(len(doc) == 0 for doc in corpus):
|
||||
return results
|
||||
|
||||
# Build BM25 index
|
||||
bm25 = BM25Okapi(corpus)
|
||||
|
||||
# Score query against corpus
|
||||
query_tokens = tokenize(query)
|
||||
if not query_tokens:
|
||||
return results
|
||||
|
||||
bm25_scores = bm25.get_scores(query_tokens)
|
||||
|
||||
# Normalize BM25 scores to 0-1 range
|
||||
max_bm25 = max(bm25_scores) if max(bm25_scores) > 0 else 1.0
|
||||
normalized_bm25 = [s / max_bm25 for s in bm25_scores]
|
||||
|
||||
# Combine scores
|
||||
original_weight = 1.0 - bm25_weight
|
||||
for i, result in enumerate(results):
|
||||
original_score = result.get(original_score_field, 0) or 0
|
||||
result["bm25_score"] = normalized_bm25[i]
|
||||
final_score = (
|
||||
original_weight * original_score + bm25_weight * normalized_bm25[i]
|
||||
)
|
||||
result["final_score"] = final_score
|
||||
result["relevance"] = final_score
|
||||
|
||||
# Sort by relevance descending
|
||||
results.sort(key=lambda x: x.get("relevance", 0), reverse=True)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedSearchWeights:
|
||||
"""Weights for unified search (no popularity signal)."""
|
||||
@@ -214,7 +295,7 @@ async def unified_hybrid_search(
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[])
|
||||
{user_filter}
|
||||
ORDER BY uce.embedding <=> {embedding_param}::vector
|
||||
ORDER BY uce.embedding OPERATOR({{pgvector_schema}}.<=>) {embedding_param}::{{pgvector_schema}}.vector
|
||||
LIMIT 200
|
||||
)
|
||||
),
|
||||
@@ -226,7 +307,7 @@ async def unified_hybrid_search(
|
||||
uce.metadata,
|
||||
uce."updatedAt" as updated_at,
|
||||
-- Semantic score: cosine similarity (1 - distance)
|
||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
COALESCE(1 - (uce.embedding OPERATOR({{pgvector_schema}}.<=>) {embedding_param}::{{pgvector_schema}}.vector), 0) as semantic_score,
|
||||
-- Lexical score: ts_rank_cd
|
||||
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match from metadata
|
||||
@@ -273,9 +354,7 @@ async def unified_hybrid_search(
|
||||
FROM normalized
|
||||
),
|
||||
filtered AS (
|
||||
SELECT
|
||||
*,
|
||||
COUNT(*) OVER () as total_count
|
||||
SELECT *, COUNT(*) OVER () as total_count
|
||||
FROM scored
|
||||
WHERE combined_score >= {min_score_param}
|
||||
)
|
||||
@@ -284,11 +363,18 @@ async def unified_hybrid_search(
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
results = await query_raw_with_schema(
|
||||
sql_query, *params, set_public_search_path=True
|
||||
)
|
||||
results = await query_raw_with_schema(sql_query, *params)
|
||||
|
||||
total = results[0]["total_count"] if results else 0
|
||||
# Apply BM25 reranking
|
||||
if results:
|
||||
results = bm25_rerank(
|
||||
query=query,
|
||||
results=results,
|
||||
text_field="searchable_text",
|
||||
bm25_weight=0.3,
|
||||
original_score_field="combined_score",
|
||||
)
|
||||
|
||||
# Clean up results
|
||||
for result in results:
|
||||
@@ -497,7 +583,7 @@ async def hybrid_search(
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND {where_clause}
|
||||
ORDER BY uce.embedding <=> {embedding_param}::vector
|
||||
ORDER BY uce.embedding OPERATOR({{pgvector_schema}}.<=>) {embedding_param}::{{pgvector_schema}}.vector
|
||||
LIMIT 200
|
||||
) uce
|
||||
),
|
||||
@@ -516,8 +602,10 @@ async def hybrid_search(
|
||||
sa.featured,
|
||||
sa.is_available,
|
||||
sa.updated_at,
|
||||
-- Searchable text for BM25 reranking
|
||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||
-- Semantic score
|
||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
COALESCE(1 - (uce.embedding OPERATOR({{pgvector_schema}}.<=>) {embedding_param}::{{pgvector_schema}}.vector), 0) as semantic_score,
|
||||
-- Lexical score (raw, will normalize)
|
||||
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match
|
||||
@@ -573,6 +661,7 @@ async def hybrid_search(
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
searchable_text,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
@@ -597,14 +686,23 @@ async def hybrid_search(
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
results = await query_raw_with_schema(
|
||||
sql_query, *params, set_public_search_path=True
|
||||
)
|
||||
results = await query_raw_with_schema(sql_query, *params)
|
||||
|
||||
total = results[0]["total_count"] if results else 0
|
||||
|
||||
# Apply BM25 reranking
|
||||
if results:
|
||||
results = bm25_rerank(
|
||||
query=query,
|
||||
results=results,
|
||||
text_field="searchable_text",
|
||||
bm25_weight=0.3,
|
||||
original_score_field="combined_score",
|
||||
)
|
||||
|
||||
for result in results:
|
||||
result.pop("total_count", None)
|
||||
result.pop("searchable_text", None)
|
||||
|
||||
logger.info(f"Hybrid search (store agents): {len(results)} results, {total} total")
|
||||
|
||||
|
||||
@@ -311,11 +311,43 @@ async def test_hybrid_search_min_score_filtering():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_pagination():
|
||||
"""Test hybrid search pagination."""
|
||||
"""Test hybrid search pagination.
|
||||
|
||||
Pagination happens in SQL (LIMIT/OFFSET), then BM25 reranking is applied
|
||||
to the paginated results.
|
||||
"""
|
||||
# Create mock results that SQL would return for a page
|
||||
mock_results = [
|
||||
{
|
||||
"slug": f"agent-{i}",
|
||||
"agent_name": f"Agent {i}",
|
||||
"agent_image": "test.png",
|
||||
"creator_username": "test",
|
||||
"creator_avatar": "avatar.png",
|
||||
"sub_heading": "Test",
|
||||
"description": "Test description",
|
||||
"runs": 100 - i,
|
||||
"rating": 4.5,
|
||||
"categories": ["test"],
|
||||
"featured": False,
|
||||
"is_available": True,
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"searchable_text": f"Agent {i} test description",
|
||||
"combined_score": 0.9 - (i * 0.01),
|
||||
"semantic_score": 0.7,
|
||||
"lexical_score": 0.6,
|
||||
"category_score": 0.5,
|
||||
"recency_score": 0.4,
|
||||
"popularity_score": 0.3,
|
||||
"total_count": 25,
|
||||
}
|
||||
for i in range(10) # SQL returns page_size results
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
mock_query.return_value = mock_results
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
@@ -329,16 +361,18 @@ async def test_hybrid_search_pagination():
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Verify pagination parameters
|
||||
# Verify results returned
|
||||
assert len(results) == 10
|
||||
assert total == 25 # Total from SQL COUNT(*) OVER()
|
||||
|
||||
# Verify the SQL query uses page_size and offset
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0]
|
||||
|
||||
# Last two params should be LIMIT and OFFSET
|
||||
limit = params[-2]
|
||||
offset = params[-1]
|
||||
|
||||
assert limit == 10 # page_size
|
||||
assert offset == 10 # (page - 1) * page_size = (2 - 1) * 10
|
||||
# Last two params are page_size and offset
|
||||
page_size_param = params[-2]
|
||||
offset_param = params[-1]
|
||||
assert page_size_param == 10
|
||||
assert offset_param == 10 # (page 2 - 1) * 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -609,14 +643,36 @@ async def test_unified_hybrid_search_empty_query():
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_unified_hybrid_search_pagination():
|
||||
"""Test unified search pagination."""
|
||||
"""Test unified search pagination with BM25 reranking.
|
||||
|
||||
Pagination happens in SQL (LIMIT/OFFSET), then BM25 reranking is applied
|
||||
to the paginated results.
|
||||
"""
|
||||
# Create mock results that SQL would return for a page
|
||||
mock_results = [
|
||||
{
|
||||
"content_type": "STORE_AGENT",
|
||||
"content_id": f"agent-{i}",
|
||||
"searchable_text": f"Agent {i} description",
|
||||
"metadata": {"name": f"Agent {i}"},
|
||||
"updated_at": "2025-01-01T00:00:00Z",
|
||||
"semantic_score": 0.7,
|
||||
"lexical_score": 0.8 - (i * 0.01),
|
||||
"category_score": 0.5,
|
||||
"recency_score": 0.3,
|
||||
"combined_score": 0.6 - (i * 0.01),
|
||||
"total_count": 50,
|
||||
}
|
||||
for i in range(15) # SQL returns page_size results
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_query.return_value = []
|
||||
mock_query.return_value = mock_results
|
||||
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
|
||||
|
||||
results, total = await unified_hybrid_search(
|
||||
@@ -625,15 +681,18 @@ async def test_unified_hybrid_search_pagination():
|
||||
page_size=15,
|
||||
)
|
||||
|
||||
# Verify pagination parameters (last two params are LIMIT and OFFSET)
|
||||
# Verify results returned
|
||||
assert len(results) == 15
|
||||
assert total == 50 # Total from SQL COUNT(*) OVER()
|
||||
|
||||
# Verify the SQL query uses page_size and offset
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0]
|
||||
|
||||
limit = params[-2]
|
||||
offset = params[-1]
|
||||
|
||||
assert limit == 15 # page_size
|
||||
assert offset == 30 # (page - 1) * page_size = (3 - 1) * 15
|
||||
# Last two params are page_size and offset
|
||||
page_size_param = params[-2]
|
||||
offset_param = params[-1]
|
||||
assert page_size_param == 15
|
||||
assert offset_param == 30 # (page 3 - 1) * 15
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
@@ -693,13 +693,13 @@ class DeleteGraphResponse(TypedDict):
|
||||
async def list_graphs(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> Sequence[graph_db.GraphMeta]:
|
||||
graphs, _ = await graph_db.list_graphs_paginated(
|
||||
paginated_result = await graph_db.list_graphs_paginated(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=250,
|
||||
filter_by="active",
|
||||
)
|
||||
return graphs
|
||||
return paginated_result.graphs
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
|
||||
@@ -174,7 +174,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
)
|
||||
frame_rate: int = SchemaField(description="Frame rate of the video", default=60)
|
||||
generation_preset: GenerationPreset = SchemaField(
|
||||
description="Generation preset for visual style - only effects AI generated visuals",
|
||||
description="Generation preset for visual style - only affects AI-generated visuals",
|
||||
default=GenerationPreset.LEONARDO,
|
||||
placeholder=GenerationPreset.LEONARDO,
|
||||
)
|
||||
|
||||
@@ -381,7 +381,7 @@ Each range you add needs to be a string, with the upper and lower numbers of the
|
||||
organization_locations: Optional[list[str]] = SchemaField(
|
||||
description="""The location of the company headquarters. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appearch in your search results, even if they match other parameters.
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appear in your search results, even if they match other parameters.
|
||||
|
||||
To exclude companies based on location, use the organization_not_locations parameter.
|
||||
""",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user