mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-19 20:18:22 -05:00
Compare commits
60 Commits
hackathon/
...
feature/vi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8d3893c16 | ||
|
|
1cfbc0dd08 | ||
|
|
ff84643b48 | ||
|
|
c19c3c834a | ||
|
|
d0f7ba8cfd | ||
|
|
2a855f4bd0 | ||
|
|
b93bb3b9f8 | ||
|
|
1b56ff13d9 | ||
|
|
f31c160043 | ||
|
|
06550a87eb | ||
|
|
088b9998dc | ||
|
|
05c89fa5c0 | ||
|
|
8cc8295f14 | ||
|
|
e55f05c7a8 | ||
|
|
4a9b13acb6 | ||
|
|
5ff669e999 | ||
|
|
ec03a13e26 | ||
|
|
b08851f5d7 | ||
|
|
8b1720e61d | ||
|
|
aa5a039c5e | ||
|
|
8b83bb8647 | ||
|
|
e80e4d9cbb | ||
|
|
375d33cca9 | ||
|
|
3b1b2fe30c | ||
|
|
af63b3678e | ||
|
|
631f1bd50a | ||
|
|
5ac941fe2f | ||
|
|
b01ea3fcbd | ||
|
|
3b09a94e3f | ||
|
|
61efee4139 | ||
|
|
e539280e98 | ||
|
|
db8b43bb3d | ||
|
|
923d8baedc | ||
|
|
a55b2e02dc | ||
|
|
6b6648b290 | ||
|
|
c0a9c0410b | ||
|
|
17a77b02c7 | ||
|
|
701fce83ca | ||
|
|
78d89d0faf | ||
|
|
f482eb668b | ||
|
|
4a52b7eca0 | ||
|
|
97847f59f7 | ||
|
|
22ca8955c5 | ||
|
|
43cbe2e011 | ||
|
|
a318832414 | ||
|
|
843c487500 | ||
|
|
47a3a5ef41 | ||
|
|
ec00aa951a | ||
|
|
36fb1ea004 | ||
|
|
a81ac150da | ||
|
|
49ee087496 | ||
|
|
fc25e008b3 | ||
|
|
b0855e8cf2 | ||
|
|
5e2146dd76 | ||
|
|
103a62c9da | ||
|
|
fc8434fb30 | ||
|
|
3ae08cd48e | ||
|
|
4db13837b9 | ||
|
|
df87867625 | ||
|
|
4a7bc006a8 |
37
.branchlet.json
Normal file
37
.branchlet.json
Normal file
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"worktreeCopyPatterns": [
|
||||
".env*",
|
||||
".vscode/**",
|
||||
".auth/**",
|
||||
".claude/**",
|
||||
"autogpt_platform/.env*",
|
||||
"autogpt_platform/backend/.env*",
|
||||
"autogpt_platform/frontend/.env*",
|
||||
"autogpt_platform/frontend/.auth/**",
|
||||
"autogpt_platform/db/docker/.env*"
|
||||
],
|
||||
"worktreeCopyIgnores": [
|
||||
"**/node_modules/**",
|
||||
"**/dist/**",
|
||||
"**/.git/**",
|
||||
"**/Thumbs.db",
|
||||
"**/.DS_Store",
|
||||
"**/.next/**",
|
||||
"**/__pycache__/**",
|
||||
"**/.ruff_cache/**",
|
||||
"**/.pytest_cache/**",
|
||||
"**/*.pyc",
|
||||
"**/playwright-report/**",
|
||||
"**/logs/**",
|
||||
"**/site/**"
|
||||
],
|
||||
"worktreePathTemplate": "$BASE_PATH.worktree",
|
||||
"postCreateCmd": [
|
||||
"cd autogpt_platform/autogpt_libs && poetry install",
|
||||
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
||||
"cd autogpt_platform/frontend && pnpm install",
|
||||
"cd docs && pip install -r requirements.txt"
|
||||
],
|
||||
"terminalCommand": "code .",
|
||||
"deleteBranchWithWorktree": false
|
||||
}
|
||||
2249
.claude/skills/vercel-react-best-practices/AGENTS.md
Normal file
2249
.claude/skills/vercel-react-best-practices/AGENTS.md
Normal file
File diff suppressed because it is too large
Load Diff
125
.claude/skills/vercel-react-best-practices/SKILL.md
Normal file
125
.claude/skills/vercel-react-best-practices/SKILL.md
Normal file
@@ -0,0 +1,125 @@
|
||||
---
|
||||
name: vercel-react-best-practices
|
||||
description: React and Next.js performance optimization guidelines from Vercel Engineering. This skill should be used when writing, reviewing, or refactoring React/Next.js code to ensure optimal performance patterns. Triggers on tasks involving React components, Next.js pages, data fetching, bundle optimization, or performance improvements.
|
||||
license: MIT
|
||||
metadata:
|
||||
author: vercel
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# Vercel React Best Practices
|
||||
|
||||
Comprehensive performance optimization guide for React and Next.js applications, maintained by Vercel. Contains 45 rules across 8 categories, prioritized by impact to guide automated refactoring and code generation.
|
||||
|
||||
## When to Apply
|
||||
|
||||
Reference these guidelines when:
|
||||
- Writing new React components or Next.js pages
|
||||
- Implementing data fetching (client or server-side)
|
||||
- Reviewing code for performance issues
|
||||
- Refactoring existing React/Next.js code
|
||||
- Optimizing bundle size or load times
|
||||
|
||||
## Rule Categories by Priority
|
||||
|
||||
| Priority | Category | Impact | Prefix |
|
||||
|----------|----------|--------|--------|
|
||||
| 1 | Eliminating Waterfalls | CRITICAL | `async-` |
|
||||
| 2 | Bundle Size Optimization | CRITICAL | `bundle-` |
|
||||
| 3 | Server-Side Performance | HIGH | `server-` |
|
||||
| 4 | Client-Side Data Fetching | MEDIUM-HIGH | `client-` |
|
||||
| 5 | Re-render Optimization | MEDIUM | `rerender-` |
|
||||
| 6 | Rendering Performance | MEDIUM | `rendering-` |
|
||||
| 7 | JavaScript Performance | LOW-MEDIUM | `js-` |
|
||||
| 8 | Advanced Patterns | LOW | `advanced-` |
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### 1. Eliminating Waterfalls (CRITICAL)
|
||||
|
||||
- `async-defer-await` - Move await into branches where actually used
|
||||
- `async-parallel` - Use Promise.all() for independent operations
|
||||
- `async-dependencies` - Use better-all for partial dependencies
|
||||
- `async-api-routes` - Start promises early, await late in API routes
|
||||
- `async-suspense-boundaries` - Use Suspense to stream content
|
||||
|
||||
### 2. Bundle Size Optimization (CRITICAL)
|
||||
|
||||
- `bundle-barrel-imports` - Import directly, avoid barrel files
|
||||
- `bundle-dynamic-imports` - Use next/dynamic for heavy components
|
||||
- `bundle-defer-third-party` - Load analytics/logging after hydration
|
||||
- `bundle-conditional` - Load modules only when feature is activated
|
||||
- `bundle-preload` - Preload on hover/focus for perceived speed
|
||||
|
||||
### 3. Server-Side Performance (HIGH)
|
||||
|
||||
- `server-cache-react` - Use React.cache() for per-request deduplication
|
||||
- `server-cache-lru` - Use LRU cache for cross-request caching
|
||||
- `server-serialization` - Minimize data passed to client components
|
||||
- `server-parallel-fetching` - Restructure components to parallelize fetches
|
||||
- `server-after-nonblocking` - Use after() for non-blocking operations
|
||||
|
||||
### 4. Client-Side Data Fetching (MEDIUM-HIGH)
|
||||
|
||||
- `client-swr-dedup` - Use SWR for automatic request deduplication
|
||||
- `client-event-listeners` - Deduplicate global event listeners
|
||||
|
||||
### 5. Re-render Optimization (MEDIUM)
|
||||
|
||||
- `rerender-defer-reads` - Don't subscribe to state only used in callbacks
|
||||
- `rerender-memo` - Extract expensive work into memoized components
|
||||
- `rerender-dependencies` - Use primitive dependencies in effects
|
||||
- `rerender-derived-state` - Subscribe to derived booleans, not raw values
|
||||
- `rerender-functional-setstate` - Use functional setState for stable callbacks
|
||||
- `rerender-lazy-state-init` - Pass function to useState for expensive values
|
||||
- `rerender-transitions` - Use startTransition for non-urgent updates
|
||||
|
||||
### 6. Rendering Performance (MEDIUM)
|
||||
|
||||
- `rendering-animate-svg-wrapper` - Animate div wrapper, not SVG element
|
||||
- `rendering-content-visibility` - Use content-visibility for long lists
|
||||
- `rendering-hoist-jsx` - Extract static JSX outside components
|
||||
- `rendering-svg-precision` - Reduce SVG coordinate precision
|
||||
- `rendering-hydration-no-flicker` - Use inline script for client-only data
|
||||
- `rendering-activity` - Use Activity component for show/hide
|
||||
- `rendering-conditional-render` - Use ternary, not && for conditionals
|
||||
|
||||
### 7. JavaScript Performance (LOW-MEDIUM)
|
||||
|
||||
- `js-batch-dom-css` - Group CSS changes via classes or cssText
|
||||
- `js-index-maps` - Build Map for repeated lookups
|
||||
- `js-cache-property-access` - Cache object properties in loops
|
||||
- `js-cache-function-results` - Cache function results in module-level Map
|
||||
- `js-cache-storage` - Cache localStorage/sessionStorage reads
|
||||
- `js-combine-iterations` - Combine multiple filter/map into one loop
|
||||
- `js-length-check-first` - Check array length before expensive comparison
|
||||
- `js-early-exit` - Return early from functions
|
||||
- `js-hoist-regexp` - Hoist RegExp creation outside loops
|
||||
- `js-min-max-loop` - Use loop for min/max instead of sort
|
||||
- `js-set-map-lookups` - Use Set/Map for O(1) lookups
|
||||
- `js-tosorted-immutable` - Use toSorted() for immutability
|
||||
|
||||
### 8. Advanced Patterns (LOW)
|
||||
|
||||
- `advanced-event-handler-refs` - Store event handlers in refs
|
||||
- `advanced-use-latest` - useLatest for stable callback refs
|
||||
|
||||
## How to Use
|
||||
|
||||
Read individual rule files for detailed explanations and code examples:
|
||||
|
||||
```
|
||||
rules/async-parallel.md
|
||||
rules/bundle-barrel-imports.md
|
||||
rules/_sections.md
|
||||
```
|
||||
|
||||
Each rule file contains:
|
||||
- Brief explanation of why it matters
|
||||
- Incorrect code example with explanation
|
||||
- Correct code example with explanation
|
||||
- Additional context and references
|
||||
|
||||
## Full Compiled Document
|
||||
|
||||
For the complete guide with all rules expanded: `AGENTS.md`
|
||||
@@ -0,0 +1,55 @@
|
||||
---
|
||||
title: Store Event Handlers in Refs
|
||||
impact: LOW
|
||||
impactDescription: stable subscriptions
|
||||
tags: advanced, hooks, refs, event-handlers, optimization
|
||||
---
|
||||
|
||||
## Store Event Handlers in Refs
|
||||
|
||||
Store callbacks in refs when used in effects that shouldn't re-subscribe on callback changes.
|
||||
|
||||
**Incorrect (re-subscribes on every render):**
|
||||
|
||||
```tsx
|
||||
function useWindowEvent(event: string, handler: () => void) {
|
||||
useEffect(() => {
|
||||
window.addEventListener(event, handler)
|
||||
return () => window.removeEventListener(event, handler)
|
||||
}, [event, handler])
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (stable subscription):**
|
||||
|
||||
```tsx
|
||||
function useWindowEvent(event: string, handler: () => void) {
|
||||
const handlerRef = useRef(handler)
|
||||
useEffect(() => {
|
||||
handlerRef.current = handler
|
||||
}, [handler])
|
||||
|
||||
useEffect(() => {
|
||||
const listener = () => handlerRef.current()
|
||||
window.addEventListener(event, listener)
|
||||
return () => window.removeEventListener(event, listener)
|
||||
}, [event])
|
||||
}
|
||||
```
|
||||
|
||||
**Alternative: use `useEffectEvent` if you're on latest React:**
|
||||
|
||||
```tsx
|
||||
import { useEffectEvent } from 'react'
|
||||
|
||||
function useWindowEvent(event: string, handler: () => void) {
|
||||
const onEvent = useEffectEvent(handler)
|
||||
|
||||
useEffect(() => {
|
||||
window.addEventListener(event, onEvent)
|
||||
return () => window.removeEventListener(event, onEvent)
|
||||
}, [event])
|
||||
}
|
||||
```
|
||||
|
||||
`useEffectEvent` provides a cleaner API for the same pattern: it creates a stable function reference that always calls the latest version of the handler.
|
||||
@@ -0,0 +1,49 @@
|
||||
---
|
||||
title: useLatest for Stable Callback Refs
|
||||
impact: LOW
|
||||
impactDescription: prevents effect re-runs
|
||||
tags: advanced, hooks, useLatest, refs, optimization
|
||||
---
|
||||
|
||||
## useLatest for Stable Callback Refs
|
||||
|
||||
Access latest values in callbacks without adding them to dependency arrays. Prevents effect re-runs while avoiding stale closures.
|
||||
|
||||
**Implementation:**
|
||||
|
||||
```typescript
|
||||
function useLatest<T>(value: T) {
|
||||
const ref = useRef(value)
|
||||
useEffect(() => {
|
||||
ref.current = value
|
||||
}, [value])
|
||||
return ref
|
||||
}
|
||||
```
|
||||
|
||||
**Incorrect (effect re-runs on every callback change):**
|
||||
|
||||
```tsx
|
||||
function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
|
||||
const [query, setQuery] = useState('')
|
||||
|
||||
useEffect(() => {
|
||||
const timeout = setTimeout(() => onSearch(query), 300)
|
||||
return () => clearTimeout(timeout)
|
||||
}, [query, onSearch])
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (stable effect, fresh callback):**
|
||||
|
||||
```tsx
|
||||
function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
|
||||
const [query, setQuery] = useState('')
|
||||
const onSearchRef = useLatest(onSearch)
|
||||
|
||||
useEffect(() => {
|
||||
const timeout = setTimeout(() => onSearchRef.current(query), 300)
|
||||
return () => clearTimeout(timeout)
|
||||
}, [query])
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,38 @@
|
||||
---
|
||||
title: Prevent Waterfall Chains in API Routes
|
||||
impact: CRITICAL
|
||||
impactDescription: 2-10× improvement
|
||||
tags: api-routes, server-actions, waterfalls, parallelization
|
||||
---
|
||||
|
||||
## Prevent Waterfall Chains in API Routes
|
||||
|
||||
In API routes and Server Actions, start independent operations immediately, even if you don't await them yet.
|
||||
|
||||
**Incorrect (config waits for auth, data waits for both):**
|
||||
|
||||
```typescript
|
||||
export async function GET(request: Request) {
|
||||
const session = await auth()
|
||||
const config = await fetchConfig()
|
||||
const data = await fetchData(session.user.id)
|
||||
return Response.json({ data, config })
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (auth and config start immediately):**
|
||||
|
||||
```typescript
|
||||
export async function GET(request: Request) {
|
||||
const sessionPromise = auth()
|
||||
const configPromise = fetchConfig()
|
||||
const session = await sessionPromise
|
||||
const [config, data] = await Promise.all([
|
||||
configPromise,
|
||||
fetchData(session.user.id)
|
||||
])
|
||||
return Response.json({ data, config })
|
||||
}
|
||||
```
|
||||
|
||||
For operations with more complex dependency chains, use `better-all` to automatically maximize parallelism (see Dependency-Based Parallelization).
|
||||
@@ -0,0 +1,80 @@
|
||||
---
|
||||
title: Defer Await Until Needed
|
||||
impact: HIGH
|
||||
impactDescription: avoids blocking unused code paths
|
||||
tags: async, await, conditional, optimization
|
||||
---
|
||||
|
||||
## Defer Await Until Needed
|
||||
|
||||
Move `await` operations into the branches where they're actually used to avoid blocking code paths that don't need them.
|
||||
|
||||
**Incorrect (blocks both branches):**
|
||||
|
||||
```typescript
|
||||
async function handleRequest(userId: string, skipProcessing: boolean) {
|
||||
const userData = await fetchUserData(userId)
|
||||
|
||||
if (skipProcessing) {
|
||||
// Returns immediately but still waited for userData
|
||||
return { skipped: true }
|
||||
}
|
||||
|
||||
// Only this branch uses userData
|
||||
return processUserData(userData)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (only blocks when needed):**
|
||||
|
||||
```typescript
|
||||
async function handleRequest(userId: string, skipProcessing: boolean) {
|
||||
if (skipProcessing) {
|
||||
// Returns immediately without waiting
|
||||
return { skipped: true }
|
||||
}
|
||||
|
||||
// Fetch only when needed
|
||||
const userData = await fetchUserData(userId)
|
||||
return processUserData(userData)
|
||||
}
|
||||
```
|
||||
|
||||
**Another example (early return optimization):**
|
||||
|
||||
```typescript
|
||||
// Incorrect: always fetches permissions
|
||||
async function updateResource(resourceId: string, userId: string) {
|
||||
const permissions = await fetchPermissions(userId)
|
||||
const resource = await getResource(resourceId)
|
||||
|
||||
if (!resource) {
|
||||
return { error: 'Not found' }
|
||||
}
|
||||
|
||||
if (!permissions.canEdit) {
|
||||
return { error: 'Forbidden' }
|
||||
}
|
||||
|
||||
return await updateResourceData(resource, permissions)
|
||||
}
|
||||
|
||||
// Correct: fetches only when needed
|
||||
async function updateResource(resourceId: string, userId: string) {
|
||||
const resource = await getResource(resourceId)
|
||||
|
||||
if (!resource) {
|
||||
return { error: 'Not found' }
|
||||
}
|
||||
|
||||
const permissions = await fetchPermissions(userId)
|
||||
|
||||
if (!permissions.canEdit) {
|
||||
return { error: 'Forbidden' }
|
||||
}
|
||||
|
||||
return await updateResourceData(resource, permissions)
|
||||
}
|
||||
```
|
||||
|
||||
This optimization is especially valuable when the skipped branch is frequently taken, or when the deferred operation is expensive.
|
||||
@@ -0,0 +1,36 @@
|
||||
---
|
||||
title: Dependency-Based Parallelization
|
||||
impact: CRITICAL
|
||||
impactDescription: 2-10× improvement
|
||||
tags: async, parallelization, dependencies, better-all
|
||||
---
|
||||
|
||||
## Dependency-Based Parallelization
|
||||
|
||||
For operations with partial dependencies, use `better-all` to maximize parallelism. It automatically starts each task at the earliest possible moment.
|
||||
|
||||
**Incorrect (profile waits for config unnecessarily):**
|
||||
|
||||
```typescript
|
||||
const [user, config] = await Promise.all([
|
||||
fetchUser(),
|
||||
fetchConfig()
|
||||
])
|
||||
const profile = await fetchProfile(user.id)
|
||||
```
|
||||
|
||||
**Correct (config and profile run in parallel):**
|
||||
|
||||
```typescript
|
||||
import { all } from 'better-all'
|
||||
|
||||
const { user, config, profile } = await all({
|
||||
async user() { return fetchUser() },
|
||||
async config() { return fetchConfig() },
|
||||
async profile() {
|
||||
return fetchProfile((await this.$.user).id)
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
Reference: [https://github.com/shuding/better-all](https://github.com/shuding/better-all)
|
||||
@@ -0,0 +1,28 @@
|
||||
---
|
||||
title: Promise.all() for Independent Operations
|
||||
impact: CRITICAL
|
||||
impactDescription: 2-10× improvement
|
||||
tags: async, parallelization, promises, waterfalls
|
||||
---
|
||||
|
||||
## Promise.all() for Independent Operations
|
||||
|
||||
When async operations have no interdependencies, execute them concurrently using `Promise.all()`.
|
||||
|
||||
**Incorrect (sequential execution, 3 round trips):**
|
||||
|
||||
```typescript
|
||||
const user = await fetchUser()
|
||||
const posts = await fetchPosts()
|
||||
const comments = await fetchComments()
|
||||
```
|
||||
|
||||
**Correct (parallel execution, 1 round trip):**
|
||||
|
||||
```typescript
|
||||
const [user, posts, comments] = await Promise.all([
|
||||
fetchUser(),
|
||||
fetchPosts(),
|
||||
fetchComments()
|
||||
])
|
||||
```
|
||||
@@ -0,0 +1,99 @@
|
||||
---
|
||||
title: Strategic Suspense Boundaries
|
||||
impact: HIGH
|
||||
impactDescription: faster initial paint
|
||||
tags: async, suspense, streaming, layout-shift
|
||||
---
|
||||
|
||||
## Strategic Suspense Boundaries
|
||||
|
||||
Instead of awaiting data in async components before returning JSX, use Suspense boundaries to show the wrapper UI faster while data loads.
|
||||
|
||||
**Incorrect (wrapper blocked by data fetching):**
|
||||
|
||||
```tsx
|
||||
async function Page() {
|
||||
const data = await fetchData() // Blocks entire page
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div>Sidebar</div>
|
||||
<div>Header</div>
|
||||
<div>
|
||||
<DataDisplay data={data} />
|
||||
</div>
|
||||
<div>Footer</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
The entire layout waits for data even though only the middle section needs it.
|
||||
|
||||
**Correct (wrapper shows immediately, data streams in):**
|
||||
|
||||
```tsx
|
||||
function Page() {
|
||||
return (
|
||||
<div>
|
||||
<div>Sidebar</div>
|
||||
<div>Header</div>
|
||||
<div>
|
||||
<Suspense fallback={<Skeleton />}>
|
||||
<DataDisplay />
|
||||
</Suspense>
|
||||
</div>
|
||||
<div>Footer</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
async function DataDisplay() {
|
||||
const data = await fetchData() // Only blocks this component
|
||||
return <div>{data.content}</div>
|
||||
}
|
||||
```
|
||||
|
||||
Sidebar, Header, and Footer render immediately. Only DataDisplay waits for data.
|
||||
|
||||
**Alternative (share promise across components):**
|
||||
|
||||
```tsx
|
||||
function Page() {
|
||||
// Start fetch immediately, but don't await
|
||||
const dataPromise = fetchData()
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div>Sidebar</div>
|
||||
<div>Header</div>
|
||||
<Suspense fallback={<Skeleton />}>
|
||||
<DataDisplay dataPromise={dataPromise} />
|
||||
<DataSummary dataPromise={dataPromise} />
|
||||
</Suspense>
|
||||
<div>Footer</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function DataDisplay({ dataPromise }: { dataPromise: Promise<Data> }) {
|
||||
const data = use(dataPromise) // Unwraps the promise
|
||||
return <div>{data.content}</div>
|
||||
}
|
||||
|
||||
function DataSummary({ dataPromise }: { dataPromise: Promise<Data> }) {
|
||||
const data = use(dataPromise) // Reuses the same promise
|
||||
return <div>{data.summary}</div>
|
||||
}
|
||||
```
|
||||
|
||||
Both components share the same promise, so only one fetch occurs. Layout renders immediately while both components wait together.
|
||||
|
||||
**When NOT to use this pattern:**
|
||||
|
||||
- Critical data needed for layout decisions (affects positioning)
|
||||
- SEO-critical content above the fold
|
||||
- Small, fast queries where suspense overhead isn't worth it
|
||||
- When you want to avoid layout shift (loading → content jump)
|
||||
|
||||
**Trade-off:** Faster initial paint vs potential layout shift. Choose based on your UX priorities.
|
||||
@@ -0,0 +1,59 @@
|
||||
---
|
||||
title: Avoid Barrel File Imports
|
||||
impact: CRITICAL
|
||||
impactDescription: 200-800ms import cost, slow builds
|
||||
tags: bundle, imports, tree-shaking, barrel-files, performance
|
||||
---
|
||||
|
||||
## Avoid Barrel File Imports
|
||||
|
||||
Import directly from source files instead of barrel files to avoid loading thousands of unused modules. **Barrel files** are entry points that re-export multiple modules (e.g., `index.js` that does `export * from './module'`).
|
||||
|
||||
Popular icon and component libraries can have **up to 10,000 re-exports** in their entry file. For many React packages, **it takes 200-800ms just to import them**, affecting both development speed and production cold starts.
|
||||
|
||||
**Why tree-shaking doesn't help:** When a library is marked as external (not bundled), the bundler can't optimize it. If you bundle it to enable tree-shaking, builds become substantially slower analyzing the entire module graph.
|
||||
|
||||
**Incorrect (imports entire library):**
|
||||
|
||||
```tsx
|
||||
import { Check, X, Menu } from 'lucide-react'
|
||||
// Loads 1,583 modules, takes ~2.8s extra in dev
|
||||
// Runtime cost: 200-800ms on every cold start
|
||||
|
||||
import { Button, TextField } from '@mui/material'
|
||||
// Loads 2,225 modules, takes ~4.2s extra in dev
|
||||
```
|
||||
|
||||
**Correct (imports only what you need):**
|
||||
|
||||
```tsx
|
||||
import Check from 'lucide-react/dist/esm/icons/check'
|
||||
import X from 'lucide-react/dist/esm/icons/x'
|
||||
import Menu from 'lucide-react/dist/esm/icons/menu'
|
||||
// Loads only 3 modules (~2KB vs ~1MB)
|
||||
|
||||
import Button from '@mui/material/Button'
|
||||
import TextField from '@mui/material/TextField'
|
||||
// Loads only what you use
|
||||
```
|
||||
|
||||
**Alternative (Next.js 13.5+):**
|
||||
|
||||
```js
|
||||
// next.config.js - use optimizePackageImports
|
||||
module.exports = {
|
||||
experimental: {
|
||||
optimizePackageImports: ['lucide-react', '@mui/material']
|
||||
}
|
||||
}
|
||||
|
||||
// Then you can keep the ergonomic barrel imports:
|
||||
import { Check, X, Menu } from 'lucide-react'
|
||||
// Automatically transformed to direct imports at build time
|
||||
```
|
||||
|
||||
Direct imports provide 15-70% faster dev boot, 28% faster builds, 40% faster cold starts, and significantly faster HMR.
|
||||
|
||||
Libraries commonly affected: `lucide-react`, `@mui/material`, `@mui/icons-material`, `@tabler/icons-react`, `react-icons`, `@headlessui/react`, `@radix-ui/react-*`, `lodash`, `ramda`, `date-fns`, `rxjs`, `react-use`.
|
||||
|
||||
Reference: [How we optimized package imports in Next.js](https://vercel.com/blog/how-we-optimized-package-imports-in-next-js)
|
||||
@@ -0,0 +1,31 @@
|
||||
---
|
||||
title: Conditional Module Loading
|
||||
impact: HIGH
|
||||
impactDescription: loads large data only when needed
|
||||
tags: bundle, conditional-loading, lazy-loading
|
||||
---
|
||||
|
||||
## Conditional Module Loading
|
||||
|
||||
Load large data or modules only when a feature is activated.
|
||||
|
||||
**Example (lazy-load animation frames):**
|
||||
|
||||
```tsx
|
||||
function AnimationPlayer({ enabled }: { enabled: boolean }) {
|
||||
const [frames, setFrames] = useState<Frame[] | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (enabled && !frames && typeof window !== 'undefined') {
|
||||
import('./animation-frames.js')
|
||||
.then(mod => setFrames(mod.frames))
|
||||
.catch(() => setEnabled(false))
|
||||
}
|
||||
}, [enabled, frames])
|
||||
|
||||
if (!frames) return <Skeleton />
|
||||
return <Canvas frames={frames} />
|
||||
}
|
||||
```
|
||||
|
||||
The `typeof window !== 'undefined'` check prevents bundling this module for SSR, optimizing server bundle size and build speed.
|
||||
@@ -0,0 +1,49 @@
|
||||
---
|
||||
title: Defer Non-Critical Third-Party Libraries
|
||||
impact: MEDIUM
|
||||
impactDescription: loads after hydration
|
||||
tags: bundle, third-party, analytics, defer
|
||||
---
|
||||
|
||||
## Defer Non-Critical Third-Party Libraries
|
||||
|
||||
Analytics, logging, and error tracking don't block user interaction. Load them after hydration.
|
||||
|
||||
**Incorrect (blocks initial bundle):**
|
||||
|
||||
```tsx
|
||||
import { Analytics } from '@vercel/analytics/react'
|
||||
|
||||
export default function RootLayout({ children }) {
|
||||
return (
|
||||
<html>
|
||||
<body>
|
||||
{children}
|
||||
<Analytics />
|
||||
</body>
|
||||
</html>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (loads after hydration):**
|
||||
|
||||
```tsx
|
||||
import dynamic from 'next/dynamic'
|
||||
|
||||
const Analytics = dynamic(
|
||||
() => import('@vercel/analytics/react').then(m => m.Analytics),
|
||||
{ ssr: false }
|
||||
)
|
||||
|
||||
export default function RootLayout({ children }) {
|
||||
return (
|
||||
<html>
|
||||
<body>
|
||||
{children}
|
||||
<Analytics />
|
||||
</body>
|
||||
</html>
|
||||
)
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,35 @@
|
||||
---
|
||||
title: Dynamic Imports for Heavy Components
|
||||
impact: CRITICAL
|
||||
impactDescription: directly affects TTI and LCP
|
||||
tags: bundle, dynamic-import, code-splitting, next-dynamic
|
||||
---
|
||||
|
||||
## Dynamic Imports for Heavy Components
|
||||
|
||||
Use `next/dynamic` to lazy-load large components not needed on initial render.
|
||||
|
||||
**Incorrect (Monaco bundles with main chunk ~300KB):**
|
||||
|
||||
```tsx
|
||||
import { MonacoEditor } from './monaco-editor'
|
||||
|
||||
function CodePanel({ code }: { code: string }) {
|
||||
return <MonacoEditor value={code} />
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (Monaco loads on demand):**
|
||||
|
||||
```tsx
|
||||
import dynamic from 'next/dynamic'
|
||||
|
||||
const MonacoEditor = dynamic(
|
||||
() => import('./monaco-editor').then(m => m.MonacoEditor),
|
||||
{ ssr: false }
|
||||
)
|
||||
|
||||
function CodePanel({ code }: { code: string }) {
|
||||
return <MonacoEditor value={code} />
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,50 @@
|
||||
---
|
||||
title: Preload Based on User Intent
|
||||
impact: MEDIUM
|
||||
impactDescription: reduces perceived latency
|
||||
tags: bundle, preload, user-intent, hover
|
||||
---
|
||||
|
||||
## Preload Based on User Intent
|
||||
|
||||
Preload heavy bundles before they're needed to reduce perceived latency.
|
||||
|
||||
**Example (preload on hover/focus):**
|
||||
|
||||
```tsx
|
||||
function EditorButton({ onClick }: { onClick: () => void }) {
|
||||
const preload = () => {
|
||||
if (typeof window !== 'undefined') {
|
||||
void import('./monaco-editor')
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<button
|
||||
onMouseEnter={preload}
|
||||
onFocus={preload}
|
||||
onClick={onClick}
|
||||
>
|
||||
Open Editor
|
||||
</button>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Example (preload when feature flag is enabled):**
|
||||
|
||||
```tsx
|
||||
function FlagsProvider({ children, flags }: Props) {
|
||||
useEffect(() => {
|
||||
if (flags.editorEnabled && typeof window !== 'undefined') {
|
||||
void import('./monaco-editor').then(mod => mod.init())
|
||||
}
|
||||
}, [flags.editorEnabled])
|
||||
|
||||
return <FlagsContext.Provider value={flags}>
|
||||
{children}
|
||||
</FlagsContext.Provider>
|
||||
}
|
||||
```
|
||||
|
||||
The `typeof window !== 'undefined'` check prevents bundling preloaded modules for SSR, optimizing server bundle size and build speed.
|
||||
@@ -0,0 +1,74 @@
|
||||
---
|
||||
title: Deduplicate Global Event Listeners
|
||||
impact: LOW
|
||||
impactDescription: single listener for N components
|
||||
tags: client, swr, event-listeners, subscription
|
||||
---
|
||||
|
||||
## Deduplicate Global Event Listeners
|
||||
|
||||
Use `useSWRSubscription()` to share global event listeners across component instances.
|
||||
|
||||
**Incorrect (N instances = N listeners):**
|
||||
|
||||
```tsx
|
||||
function useKeyboardShortcut(key: string, callback: () => void) {
|
||||
useEffect(() => {
|
||||
const handler = (e: KeyboardEvent) => {
|
||||
if (e.metaKey && e.key === key) {
|
||||
callback()
|
||||
}
|
||||
}
|
||||
window.addEventListener('keydown', handler)
|
||||
return () => window.removeEventListener('keydown', handler)
|
||||
}, [key, callback])
|
||||
}
|
||||
```
|
||||
|
||||
When using the `useKeyboardShortcut` hook multiple times, each instance will register a new listener.
|
||||
|
||||
**Correct (N instances = 1 listener):**
|
||||
|
||||
```tsx
|
||||
import useSWRSubscription from 'swr/subscription'
|
||||
|
||||
// Module-level Map to track callbacks per key
|
||||
const keyCallbacks = new Map<string, Set<() => void>>()
|
||||
|
||||
function useKeyboardShortcut(key: string, callback: () => void) {
|
||||
// Register this callback in the Map
|
||||
useEffect(() => {
|
||||
if (!keyCallbacks.has(key)) {
|
||||
keyCallbacks.set(key, new Set())
|
||||
}
|
||||
keyCallbacks.get(key)!.add(callback)
|
||||
|
||||
return () => {
|
||||
const set = keyCallbacks.get(key)
|
||||
if (set) {
|
||||
set.delete(callback)
|
||||
if (set.size === 0) {
|
||||
keyCallbacks.delete(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}, [key, callback])
|
||||
|
||||
useSWRSubscription('global-keydown', () => {
|
||||
const handler = (e: KeyboardEvent) => {
|
||||
if (e.metaKey && keyCallbacks.has(e.key)) {
|
||||
keyCallbacks.get(e.key)!.forEach(cb => cb())
|
||||
}
|
||||
}
|
||||
window.addEventListener('keydown', handler)
|
||||
return () => window.removeEventListener('keydown', handler)
|
||||
})
|
||||
}
|
||||
|
||||
function Profile() {
|
||||
// Multiple shortcuts will share the same listener
|
||||
useKeyboardShortcut('p', () => { /* ... */ })
|
||||
useKeyboardShortcut('k', () => { /* ... */ })
|
||||
// ...
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,56 @@
|
||||
---
|
||||
title: Use SWR for Automatic Deduplication
|
||||
impact: MEDIUM-HIGH
|
||||
impactDescription: automatic deduplication
|
||||
tags: client, swr, deduplication, data-fetching
|
||||
---
|
||||
|
||||
## Use SWR for Automatic Deduplication
|
||||
|
||||
SWR enables request deduplication, caching, and revalidation across component instances.
|
||||
|
||||
**Incorrect (no deduplication, each instance fetches):**
|
||||
|
||||
```tsx
|
||||
function UserList() {
|
||||
const [users, setUsers] = useState([])
|
||||
useEffect(() => {
|
||||
fetch('/api/users')
|
||||
.then(r => r.json())
|
||||
.then(setUsers)
|
||||
}, [])
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (multiple instances share one request):**
|
||||
|
||||
```tsx
|
||||
import useSWR from 'swr'
|
||||
|
||||
function UserList() {
|
||||
const { data: users } = useSWR('/api/users', fetcher)
|
||||
}
|
||||
```
|
||||
|
||||
**For immutable data:**
|
||||
|
||||
```tsx
|
||||
import { useImmutableSWR } from '@/lib/swr'
|
||||
|
||||
function StaticContent() {
|
||||
const { data } = useImmutableSWR('/api/config', fetcher)
|
||||
}
|
||||
```
|
||||
|
||||
**For mutations:**
|
||||
|
||||
```tsx
|
||||
import { useSWRMutation } from 'swr/mutation'
|
||||
|
||||
function UpdateButton() {
|
||||
const { trigger } = useSWRMutation('/api/user', updateUser)
|
||||
return <button onClick={() => trigger()}>Update</button>
|
||||
}
|
||||
```
|
||||
|
||||
Reference: [https://swr.vercel.app](https://swr.vercel.app)
|
||||
@@ -0,0 +1,82 @@
|
||||
---
|
||||
title: Batch DOM CSS Changes
|
||||
impact: MEDIUM
|
||||
impactDescription: reduces reflows/repaints
|
||||
tags: javascript, dom, css, performance, reflow
|
||||
---
|
||||
|
||||
## Batch DOM CSS Changes
|
||||
|
||||
Avoid changing styles one property at a time. Group multiple CSS changes together via classes or `cssText` to minimize browser reflows.
|
||||
|
||||
**Incorrect (multiple reflows):**
|
||||
|
||||
```typescript
|
||||
function updateElementStyles(element: HTMLElement) {
|
||||
// Each line triggers a reflow
|
||||
element.style.width = '100px'
|
||||
element.style.height = '200px'
|
||||
element.style.backgroundColor = 'blue'
|
||||
element.style.border = '1px solid black'
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (add class - single reflow):**
|
||||
|
||||
```typescript
|
||||
// CSS file
|
||||
.highlighted-box {
|
||||
width: 100px;
|
||||
height: 200px;
|
||||
background-color: blue;
|
||||
border: 1px solid black;
|
||||
}
|
||||
|
||||
// JavaScript
|
||||
function updateElementStyles(element: HTMLElement) {
|
||||
element.classList.add('highlighted-box')
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (change cssText - single reflow):**
|
||||
|
||||
```typescript
|
||||
function updateElementStyles(element: HTMLElement) {
|
||||
element.style.cssText = `
|
||||
width: 100px;
|
||||
height: 200px;
|
||||
background-color: blue;
|
||||
border: 1px solid black;
|
||||
`
|
||||
}
|
||||
```
|
||||
|
||||
**React example:**
|
||||
|
||||
```tsx
|
||||
// Incorrect: changing styles one by one
|
||||
function Box({ isHighlighted }: { isHighlighted: boolean }) {
|
||||
const ref = useRef<HTMLDivElement>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (ref.current && isHighlighted) {
|
||||
ref.current.style.width = '100px'
|
||||
ref.current.style.height = '200px'
|
||||
ref.current.style.backgroundColor = 'blue'
|
||||
}
|
||||
}, [isHighlighted])
|
||||
|
||||
return <div ref={ref}>Content</div>
|
||||
}
|
||||
|
||||
// Correct: toggle class
|
||||
function Box({ isHighlighted }: { isHighlighted: boolean }) {
|
||||
return (
|
||||
<div className={isHighlighted ? 'highlighted-box' : ''}>
|
||||
Content
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
Prefer CSS classes over inline styles when possible. Classes are cached by the browser and provide better separation of concerns.
|
||||
@@ -0,0 +1,80 @@
|
||||
---
|
||||
title: Cache Repeated Function Calls
|
||||
impact: MEDIUM
|
||||
impactDescription: avoid redundant computation
|
||||
tags: javascript, cache, memoization, performance
|
||||
---
|
||||
|
||||
## Cache Repeated Function Calls
|
||||
|
||||
Use a module-level Map to cache function results when the same function is called repeatedly with the same inputs during render.
|
||||
|
||||
**Incorrect (redundant computation):**
|
||||
|
||||
```typescript
|
||||
function ProjectList({ projects }: { projects: Project[] }) {
|
||||
return (
|
||||
<div>
|
||||
{projects.map(project => {
|
||||
// slugify() called 100+ times for same project names
|
||||
const slug = slugify(project.name)
|
||||
|
||||
return <ProjectCard key={project.id} slug={slug} />
|
||||
})}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (cached results):**
|
||||
|
||||
```typescript
|
||||
// Module-level cache
|
||||
const slugifyCache = new Map<string, string>()
|
||||
|
||||
function cachedSlugify(text: string): string {
|
||||
if (slugifyCache.has(text)) {
|
||||
return slugifyCache.get(text)!
|
||||
}
|
||||
const result = slugify(text)
|
||||
slugifyCache.set(text, result)
|
||||
return result
|
||||
}
|
||||
|
||||
function ProjectList({ projects }: { projects: Project[] }) {
|
||||
return (
|
||||
<div>
|
||||
{projects.map(project => {
|
||||
// Computed only once per unique project name
|
||||
const slug = cachedSlugify(project.name)
|
||||
|
||||
return <ProjectCard key={project.id} slug={slug} />
|
||||
})}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Simpler pattern for single-value functions:**
|
||||
|
||||
```typescript
|
||||
let isLoggedInCache: boolean | null = null
|
||||
|
||||
function isLoggedIn(): boolean {
|
||||
if (isLoggedInCache !== null) {
|
||||
return isLoggedInCache
|
||||
}
|
||||
|
||||
isLoggedInCache = document.cookie.includes('auth=')
|
||||
return isLoggedInCache
|
||||
}
|
||||
|
||||
// Clear cache when auth changes
|
||||
function onAuthChange() {
|
||||
isLoggedInCache = null
|
||||
}
|
||||
```
|
||||
|
||||
Use a Map (not a hook) so it works everywhere: utilities, event handlers, not just React components.
|
||||
|
||||
Reference: [How we made the Vercel Dashboard twice as fast](https://vercel.com/blog/how-we-made-the-vercel-dashboard-twice-as-fast)
|
||||
@@ -0,0 +1,28 @@
|
||||
---
|
||||
title: Cache Property Access in Loops
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: reduces lookups
|
||||
tags: javascript, loops, optimization, caching
|
||||
---
|
||||
|
||||
## Cache Property Access in Loops
|
||||
|
||||
Cache object property lookups in hot paths.
|
||||
|
||||
**Incorrect (3 lookups × N iterations):**
|
||||
|
||||
```typescript
|
||||
for (let i = 0; i < arr.length; i++) {
|
||||
process(obj.config.settings.value)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (1 lookup total):**
|
||||
|
||||
```typescript
|
||||
const value = obj.config.settings.value
|
||||
const len = arr.length
|
||||
for (let i = 0; i < len; i++) {
|
||||
process(value)
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,70 @@
|
||||
---
|
||||
title: Cache Storage API Calls
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: reduces expensive I/O
|
||||
tags: javascript, localStorage, storage, caching, performance
|
||||
---
|
||||
|
||||
## Cache Storage API Calls
|
||||
|
||||
`localStorage`, `sessionStorage`, and `document.cookie` are synchronous and expensive. Cache reads in memory.
|
||||
|
||||
**Incorrect (reads storage on every call):**
|
||||
|
||||
```typescript
|
||||
function getTheme() {
|
||||
return localStorage.getItem('theme') ?? 'light'
|
||||
}
|
||||
// Called 10 times = 10 storage reads
|
||||
```
|
||||
|
||||
**Correct (Map cache):**
|
||||
|
||||
```typescript
|
||||
const storageCache = new Map<string, string | null>()
|
||||
|
||||
function getLocalStorage(key: string) {
|
||||
if (!storageCache.has(key)) {
|
||||
storageCache.set(key, localStorage.getItem(key))
|
||||
}
|
||||
return storageCache.get(key)
|
||||
}
|
||||
|
||||
function setLocalStorage(key: string, value: string) {
|
||||
localStorage.setItem(key, value)
|
||||
storageCache.set(key, value) // keep cache in sync
|
||||
}
|
||||
```
|
||||
|
||||
Use a Map (not a hook) so it works everywhere: utilities, event handlers, not just React components.
|
||||
|
||||
**Cookie caching:**
|
||||
|
||||
```typescript
|
||||
let cookieCache: Record<string, string> | null = null
|
||||
|
||||
function getCookie(name: string) {
|
||||
if (!cookieCache) {
|
||||
cookieCache = Object.fromEntries(
|
||||
document.cookie.split('; ').map(c => c.split('='))
|
||||
)
|
||||
}
|
||||
return cookieCache[name]
|
||||
}
|
||||
```
|
||||
|
||||
**Important (invalidate on external changes):**
|
||||
|
||||
If storage can change externally (another tab, server-set cookies), invalidate cache:
|
||||
|
||||
```typescript
|
||||
window.addEventListener('storage', (e) => {
|
||||
if (e.key) storageCache.delete(e.key)
|
||||
})
|
||||
|
||||
document.addEventListener('visibilitychange', () => {
|
||||
if (document.visibilityState === 'visible') {
|
||||
storageCache.clear()
|
||||
}
|
||||
})
|
||||
```
|
||||
@@ -0,0 +1,32 @@
|
||||
---
|
||||
title: Combine Multiple Array Iterations
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: reduces iterations
|
||||
tags: javascript, arrays, loops, performance
|
||||
---
|
||||
|
||||
## Combine Multiple Array Iterations
|
||||
|
||||
Multiple `.filter()` or `.map()` calls iterate the array multiple times. Combine into one loop.
|
||||
|
||||
**Incorrect (3 iterations):**
|
||||
|
||||
```typescript
|
||||
const admins = users.filter(u => u.isAdmin)
|
||||
const testers = users.filter(u => u.isTester)
|
||||
const inactive = users.filter(u => !u.isActive)
|
||||
```
|
||||
|
||||
**Correct (1 iteration):**
|
||||
|
||||
```typescript
|
||||
const admins: User[] = []
|
||||
const testers: User[] = []
|
||||
const inactive: User[] = []
|
||||
|
||||
for (const user of users) {
|
||||
if (user.isAdmin) admins.push(user)
|
||||
if (user.isTester) testers.push(user)
|
||||
if (!user.isActive) inactive.push(user)
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,50 @@
|
||||
---
|
||||
title: Early Return from Functions
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: avoids unnecessary computation
|
||||
tags: javascript, functions, optimization, early-return
|
||||
---
|
||||
|
||||
## Early Return from Functions
|
||||
|
||||
Return early when result is determined to skip unnecessary processing.
|
||||
|
||||
**Incorrect (processes all items even after finding answer):**
|
||||
|
||||
```typescript
|
||||
function validateUsers(users: User[]) {
|
||||
let hasError = false
|
||||
let errorMessage = ''
|
||||
|
||||
for (const user of users) {
|
||||
if (!user.email) {
|
||||
hasError = true
|
||||
errorMessage = 'Email required'
|
||||
}
|
||||
if (!user.name) {
|
||||
hasError = true
|
||||
errorMessage = 'Name required'
|
||||
}
|
||||
// Continues checking all users even after error found
|
||||
}
|
||||
|
||||
return hasError ? { valid: false, error: errorMessage } : { valid: true }
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (returns immediately on first error):**
|
||||
|
||||
```typescript
|
||||
function validateUsers(users: User[]) {
|
||||
for (const user of users) {
|
||||
if (!user.email) {
|
||||
return { valid: false, error: 'Email required' }
|
||||
}
|
||||
if (!user.name) {
|
||||
return { valid: false, error: 'Name required' }
|
||||
}
|
||||
}
|
||||
|
||||
return { valid: true }
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,45 @@
|
||||
---
|
||||
title: Hoist RegExp Creation
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: avoids recreation
|
||||
tags: javascript, regexp, optimization, memoization
|
||||
---
|
||||
|
||||
## Hoist RegExp Creation
|
||||
|
||||
Don't create RegExp inside render. Hoist to module scope or memoize with `useMemo()`.
|
||||
|
||||
**Incorrect (new RegExp every render):**
|
||||
|
||||
```tsx
|
||||
function Highlighter({ text, query }: Props) {
|
||||
const regex = new RegExp(`(${query})`, 'gi')
|
||||
const parts = text.split(regex)
|
||||
return <>{parts.map((part, i) => ...)}</>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (memoize or hoist):**
|
||||
|
||||
```tsx
|
||||
const EMAIL_REGEX = /^[^\s@]+@[^\s@]+\.[^\s@]+$/
|
||||
|
||||
function Highlighter({ text, query }: Props) {
|
||||
const regex = useMemo(
|
||||
() => new RegExp(`(${escapeRegex(query)})`, 'gi'),
|
||||
[query]
|
||||
)
|
||||
const parts = text.split(regex)
|
||||
return <>{parts.map((part, i) => ...)}</>
|
||||
}
|
||||
```
|
||||
|
||||
**Warning (global regex has mutable state):**
|
||||
|
||||
Global regex (`/g`) has mutable `lastIndex` state:
|
||||
|
||||
```typescript
|
||||
const regex = /foo/g
|
||||
regex.test('foo') // true, lastIndex = 3
|
||||
regex.test('foo') // false, lastIndex = 0
|
||||
```
|
||||
@@ -0,0 +1,37 @@
|
||||
---
|
||||
title: Build Index Maps for Repeated Lookups
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: 1M ops to 2K ops
|
||||
tags: javascript, map, indexing, optimization, performance
|
||||
---
|
||||
|
||||
## Build Index Maps for Repeated Lookups
|
||||
|
||||
Multiple `.find()` calls by the same key should use a Map.
|
||||
|
||||
**Incorrect (O(n) per lookup):**
|
||||
|
||||
```typescript
|
||||
function processOrders(orders: Order[], users: User[]) {
|
||||
return orders.map(order => ({
|
||||
...order,
|
||||
user: users.find(u => u.id === order.userId)
|
||||
}))
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (O(1) per lookup):**
|
||||
|
||||
```typescript
|
||||
function processOrders(orders: Order[], users: User[]) {
|
||||
const userById = new Map(users.map(u => [u.id, u]))
|
||||
|
||||
return orders.map(order => ({
|
||||
...order,
|
||||
user: userById.get(order.userId)
|
||||
}))
|
||||
}
|
||||
```
|
||||
|
||||
Build map once (O(n)), then all lookups are O(1).
|
||||
For 1000 orders × 1000 users: 1M ops → 2K ops.
|
||||
@@ -0,0 +1,49 @@
|
||||
---
|
||||
title: Early Length Check for Array Comparisons
|
||||
impact: MEDIUM-HIGH
|
||||
impactDescription: avoids expensive operations when lengths differ
|
||||
tags: javascript, arrays, performance, optimization, comparison
|
||||
---
|
||||
|
||||
## Early Length Check for Array Comparisons
|
||||
|
||||
When comparing arrays with expensive operations (sorting, deep equality, serialization), check lengths first. If lengths differ, the arrays cannot be equal.
|
||||
|
||||
In real-world applications, this optimization is especially valuable when the comparison runs in hot paths (event handlers, render loops).
|
||||
|
||||
**Incorrect (always runs expensive comparison):**
|
||||
|
||||
```typescript
|
||||
function hasChanges(current: string[], original: string[]) {
|
||||
// Always sorts and joins, even when lengths differ
|
||||
return current.sort().join() !== original.sort().join()
|
||||
}
|
||||
```
|
||||
|
||||
Two O(n log n) sorts run even when `current.length` is 5 and `original.length` is 100. There is also overhead of joining the arrays and comparing the strings.
|
||||
|
||||
**Correct (O(1) length check first):**
|
||||
|
||||
```typescript
|
||||
function hasChanges(current: string[], original: string[]) {
|
||||
// Early return if lengths differ
|
||||
if (current.length !== original.length) {
|
||||
return true
|
||||
}
|
||||
// Only sort/join when lengths match
|
||||
const currentSorted = current.toSorted()
|
||||
const originalSorted = original.toSorted()
|
||||
for (let i = 0; i < currentSorted.length; i++) {
|
||||
if (currentSorted[i] !== originalSorted[i]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
```
|
||||
|
||||
This new approach is more efficient because:
|
||||
- It avoids the overhead of sorting and joining the arrays when lengths differ
|
||||
- It avoids consuming memory for the joined strings (especially important for large arrays)
|
||||
- It avoids mutating the original arrays
|
||||
- It returns early when a difference is found
|
||||
@@ -0,0 +1,82 @@
|
||||
---
|
||||
title: Use Loop for Min/Max Instead of Sort
|
||||
impact: LOW
|
||||
impactDescription: O(n) instead of O(n log n)
|
||||
tags: javascript, arrays, performance, sorting, algorithms
|
||||
---
|
||||
|
||||
## Use Loop for Min/Max Instead of Sort
|
||||
|
||||
Finding the smallest or largest element only requires a single pass through the array. Sorting is wasteful and slower.
|
||||
|
||||
**Incorrect (O(n log n) - sort to find latest):**
|
||||
|
||||
```typescript
|
||||
interface Project {
|
||||
id: string
|
||||
name: string
|
||||
updatedAt: number
|
||||
}
|
||||
|
||||
function getLatestProject(projects: Project[]) {
|
||||
const sorted = [...projects].sort((a, b) => b.updatedAt - a.updatedAt)
|
||||
return sorted[0]
|
||||
}
|
||||
```
|
||||
|
||||
Sorts the entire array just to find the maximum value.
|
||||
|
||||
**Incorrect (O(n log n) - sort for oldest and newest):**
|
||||
|
||||
```typescript
|
||||
function getOldestAndNewest(projects: Project[]) {
|
||||
const sorted = [...projects].sort((a, b) => a.updatedAt - b.updatedAt)
|
||||
return { oldest: sorted[0], newest: sorted[sorted.length - 1] }
|
||||
}
|
||||
```
|
||||
|
||||
Still sorts unnecessarily when only min/max are needed.
|
||||
|
||||
**Correct (O(n) - single loop):**
|
||||
|
||||
```typescript
|
||||
function getLatestProject(projects: Project[]) {
|
||||
if (projects.length === 0) return null
|
||||
|
||||
let latest = projects[0]
|
||||
|
||||
for (let i = 1; i < projects.length; i++) {
|
||||
if (projects[i].updatedAt > latest.updatedAt) {
|
||||
latest = projects[i]
|
||||
}
|
||||
}
|
||||
|
||||
return latest
|
||||
}
|
||||
|
||||
function getOldestAndNewest(projects: Project[]) {
|
||||
if (projects.length === 0) return { oldest: null, newest: null }
|
||||
|
||||
let oldest = projects[0]
|
||||
let newest = projects[0]
|
||||
|
||||
for (let i = 1; i < projects.length; i++) {
|
||||
if (projects[i].updatedAt < oldest.updatedAt) oldest = projects[i]
|
||||
if (projects[i].updatedAt > newest.updatedAt) newest = projects[i]
|
||||
}
|
||||
|
||||
return { oldest, newest }
|
||||
}
|
||||
```
|
||||
|
||||
Single pass through the array, no copying, no sorting.
|
||||
|
||||
**Alternative (Math.min/Math.max for small arrays):**
|
||||
|
||||
```typescript
|
||||
const numbers = [5, 2, 8, 1, 9]
|
||||
const min = Math.min(...numbers)
|
||||
const max = Math.max(...numbers)
|
||||
```
|
||||
|
||||
This works for small arrays but can be slower for very large arrays due to spread operator limitations. Use the loop approach for reliability.
|
||||
@@ -0,0 +1,24 @@
|
||||
---
|
||||
title: Use Set/Map for O(1) Lookups
|
||||
impact: LOW-MEDIUM
|
||||
impactDescription: O(n) to O(1)
|
||||
tags: javascript, set, map, data-structures, performance
|
||||
---
|
||||
|
||||
## Use Set/Map for O(1) Lookups
|
||||
|
||||
Convert arrays to Set/Map for repeated membership checks.
|
||||
|
||||
**Incorrect (O(n) per check):**
|
||||
|
||||
```typescript
|
||||
const allowedIds = ['a', 'b', 'c', ...]
|
||||
items.filter(item => allowedIds.includes(item.id))
|
||||
```
|
||||
|
||||
**Correct (O(1) per check):**
|
||||
|
||||
```typescript
|
||||
const allowedIds = new Set(['a', 'b', 'c', ...])
|
||||
items.filter(item => allowedIds.has(item.id))
|
||||
```
|
||||
@@ -0,0 +1,57 @@
|
||||
---
|
||||
title: Use toSorted() Instead of sort() for Immutability
|
||||
impact: MEDIUM-HIGH
|
||||
impactDescription: prevents mutation bugs in React state
|
||||
tags: javascript, arrays, immutability, react, state, mutation
|
||||
---
|
||||
|
||||
## Use toSorted() Instead of sort() for Immutability
|
||||
|
||||
`.sort()` mutates the array in place, which can cause bugs with React state and props. Use `.toSorted()` to create a new sorted array without mutation.
|
||||
|
||||
**Incorrect (mutates original array):**
|
||||
|
||||
```typescript
|
||||
function UserList({ users }: { users: User[] }) {
|
||||
// Mutates the users prop array!
|
||||
const sorted = useMemo(
|
||||
() => users.sort((a, b) => a.name.localeCompare(b.name)),
|
||||
[users]
|
||||
)
|
||||
return <div>{sorted.map(renderUser)}</div>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (creates new array):**
|
||||
|
||||
```typescript
|
||||
function UserList({ users }: { users: User[] }) {
|
||||
// Creates new sorted array, original unchanged
|
||||
const sorted = useMemo(
|
||||
() => users.toSorted((a, b) => a.name.localeCompare(b.name)),
|
||||
[users]
|
||||
)
|
||||
return <div>{sorted.map(renderUser)}</div>
|
||||
}
|
||||
```
|
||||
|
||||
**Why this matters in React:**
|
||||
|
||||
1. Props/state mutations break React's immutability model - React expects props and state to be treated as read-only
|
||||
2. Causes stale closure bugs - Mutating arrays inside closures (callbacks, effects) can lead to unexpected behavior
|
||||
|
||||
**Browser support (fallback for older browsers):**
|
||||
|
||||
`.toSorted()` is available in all modern browsers (Chrome 110+, Safari 16+, Firefox 115+, Node.js 20+). For older environments, use spread operator:
|
||||
|
||||
```typescript
|
||||
// Fallback for older browsers
|
||||
const sorted = [...items].sort((a, b) => a.value - b.value)
|
||||
```
|
||||
|
||||
**Other immutable array methods:**
|
||||
|
||||
- `.toSorted()` - immutable sort
|
||||
- `.toReversed()` - immutable reverse
|
||||
- `.toSpliced()` - immutable splice
|
||||
- `.with()` - immutable element replacement
|
||||
@@ -0,0 +1,26 @@
|
||||
---
|
||||
title: Use Activity Component for Show/Hide
|
||||
impact: MEDIUM
|
||||
impactDescription: preserves state/DOM
|
||||
tags: rendering, activity, visibility, state-preservation
|
||||
---
|
||||
|
||||
## Use Activity Component for Show/Hide
|
||||
|
||||
Use React's `<Activity>` to preserve state/DOM for expensive components that frequently toggle visibility.
|
||||
|
||||
**Usage:**
|
||||
|
||||
```tsx
|
||||
import { Activity } from 'react'
|
||||
|
||||
function Dropdown({ isOpen }: Props) {
|
||||
return (
|
||||
<Activity mode={isOpen ? 'visible' : 'hidden'}>
|
||||
<ExpensiveMenu />
|
||||
</Activity>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
Avoids expensive re-renders and state loss.
|
||||
@@ -0,0 +1,47 @@
|
||||
---
|
||||
title: Animate SVG Wrapper Instead of SVG Element
|
||||
impact: LOW
|
||||
impactDescription: enables hardware acceleration
|
||||
tags: rendering, svg, css, animation, performance
|
||||
---
|
||||
|
||||
## Animate SVG Wrapper Instead of SVG Element
|
||||
|
||||
Many browsers don't have hardware acceleration for CSS3 animations on SVG elements. Wrap SVG in a `<div>` and animate the wrapper instead.
|
||||
|
||||
**Incorrect (animating SVG directly - no hardware acceleration):**
|
||||
|
||||
```tsx
|
||||
function LoadingSpinner() {
|
||||
return (
|
||||
<svg
|
||||
className="animate-spin"
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<circle cx="12" cy="12" r="10" stroke="currentColor" />
|
||||
</svg>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (animating wrapper div - hardware accelerated):**
|
||||
|
||||
```tsx
|
||||
function LoadingSpinner() {
|
||||
return (
|
||||
<div className="animate-spin">
|
||||
<svg
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<circle cx="12" cy="12" r="10" stroke="currentColor" />
|
||||
</svg>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
This applies to all CSS transforms and transitions (`transform`, `opacity`, `translate`, `scale`, `rotate`). The wrapper div allows browsers to use GPU acceleration for smoother animations.
|
||||
@@ -0,0 +1,40 @@
|
||||
---
|
||||
title: Use Explicit Conditional Rendering
|
||||
impact: LOW
|
||||
impactDescription: prevents rendering 0 or NaN
|
||||
tags: rendering, conditional, jsx, falsy-values
|
||||
---
|
||||
|
||||
## Use Explicit Conditional Rendering
|
||||
|
||||
Use explicit ternary operators (`? :`) instead of `&&` for conditional rendering when the condition can be `0`, `NaN`, or other falsy values that render.
|
||||
|
||||
**Incorrect (renders "0" when count is 0):**
|
||||
|
||||
```tsx
|
||||
function Badge({ count }: { count: number }) {
|
||||
return (
|
||||
<div>
|
||||
{count && <span className="badge">{count}</span>}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// When count = 0, renders: <div>0</div>
|
||||
// When count = 5, renders: <div><span class="badge">5</span></div>
|
||||
```
|
||||
|
||||
**Correct (renders nothing when count is 0):**
|
||||
|
||||
```tsx
|
||||
function Badge({ count }: { count: number }) {
|
||||
return (
|
||||
<div>
|
||||
{count > 0 ? <span className="badge">{count}</span> : null}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// When count = 0, renders: <div></div>
|
||||
// When count = 5, renders: <div><span class="badge">5</span></div>
|
||||
```
|
||||
@@ -0,0 +1,38 @@
|
||||
---
|
||||
title: CSS content-visibility for Long Lists
|
||||
impact: HIGH
|
||||
impactDescription: faster initial render
|
||||
tags: rendering, css, content-visibility, long-lists
|
||||
---
|
||||
|
||||
## CSS content-visibility for Long Lists
|
||||
|
||||
Apply `content-visibility: auto` to defer off-screen rendering.
|
||||
|
||||
**CSS:**
|
||||
|
||||
```css
|
||||
.message-item {
|
||||
content-visibility: auto;
|
||||
contain-intrinsic-size: 0 80px;
|
||||
}
|
||||
```
|
||||
|
||||
**Example:**
|
||||
|
||||
```tsx
|
||||
function MessageList({ messages }: { messages: Message[] }) {
|
||||
return (
|
||||
<div className="overflow-y-auto h-screen">
|
||||
{messages.map(msg => (
|
||||
<div key={msg.id} className="message-item">
|
||||
<Avatar user={msg.author} />
|
||||
<div>{msg.content}</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
For 1000 messages, browser skips layout/paint for ~990 off-screen items (10× faster initial render).
|
||||
@@ -0,0 +1,46 @@
|
||||
---
|
||||
title: Hoist Static JSX Elements
|
||||
impact: LOW
|
||||
impactDescription: avoids re-creation
|
||||
tags: rendering, jsx, static, optimization
|
||||
---
|
||||
|
||||
## Hoist Static JSX Elements
|
||||
|
||||
Extract static JSX outside components to avoid re-creation.
|
||||
|
||||
**Incorrect (recreates element every render):**
|
||||
|
||||
```tsx
|
||||
function LoadingSkeleton() {
|
||||
return <div className="animate-pulse h-20 bg-gray-200" />
|
||||
}
|
||||
|
||||
function Container() {
|
||||
return (
|
||||
<div>
|
||||
{loading && <LoadingSkeleton />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (reuses same element):**
|
||||
|
||||
```tsx
|
||||
const loadingSkeleton = (
|
||||
<div className="animate-pulse h-20 bg-gray-200" />
|
||||
)
|
||||
|
||||
function Container() {
|
||||
return (
|
||||
<div>
|
||||
{loading && loadingSkeleton}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
This is especially helpful for large and static SVG nodes, which can be expensive to recreate on every render.
|
||||
|
||||
**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, the compiler automatically hoists static JSX elements and optimizes component re-renders, making manual hoisting unnecessary.
|
||||
@@ -0,0 +1,82 @@
|
||||
---
|
||||
title: Prevent Hydration Mismatch Without Flickering
|
||||
impact: MEDIUM
|
||||
impactDescription: avoids visual flicker and hydration errors
|
||||
tags: rendering, ssr, hydration, localStorage, flicker
|
||||
---
|
||||
|
||||
## Prevent Hydration Mismatch Without Flickering
|
||||
|
||||
When rendering content that depends on client-side storage (localStorage, cookies), avoid both SSR breakage and post-hydration flickering by injecting a synchronous script that updates the DOM before React hydrates.
|
||||
|
||||
**Incorrect (breaks SSR):**
|
||||
|
||||
```tsx
|
||||
function ThemeWrapper({ children }: { children: ReactNode }) {
|
||||
// localStorage is not available on server - throws error
|
||||
const theme = localStorage.getItem('theme') || 'light'
|
||||
|
||||
return (
|
||||
<div className={theme}>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
Server-side rendering will fail because `localStorage` is undefined.
|
||||
|
||||
**Incorrect (visual flickering):**
|
||||
|
||||
```tsx
|
||||
function ThemeWrapper({ children }: { children: ReactNode }) {
|
||||
const [theme, setTheme] = useState('light')
|
||||
|
||||
useEffect(() => {
|
||||
// Runs after hydration - causes visible flash
|
||||
const stored = localStorage.getItem('theme')
|
||||
if (stored) {
|
||||
setTheme(stored)
|
||||
}
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div className={theme}>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
Component first renders with default value (`light`), then updates after hydration, causing a visible flash of incorrect content.
|
||||
|
||||
**Correct (no flicker, no hydration mismatch):**
|
||||
|
||||
```tsx
|
||||
function ThemeWrapper({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<>
|
||||
<div id="theme-wrapper">
|
||||
{children}
|
||||
</div>
|
||||
<script
|
||||
dangerouslySetInnerHTML={{
|
||||
__html: `
|
||||
(function() {
|
||||
try {
|
||||
var theme = localStorage.getItem('theme') || 'light';
|
||||
var el = document.getElementById('theme-wrapper');
|
||||
if (el) el.className = theme;
|
||||
} catch (e) {}
|
||||
})();
|
||||
`,
|
||||
}}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
The inline script executes synchronously before showing the element, ensuring the DOM already has the correct value. No flickering, no hydration mismatch.
|
||||
|
||||
This pattern is especially useful for theme toggles, user preferences, authentication states, and any client-only data that should render immediately without flashing default values.
|
||||
@@ -0,0 +1,28 @@
|
||||
---
|
||||
title: Optimize SVG Precision
|
||||
impact: LOW
|
||||
impactDescription: reduces file size
|
||||
tags: rendering, svg, optimization, svgo
|
||||
---
|
||||
|
||||
## Optimize SVG Precision
|
||||
|
||||
Reduce SVG coordinate precision to decrease file size. The optimal precision depends on the viewBox size, but in general reducing precision should be considered.
|
||||
|
||||
**Incorrect (excessive precision):**
|
||||
|
||||
```svg
|
||||
<path d="M 10.293847 20.847362 L 30.938472 40.192837" />
|
||||
```
|
||||
|
||||
**Correct (1 decimal place):**
|
||||
|
||||
```svg
|
||||
<path d="M 10.3 20.8 L 30.9 40.2" />
|
||||
```
|
||||
|
||||
**Automate with SVGO:**
|
||||
|
||||
```bash
|
||||
npx svgo --precision=1 --multipass icon.svg
|
||||
```
|
||||
@@ -0,0 +1,39 @@
|
||||
---
|
||||
title: Defer State Reads to Usage Point
|
||||
impact: MEDIUM
|
||||
impactDescription: avoids unnecessary subscriptions
|
||||
tags: rerender, searchParams, localStorage, optimization
|
||||
---
|
||||
|
||||
## Defer State Reads to Usage Point
|
||||
|
||||
Don't subscribe to dynamic state (searchParams, localStorage) if you only read it inside callbacks.
|
||||
|
||||
**Incorrect (subscribes to all searchParams changes):**
|
||||
|
||||
```tsx
|
||||
function ShareButton({ chatId }: { chatId: string }) {
|
||||
const searchParams = useSearchParams()
|
||||
|
||||
const handleShare = () => {
|
||||
const ref = searchParams.get('ref')
|
||||
shareChat(chatId, { ref })
|
||||
}
|
||||
|
||||
return <button onClick={handleShare}>Share</button>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (reads on demand, no subscription):**
|
||||
|
||||
```tsx
|
||||
function ShareButton({ chatId }: { chatId: string }) {
|
||||
const handleShare = () => {
|
||||
const params = new URLSearchParams(window.location.search)
|
||||
const ref = params.get('ref')
|
||||
shareChat(chatId, { ref })
|
||||
}
|
||||
|
||||
return <button onClick={handleShare}>Share</button>
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,45 @@
|
||||
---
|
||||
title: Narrow Effect Dependencies
|
||||
impact: LOW
|
||||
impactDescription: minimizes effect re-runs
|
||||
tags: rerender, useEffect, dependencies, optimization
|
||||
---
|
||||
|
||||
## Narrow Effect Dependencies
|
||||
|
||||
Specify primitive dependencies instead of objects to minimize effect re-runs.
|
||||
|
||||
**Incorrect (re-runs on any user field change):**
|
||||
|
||||
```tsx
|
||||
useEffect(() => {
|
||||
console.log(user.id)
|
||||
}, [user])
|
||||
```
|
||||
|
||||
**Correct (re-runs only when id changes):**
|
||||
|
||||
```tsx
|
||||
useEffect(() => {
|
||||
console.log(user.id)
|
||||
}, [user.id])
|
||||
```
|
||||
|
||||
**For derived state, compute outside effect:**
|
||||
|
||||
```tsx
|
||||
// Incorrect: runs on width=767, 766, 765...
|
||||
useEffect(() => {
|
||||
if (width < 768) {
|
||||
enableMobileMode()
|
||||
}
|
||||
}, [width])
|
||||
|
||||
// Correct: runs only on boolean transition
|
||||
const isMobile = width < 768
|
||||
useEffect(() => {
|
||||
if (isMobile) {
|
||||
enableMobileMode()
|
||||
}
|
||||
}, [isMobile])
|
||||
```
|
||||
@@ -0,0 +1,29 @@
|
||||
---
|
||||
title: Subscribe to Derived State
|
||||
impact: MEDIUM
|
||||
impactDescription: reduces re-render frequency
|
||||
tags: rerender, derived-state, media-query, optimization
|
||||
---
|
||||
|
||||
## Subscribe to Derived State
|
||||
|
||||
Subscribe to derived boolean state instead of continuous values to reduce re-render frequency.
|
||||
|
||||
**Incorrect (re-renders on every pixel change):**
|
||||
|
||||
```tsx
|
||||
function Sidebar() {
|
||||
const width = useWindowWidth() // updates continuously
|
||||
const isMobile = width < 768
|
||||
return <nav className={isMobile ? 'mobile' : 'desktop'}>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (re-renders only when boolean changes):**
|
||||
|
||||
```tsx
|
||||
function Sidebar() {
|
||||
const isMobile = useMediaQuery('(max-width: 767px)')
|
||||
return <nav className={isMobile ? 'mobile' : 'desktop'}>
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,74 @@
|
||||
---
|
||||
title: Use Functional setState Updates
|
||||
impact: MEDIUM
|
||||
impactDescription: prevents stale closures and unnecessary callback recreations
|
||||
tags: react, hooks, useState, useCallback, callbacks, closures
|
||||
---
|
||||
|
||||
## Use Functional setState Updates
|
||||
|
||||
When updating state based on the current state value, use the functional update form of setState instead of directly referencing the state variable. This prevents stale closures, eliminates unnecessary dependencies, and creates stable callback references.
|
||||
|
||||
**Incorrect (requires state as dependency):**
|
||||
|
||||
```tsx
|
||||
function TodoList() {
|
||||
const [items, setItems] = useState(initialItems)
|
||||
|
||||
// Callback must depend on items, recreated on every items change
|
||||
const addItems = useCallback((newItems: Item[]) => {
|
||||
setItems([...items, ...newItems])
|
||||
}, [items]) // ❌ items dependency causes recreations
|
||||
|
||||
// Risk of stale closure if dependency is forgotten
|
||||
const removeItem = useCallback((id: string) => {
|
||||
setItems(items.filter(item => item.id !== id))
|
||||
}, []) // ❌ Missing items dependency - will use stale items!
|
||||
|
||||
return <ItemsEditor items={items} onAdd={addItems} onRemove={removeItem} />
|
||||
}
|
||||
```
|
||||
|
||||
The first callback is recreated every time `items` changes, which can cause child components to re-render unnecessarily. The second callback has a stale closure bug—it will always reference the initial `items` value.
|
||||
|
||||
**Correct (stable callbacks, no stale closures):**
|
||||
|
||||
```tsx
|
||||
function TodoList() {
|
||||
const [items, setItems] = useState(initialItems)
|
||||
|
||||
// Stable callback, never recreated
|
||||
const addItems = useCallback((newItems: Item[]) => {
|
||||
setItems(curr => [...curr, ...newItems])
|
||||
}, []) // ✅ No dependencies needed
|
||||
|
||||
// Always uses latest state, no stale closure risk
|
||||
const removeItem = useCallback((id: string) => {
|
||||
setItems(curr => curr.filter(item => item.id !== id))
|
||||
}, []) // ✅ Safe and stable
|
||||
|
||||
return <ItemsEditor items={items} onAdd={addItems} onRemove={removeItem} />
|
||||
}
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
|
||||
1. **Stable callback references** - Callbacks don't need to be recreated when state changes
|
||||
2. **No stale closures** - Always operates on the latest state value
|
||||
3. **Fewer dependencies** - Simplifies dependency arrays and reduces memory leaks
|
||||
4. **Prevents bugs** - Eliminates the most common source of React closure bugs
|
||||
|
||||
**When to use functional updates:**
|
||||
|
||||
- Any setState that depends on the current state value
|
||||
- Inside useCallback/useMemo when state is needed
|
||||
- Event handlers that reference state
|
||||
- Async operations that update state
|
||||
|
||||
**When direct updates are fine:**
|
||||
|
||||
- Setting state to a static value: `setCount(0)`
|
||||
- Setting state from props/arguments only: `setName(newName)`
|
||||
- State doesn't depend on previous value
|
||||
|
||||
**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, the compiler can automatically optimize some cases, but functional updates are still recommended for correctness and to prevent stale closure bugs.
|
||||
@@ -0,0 +1,58 @@
|
||||
---
|
||||
title: Use Lazy State Initialization
|
||||
impact: MEDIUM
|
||||
impactDescription: wasted computation on every render
|
||||
tags: react, hooks, useState, performance, initialization
|
||||
---
|
||||
|
||||
## Use Lazy State Initialization
|
||||
|
||||
Pass a function to `useState` for expensive initial values. Without the function form, the initializer runs on every render even though the value is only used once.
|
||||
|
||||
**Incorrect (runs on every render):**
|
||||
|
||||
```tsx
|
||||
function FilteredList({ items }: { items: Item[] }) {
|
||||
// buildSearchIndex() runs on EVERY render, even after initialization
|
||||
const [searchIndex, setSearchIndex] = useState(buildSearchIndex(items))
|
||||
const [query, setQuery] = useState('')
|
||||
|
||||
// When query changes, buildSearchIndex runs again unnecessarily
|
||||
return <SearchResults index={searchIndex} query={query} />
|
||||
}
|
||||
|
||||
function UserProfile() {
|
||||
// JSON.parse runs on every render
|
||||
const [settings, setSettings] = useState(
|
||||
JSON.parse(localStorage.getItem('settings') || '{}')
|
||||
)
|
||||
|
||||
return <SettingsForm settings={settings} onChange={setSettings} />
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (runs only once):**
|
||||
|
||||
```tsx
|
||||
function FilteredList({ items }: { items: Item[] }) {
|
||||
// buildSearchIndex() runs ONLY on initial render
|
||||
const [searchIndex, setSearchIndex] = useState(() => buildSearchIndex(items))
|
||||
const [query, setQuery] = useState('')
|
||||
|
||||
return <SearchResults index={searchIndex} query={query} />
|
||||
}
|
||||
|
||||
function UserProfile() {
|
||||
// JSON.parse runs only on initial render
|
||||
const [settings, setSettings] = useState(() => {
|
||||
const stored = localStorage.getItem('settings')
|
||||
return stored ? JSON.parse(stored) : {}
|
||||
})
|
||||
|
||||
return <SettingsForm settings={settings} onChange={setSettings} />
|
||||
}
|
||||
```
|
||||
|
||||
Use lazy initialization when computing initial values from localStorage/sessionStorage, building data structures (indexes, maps), reading from the DOM, or performing heavy transformations.
|
||||
|
||||
For simple primitives (`useState(0)`), direct references (`useState(props.value)`), or cheap literals (`useState({})`), the function form is unnecessary.
|
||||
@@ -0,0 +1,44 @@
|
||||
---
|
||||
title: Extract to Memoized Components
|
||||
impact: MEDIUM
|
||||
impactDescription: enables early returns
|
||||
tags: rerender, memo, useMemo, optimization
|
||||
---
|
||||
|
||||
## Extract to Memoized Components
|
||||
|
||||
Extract expensive work into memoized components to enable early returns before computation.
|
||||
|
||||
**Incorrect (computes avatar even when loading):**
|
||||
|
||||
```tsx
|
||||
function Profile({ user, loading }: Props) {
|
||||
const avatar = useMemo(() => {
|
||||
const id = computeAvatarId(user)
|
||||
return <Avatar id={id} />
|
||||
}, [user])
|
||||
|
||||
if (loading) return <Skeleton />
|
||||
return <div>{avatar}</div>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (skips computation when loading):**
|
||||
|
||||
```tsx
|
||||
const UserAvatar = memo(function UserAvatar({ user }: { user: User }) {
|
||||
const id = useMemo(() => computeAvatarId(user), [user])
|
||||
return <Avatar id={id} />
|
||||
})
|
||||
|
||||
function Profile({ user, loading }: Props) {
|
||||
if (loading) return <Skeleton />
|
||||
return (
|
||||
<div>
|
||||
<UserAvatar user={user} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, manual memoization with `memo()` and `useMemo()` is not necessary. The compiler automatically optimizes re-renders.
|
||||
@@ -0,0 +1,40 @@
|
||||
---
|
||||
title: Use Transitions for Non-Urgent Updates
|
||||
impact: MEDIUM
|
||||
impactDescription: maintains UI responsiveness
|
||||
tags: rerender, transitions, startTransition, performance
|
||||
---
|
||||
|
||||
## Use Transitions for Non-Urgent Updates
|
||||
|
||||
Mark frequent, non-urgent state updates as transitions to maintain UI responsiveness.
|
||||
|
||||
**Incorrect (blocks UI on every scroll):**
|
||||
|
||||
```tsx
|
||||
function ScrollTracker() {
|
||||
const [scrollY, setScrollY] = useState(0)
|
||||
useEffect(() => {
|
||||
const handler = () => setScrollY(window.scrollY)
|
||||
window.addEventListener('scroll', handler, { passive: true })
|
||||
return () => window.removeEventListener('scroll', handler)
|
||||
}, [])
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (non-blocking updates):**
|
||||
|
||||
```tsx
|
||||
import { startTransition } from 'react'
|
||||
|
||||
function ScrollTracker() {
|
||||
const [scrollY, setScrollY] = useState(0)
|
||||
useEffect(() => {
|
||||
const handler = () => {
|
||||
startTransition(() => setScrollY(window.scrollY))
|
||||
}
|
||||
window.addEventListener('scroll', handler, { passive: true })
|
||||
return () => window.removeEventListener('scroll', handler)
|
||||
}, [])
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,73 @@
|
||||
---
|
||||
title: Use after() for Non-Blocking Operations
|
||||
impact: MEDIUM
|
||||
impactDescription: faster response times
|
||||
tags: server, async, logging, analytics, side-effects
|
||||
---
|
||||
|
||||
## Use after() for Non-Blocking Operations
|
||||
|
||||
Use Next.js's `after()` to schedule work that should execute after a response is sent. This prevents logging, analytics, and other side effects from blocking the response.
|
||||
|
||||
**Incorrect (blocks response):**
|
||||
|
||||
```tsx
|
||||
import { logUserAction } from '@/app/utils'
|
||||
|
||||
export async function POST(request: Request) {
|
||||
// Perform mutation
|
||||
await updateDatabase(request)
|
||||
|
||||
// Logging blocks the response
|
||||
const userAgent = request.headers.get('user-agent') || 'unknown'
|
||||
await logUserAction({ userAgent })
|
||||
|
||||
return new Response(JSON.stringify({ status: 'success' }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (non-blocking):**
|
||||
|
||||
```tsx
|
||||
import { after } from 'next/server'
|
||||
import { headers, cookies } from 'next/headers'
|
||||
import { logUserAction } from '@/app/utils'
|
||||
|
||||
export async function POST(request: Request) {
|
||||
// Perform mutation
|
||||
await updateDatabase(request)
|
||||
|
||||
// Log after response is sent
|
||||
after(async () => {
|
||||
const userAgent = (await headers()).get('user-agent') || 'unknown'
|
||||
const sessionCookie = (await cookies()).get('session-id')?.value || 'anonymous'
|
||||
|
||||
logUserAction({ sessionCookie, userAgent })
|
||||
})
|
||||
|
||||
return new Response(JSON.stringify({ status: 'success' }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
The response is sent immediately while logging happens in the background.
|
||||
|
||||
**Common use cases:**
|
||||
|
||||
- Analytics tracking
|
||||
- Audit logging
|
||||
- Sending notifications
|
||||
- Cache invalidation
|
||||
- Cleanup tasks
|
||||
|
||||
**Important notes:**
|
||||
|
||||
- `after()` runs even if the response fails or redirects
|
||||
- Works in Server Actions, Route Handlers, and Server Components
|
||||
|
||||
Reference: [https://nextjs.org/docs/app/api-reference/functions/after](https://nextjs.org/docs/app/api-reference/functions/after)
|
||||
@@ -0,0 +1,41 @@
|
||||
---
|
||||
title: Cross-Request LRU Caching
|
||||
impact: HIGH
|
||||
impactDescription: caches across requests
|
||||
tags: server, cache, lru, cross-request
|
||||
---
|
||||
|
||||
## Cross-Request LRU Caching
|
||||
|
||||
`React.cache()` only works within one request. For data shared across sequential requests (user clicks button A then button B), use an LRU cache.
|
||||
|
||||
**Implementation:**
|
||||
|
||||
```typescript
|
||||
import { LRUCache } from 'lru-cache'
|
||||
|
||||
const cache = new LRUCache<string, any>({
|
||||
max: 1000,
|
||||
ttl: 5 * 60 * 1000 // 5 minutes
|
||||
})
|
||||
|
||||
export async function getUser(id: string) {
|
||||
const cached = cache.get(id)
|
||||
if (cached) return cached
|
||||
|
||||
const user = await db.user.findUnique({ where: { id } })
|
||||
cache.set(id, user)
|
||||
return user
|
||||
}
|
||||
|
||||
// Request 1: DB query, result cached
|
||||
// Request 2: cache hit, no DB query
|
||||
```
|
||||
|
||||
Use when sequential user actions hit multiple endpoints needing the same data within seconds.
|
||||
|
||||
**With Vercel's [Fluid Compute](https://vercel.com/docs/fluid-compute):** LRU caching is especially effective because multiple concurrent requests can share the same function instance and cache. This means the cache persists across requests without needing external storage like Redis.
|
||||
|
||||
**In traditional serverless:** Each invocation runs in isolation, so consider Redis for cross-process caching.
|
||||
|
||||
Reference: [https://github.com/isaacs/node-lru-cache](https://github.com/isaacs/node-lru-cache)
|
||||
@@ -0,0 +1,26 @@
|
||||
---
|
||||
title: Per-Request Deduplication with React.cache()
|
||||
impact: MEDIUM
|
||||
impactDescription: deduplicates within request
|
||||
tags: server, cache, react-cache, deduplication
|
||||
---
|
||||
|
||||
## Per-Request Deduplication with React.cache()
|
||||
|
||||
Use `React.cache()` for server-side request deduplication. Authentication and database queries benefit most.
|
||||
|
||||
**Usage:**
|
||||
|
||||
```typescript
|
||||
import { cache } from 'react'
|
||||
|
||||
export const getCurrentUser = cache(async () => {
|
||||
const session = await auth()
|
||||
if (!session?.user?.id) return null
|
||||
return await db.user.findUnique({
|
||||
where: { id: session.user.id }
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
Within a single request, multiple calls to `getCurrentUser()` execute the query only once.
|
||||
@@ -0,0 +1,79 @@
|
||||
---
|
||||
title: Parallel Data Fetching with Component Composition
|
||||
impact: CRITICAL
|
||||
impactDescription: eliminates server-side waterfalls
|
||||
tags: server, rsc, parallel-fetching, composition
|
||||
---
|
||||
|
||||
## Parallel Data Fetching with Component Composition
|
||||
|
||||
React Server Components execute sequentially within a tree. Restructure with composition to parallelize data fetching.
|
||||
|
||||
**Incorrect (Sidebar waits for Page's fetch to complete):**
|
||||
|
||||
```tsx
|
||||
export default async function Page() {
|
||||
const header = await fetchHeader()
|
||||
return (
|
||||
<div>
|
||||
<div>{header}</div>
|
||||
<Sidebar />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
async function Sidebar() {
|
||||
const items = await fetchSidebarItems()
|
||||
return <nav>{items.map(renderItem)}</nav>
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (both fetch simultaneously):**
|
||||
|
||||
```tsx
|
||||
async function Header() {
|
||||
const data = await fetchHeader()
|
||||
return <div>{data}</div>
|
||||
}
|
||||
|
||||
async function Sidebar() {
|
||||
const items = await fetchSidebarItems()
|
||||
return <nav>{items.map(renderItem)}</nav>
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<div>
|
||||
<Header />
|
||||
<Sidebar />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Alternative with children prop:**
|
||||
|
||||
```tsx
|
||||
async function Layout({ children }: { children: ReactNode }) {
|
||||
const header = await fetchHeader()
|
||||
return (
|
||||
<div>
|
||||
<div>{header}</div>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
async function Sidebar() {
|
||||
const items = await fetchSidebarItems()
|
||||
return <nav>{items.map(renderItem)}</nav>
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<Layout>
|
||||
<Sidebar />
|
||||
</Layout>
|
||||
)
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,38 @@
|
||||
---
|
||||
title: Minimize Serialization at RSC Boundaries
|
||||
impact: HIGH
|
||||
impactDescription: reduces data transfer size
|
||||
tags: server, rsc, serialization, props
|
||||
---
|
||||
|
||||
## Minimize Serialization at RSC Boundaries
|
||||
|
||||
The React Server/Client boundary serializes all object properties into strings and embeds them in the HTML response and subsequent RSC requests. This serialized data directly impacts page weight and load time, so **size matters a lot**. Only pass fields that the client actually uses.
|
||||
|
||||
**Incorrect (serializes all 50 fields):**
|
||||
|
||||
```tsx
|
||||
async function Page() {
|
||||
const user = await fetchUser() // 50 fields
|
||||
return <Profile user={user} />
|
||||
}
|
||||
|
||||
'use client'
|
||||
function Profile({ user }: { user: User }) {
|
||||
return <div>{user.name}</div> // uses 1 field
|
||||
}
|
||||
```
|
||||
|
||||
**Correct (serializes only 1 field):**
|
||||
|
||||
```tsx
|
||||
async function Page() {
|
||||
const user = await fetchUser()
|
||||
return <Profile name={user.name} />
|
||||
}
|
||||
|
||||
'use client'
|
||||
function Profile({ name }: { name: string }) {
|
||||
return <div>{name}</div>
|
||||
}
|
||||
```
|
||||
@@ -1,6 +1,9 @@
|
||||
# Ignore everything by default, selectively add things to context
|
||||
*
|
||||
|
||||
# Documentation (for embeddings/search)
|
||||
!docs/
|
||||
|
||||
# Platform - Libs
|
||||
!autogpt_platform/autogpt_libs/autogpt_libs/
|
||||
!autogpt_platform/autogpt_libs/pyproject.toml
|
||||
@@ -16,6 +19,7 @@
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
|
||||
2
.github/workflows/claude-dependabot.yml
vendored
2
.github/workflows/claude-dependabot.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
2
.github/workflows/claude.yml
vendored
2
.github/workflows/claude.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
|
||||
12
.github/workflows/copilot-setup-steps.yml
vendored
12
.github/workflows/copilot-setup-steps.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
@@ -108,6 +108,16 @@ jobs:
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
# Remove large unused tools to free disk space for Docker builds
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker system prune -af
|
||||
df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
4
.github/workflows/platform-backend-ci.yml
vendored
4
.github/workflows/platform-backend-ci.yml
vendored
@@ -134,7 +134,7 @@ jobs:
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate
|
||||
run: poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
- id: supabase
|
||||
name: Start Supabase
|
||||
@@ -176,7 +176,7 @@ jobs:
|
||||
}
|
||||
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
run: poetry run prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
25
.github/workflows/platform-frontend-ci.yml
vendored
25
.github/workflows/platform-frontend-ci.yml
vendored
@@ -11,6 +11,7 @@ on:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event_name == 'merge_group' && format('merge-queue-{0}', github.ref) || format('{0}-{1}', github.ref, github.event.pull_request.number || github.sha) }}
|
||||
@@ -151,6 +152,14 @@ jobs:
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
@@ -226,13 +235,25 @@ jobs:
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload Playwright artifacts
|
||||
if: failure()
|
||||
- name: Upload Playwright report
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Upload Playwright test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-test-results
|
||||
path: test-results
|
||||
if-no-files-found: ignore
|
||||
retention-days: 3
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(ls:*)",
|
||||
"WebFetch(domain:langfuse.com)",
|
||||
"Bash(poetry install:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend stop-backend run-frontend load-store-agents backfill-store-embeddings
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend load-store-agents
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
start-core:
|
||||
@@ -6,12 +6,14 @@ start-core:
|
||||
|
||||
# Stop core services
|
||||
stop-core:
|
||||
docker compose stop deps
|
||||
docker compose stop
|
||||
|
||||
reset-db:
|
||||
docker compose stop db
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
@@ -33,15 +35,9 @@ init-env:
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
cd backend && poetry run gen-prisma-stub
|
||||
|
||||
stop-backend:
|
||||
@echo "Stopping backend processes..."
|
||||
@cd backend && poetry run cli stop 2>/dev/null || true
|
||||
@echo "Killing any processes using backend ports..."
|
||||
@lsof -ti:8001,8002,8003,8004,8005,8006,8007 | xargs kill -9 2>/dev/null || true
|
||||
@echo "Backend stopped"
|
||||
|
||||
run-backend: stop-backend
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
|
||||
run-frontend:
|
||||
@@ -53,9 +49,6 @@ test-data:
|
||||
load-store-agents:
|
||||
cd backend && poetry run load-store-agents
|
||||
|
||||
backfill-store-embeddings:
|
||||
cd backend && poetry run python -m backend.api.features.store.backfill_embeddings
|
||||
|
||||
help:
|
||||
@echo "Usage: make <target>"
|
||||
@echo "Targets:"
|
||||
@@ -65,9 +58,7 @@ help:
|
||||
@echo " logs-core - Tail the logs for core services"
|
||||
@echo " format - Format & lint backend (Python) and frontend (TypeScript) code"
|
||||
@echo " migrate - Run backend database migrations"
|
||||
@echo " stop-backend - Stop any running backend processes"
|
||||
@echo " run-backend - Run the backend FastAPI server (stops existing processes first)"
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " test-data - Run the test data creator"
|
||||
@echo " load-store-agents - Load store agents from agents/ folder into test database"
|
||||
@echo " backfill-store-embeddings - Generate embeddings for store agents that don't have them"
|
||||
1
autogpt_platform/backend/.gitignore
vendored
1
autogpt_platform/backend/.gitignore
vendored
@@ -18,3 +18,4 @@ load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
migrations/*/rollback*.sql
|
||||
|
||||
@@ -48,7 +48,8 @@ RUN poetry install --no-ansi --no-root
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
RUN poetry run prisma generate
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
@@ -99,6 +100,7 @@ COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migration
|
||||
FROM server_dependencies AS server
|
||||
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
COPY docs /app/docs
|
||||
RUN poetry install --no-ansi --only-root
|
||||
|
||||
ENV PORT=8000
|
||||
|
||||
@@ -70,7 +70,7 @@ class RunAgentRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
def _create_ephemeral_session(user_id: str | None) -> ChatSession:
|
||||
def _create_ephemeral_session(user_id: str) -> ChatSession:
|
||||
"""Create an ephemeral session for stateless API requests."""
|
||||
return ChatSession.new(user_id)
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from backend.executor.manager import get_db_async_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class ExecutionAnalyticsRequest(BaseModel):
|
||||
@@ -63,6 +64,8 @@ class ExecutionAnalyticsResult(BaseModel):
|
||||
score: Optional[float]
|
||||
status: str # "success", "failed", "skipped"
|
||||
error_message: Optional[str] = None
|
||||
started_at: Optional[datetime] = None
|
||||
ended_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class ExecutionAnalyticsResponse(BaseModel):
|
||||
@@ -224,11 +227,6 @@ async def generate_execution_analytics(
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate model configuration
|
||||
settings = Settings()
|
||||
if not settings.secrets.openai_internal_api_key:
|
||||
raise HTTPException(status_code=500, detail="OpenAI API key not configured")
|
||||
|
||||
# Get database client
|
||||
db_client = get_db_async_client()
|
||||
|
||||
@@ -320,6 +318,8 @@ async def generate_execution_analytics(
|
||||
),
|
||||
status="skipped",
|
||||
error_message=None, # Not an error - just already processed
|
||||
started_at=execution.started_at,
|
||||
ended_at=execution.ended_at,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -349,6 +349,9 @@ async def _process_batch(
|
||||
) -> list[ExecutionAnalyticsResult]:
|
||||
"""Process a batch of executions concurrently."""
|
||||
|
||||
if not settings.secrets.openai_internal_api_key:
|
||||
raise HTTPException(status_code=500, detail="OpenAI API key not configured")
|
||||
|
||||
async def process_single_execution(execution) -> ExecutionAnalyticsResult:
|
||||
try:
|
||||
# Generate activity status and score using the specified model
|
||||
@@ -387,6 +390,8 @@ async def _process_batch(
|
||||
score=None,
|
||||
status="skipped",
|
||||
error_message="Activity generation returned None",
|
||||
started_at=execution.started_at,
|
||||
ended_at=execution.ended_at,
|
||||
)
|
||||
|
||||
# Update the execution stats
|
||||
@@ -416,6 +421,8 @@ async def _process_batch(
|
||||
summary_text=activity_response["activity_status"],
|
||||
score=activity_response["correctness_score"],
|
||||
status="success",
|
||||
started_at=execution.started_at,
|
||||
ended_at=execution.ended_at,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -429,6 +436,8 @@ async def _process_batch(
|
||||
score=None,
|
||||
status="failed",
|
||||
error_message=str(e),
|
||||
started_at=execution.started_at,
|
||||
ended_at=execution.ended_at,
|
||||
)
|
||||
|
||||
# Process all executions in the batch concurrently
|
||||
|
||||
@@ -9,7 +9,6 @@ import prisma.enums
|
||||
|
||||
import backend.api.features.store.cache as store_cache
|
||||
import backend.api.features.store.db as store_db
|
||||
import backend.api.features.store.embeddings as store_embeddings
|
||||
import backend.api.features.store.model as store_model
|
||||
import backend.util.json
|
||||
|
||||
@@ -151,54 +150,3 @@ async def admin_download_agent_file(
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/embeddings/stats",
|
||||
summary="Get Embedding Statistics",
|
||||
)
|
||||
async def get_embedding_stats() -> dict[str, typing.Any]:
|
||||
"""
|
||||
Get statistics about embedding coverage for store listings.
|
||||
|
||||
Returns counts of total approved listings, listings with embeddings,
|
||||
listings without embeddings, and coverage percentage.
|
||||
"""
|
||||
try:
|
||||
stats = await store_embeddings.get_embedding_stats()
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.exception("Error getting embedding stats: %s", e)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="An error occurred while retrieving embedding stats",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/embeddings/backfill",
|
||||
summary="Backfill Missing Embeddings",
|
||||
)
|
||||
async def backfill_embeddings(
|
||||
batch_size: int = 10,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""
|
||||
Trigger backfill of embeddings for approved listings that don't have them.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate in one call (default 10)
|
||||
|
||||
Returns:
|
||||
Dict with processed count, success count, failure count, and message
|
||||
"""
|
||||
try:
|
||||
result = await store_embeddings.backfill_missing_embeddings(
|
||||
batch_size=batch_size
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception("Error backfilling embeddings: %s", e)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="An error occurred while backfilling embeddings",
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Configuration management for chat system."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
@@ -27,12 +26,6 @@ class ChatConfig(BaseSettings):
|
||||
# Session TTL Configuration - 12 hours
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# System Prompt Configuration
|
||||
system_prompt_path: str = Field(
|
||||
default="prompts/chat_system.md",
|
||||
description="Path to system prompt file relative to chat module",
|
||||
)
|
||||
|
||||
# Streaming Configuration
|
||||
max_context_messages: int = Field(
|
||||
default=50, ge=1, le=200, description="Maximum context messages"
|
||||
@@ -89,73 +82,6 @@ class ChatConfig(BaseSettings):
|
||||
"onboarding": "prompts/onboarding_system.md",
|
||||
}
|
||||
|
||||
def get_system_prompt_for_type(
|
||||
self, prompt_type: str = "default", **template_vars
|
||||
) -> str:
|
||||
"""Load and render a system prompt by type.
|
||||
|
||||
Args:
|
||||
prompt_type: The type of prompt to load ("default" or "onboarding")
|
||||
**template_vars: Variables to substitute in the template
|
||||
|
||||
Returns:
|
||||
Rendered system prompt string
|
||||
"""
|
||||
prompt_path_str = self.PROMPT_PATHS.get(
|
||||
prompt_type, self.PROMPT_PATHS["default"]
|
||||
)
|
||||
return self._load_prompt_from_path(prompt_path_str, **template_vars)
|
||||
|
||||
def get_system_prompt(self, **template_vars) -> str:
|
||||
"""Load and render the default system prompt from file.
|
||||
|
||||
Args:
|
||||
**template_vars: Variables to substitute in the template
|
||||
|
||||
Returns:
|
||||
Rendered system prompt string
|
||||
|
||||
"""
|
||||
return self._load_prompt_from_path(self.system_prompt_path, **template_vars)
|
||||
|
||||
def _load_prompt_from_path(self, prompt_path_str: str, **template_vars) -> str:
|
||||
"""Load and render a system prompt from a given path.
|
||||
|
||||
Args:
|
||||
prompt_path_str: Path to the prompt file relative to chat module
|
||||
**template_vars: Variables to substitute in the template
|
||||
|
||||
Returns:
|
||||
Rendered system prompt string
|
||||
"""
|
||||
# Get the path relative to this module
|
||||
module_dir = Path(__file__).parent
|
||||
prompt_path = module_dir / prompt_path_str
|
||||
|
||||
# Check for .j2 extension first (Jinja2 template)
|
||||
j2_path = Path(str(prompt_path) + ".j2")
|
||||
if j2_path.exists():
|
||||
try:
|
||||
from jinja2 import Template
|
||||
|
||||
template = Template(j2_path.read_text())
|
||||
return template.render(**template_vars)
|
||||
except ImportError:
|
||||
# Jinja2 not installed, fall back to reading as plain text
|
||||
return j2_path.read_text()
|
||||
|
||||
# Check for markdown file
|
||||
if prompt_path.exists():
|
||||
content = prompt_path.read_text()
|
||||
|
||||
# Simple variable substitution if Jinja2 is not available
|
||||
for key, value in template_vars.items():
|
||||
placeholder = f"{{{key}}}"
|
||||
content = content.replace(placeholder, str(value))
|
||||
|
||||
return content
|
||||
raise FileNotFoundError(f"System prompt file not found: {prompt_path}")
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
"""Database operations for chat sessions."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import ChatSessionUpdateInput
|
||||
from prisma.types import (
|
||||
ChatMessageCreateInput,
|
||||
ChatSessionCreateInput,
|
||||
ChatSessionUpdateInput,
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,23 +27,24 @@ async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort messages by sequence in Python since Prisma doesn't support order_by in include
|
||||
# Sort messages by sequence in Python - Prisma Python client doesn't support
|
||||
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
user_id: str,
|
||||
) -> PrismaChatSession:
|
||||
"""Create a new chat session in the database."""
|
||||
data = {
|
||||
"id": session_id,
|
||||
"userId": user_id,
|
||||
"credentials": SafeJson({}),
|
||||
"successfulAgentRuns": SafeJson({}),
|
||||
"successfulAgentSchedules": SafeJson({}),
|
||||
}
|
||||
data = ChatSessionCreateInput(
|
||||
id=session_id,
|
||||
userId=user_id,
|
||||
credentials=SafeJson({}),
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
)
|
||||
return await PrismaChatSession.prisma().create(
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
@@ -74,6 +82,7 @@ async def update_chat_session(
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort in Python - Prisma Python doesn't support order_by in include clauses
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
@@ -90,12 +99,16 @@ async def add_chat_message(
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> PrismaChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
||||
# because Prisma's TypedDict validation rejects optional fields set to None.
|
||||
# We only include fields that have values, then cast at the end.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if content is not None:
|
||||
data["content"] = content
|
||||
if name is not None:
|
||||
@@ -104,18 +117,22 @@ async def add_chat_message(
|
||||
data["toolCallId"] = tool_call_id
|
||||
if refusal is not None:
|
||||
data["refusal"] = refusal
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if tool_calls is not None:
|
||||
data["toolCalls"] = SafeJson(tool_calls)
|
||||
if function_call is not None:
|
||||
data["functionCall"] = SafeJson(function_call)
|
||||
|
||||
# Update session's updatedAt timestamp
|
||||
await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
# Run message create and session timestamp update in parallel for lower latency
|
||||
_, message = await asyncio.gather(
|
||||
PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
),
|
||||
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
||||
)
|
||||
|
||||
return await PrismaChatMessage.prisma().create(data=data)
|
||||
return message
|
||||
|
||||
|
||||
async def add_chat_messages_batch(
|
||||
@@ -123,39 +140,55 @@ async def add_chat_messages_batch(
|
||||
messages: list[dict[str, Any]],
|
||||
start_sequence: int,
|
||||
) -> list[PrismaChatMessage]:
|
||||
"""Add multiple messages to a chat session in a batch."""
|
||||
"""Add multiple messages to a chat session in a batch.
|
||||
|
||||
Uses a transaction for atomicity - if any message creation fails,
|
||||
the entire batch is rolled back.
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
created_messages = []
|
||||
for i, msg in enumerate(messages):
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
async with transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
# set to None. We only include fields that have values, then cast.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
|
||||
created = await PrismaChatMessage.prisma().create(data=data)
|
||||
created_messages.append(created)
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Update session's updatedAt timestamp
|
||||
await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma(tx).create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created_messages.append(created)
|
||||
|
||||
# Update session's updatedAt timestamp within the same transaction.
|
||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||
# separately via update_chat_session() after streaming completes.
|
||||
await PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return created_messages
|
||||
|
||||
@@ -179,10 +212,31 @@ async def get_user_session_count(user_id: str) -> int:
|
||||
return await PrismaChatSession.prisma().count(where={"userId": user_id})
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str) -> bool:
|
||||
"""Delete a chat session and all its messages."""
|
||||
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
|
||||
"""Delete a chat session and all its messages.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to delete.
|
||||
user_id: If provided, validates that the session belongs to this user
|
||||
before deletion. This prevents unauthorized deletion of other
|
||||
users' sessions.
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
await PrismaChatSession.prisma().delete(where={"id": session_id})
|
||||
# Build typed where clause with optional user_id validation
|
||||
where_clause: ChatSessionWhereInput = {"id": session_id}
|
||||
if user_id is not None:
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
result = await PrismaChatSession.prisma().delete_many(where=where_clause)
|
||||
if result == 0:
|
||||
logger.warning(
|
||||
f"No session deleted for {session_id} "
|
||||
f"(user_id validation: {user_id is not None})"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete chat session {session_id}: {e}")
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
@@ -22,7 +25,7 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util import json
|
||||
from backend.util.exceptions import RedisError
|
||||
from backend.util.exceptions import DatabaseError, RedisError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
@@ -31,6 +34,48 @@ logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
# Redis cache key prefix for chat sessions
|
||||
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
||||
|
||||
|
||||
def _get_session_cache_key(session_id: str) -> str:
|
||||
"""Get the Redis cache key for a chat session."""
|
||||
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
||||
|
||||
|
||||
# Session-level locks to prevent race conditions during concurrent upserts.
|
||||
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
|
||||
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
|
||||
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
|
||||
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str | None = None
|
||||
@@ -49,7 +94,7 @@ class Usage(BaseModel):
|
||||
|
||||
class ChatSession(BaseModel):
|
||||
session_id: str
|
||||
user_id: str | None
|
||||
user_id: str
|
||||
title: str | None = None
|
||||
messages: list[ChatMessage]
|
||||
usage: list[Usage]
|
||||
@@ -60,7 +105,7 @@ class ChatSession(BaseModel):
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
|
||||
@staticmethod
|
||||
def new(user_id: str | None) -> "ChatSession":
|
||||
def new(user_id: str) -> "ChatSession":
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
@@ -73,7 +118,7 @@ class ChatSession(BaseModel):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_prisma(
|
||||
def from_db(
|
||||
prisma_session: PrismaChatSession,
|
||||
prisma_messages: list[PrismaChatMessage] | None = None,
|
||||
) -> "ChatSession":
|
||||
@@ -81,22 +126,6 @@ class ChatSession(BaseModel):
|
||||
messages = []
|
||||
if prisma_messages:
|
||||
for msg in prisma_messages:
|
||||
tool_calls = None
|
||||
if msg.toolCalls:
|
||||
tool_calls = (
|
||||
json.loads(msg.toolCalls)
|
||||
if isinstance(msg.toolCalls, str)
|
||||
else msg.toolCalls
|
||||
)
|
||||
|
||||
function_call = None
|
||||
if msg.functionCall:
|
||||
function_call = (
|
||||
json.loads(msg.functionCall)
|
||||
if isinstance(msg.functionCall, str)
|
||||
else msg.functionCall
|
||||
)
|
||||
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
role=msg.role,
|
||||
@@ -104,26 +133,18 @@ class ChatSession(BaseModel):
|
||||
name=msg.name,
|
||||
tool_call_id=msg.toolCallId,
|
||||
refusal=msg.refusal,
|
||||
tool_calls=tool_calls,
|
||||
function_call=function_call,
|
||||
tool_calls=_parse_json_field(msg.toolCalls),
|
||||
function_call=_parse_json_field(msg.functionCall),
|
||||
)
|
||||
)
|
||||
|
||||
# Parse JSON fields from Prisma
|
||||
credentials = (
|
||||
json.loads(prisma_session.credentials)
|
||||
if isinstance(prisma_session.credentials, str)
|
||||
else prisma_session.credentials or {}
|
||||
credentials = _parse_json_field(prisma_session.credentials, default={})
|
||||
successful_agent_runs = _parse_json_field(
|
||||
prisma_session.successfulAgentRuns, default={}
|
||||
)
|
||||
successful_agent_runs = (
|
||||
json.loads(prisma_session.successfulAgentRuns)
|
||||
if isinstance(prisma_session.successfulAgentRuns, str)
|
||||
else prisma_session.successfulAgentRuns or {}
|
||||
)
|
||||
successful_agent_schedules = (
|
||||
json.loads(prisma_session.successfulAgentSchedules)
|
||||
if isinstance(prisma_session.successfulAgentSchedules, str)
|
||||
else prisma_session.successfulAgentSchedules or {}
|
||||
successful_agent_schedules = _parse_json_field(
|
||||
prisma_session.successfulAgentSchedules, default={}
|
||||
)
|
||||
|
||||
# Calculate usage from token counts
|
||||
@@ -242,7 +263,7 @@ class ChatSession(BaseModel):
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = f"chat:session:{session_id}"
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
@@ -264,7 +285,7 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
|
||||
async def _cache_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis."""
|
||||
redis_key = f"chat:session:{session.session_id}"
|
||||
redis_key = _get_session_cache_key(session.session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
@@ -283,7 +304,7 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_prisma(prisma_session, messages)
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
@@ -345,19 +366,24 @@ async def _save_session_to_db(
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> ChatSession | None:
|
||||
"""Get a chat session by ID.
|
||||
|
||||
Checks Redis cache first, falls back to database if not found.
|
||||
Caches database results back to Redis.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to fetch.
|
||||
user_id: If provided, validates that the session belongs to this user.
|
||||
If None, ownership is not validated (admin/system access).
|
||||
"""
|
||||
# Try cache first
|
||||
try:
|
||||
session = await _get_session_from_cache(session_id)
|
||||
if session:
|
||||
# Verify user ownership
|
||||
if session.user_id is not None and session.user_id != user_id:
|
||||
# Verify user ownership if user_id was provided for validation
|
||||
if user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
@@ -376,8 +402,8 @@ async def get_chat_session(
|
||||
logger.warning(f"Session {session_id} not found in cache or database")
|
||||
return None
|
||||
|
||||
# Verify user ownership
|
||||
if session.user_id is not None and session.user_id != user_id:
|
||||
# Verify user ownership if user_id was provided for validation
|
||||
if user_id is not None and session.user_id != user_id:
|
||||
logger.warning(
|
||||
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
|
||||
)
|
||||
@@ -396,49 +422,88 @@ async def get_chat_session(
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
"""Update a chat session in both cache and database."""
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
"""Update a chat session in both cache and database.
|
||||
|
||||
# Save to database
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save session {session.session_id} to database: {e}")
|
||||
# Continue to cache even if DB fails
|
||||
Uses session-level locking to prevent race conditions when concurrent
|
||||
operations (e.g., background title update and main stream handler)
|
||||
attempt to upsert the same session simultaneously.
|
||||
|
||||
# Save to cache
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
raise RedisError(
|
||||
f"Failed to persist chat session {session.session_id} to Redis: {e}"
|
||||
) from e
|
||||
Raises:
|
||||
DatabaseError: If the database write fails. The cache is still updated
|
||||
as a best-effort optimization, but the error is propagated to ensure
|
||||
callers are aware of the persistence failure.
|
||||
RedisError: If the cache write fails (after successful DB write).
|
||||
"""
|
||||
# Acquire session-specific lock to prevent concurrent upserts
|
||||
lock = await _get_session_lock(session.session_id)
|
||||
|
||||
return session
|
||||
async with lock:
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
|
||||
db_error: Exception | None = None
|
||||
|
||||
# Save to database (primary storage)
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save session {session.session_id} to database: {e}"
|
||||
)
|
||||
db_error = e
|
||||
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
if db_error is None:
|
||||
raise RedisError(
|
||||
f"Failed to persist chat session {session.session_id} to Redis: {e}"
|
||||
) from e
|
||||
# If both failed, log cache error but raise DB error (more critical)
|
||||
logger.warning(
|
||||
f"Cache write also failed for session {session.session_id}: {e}"
|
||||
)
|
||||
|
||||
# Propagate DB error after attempting cache (prevents data loss)
|
||||
if db_error is not None:
|
||||
raise DatabaseError(
|
||||
f"Failed to persist chat session {session.session_id} to database"
|
||||
) from db_error
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str | None) -> ChatSession:
|
||||
"""Create a new chat session and persist it."""
|
||||
async def create_chat_session(user_id: str) -> ChatSession:
|
||||
"""Create a new chat session and persist it.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If the database write fails. We fail fast to ensure
|
||||
callers never receive a non-persisted session that only exists
|
||||
in cache (which would be lost when the cache expires).
|
||||
"""
|
||||
session = ChatSession.new(user_id)
|
||||
|
||||
# Create in database first
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session in database: {e}")
|
||||
# Continue even if DB fails - cache will still work
|
||||
logger.error(f"Failed to create session {session.session_id} in database: {e}")
|
||||
raise DatabaseError(
|
||||
f"Failed to create chat session {session.session_id} in database"
|
||||
) from e
|
||||
|
||||
# Cache the session
|
||||
# Cache the session (best-effort optimization, DB is source of truth)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache new session: {e}")
|
||||
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
||||
|
||||
return session
|
||||
|
||||
@@ -447,27 +512,86 @@ async def get_user_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[ChatSession]:
|
||||
"""Get all chat sessions for a user from the database."""
|
||||
) -> tuple[list[ChatSession], int]:
|
||||
"""Get chat sessions for a user from the database with total count.
|
||||
|
||||
Returns:
|
||||
A tuple of (sessions, total_count) where total_count is the overall
|
||||
number of sessions for the user (not just the current page).
|
||||
"""
|
||||
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await chat_db.get_user_session_count(user_id)
|
||||
|
||||
sessions = []
|
||||
for prisma_session in prisma_sessions:
|
||||
# Convert without messages for listing (lighter weight)
|
||||
sessions.append(ChatSession.from_prisma(prisma_session, None))
|
||||
sessions.append(ChatSession.from_db(prisma_session, None))
|
||||
|
||||
return sessions
|
||||
return sessions, total_count
|
||||
|
||||
|
||||
async def delete_chat_session(session_id: str) -> bool:
|
||||
"""Delete a chat session from both cache and database."""
|
||||
# Delete from cache
|
||||
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
|
||||
"""Delete a chat session from both cache and database.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to delete.
|
||||
user_id: If provided, validates that the session belongs to this user
|
||||
before deletion. This prevents unauthorized deletion.
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise.
|
||||
"""
|
||||
# Delete from database first (with optional user_id validation)
|
||||
# This confirms ownership before invalidating cache
|
||||
deleted = await chat_db.delete_chat_session(session_id, user_id)
|
||||
|
||||
if not deleted:
|
||||
return False
|
||||
|
||||
# Only invalidate cache and clean up lock after DB confirms deletion
|
||||
try:
|
||||
redis_key = f"chat:session:{session_id}"
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
|
||||
|
||||
# Delete from database
|
||||
return await chat_db.delete_chat_session(session_id)
|
||||
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
|
||||
async with _session_locks_mutex:
|
||||
_session_locks.pop(session_id, None)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def update_session_title(session_id: str, title: str) -> bool:
|
||||
"""Update only the title of a chat session.
|
||||
|
||||
This is a lightweight operation that doesn't touch messages, avoiding
|
||||
race conditions with concurrent message updates. Use this for background
|
||||
title generation instead of upsert_chat_session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to update.
|
||||
title: The new title to set.
|
||||
|
||||
Returns:
|
||||
True if updated successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = await chat_db.update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
|
||||
# Invalidate cache so next fetch gets updated title
|
||||
try:
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update title for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
@@ -43,9 +43,9 @@ async def test_chatsession_serialization_deserialization():
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage():
|
||||
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
|
||||
|
||||
s = ChatSession.new(user_id=None)
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
|
||||
s = await upsert_chat_session(s)
|
||||
@@ -59,26 +59,28 @@ async def test_chatsession_redis_storage():
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_redis_storage_user_id_mismatch():
|
||||
async def test_chatsession_redis_storage_user_id_mismatch(
|
||||
setup_test_user, test_user_id
|
||||
):
|
||||
|
||||
s = ChatSession.new(user_id="abc123")
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(s.session_id, None)
|
||||
s2 = await get_chat_session(s.session_id, "different_user_id")
|
||||
|
||||
assert s2 is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chatsession_db_storage():
|
||||
async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
||||
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
# Create session with messages including assistant message
|
||||
s = ChatSession.new(user_id=None)
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages # Contains user, assistant, and tool messages
|
||||
|
||||
assert s.session_id is not None, "Session id is not set"
|
||||
# Upsert to save to both cache and DB
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
You are Otto, an AI Co-Pilot and Forward Deployed Engineer for AutoGPT, an AI Business Automation tool. Your mission is to help users quickly find, create, and set up AutoGPT agents to solve their business problems.
|
||||
|
||||
Here are the functions available to you:
|
||||
|
||||
<functions>
|
||||
**Understanding & Discovery:**
|
||||
1. **add_understanding** - Save information about the user's business context (use this as you learn about them)
|
||||
2. **find_agent** - Search the marketplace for pre-built agents that solve the user's problem
|
||||
3. **find_library_agent** - Search the user's personal library of saved agents
|
||||
4. **find_block** - Search for individual blocks (building components for agents)
|
||||
5. **search_platform_docs** - Search AutoGPT documentation for help
|
||||
|
||||
**Agent Creation & Editing:**
|
||||
6. **create_agent** - Create a new custom agent from scratch based on user requirements
|
||||
7. **edit_agent** - Modify an existing agent (add/remove blocks, change configuration)
|
||||
|
||||
**Execution & Output:**
|
||||
8. **run_agent** - Run or schedule an agent (automatically handles setup)
|
||||
9. **run_block** - Run a single block directly without creating an agent
|
||||
10. **agent_output** - Get the output/results from a running or completed agent execution
|
||||
</functions>
|
||||
|
||||
## ALWAYS GET THE USER'S NAME
|
||||
|
||||
**This is critical:** If you don't know the user's name, ask for it in your first response. Use a friendly, natural approach:
|
||||
- "Hi! I'm Otto. What's your name?"
|
||||
- "Hey there! Before we dive in, what should I call you?"
|
||||
|
||||
Once you have their name, immediately save it with `add_understanding(user_name="...")` and use it throughout the conversation.
|
||||
|
||||
## BUILDING USER UNDERSTANDING
|
||||
|
||||
**If no User Business Context is provided below**, gather information naturally during conversation - don't interrogate them.
|
||||
|
||||
**Key information to gather (in priority order):**
|
||||
1. Their name (ALWAYS first if unknown)
|
||||
2. Their job title and role
|
||||
3. Their business/company and industry
|
||||
4. Pain points and what they want to automate
|
||||
5. Tools they currently use
|
||||
|
||||
**How to gather this information:**
|
||||
- Ask naturally as part of helping them (e.g., "What's your role?" or "What industry are you in?")
|
||||
- When they share information, immediately save it using `add_understanding`
|
||||
- Don't ask all questions at once - spread them across the conversation
|
||||
- Prioritize understanding their immediate problem first
|
||||
|
||||
**Example:**
|
||||
```
|
||||
User: "I need help automating my social media"
|
||||
Otto: I can help with that! I'm Otto - what's your name?
|
||||
User: "I'm Sarah"
|
||||
Otto: [calls add_understanding with user_name="Sarah"]
|
||||
Nice to meet you, Sarah! What's your role - are you a social media manager or business owner?
|
||||
User: "I'm the marketing director at a fintech startup"
|
||||
Otto: [calls add_understanding with job_title="Marketing Director", industry="fintech", business_size="startup"]
|
||||
Great! Let me find social media automation agents for you.
|
||||
[calls find_agent with query="social media automation marketing"]
|
||||
```
|
||||
|
||||
## WHEN TO USE WHICH TOOL
|
||||
|
||||
**Finding existing agents:**
|
||||
- `find_agent` - Search the marketplace for pre-built agents others have created
|
||||
- `find_library_agent` - Search agents the user has already saved to their library
|
||||
|
||||
**Creating/editing agents:**
|
||||
- `create_agent` - When user wants a custom agent that doesn't exist, or has specific requirements
|
||||
- `edit_agent` - When user wants to modify an existing agent (change inputs, add blocks, etc.)
|
||||
|
||||
**Running agents:**
|
||||
- `run_agent` - To execute an agent (handles credentials and inputs automatically)
|
||||
- `agent_output` - To check the results of a running or completed agent execution
|
||||
|
||||
**Direct execution:**
|
||||
- `run_block` - Run a single block directly without needing a full agent
|
||||
|
||||
## HOW run_agent WORKS
|
||||
|
||||
The `run_agent` tool automatically handles the entire setup flow:
|
||||
|
||||
1. **First call** (no inputs) → Returns available inputs so user can decide what values to use
|
||||
2. **Credentials check** → If missing, UI automatically prompts user to add them (you don't need to mention this)
|
||||
3. **Execution** → Runs when you provide `inputs` OR set `use_defaults=true`
|
||||
|
||||
Parameters:
|
||||
- `username_agent_slug` (required): Agent identifier like "creator/agent-name"
|
||||
- `inputs`: Object with input values for the agent
|
||||
- `use_defaults`: Set to `true` to run with default values (only after user confirms)
|
||||
- `schedule_name` + `cron`: For scheduled execution
|
||||
|
||||
## HOW create_agent WORKS
|
||||
|
||||
Use `create_agent` when the user wants to build a custom automation:
|
||||
- Describe what the agent should do
|
||||
- The tool will create the agent structure with appropriate blocks
|
||||
- Returns the agent ID for further editing or running
|
||||
|
||||
## HOW agent_output WORKS
|
||||
|
||||
Use `agent_output` to get results from agent executions:
|
||||
- Pass the execution_id from a run_agent response
|
||||
- Returns the current status and any outputs produced
|
||||
- Useful for checking if an agent has completed and what it produced
|
||||
|
||||
## WORKFLOW
|
||||
|
||||
1. **Get their name** - If unknown, ask for it first
|
||||
2. **Understand context** - Ask 1-2 questions about their problem while helping
|
||||
3. **Find or create** - Use find_agent for existing solutions, create_agent for custom needs
|
||||
4. **Set up and run** - Use run_agent to execute, agent_output to get results
|
||||
|
||||
## YOUR APPROACH
|
||||
|
||||
**Step 1: Greet and Identify**
|
||||
- If you don't know their name, ask for it
|
||||
- Be friendly and conversational
|
||||
|
||||
**Step 2: Understand the Problem**
|
||||
- Ask maximum 1-2 targeted questions
|
||||
- Focus on: What business problem are they solving?
|
||||
- If they want to create/edit an agent, understand what it should do
|
||||
|
||||
**Step 3: Find or Create**
|
||||
- For existing solutions: Use `find_agent` with relevant keywords
|
||||
- For custom needs: Use `create_agent` with their requirements
|
||||
- For modifications: Use `edit_agent` on an existing agent
|
||||
|
||||
**Step 4: Execute**
|
||||
- Call `run_agent` without inputs first to see what's available
|
||||
- Ask user what values they want or if defaults are okay
|
||||
- Call `run_agent` again with inputs or `use_defaults=true`
|
||||
- Use `agent_output` to check results when needed
|
||||
|
||||
## USING add_understanding
|
||||
|
||||
Call `add_understanding` whenever you learn something about the user:
|
||||
|
||||
**User info:** `user_name`, `job_title`
|
||||
**Business:** `business_name`, `industry`, `business_size` (1-10, 11-50, 51-200, 201-1000, 1000+), `user_role` (decision maker, implementer, end user)
|
||||
**Processes:** `key_workflows` (array), `daily_activities` (array)
|
||||
**Pain points:** `pain_points` (array), `bottlenecks` (array), `manual_tasks` (array), `automation_goals` (array)
|
||||
**Tools:** `current_software` (array), `existing_automation` (array)
|
||||
**Other:** `additional_notes`
|
||||
|
||||
Example: `add_understanding(user_name="Sarah", job_title="Marketing Director", industry="fintech")`
|
||||
|
||||
## KEY RULES
|
||||
|
||||
**What You DON'T Do:**
|
||||
- Don't help with login (frontend handles this)
|
||||
- Don't mention or explain credentials to the user (frontend handles this automatically)
|
||||
- Don't run agents without first showing available inputs to the user
|
||||
- Don't use `use_defaults=true` without user explicitly confirming
|
||||
- Don't write responses longer than 3 sentences
|
||||
- Don't interrogate users with many questions - gather info naturally
|
||||
|
||||
**What You DO:**
|
||||
- ALWAYS ask for user's name if you don't have it
|
||||
- Save user information with `add_understanding` as you learn it
|
||||
- Use their name when addressing them
|
||||
- Always call run_agent first without inputs to see what's available
|
||||
- Ask user what values they want OR if they want to use defaults
|
||||
- Keep all responses to maximum 3 sentences
|
||||
- Include the agent link in your response after successful execution
|
||||
|
||||
**Error Handling:**
|
||||
- Authentication needed → "Please sign in via the interface"
|
||||
- Credentials missing → The UI handles this automatically. Focus on asking the user about input values instead.
|
||||
|
||||
## RESPONSE STRUCTURE
|
||||
|
||||
Before responding, wrap your analysis in <thinking> tags to systematically plan your approach:
|
||||
- Check if you know the user's name - if not, ask for it
|
||||
- Check if you have user context - if not, plan to gather some naturally
|
||||
- Extract the key business problem or request from the user's message
|
||||
- Determine what function call (if any) you need to make next
|
||||
- Plan your response to stay under the 3-sentence maximum
|
||||
|
||||
Example interaction:
|
||||
```
|
||||
User: "Hi, I want to build an agent that monitors my competitors"
|
||||
Otto: <thinking>I don't know this user's name. I should ask for it while acknowledging their request.</thinking>
|
||||
Hi! I'm Otto and I'd love to help you build a competitor monitoring agent. What's your name?
|
||||
User: "I'm Mike"
|
||||
Otto: [calls add_understanding with user_name="Mike"]
|
||||
<thinking>Now I know Mike wants competitor monitoring. I should search for existing agents first.</thinking>
|
||||
Great to meet you, Mike! Let me search for competitor monitoring agents.
|
||||
[calls find_agent with query="competitor monitoring analysis"]
|
||||
```
|
||||
|
||||
KEEP ANSWERS TO 3 SENTENCES
|
||||
@@ -1,155 +0,0 @@
|
||||
You are Otto, an AI Co-Pilot helping new users get started with AutoGPT, an AI Business Automation platform. Your mission is to welcome them, learn about their needs, and help them run their first successful agent.
|
||||
|
||||
Here are the functions available to you:
|
||||
|
||||
<functions>
|
||||
**Understanding & Discovery:**
|
||||
1. **add_understanding** - Save information about the user's business context (use this as you learn about them)
|
||||
2. **find_agent** - Search the marketplace for pre-built agents that solve the user's problem
|
||||
3. **find_library_agent** - Search the user's personal library of saved agents
|
||||
4. **find_block** - Search for individual blocks (building components for agents)
|
||||
5. **search_platform_docs** - Search AutoGPT documentation for help
|
||||
|
||||
**Agent Creation & Editing:**
|
||||
6. **create_agent** - Create a new custom agent from scratch based on user requirements
|
||||
7. **edit_agent** - Modify an existing agent (add/remove blocks, change configuration)
|
||||
|
||||
**Execution & Output:**
|
||||
8. **run_agent** - Run or schedule an agent (automatically handles setup)
|
||||
9. **run_block** - Run a single block directly without creating an agent
|
||||
10. **agent_output** - Get the output/results from a running or completed agent execution
|
||||
</functions>
|
||||
|
||||
## YOUR ONBOARDING MISSION
|
||||
|
||||
You are guiding a new user through their first experience with AutoGPT. Your goal is to:
|
||||
1. Welcome them warmly and get their name
|
||||
2. Learn about them and their business
|
||||
3. Find or create an agent that solves a real problem for them
|
||||
4. Get that agent running successfully
|
||||
5. Celebrate their success and point them to next steps
|
||||
|
||||
## PHASE 1: WELCOME & INTRODUCTION
|
||||
|
||||
**Start every conversation by:**
|
||||
- Giving a warm, friendly greeting
|
||||
- Introducing yourself as Otto, their AI assistant
|
||||
- Asking for their name immediately
|
||||
|
||||
**Example opening:**
|
||||
```
|
||||
Hi! I'm Otto, your AI assistant. Welcome to AutoGPT! I'm here to help you set up your first automation. What's your name?
|
||||
```
|
||||
|
||||
Once you have their name, save it immediately with `add_understanding(user_name="...")` and use it throughout.
|
||||
|
||||
## PHASE 2: DISCOVERY
|
||||
|
||||
**After getting their name, learn about them:**
|
||||
- What's their role/job title?
|
||||
- What industry/business are they in?
|
||||
- What's one thing they'd love to automate?
|
||||
|
||||
**Keep it conversational - don't interrogate. Example:**
|
||||
```
|
||||
Nice to meet you, Sarah! What do you do for work, and what's one task you wish you could automate?
|
||||
```
|
||||
|
||||
Save everything you learn with `add_understanding`.
|
||||
|
||||
## PHASE 3: FIND OR CREATE AN AGENT
|
||||
|
||||
**Once you understand their need:**
|
||||
- Search for existing agents with `find_agent`
|
||||
- Present the best match and explain how it helps them
|
||||
- If nothing fits, offer to create a custom agent with `create_agent`
|
||||
|
||||
**Be enthusiastic about the solution:**
|
||||
```
|
||||
I found a great agent for you! The "Social Media Scheduler" can automatically post to your accounts on a schedule. Want to try it?
|
||||
```
|
||||
|
||||
## PHASE 4: SETUP & RUN
|
||||
|
||||
**Guide them through running the agent:**
|
||||
1. Call `run_agent` without inputs first to see what's needed
|
||||
2. Explain each input in simple terms
|
||||
3. Ask what values they want to use
|
||||
4. Run the agent with their inputs or defaults
|
||||
|
||||
**Don't mention credentials** - the UI handles that automatically.
|
||||
|
||||
## PHASE 5: CELEBRATE & HANDOFF
|
||||
|
||||
**After successful execution:**
|
||||
- Congratulate them on their first automation!
|
||||
- Tell them where to find this agent (their Library)
|
||||
- Mention they can explore more agents in the Marketplace
|
||||
- Offer to help with anything else
|
||||
|
||||
**Example:**
|
||||
```
|
||||
You did it! Your first agent is running. You can find it anytime in your Library. Ready to explore more automations?
|
||||
```
|
||||
|
||||
## KEY RULES
|
||||
|
||||
**What You DON'T Do:**
|
||||
- Don't help with login (frontend handles this)
|
||||
- Don't mention credentials (UI handles automatically)
|
||||
- Don't run agents without showing inputs first
|
||||
- Don't use `use_defaults=true` without explicit confirmation
|
||||
- Don't write responses longer than 3 sentences
|
||||
- Don't overwhelm with too many questions at once
|
||||
|
||||
**What You DO:**
|
||||
- ALWAYS get the user's name first
|
||||
- Be warm, encouraging, and celebratory
|
||||
- Save info with `add_understanding` as you learn it
|
||||
- Use their name when addressing them
|
||||
- Keep responses to maximum 3 sentences
|
||||
- Make them feel successful at each step
|
||||
|
||||
## USING add_understanding
|
||||
|
||||
Save information as you learn it:
|
||||
|
||||
**User info:** `user_name`, `job_title`
|
||||
**Business:** `business_name`, `industry`, `business_size`, `user_role`
|
||||
**Pain points:** `pain_points`, `manual_tasks`, `automation_goals`
|
||||
**Tools:** `current_software`
|
||||
|
||||
Example: `add_understanding(user_name="Sarah", job_title="Marketing Manager", automation_goals=["social media scheduling"])`
|
||||
|
||||
## HOW run_agent WORKS
|
||||
|
||||
1. **First call** (no inputs) → Shows available inputs
|
||||
2. **Credentials** → UI handles automatically (don't mention)
|
||||
3. **Execution** → Run with `inputs={...}` or `use_defaults=true`
|
||||
|
||||
## RESPONSE STRUCTURE
|
||||
|
||||
Before responding, plan your approach in <thinking> tags:
|
||||
- What phase am I in? (Welcome/Discovery/Find/Setup/Celebrate)
|
||||
- Do I know their name? If not, ask for it
|
||||
- What's the next step to move them forward?
|
||||
- Keep response under 3 sentences
|
||||
|
||||
**Example flow:**
|
||||
```
|
||||
User: "Hi"
|
||||
Otto: <thinking>Phase 1 - I need to welcome them and get their name.</thinking>
|
||||
Hi! I'm Otto, welcome to AutoGPT! I'm here to help you set up your first automation - what's your name?
|
||||
|
||||
User: "I'm Alex"
|
||||
Otto: [calls add_understanding with user_name="Alex"]
|
||||
<thinking>Got their name. Phase 2 - learn about them.</thinking>
|
||||
Great to meet you, Alex! What do you do for work, and what's one task you'd love to automate?
|
||||
|
||||
User: "I run an e-commerce store and spend hours on customer support emails"
|
||||
Otto: [calls add_understanding with industry="e-commerce", pain_points=["customer support emails"]]
|
||||
<thinking>Phase 3 - search for agents.</thinking>
|
||||
[calls find_agent with query="customer support email automation"]
|
||||
```
|
||||
|
||||
KEEP ANSWERS TO 3 SENTENCES - Be warm, helpful, and focused on their success!
|
||||
@@ -1,3 +1,10 @@
|
||||
"""
|
||||
Response models for Vercel AI SDK UI Stream Protocol.
|
||||
|
||||
This module implements the AI SDK UI Stream Protocol (v1) for streaming chat responses.
|
||||
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
@@ -5,97 +12,133 @@ from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of streaming responses."""
|
||||
"""Types of streaming responses following AI SDK protocol."""
|
||||
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_ENDED = "text_ended"
|
||||
TOOL_CALL = "tool_call"
|
||||
TOOL_CALL_START = "tool_call_start"
|
||||
TOOL_RESPONSE = "tool_response"
|
||||
# Message lifecycle
|
||||
START = "start"
|
||||
FINISH = "finish"
|
||||
|
||||
# Text streaming
|
||||
TEXT_START = "text-start"
|
||||
TEXT_DELTA = "text-delta"
|
||||
TEXT_END = "text-end"
|
||||
|
||||
# Tool interaction
|
||||
TOOL_INPUT_START = "tool-input-start"
|
||||
TOOL_INPUT_AVAILABLE = "tool-input-available"
|
||||
TOOL_OUTPUT_AVAILABLE = "tool-output-available"
|
||||
|
||||
# Other
|
||||
ERROR = "error"
|
||||
USAGE = "usage"
|
||||
STREAM_END = "stream_end"
|
||||
|
||||
|
||||
class StreamBaseResponse(BaseModel):
|
||||
"""Base response model for all streaming responses."""
|
||||
|
||||
type: ResponseType
|
||||
timestamp: str | None = None
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
class StreamTextChunk(StreamBaseResponse):
|
||||
"""Streaming text content from the assistant."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_CHUNK
|
||||
content: str = Field(..., description="Text content chunk")
|
||||
# ========== Message Lifecycle ==========
|
||||
|
||||
|
||||
class StreamToolCallStart(StreamBaseResponse):
|
||||
class StreamStart(StreamBaseResponse):
|
||||
"""Start of a new message."""
|
||||
|
||||
type: ResponseType = ResponseType.START
|
||||
messageId: str = Field(..., description="Unique message ID")
|
||||
|
||||
|
||||
class StreamFinish(StreamBaseResponse):
|
||||
"""End of message/stream."""
|
||||
|
||||
type: ResponseType = ResponseType.FINISH
|
||||
|
||||
|
||||
# ========== Text Streaming ==========
|
||||
|
||||
|
||||
class StreamTextStart(StreamBaseResponse):
|
||||
"""Start of a text block."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_START
|
||||
id: str = Field(..., description="Text block ID")
|
||||
|
||||
|
||||
class StreamTextDelta(StreamBaseResponse):
|
||||
"""Streaming text content delta."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_DELTA
|
||||
id: str = Field(..., description="Text block ID")
|
||||
delta: str = Field(..., description="Text content delta")
|
||||
|
||||
|
||||
class StreamTextEnd(StreamBaseResponse):
|
||||
"""End of a text block."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_END
|
||||
id: str = Field(..., description="Text block ID")
|
||||
|
||||
|
||||
# ========== Tool Interaction ==========
|
||||
|
||||
|
||||
class StreamToolInputStart(StreamBaseResponse):
|
||||
"""Tool call started notification."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_CALL_START
|
||||
tool_name: str = Field(..., description="Name of the tool that was executed")
|
||||
tool_id: str = Field(..., description="Unique tool call ID")
|
||||
type: ResponseType = ResponseType.TOOL_INPUT_START
|
||||
toolCallId: str = Field(..., description="Unique tool call ID")
|
||||
toolName: str = Field(..., description="Name of the tool being called")
|
||||
|
||||
|
||||
class StreamToolCall(StreamBaseResponse):
|
||||
"""Tool invocation notification."""
|
||||
class StreamToolInputAvailable(StreamBaseResponse):
|
||||
"""Tool input is ready for execution."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_CALL
|
||||
tool_id: str = Field(..., description="Unique tool call ID")
|
||||
tool_name: str = Field(..., description="Name of the tool being called")
|
||||
arguments: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Tool arguments"
|
||||
type: ResponseType = ResponseType.TOOL_INPUT_AVAILABLE
|
||||
toolCallId: str = Field(..., description="Unique tool call ID")
|
||||
toolName: str = Field(..., description="Name of the tool being called")
|
||||
input: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Tool input arguments"
|
||||
)
|
||||
|
||||
|
||||
class StreamToolExecutionResult(StreamBaseResponse):
|
||||
class StreamToolOutputAvailable(StreamBaseResponse):
|
||||
"""Tool execution result."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_RESPONSE
|
||||
tool_id: str = Field(..., description="Tool call ID this responds to")
|
||||
tool_name: str = Field(..., description="Name of the tool that was executed")
|
||||
result: str | dict[str, Any] = Field(..., description="Tool execution result")
|
||||
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
||||
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
||||
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
||||
# Additional fields for internal use (not part of AI SDK spec but useful)
|
||||
toolName: str | None = Field(
|
||||
default=None, description="Name of the tool that was executed"
|
||||
)
|
||||
success: bool = Field(
|
||||
default=True, description="Whether the tool execution succeeded"
|
||||
)
|
||||
|
||||
|
||||
# ========== Other ==========
|
||||
|
||||
|
||||
class StreamUsage(StreamBaseResponse):
|
||||
"""Token usage statistics."""
|
||||
|
||||
type: ResponseType = ResponseType.USAGE
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
promptTokens: int = Field(..., description="Number of prompt tokens")
|
||||
completionTokens: int = Field(..., description="Number of completion tokens")
|
||||
totalTokens: int = Field(..., description="Total number of tokens")
|
||||
|
||||
|
||||
class StreamError(StreamBaseResponse):
|
||||
"""Error response."""
|
||||
|
||||
type: ResponseType = ResponseType.ERROR
|
||||
message: str = Field(..., description="Error message")
|
||||
errorText: str = Field(..., description="Error message text")
|
||||
code: str | None = Field(default=None, description="Error code")
|
||||
details: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional error details"
|
||||
)
|
||||
|
||||
|
||||
class StreamTextEnded(StreamBaseResponse):
|
||||
"""Text streaming completed marker."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_ENDED
|
||||
|
||||
|
||||
class StreamEnd(StreamBaseResponse):
|
||||
"""End of stream marker."""
|
||||
|
||||
type: ResponseType = ResponseType.STREAM_END
|
||||
summary: dict[str, Any] | None = Field(
|
||||
default=None, description="Stream summary statistics"
|
||||
)
|
||||
|
||||
@@ -13,12 +13,25 @@ from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _validate_and_get_session(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> ChatSession:
|
||||
"""Validate session exists and belongs to user."""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
return session
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
tags=["chat"],
|
||||
)
|
||||
@@ -94,7 +107,7 @@ async def list_sessions(
|
||||
Returns:
|
||||
ListSessionsResponse: List of session summaries and total count.
|
||||
"""
|
||||
sessions = await chat_service.get_user_sessions(user_id, limit, offset)
|
||||
sessions, total_count = await get_user_sessions(user_id, limit, offset)
|
||||
|
||||
return ListSessionsResponse(
|
||||
sessions=[
|
||||
@@ -102,11 +115,11 @@ async def list_sessions(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
title=None, # TODO: Add title support
|
||||
title=session.title,
|
||||
)
|
||||
for session in sessions
|
||||
],
|
||||
total=len(sessions),
|
||||
total=total_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -114,15 +127,15 @@ async def list_sessions(
|
||||
"/sessions",
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
user_id: Annotated[str, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
|
||||
Initiates a new chat session for either an authenticated or anonymous user.
|
||||
Initiates a new chat session for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id: The optional authenticated user ID parsed from the JWT. If missing, creates an anonymous session.
|
||||
user_id: The authenticated user ID parsed from the JWT (required).
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created session.
|
||||
@@ -130,15 +143,15 @@ async def create_session(
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating session with user_id: "
|
||||
f"...{user_id[-8:] if user_id and len(user_id) > 8 else '<redacted>'}"
|
||||
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
|
||||
)
|
||||
|
||||
session = await chat_service.create_chat_session(user_id)
|
||||
session = await create_chat_session(user_id)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -162,7 +175,7 @@ async def get_session(
|
||||
SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.
|
||||
|
||||
"""
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
|
||||
@@ -206,14 +219,7 @@ async def stream_chat_post(
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
# Validate session exists before starting the stream
|
||||
# This prevents errors after the response has already started
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found. ")
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
@@ -225,6 +231,8 @@ async def stream_chat_post(
|
||||
context=request.context,
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -233,6 +241,7 @@ async def stream_chat_post(
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
@@ -263,14 +272,7 @@ async def stream_chat_get(
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
# Validate session exists before starting the stream
|
||||
# This prevents errors after the response has already started
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found. ")
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
@@ -281,6 +283,8 @@ async def stream_chat_get(
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -289,6 +293,7 @@ async def stream_chat_get(
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
||||
},
|
||||
)
|
||||
|
||||
@@ -319,133 +324,6 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Onboarding Routes ==========
|
||||
# These routes use a specialized onboarding system prompt
|
||||
|
||||
|
||||
@router.post(
|
||||
"/onboarding/sessions",
|
||||
)
|
||||
async def create_onboarding_session(
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new onboarding chat session.
|
||||
|
||||
Initiates a new chat session specifically for user onboarding,
|
||||
using a specialized prompt that guides users through their first
|
||||
experience with AutoGPT.
|
||||
|
||||
Args:
|
||||
user_id: The optional authenticated user ID parsed from the JWT.
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: Details of the created onboarding session.
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating onboarding session with user_id: "
|
||||
f"...{user_id[-8:] if user_id and len(user_id) > 8 else '<redacted>'}"
|
||||
)
|
||||
|
||||
session = await chat_service.create_chat_session(user_id)
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/onboarding/sessions/{session_id}",
|
||||
)
|
||||
async def get_onboarding_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of an onboarding chat session.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the onboarding session.
|
||||
user_id: The optional authenticated user ID.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session.
|
||||
"""
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
logger.info(
|
||||
f"Returning onboarding session {session_id}: "
|
||||
f"message_count={len(messages)}, "
|
||||
f"roles={[m.get('role') for m in messages]}"
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
created_at=session.started_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/onboarding/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_onboarding_chat(
|
||||
session_id: str,
|
||||
request: StreamChatRequest,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
):
|
||||
"""
|
||||
Stream onboarding chat responses for a session.
|
||||
|
||||
Uses the specialized onboarding system prompt to guide new users
|
||||
through their first experience with AutoGPT. Streams AI responses
|
||||
in real time over Server-Sent Events (SSE).
|
||||
|
||||
Args:
|
||||
session_id: The onboarding session identifier.
|
||||
request: Request body containing message and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
"""
|
||||
session = await chat_service.get_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
if session.user_id is None and user_id is not None:
|
||||
session = await chat_service.assign_user_to_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
context=request.context,
|
||||
prompt_type="onboarding", # Use onboarding system prompt
|
||||
):
|
||||
yield chunk.to_sse()
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
|
||||
|
||||
@@ -454,16 +332,28 @@ async def health_check() -> dict:
|
||||
"""
|
||||
Health check endpoint for the chat service.
|
||||
|
||||
Performs a full cycle test of session creation, assignment, and retrieval. Should always return healthy
|
||||
Performs a full cycle test of session creation and retrieval. Should always return healthy
|
||||
if the service and data layer are operational.
|
||||
|
||||
Returns:
|
||||
dict: A status dictionary indicating health, service name, and API version.
|
||||
|
||||
"""
|
||||
session = await chat_service.create_chat_session(None)
|
||||
await chat_service.assign_user_to_session(session.session_id, "test_user")
|
||||
await chat_service.get_session(session.session_id, "test_user")
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Ensure health check user exists (required for FK constraint)
|
||||
health_check_user_id = "health-check-user"
|
||||
await get_or_create_user(
|
||||
{
|
||||
"sub": health_check_user_id,
|
||||
"email": "health-check@system.local",
|
||||
"user_metadata": {"name": "Health Check User"},
|
||||
}
|
||||
)
|
||||
|
||||
# Create and retrieve session to verify full data layer
|
||||
session = await create_chat_session(health_check_user_id)
|
||||
await get_chat_session(session.session_id, health_check_user_id)
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,18 +4,19 @@ from os import getenv
|
||||
import pytest
|
||||
|
||||
from . import service as chat_service
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import (
|
||||
StreamEnd,
|
||||
StreamError,
|
||||
StreamTextChunk,
|
||||
StreamToolExecutionResult,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion():
|
||||
async def test_stream_chat_completion(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
@@ -23,7 +24,7 @@ async def test_stream_chat_completion():
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await chat_service.create_chat_session()
|
||||
session = await create_chat_session(test_user_id)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
@@ -34,9 +35,9 @@ async def test_stream_chat_completion():
|
||||
logger.info(chunk)
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
if isinstance(chunk, StreamTextChunk):
|
||||
assistant_message += chunk.content
|
||||
if isinstance(chunk, StreamEnd):
|
||||
if isinstance(chunk, StreamTextDelta):
|
||||
assistant_message += chunk.delta
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
@@ -45,7 +46,7 @@ async def test_stream_chat_completion():
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_stream_chat_completion_with_tool_calls():
|
||||
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
|
||||
"""
|
||||
Test the stream_chat_completion function.
|
||||
"""
|
||||
@@ -53,8 +54,8 @@ async def test_stream_chat_completion_with_tool_calls():
|
||||
if not api_key:
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await chat_service.create_chat_session()
|
||||
session = await chat_service.upsert_chat_session(session)
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
@@ -68,14 +69,14 @@ async def test_stream_chat_completion_with_tool_calls():
|
||||
if isinstance(chunk, StreamError):
|
||||
has_errors = True
|
||||
|
||||
if isinstance(chunk, StreamEnd):
|
||||
if isinstance(chunk, StreamFinish):
|
||||
has_ended = True
|
||||
if isinstance(chunk, StreamToolExecutionResult):
|
||||
if isinstance(chunk, StreamToolOutputAvailable):
|
||||
had_tool_calls = True
|
||||
|
||||
assert has_ended, "Chat completion did not end"
|
||||
assert not has_errors, "Error occurred while streaming chat completion"
|
||||
assert had_tool_calls, "Tool calls did not occur"
|
||||
session = await chat_service.get_session(session.session_id)
|
||||
session = await get_chat_session(session.session_id)
|
||||
assert session, "Session not found"
|
||||
assert session.usage, "Usage is empty"
|
||||
|
||||
@@ -12,37 +12,36 @@ from .edit_agent import EditAgentTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
from .get_doc_page import GetDocPageTool
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .search_docs import SearchDocsTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
|
||||
# Initialize tool instances
|
||||
add_understanding_tool = AddUnderstandingTool()
|
||||
create_agent_tool = CreateAgentTool()
|
||||
edit_agent_tool = EditAgentTool()
|
||||
find_agent_tool = FindAgentTool()
|
||||
find_block_tool = FindBlockTool()
|
||||
find_library_agent_tool = FindLibraryAgentTool()
|
||||
run_agent_tool = RunAgentTool()
|
||||
run_block_tool = RunBlockTool()
|
||||
search_docs_tool = SearchDocsTool()
|
||||
agent_output_tool = AgentOutputTool()
|
||||
# Single source of truth for all tools
|
||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"add_understanding": AddUnderstandingTool(),
|
||||
"create_agent": CreateAgentTool(),
|
||||
"edit_agent": EditAgentTool(),
|
||||
"find_agent": FindAgentTool(),
|
||||
"find_block": FindBlockTool(),
|
||||
"find_library_agent": FindLibraryAgentTool(),
|
||||
"run_agent": RunAgentTool(),
|
||||
"run_block": RunBlockTool(),
|
||||
"agent_output": AgentOutputTool(),
|
||||
"search_docs": SearchDocsTool(),
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
}
|
||||
|
||||
# Export tools as OpenAI format
|
||||
# Export individual tool instances for backwards compatibility
|
||||
find_agent_tool = TOOL_REGISTRY["find_agent"]
|
||||
run_agent_tool = TOOL_REGISTRY["run_agent"]
|
||||
|
||||
# Generated from registry for OpenAI API
|
||||
tools: list[ChatCompletionToolParam] = [
|
||||
add_understanding_tool.as_openai_tool(),
|
||||
create_agent_tool.as_openai_tool(),
|
||||
edit_agent_tool.as_openai_tool(),
|
||||
find_agent_tool.as_openai_tool(),
|
||||
find_block_tool.as_openai_tool(),
|
||||
find_library_agent_tool.as_openai_tool(),
|
||||
run_agent_tool.as_openai_tool(),
|
||||
run_block_tool.as_openai_tool(),
|
||||
search_docs_tool.as_openai_tool(),
|
||||
agent_output_tool.as_openai_tool(),
|
||||
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
|
||||
]
|
||||
|
||||
|
||||
@@ -52,22 +51,9 @@ async def execute_tool(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
tool_call_id: str,
|
||||
) -> "StreamToolExecutionResult":
|
||||
|
||||
tool_map: dict[str, BaseTool] = {
|
||||
"add_understanding": add_understanding_tool,
|
||||
"create_agent": create_agent_tool,
|
||||
"edit_agent": edit_agent_tool,
|
||||
"find_agent": find_agent_tool,
|
||||
"find_block": find_block_tool,
|
||||
"find_library_agent": find_library_agent_tool,
|
||||
"run_agent": run_agent_tool,
|
||||
"run_block": run_block_tool,
|
||||
"search_platform_docs": search_docs_tool,
|
||||
"agent_output": agent_output_tool,
|
||||
}
|
||||
if tool_name not in tool_map:
|
||||
) -> "StreamToolOutputAvailable":
|
||||
"""Execute a tool by name."""
|
||||
tool = TOOL_REGISTRY.get(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
return await tool_map[tool_name].execute(
|
||||
user_id, session, tool_call_id, **parameters
|
||||
)
|
||||
return await tool.execute(user_id, session, tool_call_id, **parameters)
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import UTC, datetime
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
from prisma.types import ProfileCreateInput
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
@@ -17,7 +18,7 @@ from backend.data.user import get_or_create_user
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
|
||||
def make_session(user_id: str | None = None):
|
||||
def make_session(user_id: str):
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
@@ -49,13 +50,13 @@ async def setup_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create a test graph with agent input -> agent output
|
||||
@@ -172,13 +173,13 @@ async def setup_llm_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for LLM tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile for LLM tests",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create test OpenAI credentials for the user
|
||||
@@ -332,13 +333,13 @@ async def setup_firecrawl_test_data():
|
||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||
username = user.email.split("@")[0]
|
||||
await prisma.profile.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"username": username,
|
||||
"name": f"Test User {username}",
|
||||
"description": "Test user profile for Firecrawl tests",
|
||||
"links": [], # Required field - empty array for test profiles
|
||||
}
|
||||
data=ProfileCreateInput(
|
||||
userId=user.id,
|
||||
username=username,
|
||||
name=f"Test User {username}",
|
||||
description="Test user profile for Firecrawl tests",
|
||||
links=[], # Required field - empty array for test profiles
|
||||
)
|
||||
)
|
||||
|
||||
# NOTE: We deliberately do NOT create Firecrawl credentials for this user
|
||||
|
||||
@@ -10,11 +10,7 @@ from backend.data.understanding import (
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,80 +34,25 @@ and automations for the user's specific needs."""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user_name": {
|
||||
"type": "string",
|
||||
"description": "The user's name",
|
||||
},
|
||||
"job_title": {
|
||||
"type": "string",
|
||||
"description": "The user's job title (e.g., 'Marketing Manager', 'CEO', 'Software Engineer')",
|
||||
},
|
||||
"business_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the user's business or organization",
|
||||
},
|
||||
"industry": {
|
||||
"type": "string",
|
||||
"description": "Industry or sector (e.g., 'e-commerce', 'healthcare', 'finance')",
|
||||
},
|
||||
"business_size": {
|
||||
"type": "string",
|
||||
"description": "Company size: '1-10', '11-50', '51-200', '201-1000', or '1000+'",
|
||||
},
|
||||
"user_role": {
|
||||
"type": "string",
|
||||
"description": "User's role in organization context (e.g., 'decision maker', 'implementer', 'end user')",
|
||||
},
|
||||
"key_workflows": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Key business workflows (e.g., 'lead qualification', 'content publishing')",
|
||||
},
|
||||
"daily_activities": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Regular daily activities the user performs",
|
||||
},
|
||||
"pain_points": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Current pain points or challenges",
|
||||
},
|
||||
"bottlenecks": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Process bottlenecks slowing things down",
|
||||
},
|
||||
"manual_tasks": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Manual or repetitive tasks that could be automated",
|
||||
},
|
||||
"automation_goals": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Desired automation outcomes or goals",
|
||||
},
|
||||
"current_software": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Software and tools currently in use",
|
||||
},
|
||||
"existing_automation": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Any existing automations or integrations",
|
||||
},
|
||||
"additional_notes": {
|
||||
"type": "string",
|
||||
"description": "Any other relevant context or notes",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
# Auto-generate from Pydantic model schema
|
||||
schema = BusinessUnderstandingInput.model_json_schema()
|
||||
properties = {}
|
||||
for field_name, field_schema in schema.get("properties", {}).items():
|
||||
prop: dict[str, Any] = {"description": field_schema.get("description", "")}
|
||||
# Handle anyOf for Optional types
|
||||
if "anyOf" in field_schema:
|
||||
for option in field_schema["anyOf"]:
|
||||
if option.get("type") != "null":
|
||||
prop["type"] = option.get("type", "string")
|
||||
if "items" in option:
|
||||
prop["items"] = option["items"]
|
||||
break
|
||||
else:
|
||||
prop["type"] = field_schema.get("type", "string")
|
||||
if "items" in field_schema:
|
||||
prop["items"] = field_schema["items"]
|
||||
properties[field_name] = prop
|
||||
return {"type": "object", "properties": properties, "required": []}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
@@ -146,54 +87,26 @@ and automations for the user's specific needs."""
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build input model
|
||||
# Build input model from kwargs (only include fields defined in the model)
|
||||
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
|
||||
input_data = BusinessUnderstandingInput(
|
||||
user_name=kwargs.get("user_name"),
|
||||
job_title=kwargs.get("job_title"),
|
||||
business_name=kwargs.get("business_name"),
|
||||
industry=kwargs.get("industry"),
|
||||
business_size=kwargs.get("business_size"),
|
||||
user_role=kwargs.get("user_role"),
|
||||
key_workflows=kwargs.get("key_workflows"),
|
||||
daily_activities=kwargs.get("daily_activities"),
|
||||
pain_points=kwargs.get("pain_points"),
|
||||
bottlenecks=kwargs.get("bottlenecks"),
|
||||
manual_tasks=kwargs.get("manual_tasks"),
|
||||
automation_goals=kwargs.get("automation_goals"),
|
||||
current_software=kwargs.get("current_software"),
|
||||
existing_automation=kwargs.get("existing_automation"),
|
||||
additional_notes=kwargs.get("additional_notes"),
|
||||
**{k: v for k, v in kwargs.items() if k in valid_fields}
|
||||
)
|
||||
|
||||
# Track which fields were updated
|
||||
updated_fields = [k for k, v in kwargs.items() if v is not None]
|
||||
updated_fields = [
|
||||
k for k, v in kwargs.items() if k in valid_fields and v is not None
|
||||
]
|
||||
|
||||
# Upsert with merge
|
||||
understanding = await upsert_business_understanding(user_id, input_data)
|
||||
|
||||
# Build current understanding summary for the response
|
||||
current_understanding = {
|
||||
"user_name": understanding.user_name,
|
||||
"job_title": understanding.job_title,
|
||||
"business_name": understanding.business_name,
|
||||
"industry": understanding.industry,
|
||||
"business_size": understanding.business_size,
|
||||
"user_role": understanding.user_role,
|
||||
"key_workflows": understanding.key_workflows,
|
||||
"daily_activities": understanding.daily_activities,
|
||||
"pain_points": understanding.pain_points,
|
||||
"bottlenecks": understanding.bottlenecks,
|
||||
"manual_tasks": understanding.manual_tasks,
|
||||
"automation_goals": understanding.automation_goals,
|
||||
"current_software": understanding.current_software,
|
||||
"existing_automation": understanding.existing_automation,
|
||||
"additional_notes": understanding.additional_notes,
|
||||
}
|
||||
|
||||
# Filter out empty values for cleaner response
|
||||
# Build current understanding summary (filter out empty values)
|
||||
current_understanding = {
|
||||
k: v
|
||||
for k, v in current_understanding.items()
|
||||
for k, v in understanding.model_dump(
|
||||
exclude={"id", "user_id", "created_at", "updated_at"}
|
||||
).items()
|
||||
if v is not None and v != [] and v != ""
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Configuration - use OPEN_ROUTER_API_KEY for consistency with chat/config.py
|
||||
OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY") or os.getenv("OPENROUTER_API_KEY")
|
||||
OPENROUTER_API_KEY = os.getenv("OPEN_ROUTER_API_KEY")
|
||||
AGENT_GENERATOR_MODEL = os.getenv("AGENT_GENERATOR_MODEL", "anthropic/claude-opus-4.5")
|
||||
|
||||
# OpenRouter client (OpenAI-compatible API)
|
||||
|
||||
@@ -55,56 +55,47 @@ def parse_time_expression(
|
||||
"""
|
||||
Parse time expression into datetime range (start, end).
|
||||
|
||||
Supports:
|
||||
- "latest" or None -> returns (None, None) to get most recent
|
||||
- "yesterday" -> 24h window for yesterday
|
||||
- "today" -> Today from midnight
|
||||
- "last week" / "last 7 days" -> 7 day window
|
||||
- "last month" / "last 30 days" -> 30 day window
|
||||
- ISO date "YYYY-MM-DD" -> 24h window for that date
|
||||
Supports: "latest", "yesterday", "today", "last week", "last 7 days",
|
||||
"last month", "last 30 days", ISO date "YYYY-MM-DD", ISO datetime.
|
||||
"""
|
||||
if not time_expr or time_expr.lower() == "latest":
|
||||
return None, None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
expr = time_expr.lower().strip()
|
||||
|
||||
# Relative expressions
|
||||
if expr == "yesterday":
|
||||
end = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
start = end - timedelta(days=1)
|
||||
return start, end
|
||||
|
||||
if expr in ("last week", "last 7 days"):
|
||||
return now - timedelta(days=7), now
|
||||
|
||||
if expr in ("last month", "last 30 days"):
|
||||
return now - timedelta(days=30), now
|
||||
|
||||
if expr == "today":
|
||||
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
return start, now
|
||||
# Relative time expressions lookup
|
||||
relative_times: dict[str, tuple[datetime, datetime]] = {
|
||||
"yesterday": (today_start - timedelta(days=1), today_start),
|
||||
"today": (today_start, now),
|
||||
"last week": (now - timedelta(days=7), now),
|
||||
"last 7 days": (now - timedelta(days=7), now),
|
||||
"last month": (now - timedelta(days=30), now),
|
||||
"last 30 days": (now - timedelta(days=30), now),
|
||||
}
|
||||
if expr in relative_times:
|
||||
return relative_times[expr]
|
||||
|
||||
# Try ISO date format (YYYY-MM-DD)
|
||||
date_match = re.match(r"^(\d{4})-(\d{2})-(\d{2})$", expr)
|
||||
if date_match:
|
||||
year, month, day = map(int, date_match.groups())
|
||||
start = datetime(year, month, day, 0, 0, 0, tzinfo=timezone.utc)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
try:
|
||||
year, month, day = map(int, date_match.groups())
|
||||
start = datetime(year, month, day, 0, 0, 0, tzinfo=timezone.utc)
|
||||
return start, start + timedelta(days=1)
|
||||
except ValueError:
|
||||
# Invalid date components (e.g., month=13, day=32)
|
||||
pass
|
||||
|
||||
# Try ISO datetime
|
||||
try:
|
||||
parsed = datetime.fromisoformat(expr.replace("Z", "+00:00"))
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
# Return +/- 1 hour window around the specified time
|
||||
return parsed - timedelta(hours=1), parsed + timedelta(hours=1)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Fallback: treat as "latest"
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
|
||||
class AgentOutputTool(BaseTool):
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .models import (
|
||||
AgentInfo,
|
||||
AgentsFoundResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SearchSource = Literal["marketplace", "library"]
|
||||
|
||||
|
||||
async def search_agents(
|
||||
query: str,
|
||||
source: SearchSource,
|
||||
session_id: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
Search for agents in marketplace or user library.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
source: "marketplace" or "library"
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (required for library search)
|
||||
|
||||
Returns:
|
||||
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
|
||||
"""
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query", session_id=session_id
|
||||
)
|
||||
|
||||
if source == "library" and not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required to search library",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agents: list[AgentInfo] = []
|
||||
try:
|
||||
if source == "marketplace":
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
results = await store_db.get_store_agents(search_query=query, page_size=5)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=f"{agent.creator}/{agent.slug}",
|
||||
name=agent.agent_name,
|
||||
description=agent.description or "",
|
||||
source="marketplace",
|
||||
in_library=False,
|
||||
creator=agent.creator,
|
||||
category="general",
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
)
|
||||
)
|
||||
else: # library
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db.list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
)
|
||||
logger.info(f"Found {len(agents)} agents in {source}")
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching {source}: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to search {source}. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agents:
|
||||
suggestions = (
|
||||
[
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
]
|
||||
if source == "marketplace"
|
||||
else [
|
||||
"Try different keywords",
|
||||
"Use find_agent to search the marketplace",
|
||||
"Check your library at /library",
|
||||
]
|
||||
)
|
||||
no_results_msg = (
|
||||
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
||||
if source == "marketplace"
|
||||
else f"No agents matching '{query}' found in your library."
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||
)
|
||||
|
||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
|
||||
title += (
|
||||
f"for '{query}'"
|
||||
if source == "marketplace"
|
||||
else f"in your library for '{query}'"
|
||||
)
|
||||
|
||||
message = (
|
||||
"Now you have found some options for the user to choose from. "
|
||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||
"Please ask the user if they would like to use any of these agents."
|
||||
if source == "marketplace"
|
||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
||||
)
|
||||
|
||||
return AgentsFoundResponse(
|
||||
message=message,
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -6,7 +6,7 @@ from typing import Any
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.response_model import StreamToolExecutionResult
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
@@ -53,7 +53,7 @@ class BaseTool:
|
||||
session: ChatSession,
|
||||
tool_call_id: str,
|
||||
**kwargs,
|
||||
) -> StreamToolExecutionResult:
|
||||
) -> StreamToolOutputAvailable:
|
||||
"""Execute the tool with authentication check.
|
||||
|
||||
Args:
|
||||
@@ -69,10 +69,10 @@ class BaseTool:
|
||||
logger.error(
|
||||
f"Attempted tool call for {self.name} but user not authenticated"
|
||||
)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
result=NeedLoginResponse(
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=NeedLoginResponse(
|
||||
message=f"Please sign in to use {self.name}",
|
||||
session_id=session.session_id,
|
||||
).model_dump_json(),
|
||||
@@ -81,17 +81,17 @@ class BaseTool:
|
||||
|
||||
try:
|
||||
result = await self._execute(user_id, session, **kwargs)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
result=result.model_dump_json(),
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=result.model_dump_json(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in {self.name}: {e}", exc_info=True)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
result=ErrorResponse(
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=self.name,
|
||||
output=ErrorResponse(
|
||||
message=f"An error occurred while executing {self.name}",
|
||||
error=str(e),
|
||||
session_id=session.session_id,
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,26 +1,16 @@
|
||||
"""Tool for discovering agents from marketplace and user library."""
|
||||
"""Tool for discovering agents from marketplace."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentCarouselResponse,
|
||||
AgentInfo,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from .models import ToolResponseBase
|
||||
|
||||
|
||||
class FindAgentTool(BaseTool):
|
||||
"""Tool for discovering agents based on user needs."""
|
||||
"""Tool for discovering agents from the marketplace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -46,84 +36,11 @@ class FindAgentTool(BaseTool):
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Search for agents in the marketplace.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous)
|
||||
session_id: Chat session ID
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
AgentCarouselResponse: List of agents found in the marketplace
|
||||
NoResultsResponse: No agents found in the marketplace
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
agents = []
|
||||
try:
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
store_results = await store_db.get_store_agents(
|
||||
search_query=query,
|
||||
page_size=5,
|
||||
)
|
||||
|
||||
logger.info(f"Find agents tool found {len(store_results.agents)} agents")
|
||||
for agent in store_results.agents:
|
||||
agent_id = f"{agent.creator}/{agent.slug}"
|
||||
logger.info(f"Building agent ID = {agent_id}")
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent_id,
|
||||
name=agent.agent_name,
|
||||
description=agent.description or "",
|
||||
source="marketplace",
|
||||
in_library=False,
|
||||
creator=agent.creator,
|
||||
category="general",
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
),
|
||||
)
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching agents: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search for agents. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
if not agents:
|
||||
return NoResultsResponse(
|
||||
message=f"No agents found matching '{query}'. Try different keywords or browse the marketplace. If you have 3 consecutive find_agent tool calls results and found no agents. Please stop trying and ask the user if there is anything else you can help with.",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
],
|
||||
)
|
||||
|
||||
# Return formatted carousel
|
||||
title = (
|
||||
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
|
||||
)
|
||||
return AgentCarouselResponse(
|
||||
message="Now you have found some options for the user to choose from. You can add a link to a recommended agent at: /marketplace/agent/agent_id Please ask the user if they would like to use any of these agents. If they do, please call the get_agent_details tool for this agent.",
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
return await search_agents(
|
||||
query=kwargs.get("query", "").strip(),
|
||||
source="marketplace",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
"""Tool for searching available blocks using hybrid search."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.blocks import load_all_blocks
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||
from backend.api.features.chat.tools.models import (
|
||||
BlockInfoSummary,
|
||||
BlockInputFieldInfo,
|
||||
BlockListResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .search_blocks import get_block_search_index
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.data.block import get_block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,7 +31,8 @@ class FindBlockTool(BaseTool):
|
||||
"Search for available blocks by name or description. "
|
||||
"Blocks are reusable components that perform specific tasks like "
|
||||
"sending emails, making API calls, processing text, etc. "
|
||||
"Use this to find blocks that can be executed directly."
|
||||
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
||||
"The response includes each block's id, required_inputs, and input_schema."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -55,39 +55,6 @@ class FindBlockTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
def _matches_query(self, block, query: str) -> tuple[int, bool]:
|
||||
"""
|
||||
Check if a block matches the query and return a priority score.
|
||||
|
||||
Returns (priority, matches) where:
|
||||
- priority 0: exact name match
|
||||
- priority 1: name contains query
|
||||
- priority 2: description contains query
|
||||
- priority 3: category contains query
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
name_lower = block.name.lower()
|
||||
desc_lower = block.description.lower()
|
||||
|
||||
# Exact name match
|
||||
if query_lower == name_lower:
|
||||
return 0, True
|
||||
|
||||
# Name contains query
|
||||
if query_lower in name_lower:
|
||||
return 1, True
|
||||
|
||||
# Description contains query
|
||||
if query_lower in desc_lower:
|
||||
return 2, True
|
||||
|
||||
# Category contains query
|
||||
for category in block.categories:
|
||||
if query_lower in category.name.lower():
|
||||
return 3, True
|
||||
|
||||
return 4, False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
@@ -116,138 +83,110 @@ class FindBlockTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
# Try hybrid search first
|
||||
search_results = self._hybrid_search(query)
|
||||
# Search for blocks using hybrid search
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
if search_results is not None:
|
||||
# Hybrid search succeeded
|
||||
if not search_results:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found matching '{query}'",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Search by category: ai, text, social, search, etc.",
|
||||
"Check block names like 'SendEmail', 'HttpRequest', etc.",
|
||||
],
|
||||
)
|
||||
|
||||
# Get full block info for each result
|
||||
all_blocks = load_all_blocks()
|
||||
blocks = []
|
||||
for result in search_results:
|
||||
block_cls = all_blocks.get(result.block_id)
|
||||
if block_cls:
|
||||
block = block_cls()
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
description=block.description,
|
||||
categories=[cat.name for cat in block.categories],
|
||||
input_schema=block.input_schema.jsonschema(),
|
||||
output_schema=block.output_schema.jsonschema(),
|
||||
)
|
||||
)
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block{'s' if len(blocks) != 1 else ''} "
|
||||
f"matching '{query}'. Use run_block to execute a block with "
|
||||
"the required inputs."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
query=query,
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{query}'",
|
||||
suggestions=[
|
||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||
"Check spelling of technical terms",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Fallback to simple search if hybrid search failed
|
||||
return self._simple_search(query, session_id)
|
||||
# Enrich results with full block information
|
||||
blocks: list[BlockInfoSummary] = []
|
||||
for result in results:
|
||||
block_id = result["content_id"]
|
||||
block = get_block(block_id)
|
||||
|
||||
if block:
|
||||
# Get input/output schemas
|
||||
input_schema = {}
|
||||
output_schema = {}
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get categories from block instance
|
||||
categories = []
|
||||
if hasattr(block, "categories") and block.categories:
|
||||
categories = [cat.value for cat in block.categories]
|
||||
|
||||
# Extract required inputs for easier use
|
||||
required_inputs: list[BlockInputFieldInfo] = []
|
||||
if input_schema:
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = set(input_schema.get("required", []))
|
||||
# Get credential field names to exclude from required inputs
|
||||
credentials_fields = set(
|
||||
block.input_schema.get_credentials_fields().keys()
|
||||
)
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields - they're handled separately
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
required_inputs.append(
|
||||
BlockInputFieldInfo(
|
||||
name=field_name,
|
||||
type=field_schema.get("type", "string"),
|
||||
description=field_schema.get("description", ""),
|
||||
required=field_name in required_fields,
|
||||
default=field_schema.get("default"),
|
||||
)
|
||||
)
|
||||
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.description or "",
|
||||
categories=categories,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
required_inputs=required_inputs,
|
||||
)
|
||||
)
|
||||
|
||||
if not blocks:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{query}'",
|
||||
suggestions=[
|
||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block(s) matching '{query}'. "
|
||||
"To execute a block, use run_block with the block's 'id' field "
|
||||
"and provide 'input_data' matching the block's input_schema."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching blocks: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search blocks. Please try again.",
|
||||
message="Failed to search blocks",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
def _hybrid_search(self, query: str) -> list | None:
|
||||
"""
|
||||
Perform hybrid search using embeddings and BM25.
|
||||
|
||||
Returns:
|
||||
List of BlockSearchResult if successful, None if index not available
|
||||
"""
|
||||
try:
|
||||
index = get_block_search_index()
|
||||
if not index.load():
|
||||
logger.info(
|
||||
"Block search index not available, falling back to simple search"
|
||||
)
|
||||
return None
|
||||
|
||||
results = index.search(query, top_k=10)
|
||||
logger.info(f"Hybrid search found {len(results)} blocks for: {query}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Hybrid search failed, falling back to simple: {e}")
|
||||
return None
|
||||
|
||||
def _simple_search(self, query: str, session_id: str) -> ToolResponseBase:
|
||||
"""Fallback simple search using substring matching."""
|
||||
all_blocks = load_all_blocks()
|
||||
logger.info(f"Simple searching {len(all_blocks)} blocks for: {query}")
|
||||
|
||||
# Find matching blocks with priority scores
|
||||
matches: list[tuple[int, Any]] = []
|
||||
for block_id, block_cls in all_blocks.items():
|
||||
block = block_cls()
|
||||
priority, is_match = self._matches_query(block, query)
|
||||
if is_match:
|
||||
matches.append((priority, block))
|
||||
|
||||
# Sort by priority (lower is better)
|
||||
matches.sort(key=lambda x: x[0])
|
||||
|
||||
# Take top 10 results
|
||||
top_matches = [block for _, block in matches[:10]]
|
||||
|
||||
if not top_matches:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found matching '{query}'",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Search by category: ai, text, social, search, etc.",
|
||||
"Check block names like 'SendEmail', 'HttpRequest', etc.",
|
||||
],
|
||||
)
|
||||
|
||||
# Build response
|
||||
blocks = []
|
||||
for block in top_matches:
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
description=block.description,
|
||||
categories=[cat.name for cat in block.categories],
|
||||
input_schema=block.input_schema.jsonschema(),
|
||||
output_schema=block.output_schema.jsonschema(),
|
||||
)
|
||||
)
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block{'s' if len(blocks) != 1 else ''} "
|
||||
f"matching '{query}'. Use run_block to execute a block with "
|
||||
"the required inputs."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -1,22 +1,12 @@
|
||||
"""Tool for searching agents in the user's library."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.util.exceptions import DatabaseError
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentCarouselResponse,
|
||||
AgentInfo,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from .models import ToolResponseBase
|
||||
|
||||
|
||||
class FindLibraryAgentTool(BaseTool):
|
||||
@@ -41,10 +31,7 @@ class FindLibraryAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find agents by name or description. "
|
||||
"Use keywords for best results."
|
||||
),
|
||||
"description": "Search query to find agents by name or description.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
@@ -55,103 +42,11 @@ class FindLibraryAgentTool(BaseTool):
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
) -> ToolResponseBase:
|
||||
"""Search for agents in the user's library.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
AgentCarouselResponse: List of agents found in the library
|
||||
NoResultsResponse: No agents found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required to search library",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agents = []
|
||||
try:
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
library_results = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Find library agents tool found {len(library_results.agents)} agents"
|
||||
)
|
||||
|
||||
for agent in library_results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
),
|
||||
)
|
||||
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching library agents: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search library. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agents:
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"No agents found matching '{query}' in your library. "
|
||||
"Try different keywords or use find_agent to search the marketplace."
|
||||
),
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Use find_agent to search the marketplace",
|
||||
"Check your library at /library",
|
||||
],
|
||||
)
|
||||
|
||||
title = (
|
||||
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
|
||||
f"in your library for '{query}'"
|
||||
)
|
||||
|
||||
return AgentCarouselResponse(
|
||||
message=(
|
||||
"Found agents in the user's library. You can provide a link to "
|
||||
"view an agent at: /library/agents/{agent_id}. "
|
||||
"Use agent_output to get execution results, or run_agent to execute."
|
||||
),
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
return await search_agents(
|
||||
query=kwargs.get("query", "").strip(),
|
||||
source="library",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
"""GetDocPageTool - Fetch full content of a documentation page."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
DocPageResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Base URL for documentation (can be configured)
|
||||
DOCS_BASE_URL = "https://docs.agpt.co"
|
||||
|
||||
|
||||
class GetDocPageTool(BaseTool):
|
||||
"""Tool for fetching full content of a documentation page."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_doc_page"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Get the full content of a documentation page by its path. "
|
||||
"Use this after search_docs to read the complete content of a relevant page."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The path to the documentation file, as returned by search_docs. "
|
||||
"Example: 'platform/block-sdk-guide.md'"
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False # Documentation is public
|
||||
|
||||
def _get_docs_root(self) -> Path:
|
||||
"""Get the documentation root directory."""
|
||||
this_file = Path(__file__)
|
||||
project_root = this_file.parent.parent.parent.parent.parent.parent.parent.parent
|
||||
return project_root / "docs"
|
||||
|
||||
def _extract_title(self, content: str, fallback: str) -> str:
|
||||
"""Extract title from markdown content."""
|
||||
lines = content.split("\n")
|
||||
for line in lines:
|
||||
if line.startswith("# "):
|
||||
return line[2:].strip()
|
||||
return fallback
|
||||
|
||||
def _make_doc_url(self, path: str) -> str:
|
||||
"""Create a URL for a documentation page."""
|
||||
url_path = path.rsplit(".", 1)[0] if "." in path else path
|
||||
return f"{DOCS_BASE_URL}/{url_path}"
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Fetch full content of a documentation page.
|
||||
|
||||
Args:
|
||||
user_id: User ID (not required for docs)
|
||||
session: Chat session
|
||||
path: Path to the documentation file
|
||||
|
||||
Returns:
|
||||
DocPageResponse: Full document content
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
path = kwargs.get("path", "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide a documentation path.",
|
||||
error="Missing path parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Sanitize path to prevent directory traversal
|
||||
if ".." in path or path.startswith("/"):
|
||||
return ErrorResponse(
|
||||
message="Invalid documentation path.",
|
||||
error="invalid_path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
docs_root = self._get_docs_root()
|
||||
full_path = docs_root / path
|
||||
|
||||
if not full_path.exists():
|
||||
return ErrorResponse(
|
||||
message=f"Documentation page not found: {path}",
|
||||
error="not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Ensure the path is within docs root
|
||||
try:
|
||||
full_path.resolve().relative_to(docs_root.resolve())
|
||||
except ValueError:
|
||||
return ErrorResponse(
|
||||
message="Invalid documentation path.",
|
||||
error="invalid_path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
content = full_path.read_text(encoding="utf-8")
|
||||
title = self._extract_title(content, path)
|
||||
|
||||
return DocPageResponse(
|
||||
message=f"Retrieved documentation page: {title}",
|
||||
title=title,
|
||||
path=path,
|
||||
content=content,
|
||||
doc_url=self._make_doc_url(path),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read documentation page {path}: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read documentation page: {str(e)}",
|
||||
error="read_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,483 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Block Indexer for Hybrid Search
|
||||
|
||||
Creates a hybrid search index from blocks:
|
||||
- OpenAI embeddings (text-embedding-3-small)
|
||||
- BM25 index for lexical search
|
||||
- Name index for title matching boost
|
||||
|
||||
Supports incremental updates by tracking content hashes.
|
||||
|
||||
Usage:
|
||||
python -m backend.server.v2.chat.tools.index_blocks [--force]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Check for OpenAI availability
|
||||
try:
|
||||
import openai # noqa: F401
|
||||
|
||||
HAS_OPENAI = True
|
||||
except ImportError:
|
||||
HAS_OPENAI = False
|
||||
print("Warning: openai not installed. Run: pip install openai")
|
||||
|
||||
# Default embedding model (OpenAI)
|
||||
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_EMBEDDING_DIM = 1536
|
||||
|
||||
# Output path (relative to this file)
|
||||
INDEX_PATH = Path(__file__).parent / "blocks_index.json"
|
||||
|
||||
# Stopwords for tokenization
|
||||
STOPWORDS = {
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
"may",
|
||||
"might",
|
||||
"must",
|
||||
"shall",
|
||||
"can",
|
||||
"need",
|
||||
"dare",
|
||||
"ought",
|
||||
"used",
|
||||
"to",
|
||||
"of",
|
||||
"in",
|
||||
"for",
|
||||
"on",
|
||||
"with",
|
||||
"at",
|
||||
"by",
|
||||
"from",
|
||||
"as",
|
||||
"into",
|
||||
"through",
|
||||
"during",
|
||||
"before",
|
||||
"after",
|
||||
"above",
|
||||
"below",
|
||||
"between",
|
||||
"under",
|
||||
"again",
|
||||
"further",
|
||||
"then",
|
||||
"once",
|
||||
"and",
|
||||
"but",
|
||||
"or",
|
||||
"nor",
|
||||
"so",
|
||||
"yet",
|
||||
"both",
|
||||
"either",
|
||||
"neither",
|
||||
"not",
|
||||
"only",
|
||||
"own",
|
||||
"same",
|
||||
"than",
|
||||
"too",
|
||||
"very",
|
||||
"just",
|
||||
"also",
|
||||
"now",
|
||||
"here",
|
||||
"there",
|
||||
"when",
|
||||
"where",
|
||||
"why",
|
||||
"how",
|
||||
"all",
|
||||
"each",
|
||||
"every",
|
||||
"few",
|
||||
"more",
|
||||
"most",
|
||||
"other",
|
||||
"some",
|
||||
"such",
|
||||
"no",
|
||||
"any",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
"it",
|
||||
"its",
|
||||
"block", # Too common in block context
|
||||
}
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for BM25."""
|
||||
text = text.lower()
|
||||
# Remove code blocks if any
|
||||
text = re.sub(r"```[\s\S]*?```", "", text)
|
||||
text = re.sub(r"`[^`]+`", "", text)
|
||||
# Extract words (including camelCase split)
|
||||
# First, split camelCase
|
||||
text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
|
||||
# Extract words
|
||||
words = re.findall(r"\b[a-z][a-z0-9_-]*\b", text)
|
||||
# Remove very short words and stopwords
|
||||
return [w for w in words if len(w) > 2 and w not in STOPWORDS]
|
||||
|
||||
|
||||
def build_searchable_text(block: Any) -> str:
|
||||
"""Build searchable text from block attributes."""
|
||||
parts = []
|
||||
|
||||
# Block name (split camelCase for better tokenization)
|
||||
name = block.name
|
||||
# Split camelCase: GetCurrentTimeBlock -> Get Current Time Block
|
||||
name_split = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
|
||||
parts.append(name_split)
|
||||
|
||||
# Description
|
||||
if block.description:
|
||||
parts.append(block.description)
|
||||
|
||||
# Categories
|
||||
for category in block.categories:
|
||||
parts.append(category.name)
|
||||
|
||||
# Input schema field names and descriptions
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
if "properties" in input_schema:
|
||||
for field_name, field_info in input_schema["properties"].items():
|
||||
parts.append(field_name)
|
||||
if "description" in field_info:
|
||||
parts.append(field_info["description"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Output schema field names
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
if "properties" in output_schema:
|
||||
for field_name in output_schema["properties"]:
|
||||
parts.append(field_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def compute_content_hash(text: str) -> str:
|
||||
"""Compute MD5 hash of text for change detection."""
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def load_existing_index(index_path: Path) -> dict[str, Any] | None:
|
||||
"""Load existing index if present."""
|
||||
if not index_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(index_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load existing index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def create_embeddings(
|
||||
texts: list[str],
|
||||
model_name: str = DEFAULT_EMBEDDING_MODEL,
|
||||
batch_size: int = 100,
|
||||
) -> np.ndarray:
|
||||
"""Create embeddings using OpenAI API."""
|
||||
if not HAS_OPENAI:
|
||||
raise RuntimeError("openai not installed. Run: pip install openai")
|
||||
|
||||
# Import here to satisfy type checker
|
||||
from openai import OpenAI
|
||||
|
||||
# Check for API key
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
client = OpenAI(api_key=api_key)
|
||||
embeddings = []
|
||||
|
||||
print(f"Creating embeddings for {len(texts)} texts using {model_name}...")
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
# Truncate texts to max token limit (8191 tokens for text-embedding-3-small)
|
||||
# Roughly 4 chars per token, so ~32000 chars max
|
||||
batch = [text[:32000] for text in batch]
|
||||
|
||||
response = client.embeddings.create(
|
||||
model=model_name,
|
||||
input=batch,
|
||||
)
|
||||
|
||||
for embedding_data in response.data:
|
||||
embeddings.append(embedding_data.embedding)
|
||||
|
||||
print(f" Processed {min(i + batch_size, len(texts))}/{len(texts)} texts")
|
||||
|
||||
return np.array(embeddings, dtype=np.float32)
|
||||
|
||||
|
||||
def build_bm25_data(
|
||||
blocks_data: list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""Build BM25 metadata from block data."""
|
||||
# Tokenize all searchable texts
|
||||
tokenized_docs = []
|
||||
for block in blocks_data:
|
||||
tokens = tokenize(block["searchable_text"])
|
||||
tokenized_docs.append(tokens)
|
||||
|
||||
# Calculate document frequencies
|
||||
doc_freq: dict[str, int] = {}
|
||||
for tokens in tokenized_docs:
|
||||
seen = set()
|
||||
for token in tokens:
|
||||
if token not in seen:
|
||||
doc_freq[token] = doc_freq.get(token, 0) + 1
|
||||
seen.add(token)
|
||||
|
||||
n_docs = len(tokenized_docs)
|
||||
doc_lens = [len(d) for d in tokenized_docs]
|
||||
avgdl = sum(doc_lens) / max(n_docs, 1)
|
||||
|
||||
return {
|
||||
"n_docs": n_docs,
|
||||
"avgdl": avgdl,
|
||||
"df": doc_freq,
|
||||
"doc_lens": doc_lens,
|
||||
}
|
||||
|
||||
|
||||
def build_name_index(
|
||||
blocks_data: list[dict[str, Any]],
|
||||
) -> dict[str, list[list[int | float]]]:
|
||||
"""Build inverted index for name search boost."""
|
||||
index: dict[str, list[list[int | float]]] = defaultdict(list)
|
||||
|
||||
for idx, block in enumerate(blocks_data):
|
||||
# Tokenize block name
|
||||
name_tokens = tokenize(block["name"])
|
||||
seen = set()
|
||||
|
||||
for i, token in enumerate(name_tokens):
|
||||
if token in seen:
|
||||
continue
|
||||
seen.add(token)
|
||||
|
||||
# Score: first token gets higher weight
|
||||
score = 1.5 if i == 0 else 1.0
|
||||
index[token].append([idx, score])
|
||||
|
||||
return dict(index)
|
||||
|
||||
|
||||
def build_block_index(
|
||||
force_rebuild: bool = False,
|
||||
output_path: Path = INDEX_PATH,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build the block search index.
|
||||
|
||||
Args:
|
||||
force_rebuild: If True, rebuild all embeddings even if unchanged
|
||||
output_path: Path to save the index
|
||||
|
||||
Returns:
|
||||
The generated index dictionary
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
print("Loading all blocks...")
|
||||
all_blocks = load_all_blocks()
|
||||
print(f"Found {len(all_blocks)} blocks")
|
||||
|
||||
# Load existing index for incremental updates
|
||||
existing_index = None if force_rebuild else load_existing_index(output_path)
|
||||
existing_blocks: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if existing_index:
|
||||
print(
|
||||
f"Loaded existing index with {len(existing_index.get('blocks', []))} blocks"
|
||||
)
|
||||
for block in existing_index.get("blocks", []):
|
||||
existing_blocks[block["id"]] = block
|
||||
|
||||
# Process each block
|
||||
blocks_data: list[dict[str, Any]] = []
|
||||
blocks_needing_embedding: list[tuple[int, str]] = [] # (index, searchable_text)
|
||||
|
||||
for block_id, block_cls in all_blocks.items():
|
||||
try:
|
||||
block = block_cls()
|
||||
|
||||
# Skip disabled blocks
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
searchable_text = build_searchable_text(block)
|
||||
content_hash = compute_content_hash(searchable_text)
|
||||
|
||||
block_data = {
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"description": block.description,
|
||||
"categories": [cat.name for cat in block.categories],
|
||||
"searchable_text": searchable_text,
|
||||
"content_hash": content_hash,
|
||||
"emb": None, # Will be filled later
|
||||
}
|
||||
|
||||
# Check if we can reuse existing embedding
|
||||
if (
|
||||
block.id in existing_blocks
|
||||
and existing_blocks[block.id].get("content_hash") == content_hash
|
||||
and existing_blocks[block.id].get("emb")
|
||||
):
|
||||
# Reuse existing embedding
|
||||
block_data["emb"] = existing_blocks[block.id]["emb"]
|
||||
else:
|
||||
# Need new embedding
|
||||
blocks_needing_embedding.append((len(blocks_data), searchable_text))
|
||||
|
||||
blocks_data.append(block_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process block {block_id}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Processed {len(blocks_data)} blocks")
|
||||
print(f"Blocks needing new embeddings: {len(blocks_needing_embedding)}")
|
||||
|
||||
# Create embeddings for new/changed blocks
|
||||
if blocks_needing_embedding and HAS_OPENAI:
|
||||
texts_to_embed = [text for _, text in blocks_needing_embedding]
|
||||
try:
|
||||
embeddings = create_embeddings(texts_to_embed)
|
||||
|
||||
# Assign embeddings to blocks
|
||||
for i, (block_idx, _) in enumerate(blocks_needing_embedding):
|
||||
emb = embeddings[i].astype(np.float32)
|
||||
# Encode as base64
|
||||
blocks_data[block_idx]["emb"] = base64.b64encode(emb.tobytes()).decode(
|
||||
"ascii"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to create embeddings: {e}")
|
||||
elif blocks_needing_embedding:
|
||||
print(
|
||||
"Warning: Cannot create embeddings (openai not installed or OPENAI_API_KEY not set)"
|
||||
)
|
||||
|
||||
# Build BM25 data
|
||||
print("Building BM25 index...")
|
||||
bm25_data = build_bm25_data(blocks_data)
|
||||
|
||||
# Build name index
|
||||
print("Building name index...")
|
||||
name_index = build_name_index(blocks_data)
|
||||
|
||||
# Build final index
|
||||
index = {
|
||||
"version": "1.0.0",
|
||||
"embedding_model": DEFAULT_EMBEDDING_MODEL,
|
||||
"embedding_dim": DEFAULT_EMBEDDING_DIM,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"blocks": blocks_data,
|
||||
"bm25": bm25_data,
|
||||
"name_index": name_index,
|
||||
}
|
||||
|
||||
# Save index
|
||||
print(f"Saving index to {output_path}...")
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, separators=(",", ":"))
|
||||
|
||||
size_kb = output_path.stat().st_size / 1024
|
||||
print(f"Index saved ({size_kb:.1f} KB)")
|
||||
|
||||
# Print statistics
|
||||
print("\nIndex Statistics:")
|
||||
print(f" Blocks indexed: {len(blocks_data)}")
|
||||
print(f" BM25 vocabulary size: {len(bm25_data['df'])}")
|
||||
print(f" Name index terms: {len(name_index)}")
|
||||
print(f" Embeddings: {'Yes' if any(b.get('emb') for b in blocks_data) else 'No'}")
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Build hybrid search index for blocks")
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Force rebuild all embeddings even if unchanged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=INDEX_PATH,
|
||||
help=f"Output index file path (default: {INDEX_PATH})",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
build_block_index(
|
||||
force_rebuild=args.force,
|
||||
output_path=args.output,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error building index: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -12,23 +12,22 @@ from backend.data.model import CredentialsMetaInput
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of tool responses."""
|
||||
|
||||
AGENT_CAROUSEL = "agent_carousel"
|
||||
AGENTS_FOUND = "agents_found"
|
||||
AGENT_DETAILS = "agent_details"
|
||||
SETUP_REQUIREMENTS = "setup_requirements"
|
||||
EXECUTION_STARTED = "execution_started"
|
||||
NEED_LOGIN = "need_login"
|
||||
ERROR = "error"
|
||||
NO_RESULTS = "no_results"
|
||||
SUCCESS = "success"
|
||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||
AGENT_OUTPUT = "agent_output"
|
||||
BLOCK_LIST = "block_list"
|
||||
BLOCK_OUTPUT = "block_output"
|
||||
UNDERSTANDING_UPDATED = "understanding_updated"
|
||||
# Agent generation responses
|
||||
AGENT_PREVIEW = "agent_preview"
|
||||
AGENT_SAVED = "agent_saved"
|
||||
CLARIFICATION_NEEDED = "clarification_needed"
|
||||
BLOCK_LIST = "block_list"
|
||||
BLOCK_OUTPUT = "block_output"
|
||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||
DOC_PAGE = "doc_page"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -61,14 +60,14 @@ class AgentInfo(BaseModel):
|
||||
graph_id: str | None = None
|
||||
|
||||
|
||||
class AgentCarouselResponse(ToolResponseBase):
|
||||
class AgentsFoundResponse(ToolResponseBase):
|
||||
"""Response for find_agent tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_CAROUSEL
|
||||
type: ResponseType = ResponseType.AGENTS_FOUND
|
||||
title: str = "Available Agents"
|
||||
agents: list[AgentInfo]
|
||||
count: int
|
||||
name: str = "agent_carousel"
|
||||
name: str = "agents_found"
|
||||
|
||||
|
||||
class NoResultsResponse(ToolResponseBase):
|
||||
@@ -185,28 +184,6 @@ class ErrorResponse(ToolResponseBase):
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# Documentation search models
|
||||
class DocSearchResult(BaseModel):
|
||||
"""A single documentation search result."""
|
||||
|
||||
title: str
|
||||
path: str
|
||||
section: str
|
||||
snippet: str # Short excerpt for UI display
|
||||
content: str # Full text content for LLM to read and understand
|
||||
score: float
|
||||
doc_url: str | None = None
|
||||
|
||||
|
||||
class DocSearchResultsResponse(ToolResponseBase):
|
||||
"""Response for search_docs tool."""
|
||||
|
||||
type: ResponseType = ResponseType.DOC_SEARCH_RESULTS
|
||||
results: list[DocSearchResult]
|
||||
count: int
|
||||
query: str
|
||||
|
||||
|
||||
# Agent output models
|
||||
class ExecutionOutputInfo(BaseModel):
|
||||
"""Summary of a single execution's outputs."""
|
||||
@@ -232,37 +209,6 @@ class AgentOutputResponse(ToolResponseBase):
|
||||
total_executions: int = 0
|
||||
|
||||
|
||||
# Block models
|
||||
class BlockInfoSummary(BaseModel):
|
||||
"""Summary of a block for search results."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
categories: list[str]
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
|
||||
|
||||
class BlockListResponse(ToolResponseBase):
|
||||
"""Response for find_block tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BLOCK_LIST
|
||||
blocks: list[BlockInfoSummary]
|
||||
count: int
|
||||
query: str
|
||||
|
||||
|
||||
class BlockOutputResponse(ToolResponseBase):
|
||||
"""Response for run_block tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BLOCK_OUTPUT
|
||||
block_id: str
|
||||
block_name: str
|
||||
outputs: dict[str, list[Any]]
|
||||
success: bool = True
|
||||
|
||||
|
||||
# Business understanding models
|
||||
class UnderstandingUpdatedResponse(ToolResponseBase):
|
||||
"""Response for add_understanding tool."""
|
||||
@@ -308,3 +254,83 @@ class ClarificationNeededResponse(ToolResponseBase):
|
||||
|
||||
type: ResponseType = ResponseType.CLARIFICATION_NEEDED
|
||||
questions: list[ClarifyingQuestion] = Field(default_factory=list)
|
||||
|
||||
|
||||
# Documentation search models
|
||||
class DocSearchResult(BaseModel):
|
||||
"""A single documentation search result."""
|
||||
|
||||
title: str
|
||||
path: str
|
||||
section: str
|
||||
snippet: str # Short excerpt for UI display
|
||||
score: float
|
||||
doc_url: str | None = None
|
||||
|
||||
|
||||
class DocSearchResultsResponse(ToolResponseBase):
|
||||
"""Response for search_docs tool."""
|
||||
|
||||
type: ResponseType = ResponseType.DOC_SEARCH_RESULTS
|
||||
results: list[DocSearchResult]
|
||||
count: int
|
||||
query: str
|
||||
|
||||
|
||||
class DocPageResponse(ToolResponseBase):
|
||||
"""Response for get_doc_page tool."""
|
||||
|
||||
type: ResponseType = ResponseType.DOC_PAGE
|
||||
title: str
|
||||
path: str
|
||||
content: str # Full document content
|
||||
doc_url: str | None = None
|
||||
|
||||
|
||||
# Block models
|
||||
class BlockInputFieldInfo(BaseModel):
|
||||
"""Information about a block input field."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
description: str = ""
|
||||
required: bool = False
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
class BlockInfoSummary(BaseModel):
|
||||
"""Summary of a block for search results."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
categories: list[str]
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
||||
default_factory=list,
|
||||
description="List of required input fields for this block",
|
||||
)
|
||||
|
||||
|
||||
class BlockListResponse(ToolResponseBase):
|
||||
"""Response for find_block tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BLOCK_LIST
|
||||
blocks: list[BlockInfoSummary]
|
||||
count: int
|
||||
query: str
|
||||
usage_hint: str = Field(
|
||||
default="To execute a block, call run_block with block_id set to the block's "
|
||||
"'id' field and input_data containing the required fields from input_schema."
|
||||
)
|
||||
|
||||
|
||||
class BlockOutputResponse(ToolResponseBase):
|
||||
"""Response for run_block tool."""
|
||||
|
||||
type: ResponseType = ResponseType.BLOCK_OUTPUT
|
||||
block_id: str
|
||||
block_name: str
|
||||
outputs: dict[str, list[Any]]
|
||||
success: bool = True
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
@@ -17,6 +18,17 @@ setup_test_data = setup_test_data
|
||||
setup_firecrawl_test_data = setup_firecrawl_test_data
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_embedding_functions():
|
||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.ensure_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_run_agent(setup_test_data):
|
||||
"""Test that the run_agent tool successfully executes an approved agent"""
|
||||
@@ -46,11 +58,11 @@ async def test_run_agent(setup_test_data):
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
# Parse the result JSON to verify the execution started
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
assert "execution_id" in result_data
|
||||
assert "graph_id" in result_data
|
||||
assert result_data["graph_id"] == graph.id
|
||||
@@ -86,11 +98,11 @@ async def test_run_agent_missing_inputs(setup_test_data):
|
||||
|
||||
# Verify that we get an error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
# The tool should return an ErrorResponse when setup info indicates not ready
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
assert "message" in result_data
|
||||
|
||||
|
||||
@@ -118,10 +130,10 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
|
||||
|
||||
# Verify that we get an error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
assert "message" in result_data
|
||||
# Should get an error about failed setup or not found
|
||||
assert any(
|
||||
@@ -158,12 +170,12 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert hasattr(response, "output")
|
||||
|
||||
# Parse the result JSON to verify the execution started
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should successfully start execution since credentials are available
|
||||
assert "execution_id" in result_data
|
||||
@@ -195,9 +207,9 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return agent_details type showing available inputs
|
||||
assert result_data.get("type") == "agent_details"
|
||||
@@ -230,9 +242,9 @@ async def test_run_agent_with_use_defaults(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should execute successfully
|
||||
assert "execution_id" in result_data
|
||||
@@ -260,9 +272,9 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return setup_requirements type with missing credentials
|
||||
assert result_data.get("type") == "setup_requirements"
|
||||
@@ -292,9 +304,9 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return error
|
||||
assert result_data.get("type") == "error"
|
||||
@@ -305,9 +317,10 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
|
||||
async def test_run_agent_unauthenticated():
|
||||
"""Test that run_agent returns need_login for unauthenticated users."""
|
||||
tool = RunAgentTool()
|
||||
session = make_session(user_id=None)
|
||||
# Session has a user_id (session owner), but we test tool execution without user_id
|
||||
session = make_session(user_id="test-session-owner")
|
||||
|
||||
# Execute without user_id
|
||||
# Execute without user_id to test unauthenticated behavior
|
||||
response = await tool.execute(
|
||||
user_id=None,
|
||||
session_id=str(uuid.uuid4()),
|
||||
@@ -318,9 +331,9 @@ async def test_run_agent_unauthenticated():
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Base tool returns need_login type for unauthenticated users
|
||||
assert result_data.get("type") == "need_login"
|
||||
@@ -350,9 +363,9 @@ async def test_run_agent_schedule_without_cron(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return error about missing cron
|
||||
assert result_data.get("type") == "error"
|
||||
@@ -382,9 +395,9 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return error about missing schedule_name
|
||||
assert result_data.get("type") == "error"
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
@@ -34,8 +35,10 @@ class RunBlockTool(BaseTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a specific block with the provided input data. "
|
||||
"Use find_block to discover available blocks and their input schemas. "
|
||||
"The block will run and return its outputs once complete."
|
||||
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
||||
"do NOT guess or make up block IDs. "
|
||||
"Use the 'id' from find_block results and provide input_data "
|
||||
"matching the block's required_inputs."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -45,13 +48,16 @@ class RunBlockTool(BaseTool):
|
||||
"properties": {
|
||||
"block_id": {
|
||||
"type": "string",
|
||||
"description": "The UUID of the block to execute",
|
||||
"description": (
|
||||
"The block's 'id' field from find_block results. "
|
||||
"NEVER guess this - always get it from find_block first."
|
||||
),
|
||||
},
|
||||
"input_data": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Input values for the block. Must match the block's input schema. "
|
||||
"Check the block's input_schema from find_block for required fields."
|
||||
"Input values for the block. Use the 'required_inputs' field "
|
||||
"from find_block to see what fields are needed."
|
||||
),
|
||||
},
|
||||
},
|
||||
@@ -208,7 +214,11 @@ class RunBlockTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Fetch actual credentials and prepare kwargs for block execution
|
||||
exec_kwargs: dict[str, Any] = {"user_id": user_id}
|
||||
# Create execution context with defaults (blocks may require it)
|
||||
exec_kwargs: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"execution_context": ExecutionContext(),
|
||||
}
|
||||
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
# Inject metadata into input_data (for validation)
|
||||
|
||||
@@ -1,460 +0,0 @@
|
||||
"""
|
||||
Block Hybrid Search
|
||||
|
||||
Combines multiple ranking signals for block search:
|
||||
- Semantic search (OpenAI embeddings + cosine similarity)
|
||||
- Lexical search (BM25)
|
||||
- Name matching (boost for block name matches)
|
||||
- Category matching (boost for category matches)
|
||||
|
||||
Based on the docs search implementation.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenAI embedding model
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
|
||||
# Path to the JSON index file
|
||||
INDEX_PATH = Path(__file__).parent / "blocks_index.json"
|
||||
|
||||
# Stopwords for tokenization (same as index_blocks.py)
|
||||
STOPWORDS = {
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
"may",
|
||||
"might",
|
||||
"must",
|
||||
"shall",
|
||||
"can",
|
||||
"need",
|
||||
"dare",
|
||||
"ought",
|
||||
"used",
|
||||
"to",
|
||||
"of",
|
||||
"in",
|
||||
"for",
|
||||
"on",
|
||||
"with",
|
||||
"at",
|
||||
"by",
|
||||
"from",
|
||||
"as",
|
||||
"into",
|
||||
"through",
|
||||
"during",
|
||||
"before",
|
||||
"after",
|
||||
"above",
|
||||
"below",
|
||||
"between",
|
||||
"under",
|
||||
"again",
|
||||
"further",
|
||||
"then",
|
||||
"once",
|
||||
"and",
|
||||
"but",
|
||||
"or",
|
||||
"nor",
|
||||
"so",
|
||||
"yet",
|
||||
"both",
|
||||
"either",
|
||||
"neither",
|
||||
"not",
|
||||
"only",
|
||||
"own",
|
||||
"same",
|
||||
"than",
|
||||
"too",
|
||||
"very",
|
||||
"just",
|
||||
"also",
|
||||
"now",
|
||||
"here",
|
||||
"there",
|
||||
"when",
|
||||
"where",
|
||||
"why",
|
||||
"how",
|
||||
"all",
|
||||
"each",
|
||||
"every",
|
||||
"few",
|
||||
"more",
|
||||
"most",
|
||||
"other",
|
||||
"some",
|
||||
"such",
|
||||
"no",
|
||||
"any",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
"it",
|
||||
"its",
|
||||
"block",
|
||||
}
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for search."""
|
||||
text = text.lower()
|
||||
# Remove code blocks if any
|
||||
text = re.sub(r"```[\s\S]*?```", "", text)
|
||||
text = re.sub(r"`[^`]+`", "", text)
|
||||
# Split camelCase
|
||||
text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
|
||||
# Extract words
|
||||
words = re.findall(r"\b[a-z][a-z0-9_-]*\b", text)
|
||||
# Remove very short words and stopwords
|
||||
return [w for w in words if len(w) > 2 and w not in STOPWORDS]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchWeights:
|
||||
"""Configuration for hybrid search signal weights."""
|
||||
|
||||
semantic: float = 0.40 # Embedding similarity
|
||||
bm25: float = 0.25 # Lexical matching
|
||||
name_match: float = 0.25 # Block name matches
|
||||
category_match: float = 0.10 # Category matches
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockSearchResult:
|
||||
"""A single block search result."""
|
||||
|
||||
block_id: str
|
||||
name: str
|
||||
description: str
|
||||
categories: list[str]
|
||||
score: float
|
||||
|
||||
# Individual signal scores (for debugging)
|
||||
semantic_score: float = 0.0
|
||||
bm25_score: float = 0.0
|
||||
name_score: float = 0.0
|
||||
category_score: float = 0.0
|
||||
|
||||
|
||||
class BlockSearchIndex:
|
||||
"""Hybrid search index for blocks combining BM25 + embeddings."""
|
||||
|
||||
def __init__(self, index_path: Path = INDEX_PATH):
|
||||
self.blocks: list[dict[str, Any]] = []
|
||||
self.bm25_data: dict[str, Any] = {}
|
||||
self.name_index: dict[str, list[list[int | float]]] = {}
|
||||
self.embeddings: Optional[np.ndarray] = None
|
||||
self.normalized_embeddings: Optional[np.ndarray] = None
|
||||
self._loaded = False
|
||||
self._index_path = index_path
|
||||
self._embedding_model: Any = None
|
||||
|
||||
def load(self) -> bool:
|
||||
"""Load the index from JSON file."""
|
||||
if self._loaded:
|
||||
return True
|
||||
|
||||
if not self._index_path.exists():
|
||||
logger.warning(f"Block index not found at {self._index_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(self._index_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.blocks = data.get("blocks", [])
|
||||
self.bm25_data = data.get("bm25", {})
|
||||
self.name_index = data.get("name_index", {})
|
||||
|
||||
# Decode embeddings from base64
|
||||
embeddings_list = []
|
||||
for block in self.blocks:
|
||||
if block.get("emb"):
|
||||
emb_bytes = base64.b64decode(block["emb"])
|
||||
emb = np.frombuffer(emb_bytes, dtype=np.float32)
|
||||
embeddings_list.append(emb)
|
||||
else:
|
||||
# No embedding, use zeros
|
||||
dim = data.get("embedding_dim", 384)
|
||||
embeddings_list.append(np.zeros(dim, dtype=np.float32))
|
||||
|
||||
if embeddings_list:
|
||||
self.embeddings = np.stack(embeddings_list)
|
||||
# Precompute normalized embeddings for cosine similarity
|
||||
norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True)
|
||||
self.normalized_embeddings = self.embeddings / (norms + 1e-10)
|
||||
|
||||
self._loaded = True
|
||||
logger.info(f"Loaded block index with {len(self.blocks)} blocks")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load block index: {e}")
|
||||
return False
|
||||
|
||||
def _get_openai_client(self) -> Any:
|
||||
"""Get OpenAI client for query embedding."""
|
||||
if self._embedding_model is None:
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("OPENAI_API_KEY not set")
|
||||
return None
|
||||
self._embedding_model = OpenAI(api_key=api_key)
|
||||
except ImportError:
|
||||
logger.warning("openai not installed")
|
||||
return None
|
||||
return self._embedding_model
|
||||
|
||||
def _embed_query(self, query: str) -> Optional[np.ndarray]:
|
||||
"""Embed the search query using OpenAI."""
|
||||
client = self._get_openai_client()
|
||||
if client is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
response = client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=query,
|
||||
)
|
||||
embedding = response.data[0].embedding
|
||||
return np.array(embedding, dtype=np.float32)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to embed query: {e}")
|
||||
return None
|
||||
|
||||
def _compute_semantic_scores(self, query_embedding: np.ndarray) -> np.ndarray:
|
||||
"""Compute cosine similarity between query and all blocks."""
|
||||
if self.normalized_embeddings is None:
|
||||
return np.zeros(len(self.blocks))
|
||||
|
||||
# Normalize query embedding
|
||||
query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
|
||||
|
||||
# Cosine similarity via dot product
|
||||
similarities = self.normalized_embeddings @ query_norm
|
||||
|
||||
# Scale to [0, 1] (cosine ranges from -1 to 1)
|
||||
return (similarities + 1) / 2
|
||||
|
||||
def _compute_bm25_scores(self, query_tokens: list[str]) -> np.ndarray:
|
||||
"""Compute BM25 scores for all blocks."""
|
||||
scores = np.zeros(len(self.blocks))
|
||||
|
||||
if not self.bm25_data or not query_tokens:
|
||||
return scores
|
||||
|
||||
# BM25 parameters
|
||||
k1 = 1.5
|
||||
b = 0.75
|
||||
n_docs = self.bm25_data.get("n_docs", len(self.blocks))
|
||||
avgdl = self.bm25_data.get("avgdl", 100)
|
||||
df = self.bm25_data.get("df", {})
|
||||
doc_lens = self.bm25_data.get("doc_lens", [100] * len(self.blocks))
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
# Tokenize block's searchable text
|
||||
block_tokens = tokenize(block.get("searchable_text", ""))
|
||||
doc_len = doc_lens[i] if i < len(doc_lens) else len(block_tokens)
|
||||
|
||||
# Calculate BM25 score
|
||||
score = 0.0
|
||||
for token in query_tokens:
|
||||
if token not in df:
|
||||
continue
|
||||
|
||||
# Term frequency in this document
|
||||
tf = block_tokens.count(token)
|
||||
if tf == 0:
|
||||
continue
|
||||
|
||||
# IDF
|
||||
doc_freq = df.get(token, 0)
|
||||
idf = math.log((n_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
|
||||
|
||||
# BM25 score component
|
||||
numerator = tf * (k1 + 1)
|
||||
denominator = tf + k1 * (1 - b + b * doc_len / avgdl)
|
||||
score += idf * numerator / denominator
|
||||
|
||||
scores[i] = score
|
||||
|
||||
# Normalize to [0, 1]
|
||||
max_score = scores.max()
|
||||
if max_score > 0:
|
||||
scores = scores / max_score
|
||||
|
||||
return scores
|
||||
|
||||
def _compute_name_scores(self, query_tokens: list[str]) -> np.ndarray:
|
||||
"""Compute name match scores using the name index."""
|
||||
scores = np.zeros(len(self.blocks))
|
||||
|
||||
if not self.name_index or not query_tokens:
|
||||
return scores
|
||||
|
||||
for token in query_tokens:
|
||||
if token in self.name_index:
|
||||
for block_idx, weight in self.name_index[token]:
|
||||
if block_idx < len(scores):
|
||||
scores[int(block_idx)] += weight
|
||||
|
||||
# Also check for partial matches in block names
|
||||
for i, block in enumerate(self.blocks):
|
||||
name_lower = block.get("name", "").lower()
|
||||
for token in query_tokens:
|
||||
if token in name_lower:
|
||||
scores[i] += 0.5
|
||||
|
||||
# Normalize to [0, 1]
|
||||
max_score = scores.max()
|
||||
if max_score > 0:
|
||||
scores = scores / max_score
|
||||
|
||||
return scores
|
||||
|
||||
def _compute_category_scores(self, query_tokens: list[str]) -> np.ndarray:
|
||||
"""Compute category match scores."""
|
||||
scores = np.zeros(len(self.blocks))
|
||||
|
||||
if not query_tokens:
|
||||
return scores
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
categories = block.get("categories", [])
|
||||
category_text = " ".join(categories).lower()
|
||||
|
||||
for token in query_tokens:
|
||||
if token in category_text:
|
||||
scores[i] += 1.0
|
||||
|
||||
# Normalize to [0, 1]
|
||||
max_score = scores.max()
|
||||
if max_score > 0:
|
||||
scores = scores / max_score
|
||||
|
||||
return scores
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
weights: Optional[SearchWeights] = None,
|
||||
) -> list[BlockSearchResult]:
|
||||
"""
|
||||
Perform hybrid search combining multiple signals.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
top_k: Number of results to return
|
||||
weights: Optional custom weights for signals
|
||||
|
||||
Returns:
|
||||
List of BlockSearchResult sorted by score
|
||||
"""
|
||||
if not self._loaded and not self.load():
|
||||
return []
|
||||
|
||||
if weights is None:
|
||||
weights = SearchWeights()
|
||||
|
||||
# Tokenize query
|
||||
query_tokens = tokenize(query)
|
||||
if not query_tokens:
|
||||
# Fallback: try raw query words
|
||||
query_tokens = query.lower().split()
|
||||
|
||||
# Compute semantic scores
|
||||
semantic_scores = np.zeros(len(self.blocks))
|
||||
if self.normalized_embeddings is not None:
|
||||
query_embedding = self._embed_query(query)
|
||||
if query_embedding is not None:
|
||||
semantic_scores = self._compute_semantic_scores(query_embedding)
|
||||
|
||||
# Compute other scores
|
||||
bm25_scores = self._compute_bm25_scores(query_tokens)
|
||||
name_scores = self._compute_name_scores(query_tokens)
|
||||
category_scores = self._compute_category_scores(query_tokens)
|
||||
|
||||
# Combine scores using weights
|
||||
combined_scores = (
|
||||
weights.semantic * semantic_scores
|
||||
+ weights.bm25 * bm25_scores
|
||||
+ weights.name_match * name_scores
|
||||
+ weights.category_match * category_scores
|
||||
)
|
||||
|
||||
# Get top-k indices
|
||||
top_indices = np.argsort(combined_scores)[::-1][:top_k]
|
||||
|
||||
# Build results
|
||||
results = []
|
||||
for idx in top_indices:
|
||||
if combined_scores[idx] <= 0:
|
||||
continue
|
||||
|
||||
block = self.blocks[idx]
|
||||
results.append(
|
||||
BlockSearchResult(
|
||||
block_id=block["id"],
|
||||
name=block["name"],
|
||||
description=block["description"],
|
||||
categories=block.get("categories", []),
|
||||
score=float(combined_scores[idx]),
|
||||
semantic_score=float(semantic_scores[idx]),
|
||||
bm25_score=float(bm25_scores[idx]),
|
||||
name_score=float(name_scores[idx]),
|
||||
category_score=float(category_scores[idx]),
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Global index instance (lazy loaded)
|
||||
_block_search_index: Optional[BlockSearchIndex] = None
|
||||
|
||||
|
||||
def get_block_search_index() -> BlockSearchIndex:
|
||||
"""Get or create the block search index singleton."""
|
||||
global _block_search_index
|
||||
if _block_search_index is None:
|
||||
_block_search_index = BlockSearchIndex(INDEX_PATH)
|
||||
return _block_search_index
|
||||
@@ -1,269 +1,31 @@
|
||||
"""Tool for searching platform documentation."""
|
||||
"""SearchDocsTool - Search documentation using hybrid search."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
DocSearchResult,
|
||||
DocSearchResultsResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Documentation base URL
|
||||
DOCS_BASE_URL = "https://docs.agpt.co/platform"
|
||||
# Base URL for documentation (can be configured)
|
||||
DOCS_BASE_URL = "https://docs.agpt.co"
|
||||
|
||||
# Path to the JSON index file (relative to this file)
|
||||
INDEX_PATH = Path(__file__).parent / "docs_index.json"
|
||||
# Maximum number of results to return
|
||||
MAX_RESULTS = 5
|
||||
|
||||
|
||||
def tokenize(text: str) -> list[str]:
|
||||
"""Simple tokenizer for BM25."""
|
||||
text = text.lower()
|
||||
# Remove code blocks
|
||||
text = re.sub(r"```[\s\S]*?```", "", text)
|
||||
text = re.sub(r"`[^`]+`", "", text)
|
||||
# Extract words
|
||||
words = re.findall(r"\b[a-z][a-z0-9_-]*\b", text)
|
||||
# Remove very short words and stopwords
|
||||
stopwords = {
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
"may",
|
||||
"might",
|
||||
"must",
|
||||
"shall",
|
||||
"can",
|
||||
"need",
|
||||
"dare",
|
||||
"ought",
|
||||
"used",
|
||||
"to",
|
||||
"of",
|
||||
"in",
|
||||
"for",
|
||||
"on",
|
||||
"with",
|
||||
"at",
|
||||
"by",
|
||||
"from",
|
||||
"as",
|
||||
"into",
|
||||
"through",
|
||||
"during",
|
||||
"before",
|
||||
"after",
|
||||
"above",
|
||||
"below",
|
||||
"between",
|
||||
"under",
|
||||
"again",
|
||||
"further",
|
||||
"then",
|
||||
"once",
|
||||
"and",
|
||||
"but",
|
||||
"or",
|
||||
"nor",
|
||||
"so",
|
||||
"yet",
|
||||
"both",
|
||||
"either",
|
||||
"neither",
|
||||
"not",
|
||||
"only",
|
||||
"own",
|
||||
"same",
|
||||
"than",
|
||||
"too",
|
||||
"very",
|
||||
"just",
|
||||
"also",
|
||||
"now",
|
||||
"here",
|
||||
"there",
|
||||
"when",
|
||||
"where",
|
||||
"why",
|
||||
"how",
|
||||
"all",
|
||||
"each",
|
||||
"every",
|
||||
"both",
|
||||
"few",
|
||||
"more",
|
||||
"most",
|
||||
"other",
|
||||
"some",
|
||||
"such",
|
||||
"no",
|
||||
"any",
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
"it",
|
||||
"its",
|
||||
}
|
||||
return [w for w in words if len(w) > 2 and w not in stopwords]
|
||||
|
||||
|
||||
class DocSearchIndex:
|
||||
"""Lightweight documentation search index using BM25."""
|
||||
|
||||
def __init__(self, index_path: Path):
|
||||
self.chunks: list[dict] = []
|
||||
self.bm25_data: dict = {}
|
||||
self._loaded = False
|
||||
self._index_path = index_path
|
||||
|
||||
def load(self) -> bool:
|
||||
"""Load the index from JSON file."""
|
||||
if self._loaded:
|
||||
return True
|
||||
|
||||
if not self._index_path.exists():
|
||||
logger.warning(f"Documentation index not found at {self._index_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(self._index_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.chunks = data.get("chunks", [])
|
||||
self.bm25_data = data.get("bm25", {})
|
||||
self._loaded = True
|
||||
logger.info(f"Loaded documentation index with {len(self.chunks)} chunks")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load documentation index: {e}")
|
||||
return False
|
||||
|
||||
def search(self, query: str, top_k: int = 5) -> list[dict]:
|
||||
"""Search the index using BM25."""
|
||||
if not self._loaded and not self.load():
|
||||
return []
|
||||
|
||||
query_tokens = tokenize(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
# BM25 parameters
|
||||
k1 = 1.5
|
||||
b = 0.75
|
||||
n_docs = self.bm25_data.get("n_docs", len(self.chunks))
|
||||
avgdl = self.bm25_data.get("avgdl", 100)
|
||||
df = self.bm25_data.get("df", {})
|
||||
doc_lens = self.bm25_data.get("doc_lens", [100] * len(self.chunks))
|
||||
|
||||
scores = []
|
||||
for i, chunk in enumerate(self.chunks):
|
||||
# Tokenize chunk text
|
||||
chunk_tokens = tokenize(chunk.get("text", ""))
|
||||
doc_len = doc_lens[i] if i < len(doc_lens) else len(chunk_tokens)
|
||||
|
||||
# Calculate BM25 score
|
||||
score = 0.0
|
||||
for token in query_tokens:
|
||||
if token not in df:
|
||||
continue
|
||||
|
||||
# Term frequency in this document
|
||||
tf = chunk_tokens.count(token)
|
||||
if tf == 0:
|
||||
continue
|
||||
|
||||
# IDF
|
||||
doc_freq = df.get(token, 0)
|
||||
idf = math.log((n_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
|
||||
|
||||
# BM25 score component
|
||||
numerator = tf * (k1 + 1)
|
||||
denominator = tf + k1 * (1 - b + b * doc_len / avgdl)
|
||||
score += idf * numerator / denominator
|
||||
|
||||
# Boost for title/heading matches
|
||||
title = chunk.get("title", "").lower()
|
||||
heading = chunk.get("heading", "").lower()
|
||||
for token in query_tokens:
|
||||
if token in title:
|
||||
score *= 1.5
|
||||
if token in heading:
|
||||
score *= 1.2
|
||||
|
||||
scores.append((i, score))
|
||||
|
||||
# Sort by score and return top_k
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
results = []
|
||||
seen_sections = set()
|
||||
for idx, score in scores:
|
||||
if score <= 0:
|
||||
continue
|
||||
|
||||
chunk = self.chunks[idx]
|
||||
section_key = (chunk.get("doc", ""), chunk.get("heading", ""))
|
||||
|
||||
# Deduplicate by section
|
||||
if section_key in seen_sections:
|
||||
continue
|
||||
seen_sections.add(section_key)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"title": chunk.get("title", ""),
|
||||
"path": chunk.get("doc", ""),
|
||||
"heading": chunk.get("heading", ""),
|
||||
"text": chunk.get("text", ""), # Full text for LLM comprehension
|
||||
"score": score,
|
||||
}
|
||||
)
|
||||
|
||||
if len(results) >= top_k:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Global index instance (lazy loaded)
|
||||
_search_index: DocSearchIndex | None = None
|
||||
|
||||
|
||||
def get_search_index() -> DocSearchIndex:
|
||||
"""Get or create the search index singleton."""
|
||||
global _search_index
|
||||
if _search_index is None:
|
||||
_search_index = DocSearchIndex(INDEX_PATH)
|
||||
return _search_index
|
||||
# Snippet length for preview
|
||||
SNIPPET_LENGTH = 200
|
||||
|
||||
|
||||
class SearchDocsTool(BaseTool):
|
||||
@@ -271,15 +33,14 @@ class SearchDocsTool(BaseTool):
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "search_platform_docs"
|
||||
return "search_docs"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search the AutoGPT platform documentation and support Q&A for information about "
|
||||
"how to use the platform, create agents, configure blocks, "
|
||||
"set up integrations, troubleshoot issues, and more. Use this when users ask "
|
||||
"support questions or want to learn how to do something with AutoGPT."
|
||||
"Search the AutoGPT platform documentation for information about "
|
||||
"how to use the platform, build agents, configure blocks, and more. "
|
||||
"Returns relevant documentation sections. Use get_doc_page to read full content."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -290,24 +51,52 @@ class SearchDocsTool(BaseTool):
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query describing what the user wants to learn about. "
|
||||
"Use keywords like 'blocks', 'agents', 'credentials', 'API', etc."
|
||||
"Search query to find relevant documentation. "
|
||||
"Use natural language to describe what you're looking for."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False # Documentation is public
|
||||
|
||||
def _create_snippet(self, content: str, max_length: int = SNIPPET_LENGTH) -> str:
|
||||
"""Create a short snippet from content for preview."""
|
||||
# Remove markdown formatting for cleaner snippet
|
||||
clean_content = content.replace("#", "").replace("*", "").replace("`", "")
|
||||
# Remove extra whitespace
|
||||
clean_content = " ".join(clean_content.split())
|
||||
|
||||
if len(clean_content) <= max_length:
|
||||
return clean_content
|
||||
|
||||
# Truncate at word boundary
|
||||
truncated = clean_content[:max_length]
|
||||
last_space = truncated.rfind(" ")
|
||||
if last_space > max_length // 2:
|
||||
truncated = truncated[:last_space]
|
||||
|
||||
return truncated + "..."
|
||||
|
||||
def _make_doc_url(self, path: str) -> str:
|
||||
"""Create a URL for a documentation page."""
|
||||
# Remove file extension for URL
|
||||
url_path = path.rsplit(".", 1)[0] if "." in path else path
|
||||
return f"{DOCS_BASE_URL}/{url_path}"
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search documentation for the query.
|
||||
"""Search documentation and return relevant sections.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous)
|
||||
user_id: User ID (not required for docs)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
|
||||
@@ -317,60 +106,93 @@ class SearchDocsTool(BaseTool):
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
message="Please provide a search query.",
|
||||
error="Missing query parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
index = get_search_index()
|
||||
results = index.search(query, top_k=5)
|
||||
# Search using hybrid search for DOCUMENTATION content type only
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.DOCUMENTATION],
|
||||
page=1,
|
||||
page_size=MAX_RESULTS * 2, # Fetch extra for deduplication
|
||||
min_score=0.1, # Lower threshold for docs
|
||||
)
|
||||
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No documentation found for '{query}'. Try different keywords.",
|
||||
session_id=session_id,
|
||||
message=f"No documentation found for '{query}'.",
|
||||
suggestions=[
|
||||
"Try more general terms like 'blocks', 'agents', 'setup'",
|
||||
"Check the documentation at docs.agpt.co",
|
||||
"Try different keywords",
|
||||
"Use more general terms",
|
||||
"Check for typos in your query",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Convert to response format
|
||||
doc_results = []
|
||||
for r in results:
|
||||
# Build documentation URL
|
||||
path = r["path"]
|
||||
if path.endswith(".md"):
|
||||
path = path[:-3] # Remove .md extension
|
||||
doc_url = f"{DOCS_BASE_URL}/{path}"
|
||||
# Deduplicate by document path (keep highest scoring section per doc)
|
||||
seen_docs: dict[str, dict[str, Any]] = {}
|
||||
for result in results:
|
||||
metadata = result.get("metadata", {})
|
||||
doc_path = metadata.get("path", "")
|
||||
|
||||
if not doc_path:
|
||||
continue
|
||||
|
||||
# Keep the highest scoring result for each document
|
||||
if doc_path not in seen_docs:
|
||||
seen_docs[doc_path] = result
|
||||
elif result.get("combined_score", 0) > seen_docs[doc_path].get(
|
||||
"combined_score", 0
|
||||
):
|
||||
seen_docs[doc_path] = result
|
||||
|
||||
# Sort by score and take top MAX_RESULTS
|
||||
deduplicated = sorted(
|
||||
seen_docs.values(),
|
||||
key=lambda x: x.get("combined_score", 0),
|
||||
reverse=True,
|
||||
)[:MAX_RESULTS]
|
||||
|
||||
if not deduplicated:
|
||||
return NoResultsResponse(
|
||||
message=f"No documentation found for '{query}'.",
|
||||
suggestions=[
|
||||
"Try different keywords",
|
||||
"Use more general terms",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build response
|
||||
doc_results: list[DocSearchResult] = []
|
||||
for result in deduplicated:
|
||||
metadata = result.get("metadata", {})
|
||||
doc_path = metadata.get("path", "")
|
||||
doc_title = metadata.get("doc_title", "")
|
||||
section_title = metadata.get("section_title", "")
|
||||
searchable_text = result.get("searchable_text", "")
|
||||
score = result.get("combined_score", 0)
|
||||
|
||||
full_text = r["text"]
|
||||
doc_results.append(
|
||||
DocSearchResult(
|
||||
title=r["title"],
|
||||
path=r["path"],
|
||||
section=r["heading"],
|
||||
snippet=(
|
||||
full_text[:300] + "..."
|
||||
if len(full_text) > 300
|
||||
else full_text
|
||||
),
|
||||
content=full_text, # Full text for LLM to read and understand
|
||||
score=round(r["score"], 3),
|
||||
doc_url=doc_url,
|
||||
title=doc_title or section_title or doc_path,
|
||||
path=doc_path,
|
||||
section=section_title,
|
||||
snippet=self._create_snippet(searchable_text),
|
||||
score=round(score, 3),
|
||||
doc_url=self._make_doc_url(doc_path),
|
||||
)
|
||||
)
|
||||
|
||||
return DocSearchResultsResponse(
|
||||
message=(
|
||||
f"Found {len(doc_results)} relevant documentation sections. "
|
||||
"Use these to help answer the user's question. "
|
||||
"Include links to the documentation when helpful."
|
||||
),
|
||||
message=f"Found {len(doc_results)} relevant documentation sections.",
|
||||
results=doc_results,
|
||||
count=len(doc_results),
|
||||
query=query,
|
||||
@@ -378,9 +200,9 @@ class SearchDocsTool(BaseTool):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching documentation: {e}", exc_info=True)
|
||||
logger.error(f"Documentation search failed: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to search documentation. Please try again.",
|
||||
error=str(e),
|
||||
message=f"Failed to search documentation: {str(e)}",
|
||||
error="search_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -35,11 +35,7 @@ from backend.data.model import (
|
||||
OAuth2Credentials,
|
||||
UserIntegrations,
|
||||
)
|
||||
from backend.data.onboarding import (
|
||||
OnboardingStep,
|
||||
complete_onboarding_step,
|
||||
increment_runs,
|
||||
)
|
||||
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
||||
from backend.data.user import get_user_integrations
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||
@@ -175,6 +171,7 @@ async def callback(
|
||||
f"Successfully processed OAuth callback for user {user_id} "
|
||||
f"and provider {provider.value}"
|
||||
)
|
||||
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
provider=credentials.provider,
|
||||
@@ -193,6 +190,7 @@ async def list_credentials(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
@@ -215,6 +213,7 @@ async def list_credentials_by_provider(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
@@ -378,7 +377,6 @@ async def webhook_ingress_generic(
|
||||
return
|
||||
|
||||
await complete_onboarding_step(user_id, OnboardingStep.TRIGGER_WEBHOOK)
|
||||
await increment_runs(user_id)
|
||||
|
||||
# Execute all triggers concurrently for better performance
|
||||
tasks = []
|
||||
@@ -831,6 +829,18 @@ async def list_providers() -> List[str]:
|
||||
return all_providers
|
||||
|
||||
|
||||
@router.get("/providers/system", response_model=List[str])
|
||||
async def list_system_providers() -> List[str]:
|
||||
"""
|
||||
Get a list of providers that have platform credits (system credentials) available.
|
||||
|
||||
These providers can be used without the user providing their own API keys.
|
||||
"""
|
||||
from backend.integrations.credentials_store import SYSTEM_PROVIDERS
|
||||
|
||||
return list(SYSTEM_PROVIDERS)
|
||||
|
||||
|
||||
@router.get("/providers/names", response_model=ProviderNamesResponse)
|
||||
async def get_provider_names() -> ProviderNamesResponse:
|
||||
"""
|
||||
|
||||
@@ -489,7 +489,7 @@ async def update_agent_version_in_library(
|
||||
agent_graph_version: int,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Updates the agent version in the library if useGraphIsActiveVersion is True.
|
||||
Updates the agent version in the library for any agent owned by the user.
|
||||
|
||||
Args:
|
||||
user_id: Owner of the LibraryAgent.
|
||||
@@ -498,20 +498,31 @@ async def update_agent_version_in_library(
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's an error with the update.
|
||||
NotFoundError: If no library agent is found for this user and agent.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Updating agent version in library for user #{user_id}, "
|
||||
f"agent #{agent_graph_id} v{agent_graph_version}"
|
||||
)
|
||||
try:
|
||||
library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise(
|
||||
async with transaction() as tx:
|
||||
library_agent = await prisma.models.LibraryAgent.prisma(tx).find_first_or_raise(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": agent_graph_id,
|
||||
"useGraphIsActiveVersion": True,
|
||||
},
|
||||
)
|
||||
lib = await prisma.models.LibraryAgent.prisma().update(
|
||||
|
||||
# Delete any conflicting LibraryAgent for the target version
|
||||
await prisma.models.LibraryAgent.prisma(tx).delete_many(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentGraphId": agent_graph_id,
|
||||
"agentGraphVersion": agent_graph_version,
|
||||
"id": {"not": library_agent.id},
|
||||
}
|
||||
)
|
||||
|
||||
lib = await prisma.models.LibraryAgent.prisma(tx).update(
|
||||
where={"id": library_agent.id},
|
||||
data={
|
||||
"AgentGraph": {
|
||||
@@ -525,13 +536,13 @@ async def update_agent_version_in_library(
|
||||
},
|
||||
include={"AgentGraph": True},
|
||||
)
|
||||
if lib is None:
|
||||
raise NotFoundError(f"Library agent {library_agent.id} not found")
|
||||
|
||||
return library_model.LibraryAgent.from_db(lib)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating agent version in library: {e}")
|
||||
raise DatabaseError("Failed to update agent version in library") from e
|
||||
if lib is None:
|
||||
raise NotFoundError(
|
||||
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
|
||||
)
|
||||
|
||||
return library_model.LibraryAgent.from_db(lib)
|
||||
|
||||
|
||||
async def update_library_agent(
|
||||
@@ -825,6 +836,7 @@ async def add_store_agent_to_library(
|
||||
}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": SafeJson(
|
||||
_initialize_graph_settings(graph_model).model_dump()
|
||||
),
|
||||
|
||||
@@ -48,6 +48,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
owner_user_id: str # ID of user who owns/created this agent graph
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -163,6 +164,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id=agent.id,
|
||||
graph_id=agent.agentGraphId,
|
||||
graph_version=agent.agentGraphVersion,
|
||||
owner_user_id=agent.userId,
|
||||
image_url=agent.imageUrl,
|
||||
creator_name=creator_name,
|
||||
creator_image_url=creator_image_url,
|
||||
|
||||
@@ -8,7 +8,6 @@ from backend.data.execution import GraphExecutionMeta
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.integrations import get_webhook
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.onboarding import increment_runs
|
||||
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
@@ -403,8 +402,6 @@ async def execute_preset(
|
||||
merged_node_input = preset.inputs | inputs
|
||||
merged_credential_inputs = preset.credentials | credential_inputs
|
||||
|
||||
await increment_runs(user_id)
|
||||
|
||||
return await add_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=preset.graph_id,
|
||||
|
||||
@@ -42,6 +42,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
@@ -64,6 +65,7 @@ async def test_get_library_agents_success(
|
||||
id="test-agent-2",
|
||||
graph_id="test-agent-2",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
image_url=None,
|
||||
@@ -138,6 +140,7 @@ async def test_get_favorite_library_agents_success(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Favorite Agent 1",
|
||||
description="Test Favorite Description 1",
|
||||
image_url=None,
|
||||
@@ -205,6 +208,7 @@ def test_add_agent_to_library_success(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
owner_user_id=test_user_id,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CLI script to backfill embeddings for store agents.
|
||||
|
||||
Usage:
|
||||
poetry run python -m backend.server.v2.store.backfill_embeddings [--batch-size N]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import prisma
|
||||
|
||||
|
||||
async def main(batch_size: int = 100) -> int:
|
||||
"""Run the backfill process."""
|
||||
# Initialize Prisma client
|
||||
client = prisma.Prisma()
|
||||
await client.connect()
|
||||
prisma.register(client)
|
||||
|
||||
try:
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_missing_embeddings,
|
||||
get_embedding_stats,
|
||||
)
|
||||
|
||||
# Get current stats
|
||||
print("Current embedding stats:")
|
||||
stats = await get_embedding_stats()
|
||||
print(f" Total approved: {stats['total_approved']}")
|
||||
print(f" With embeddings: {stats['with_embeddings']}")
|
||||
print(f" Without embeddings: {stats['without_embeddings']}")
|
||||
print(f" Coverage: {stats['coverage_percent']}%")
|
||||
|
||||
if stats["without_embeddings"] == 0:
|
||||
print("\nAll agents already have embeddings. Nothing to do.")
|
||||
return 0
|
||||
|
||||
# Run backfill
|
||||
print(f"\nBackfilling up to {batch_size} embeddings...")
|
||||
result = await backfill_missing_embeddings(batch_size=batch_size)
|
||||
print(f" Processed: {result['processed']}")
|
||||
print(f" Success: {result['success']}")
|
||||
print(f" Failed: {result['failed']}")
|
||||
|
||||
# Get final stats
|
||||
print("\nFinal embedding stats:")
|
||||
stats = await get_embedding_stats()
|
||||
print(f" Total approved: {stats['total_approved']}")
|
||||
print(f" With embeddings: {stats['with_embeddings']}")
|
||||
print(f" Without embeddings: {stats['without_embeddings']}")
|
||||
print(f" Coverage: {stats['coverage_percent']}%")
|
||||
|
||||
return 0 if result["failed"] == 0 else 1
|
||||
|
||||
finally:
|
||||
await client.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Backfill embeddings for store agents")
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of embeddings to generate (default: 100)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sys.exit(asyncio.run(main(batch_size=args.batch_size)))
|
||||
@@ -0,0 +1,610 @@
|
||||
"""
|
||||
Content Type Handlers for Unified Embeddings
|
||||
|
||||
Pluggable system for different content sources (store agents, blocks, docs).
|
||||
Each handler knows how to fetch and process its content type for embedding.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.data.db import query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentItem:
|
||||
"""Represents a piece of content to be embedded."""
|
||||
|
||||
content_id: str # Unique identifier (DB ID or file path)
|
||||
content_type: ContentType
|
||||
searchable_text: str # Combined text for embedding
|
||||
metadata: dict[str, Any] # Content-specific metadata
|
||||
user_id: str | None = None # For user-scoped content
|
||||
|
||||
|
||||
class ContentHandler(ABC):
|
||||
"""Base handler for fetching and processing content for embeddings."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def content_type(self) -> ContentType:
|
||||
"""The ContentType this handler manages."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""
|
||||
Fetch items that don't have embeddings yet.
|
||||
|
||||
Args:
|
||||
batch_size: Maximum number of items to return
|
||||
|
||||
Returns:
|
||||
List of ContentItem objects ready for embedding
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get statistics about embedding coverage.
|
||||
|
||||
Returns:
|
||||
Dict with keys: total, with_embeddings, without_embeddings
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StoreAgentHandler(ContentHandler):
|
||||
"""Handler for marketplace store agent listings."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.STORE_AGENT
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch approved store listings without embeddings."""
|
||||
from backend.api.features.store.embeddings import build_searchable_text
|
||||
|
||||
missing = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
slv.id,
|
||||
slv.name,
|
||||
slv.description,
|
||||
slv."subHeading",
|
||||
slv.categories
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce
|
||||
ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND uce."contentId" IS NULL
|
||||
LIMIT $1
|
||||
""",
|
||||
batch_size,
|
||||
)
|
||||
|
||||
return [
|
||||
ContentItem(
|
||||
content_id=row["id"],
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
searchable_text=build_searchable_text(
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
sub_heading=row["subHeading"],
|
||||
categories=row["categories"] or [],
|
||||
),
|
||||
metadata={
|
||||
"name": row["name"],
|
||||
"categories": row["categories"] or [],
|
||||
},
|
||||
user_id=None, # Store agents are public
|
||||
)
|
||||
for row in missing
|
||||
]
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about store agent embedding coverage."""
|
||||
# Count approved versions
|
||||
approved_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
AND "isDeleted" = false
|
||||
"""
|
||||
)
|
||||
total_approved = approved_result[0]["count"] if approved_result else 0
|
||||
|
||||
# Count versions with embeddings
|
||||
embedded_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
"""
|
||||
)
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_approved,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_approved - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
class BlockHandler(ContentHandler):
|
||||
"""Handler for block definitions (Python classes)."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.BLOCK
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch blocks without embeddings."""
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
# Get all available blocks
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Check which ones have embeddings
|
||||
if not all_blocks:
|
||||
return []
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
|
||||
# Query for existing embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
existing_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT "contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*block_ids,
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
missing_blocks = [
|
||||
(block_id, block_cls)
|
||||
for block_id, block_cls in all_blocks.items()
|
||||
if block_id not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem
|
||||
items = []
|
||||
for block_id, block_cls in missing_blocks[:batch_size]:
|
||||
try:
|
||||
block_instance = block_cls()
|
||||
|
||||
# Build searchable text from block metadata
|
||||
parts = []
|
||||
if hasattr(block_instance, "name") and block_instance.name:
|
||||
parts.append(block_instance.name)
|
||||
if (
|
||||
hasattr(block_instance, "description")
|
||||
and block_instance.description
|
||||
):
|
||||
parts.append(block_instance.description)
|
||||
if hasattr(block_instance, "categories") and block_instance.categories:
|
||||
# Convert BlockCategory enum to strings
|
||||
parts.append(
|
||||
" ".join(str(cat.value) for cat in block_instance.categories)
|
||||
)
|
||||
|
||||
# Add input/output schema info
|
||||
if hasattr(block_instance, "input_schema"):
|
||||
schema = block_instance.input_schema
|
||||
if hasattr(schema, "model_json_schema"):
|
||||
schema_dict = schema.model_json_schema()
|
||||
if "properties" in schema_dict:
|
||||
for prop_name, prop_info in schema_dict[
|
||||
"properties"
|
||||
].items():
|
||||
if "description" in prop_info:
|
||||
parts.append(
|
||||
f"{prop_name}: {prop_info['description']}"
|
||||
)
|
||||
|
||||
searchable_text = " ".join(parts)
|
||||
|
||||
# Convert categories set of enums to list of strings for JSON serialization
|
||||
categories = getattr(block_instance, "categories", set())
|
||||
categories_list = (
|
||||
[cat.value for cat in categories] if categories else []
|
||||
)
|
||||
|
||||
items.append(
|
||||
ContentItem(
|
||||
content_id=block_id,
|
||||
content_type=ContentType.BLOCK,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"name": getattr(block_instance, "name", ""),
|
||||
"categories": categories_list,
|
||||
},
|
||||
user_id=None, # Blocks are public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process block {block_id}: {e}")
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about block embedding coverage."""
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
all_blocks = get_blocks()
|
||||
total_blocks = len(all_blocks)
|
||||
|
||||
if total_blocks == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
|
||||
embedded_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*block_ids,
|
||||
)
|
||||
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_blocks,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_blocks - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarkdownSection:
|
||||
"""Represents a section of a markdown document."""
|
||||
|
||||
title: str # Section heading text
|
||||
content: str # Section content (including the heading line)
|
||||
level: int # Heading level (1 for #, 2 for ##, etc.)
|
||||
index: int # Section index within the document
|
||||
|
||||
|
||||
class DocumentationHandler(ContentHandler):
|
||||
"""Handler for documentation files (.md/.mdx).
|
||||
|
||||
Chunks documents by markdown headings to create multiple embeddings per file.
|
||||
Each section (## heading) becomes a separate embedding for better retrieval.
|
||||
"""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.DOCUMENTATION
|
||||
|
||||
def _get_docs_root(self) -> Path:
|
||||
"""Get the documentation root directory."""
|
||||
# content_handlers.py is at: backend/backend/api/features/store/content_handlers.py
|
||||
# Need to go up to project root then into docs/
|
||||
# In container: /app/autogpt_platform/backend/backend/api/features/store -> /app/docs
|
||||
# In development: /repo/autogpt_platform/backend/backend/api/features/store -> /repo/docs
|
||||
this_file = Path(
|
||||
__file__
|
||||
) # .../backend/backend/api/features/store/content_handlers.py
|
||||
project_root = (
|
||||
this_file.parent.parent.parent.parent.parent.parent.parent
|
||||
) # -> /app or /repo
|
||||
docs_root = project_root / "docs"
|
||||
return docs_root
|
||||
|
||||
def _extract_doc_title(self, file_path: Path) -> str:
|
||||
"""Extract the document title from a markdown file."""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
lines = content.split("\n")
|
||||
|
||||
# Try to extract title from first # heading
|
||||
for line in lines:
|
||||
if line.startswith("# "):
|
||||
return line[2:].strip()
|
||||
|
||||
# If no title found, use filename
|
||||
return file_path.stem.replace("-", " ").replace("_", " ").title()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read title from {file_path}: {e}")
|
||||
return file_path.stem.replace("-", " ").replace("_", " ").title()
|
||||
|
||||
def _chunk_markdown_by_headings(
|
||||
self, file_path: Path, min_heading_level: int = 2
|
||||
) -> list[MarkdownSection]:
|
||||
"""
|
||||
Split a markdown file into sections based on headings.
|
||||
|
||||
Args:
|
||||
file_path: Path to the markdown file
|
||||
min_heading_level: Minimum heading level to split on (default: 2 for ##)
|
||||
|
||||
Returns:
|
||||
List of MarkdownSection objects, one per section.
|
||||
If no headings found, returns a single section with all content.
|
||||
"""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {file_path}: {e}")
|
||||
return []
|
||||
|
||||
lines = content.split("\n")
|
||||
sections: list[MarkdownSection] = []
|
||||
current_section_lines: list[str] = []
|
||||
current_title = ""
|
||||
current_level = 0
|
||||
section_index = 0
|
||||
doc_title = ""
|
||||
|
||||
for line in lines:
|
||||
# Check if line is a heading
|
||||
if line.startswith("#"):
|
||||
# Count heading level
|
||||
level = 0
|
||||
for char in line:
|
||||
if char == "#":
|
||||
level += 1
|
||||
else:
|
||||
break
|
||||
|
||||
heading_text = line[level:].strip()
|
||||
|
||||
# Track document title (level 1 heading)
|
||||
if level == 1 and not doc_title:
|
||||
doc_title = heading_text
|
||||
# Don't create a section for just the title - add it to first section
|
||||
current_section_lines.append(line)
|
||||
continue
|
||||
|
||||
# Check if this heading should start a new section
|
||||
if level >= min_heading_level:
|
||||
# Save previous section if it has content
|
||||
if current_section_lines:
|
||||
section_content = "\n".join(current_section_lines).strip()
|
||||
if section_content:
|
||||
# Use doc title for first section if no specific title
|
||||
title = current_title if current_title else doc_title
|
||||
if not title:
|
||||
title = file_path.stem.replace("-", " ").replace(
|
||||
"_", " "
|
||||
)
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
title=title,
|
||||
content=section_content,
|
||||
level=current_level if current_level else 1,
|
||||
index=section_index,
|
||||
)
|
||||
)
|
||||
section_index += 1
|
||||
|
||||
# Start new section
|
||||
current_section_lines = [line]
|
||||
current_title = heading_text
|
||||
current_level = level
|
||||
else:
|
||||
# Lower level heading (e.g., # when splitting on ##)
|
||||
current_section_lines.append(line)
|
||||
else:
|
||||
current_section_lines.append(line)
|
||||
|
||||
# Don't forget the last section
|
||||
if current_section_lines:
|
||||
section_content = "\n".join(current_section_lines).strip()
|
||||
if section_content:
|
||||
title = current_title if current_title else doc_title
|
||||
if not title:
|
||||
title = file_path.stem.replace("-", " ").replace("_", " ")
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
title=title,
|
||||
content=section_content,
|
||||
level=current_level if current_level else 1,
|
||||
index=section_index,
|
||||
)
|
||||
)
|
||||
|
||||
# If no sections were created (no headings found), create one section with all content
|
||||
if not sections and content.strip():
|
||||
title = (
|
||||
doc_title
|
||||
if doc_title
|
||||
else file_path.stem.replace("-", " ").replace("_", " ")
|
||||
)
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
title=title,
|
||||
content=content.strip(),
|
||||
level=1,
|
||||
index=0,
|
||||
)
|
||||
)
|
||||
|
||||
return sections
|
||||
|
||||
def _make_section_content_id(self, doc_path: str, section_index: int) -> str:
|
||||
"""Create a unique content ID for a document section.
|
||||
|
||||
Format: doc_path::section_index
|
||||
Example: 'platform/getting-started.md::0'
|
||||
"""
|
||||
return f"{doc_path}::{section_index}"
|
||||
|
||||
def _parse_section_content_id(self, content_id: str) -> tuple[str, int]:
|
||||
"""Parse a section content ID back into doc_path and section_index.
|
||||
|
||||
Returns: (doc_path, section_index)
|
||||
"""
|
||||
if "::" in content_id:
|
||||
parts = content_id.rsplit("::", 1)
|
||||
return parts[0], int(parts[1])
|
||||
# Legacy format (whole document)
|
||||
return content_id, 0
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch documentation sections without embeddings.
|
||||
|
||||
Chunks each document by markdown headings and creates embeddings for each section.
|
||||
Content IDs use the format: 'path/to/doc.md::section_index'
|
||||
"""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
logger.warning(f"Documentation root not found: {docs_root}")
|
||||
return []
|
||||
|
||||
# Find all .md and .mdx files
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
|
||||
if not all_docs:
|
||||
return []
|
||||
|
||||
# Build list of all sections from all documents
|
||||
all_sections: list[tuple[str, Path, MarkdownSection]] = []
|
||||
for doc_file in all_docs:
|
||||
doc_path = str(doc_file.relative_to(docs_root))
|
||||
sections = self._chunk_markdown_by_headings(doc_file)
|
||||
for section in sections:
|
||||
all_sections.append((doc_path, doc_file, section))
|
||||
|
||||
if not all_sections:
|
||||
return []
|
||||
|
||||
# Generate content IDs for all sections
|
||||
section_content_ids = [
|
||||
self._make_section_content_id(doc_path, section.index)
|
||||
for doc_path, _, section in all_sections
|
||||
]
|
||||
|
||||
# Check which ones have embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(section_content_ids))])
|
||||
existing_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT "contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*section_content_ids,
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
|
||||
# Filter to missing sections
|
||||
missing_sections = [
|
||||
(doc_path, doc_file, section, content_id)
|
||||
for (doc_path, doc_file, section), content_id in zip(
|
||||
all_sections, section_content_ids
|
||||
)
|
||||
if content_id not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem (up to batch_size)
|
||||
items = []
|
||||
for doc_path, doc_file, section, content_id in missing_sections[:batch_size]:
|
||||
try:
|
||||
# Get document title for context
|
||||
doc_title = self._extract_doc_title(doc_file)
|
||||
|
||||
# Build searchable text with context
|
||||
# Include doc title and section title for better search relevance
|
||||
searchable_text = f"{doc_title} - {section.title}\n\n{section.content}"
|
||||
|
||||
items.append(
|
||||
ContentItem(
|
||||
content_id=content_id,
|
||||
content_type=ContentType.DOCUMENTATION,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"doc_title": doc_title,
|
||||
"section_title": section.title,
|
||||
"section_index": section.index,
|
||||
"heading_level": section.level,
|
||||
"path": doc_path,
|
||||
},
|
||||
user_id=None, # Documentation is public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process section {content_id}: {e}")
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
def _get_all_section_content_ids(self, docs_root: Path) -> set[str]:
|
||||
"""Get all current section content IDs from the docs directory.
|
||||
|
||||
Used for stats and cleanup to know what sections should exist.
|
||||
"""
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
content_ids = set()
|
||||
|
||||
for doc_file in all_docs:
|
||||
doc_path = str(doc_file.relative_to(docs_root))
|
||||
sections = self._chunk_markdown_by_headings(doc_file)
|
||||
for section in sections:
|
||||
content_ids.add(self._make_section_content_id(doc_path, section.index))
|
||||
|
||||
return content_ids
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about documentation embedding coverage.
|
||||
|
||||
Counts sections (not documents) since each section gets its own embedding.
|
||||
"""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
# Get all section content IDs
|
||||
all_section_ids = self._get_all_section_content_ids(docs_root)
|
||||
total_sections = len(all_section_ids)
|
||||
|
||||
if total_sections == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
# Count embeddings in database for DOCUMENTATION type
|
||||
embedded_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{schema_prefix}"ContentType"
|
||||
"""
|
||||
)
|
||||
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_sections,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_sections - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
# Content handler registry
|
||||
CONTENT_HANDLERS: dict[ContentType, ContentHandler] = {
|
||||
ContentType.STORE_AGENT: StoreAgentHandler(),
|
||||
ContentType.BLOCK: BlockHandler(),
|
||||
ContentType.DOCUMENTATION: DocumentationHandler(),
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Integration tests for content handlers using real DB.
|
||||
|
||||
Run with: poetry run pytest backend/api/features/store/content_handlers_integration_test.py -xvs
|
||||
|
||||
These tests use the real database but mock OpenAI calls.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.store.content_handlers import (
|
||||
CONTENT_HANDLERS,
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
)
|
||||
from backend.api.features.store.embeddings import (
|
||||
EMBEDDING_DIM,
|
||||
backfill_all_content_types,
|
||||
ensure_content_embedding,
|
||||
get_embedding_stats,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_real_db():
|
||||
"""Test StoreAgentHandler with real database queries."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Get stats from real DB
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list (may be empty if all have embeddings)
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None
|
||||
assert item.content_type.value == "STORE_AGENT"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_real_db():
|
||||
"""Test BlockHandler with real database queries."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Get stats from real DB
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0 # Should have at least some blocks
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None # Should be block UUID
|
||||
assert item.content_type.value == "BLOCK"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_real_fs():
|
||||
"""Test DocumentationHandler with real filesystem."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Get stats from real filesystem
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None # Should be relative path
|
||||
assert item.content_type.value == "DOCUMENTATION"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_stats_all_types():
|
||||
"""Test get_embedding_stats aggregates all content types."""
|
||||
stats = await get_embedding_stats()
|
||||
|
||||
# Should have structure with by_type and totals
|
||||
assert "by_type" in stats
|
||||
assert "totals" in stats
|
||||
|
||||
# Check each content type is present
|
||||
by_type = stats["by_type"]
|
||||
assert "STORE_AGENT" in by_type
|
||||
assert "BLOCK" in by_type
|
||||
assert "DOCUMENTATION" in by_type
|
||||
|
||||
# Check totals are aggregated
|
||||
totals = stats["totals"]
|
||||
assert totals["total"] >= 0
|
||||
assert totals["with_embeddings"] >= 0
|
||||
assert totals["without_embeddings"] >= 0
|
||||
assert "coverage_percent" in totals
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
async def test_ensure_content_embedding_blocks(mock_generate):
|
||||
"""Test creating embeddings for blocks (mocked OpenAI)."""
|
||||
# Mock OpenAI to return fake embedding
|
||||
mock_generate.return_value = [0.1] * EMBEDDING_DIM
|
||||
|
||||
# Get one block without embedding
|
||||
handler = BlockHandler()
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
if not items:
|
||||
pytest.skip("No blocks without embeddings")
|
||||
|
||||
item = items[0]
|
||||
|
||||
# Try to create embedding (OpenAI mocked)
|
||||
result = await ensure_content_embedding(
|
||||
content_type=item.content_type,
|
||||
content_id=item.content_id,
|
||||
searchable_text=item.searchable_text,
|
||||
metadata=item.metadata,
|
||||
user_id=item.user_id,
|
||||
)
|
||||
|
||||
# Should succeed with mocked OpenAI
|
||||
assert result is True
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
async def test_backfill_all_content_types_dry_run(mock_generate):
|
||||
"""Test backfill_all_content_types processes all handlers in order."""
|
||||
# Mock OpenAI to return fake embedding
|
||||
mock_generate.return_value = [0.1] * EMBEDDING_DIM
|
||||
|
||||
# Run backfill with batch_size=1 to process max 1 per type
|
||||
result = await backfill_all_content_types(batch_size=1)
|
||||
|
||||
# Should have results for all content types
|
||||
assert "by_type" in result
|
||||
assert "totals" in result
|
||||
|
||||
by_type = result["by_type"]
|
||||
assert "BLOCK" in by_type
|
||||
assert "STORE_AGENT" in by_type
|
||||
assert "DOCUMENTATION" in by_type
|
||||
|
||||
# Each type should have correct structure
|
||||
for content_type, type_result in by_type.items():
|
||||
assert "processed" in type_result
|
||||
assert "success" in type_result
|
||||
assert "failed" in type_result
|
||||
|
||||
# Totals should aggregate
|
||||
totals = result["totals"]
|
||||
assert totals["processed"] >= 0
|
||||
assert totals["success"] >= 0
|
||||
assert totals["failed"] >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handler_registry():
|
||||
"""Test all handlers are registered in correct order."""
|
||||
from prisma.enums import ContentType
|
||||
|
||||
# All three types should be registered
|
||||
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
|
||||
assert ContentType.BLOCK in CONTENT_HANDLERS
|
||||
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
|
||||
|
||||
# Check handler types
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
E2E tests for content handlers (blocks, store agents, documentation).
|
||||
|
||||
Tests the full flow: discovering content → generating embeddings → storing.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store.content_handlers import (
|
||||
CONTENT_HANDLERS,
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_missing_items(mocker):
|
||||
"""Test StoreAgentHandler fetches approved agents without embeddings."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock database query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "agent-1",
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"subHeading": "Test heading",
|
||||
"categories": ["AI", "Testing"],
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_missing,
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "agent-1"
|
||||
assert items[0].content_type == ContentType.STORE_AGENT
|
||||
assert "Test Agent" in items[0].searchable_text
|
||||
assert "A test agent" in items[0].searchable_text
|
||||
assert items[0].metadata["name"] == "Test Agent"
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_stats(mocker):
|
||||
"""Test StoreAgentHandler returns correct stats."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock approved count query
|
||||
mock_approved = [{"count": 50}]
|
||||
# Mock embedded count query
|
||||
mock_embedded = [{"count": 30}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
side_effect=[mock_approved, mock_embedded],
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 50
|
||||
assert stats["with_embeddings"] == 30
|
||||
assert stats["without_embeddings"] == 20
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_missing_items(mocker):
|
||||
"""Test BlockHandler discovers blocks without embeddings."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks to return test blocks
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Calculator Block"
|
||||
mock_block_instance.description = "Performs calculations"
|
||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||
mock_block_instance.input_schema.model_json_schema.return_value = {
|
||||
"properties": {"expression": {"description": "Math expression to evaluate"}}
|
||||
}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-uuid-1": mock_block_class}
|
||||
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
mock_existing = []
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_existing,
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "block-uuid-1"
|
||||
assert items[0].content_type == ContentType.BLOCK
|
||||
assert "Calculator Block" in items[0].searchable_text
|
||||
assert "Performs calculations" in items[0].searchable_text
|
||||
assert "MATH" in items[0].searchable_text
|
||||
assert "expression: Math expression" in items[0].searchable_text
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats(mocker):
|
||||
"""Test BlockHandler returns correct stats."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks
|
||||
mock_blocks = {
|
||||
"block-1": MagicMock(),
|
||||
"block-2": MagicMock(),
|
||||
"block-3": MagicMock(),
|
||||
}
|
||||
|
||||
# Mock embedded count query (2 blocks have embeddings)
|
||||
mock_embedded = [{"count": 2}]
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["with_embeddings"] == 2
|
||||
assert stats["without_embeddings"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
"""Test DocumentationHandler discovers docs without embeddings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory with test files
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
|
||||
(docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.")
|
||||
(docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.")
|
||||
|
||||
# Mock _get_docs_root to return temp dir
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 2
|
||||
|
||||
# Check guide.md (content_id format: doc_path::section_index)
|
||||
guide_item = next(
|
||||
(item for item in items if item.content_id == "guide.md::0"), None
|
||||
)
|
||||
assert guide_item is not None
|
||||
assert guide_item.content_type == ContentType.DOCUMENTATION
|
||||
assert "Getting Started" in guide_item.searchable_text
|
||||
assert "This is a guide" in guide_item.searchable_text
|
||||
assert guide_item.metadata["doc_title"] == "Getting Started"
|
||||
assert guide_item.user_id is None
|
||||
|
||||
# Check api.mdx (content_id format: doc_path::section_index)
|
||||
api_item = next(
|
||||
(item for item in items if item.content_id == "api.mdx::0"), None
|
||||
)
|
||||
assert api_item is not None
|
||||
assert "API Reference" in api_item.searchable_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_stats(tmp_path, mocker):
|
||||
"""Test DocumentationHandler returns correct stats."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
(docs_root / "doc1.md").write_text("# Doc 1")
|
||||
(docs_root / "doc2.md").write_text("# Doc 2")
|
||||
(docs_root / "doc3.mdx").write_text("# Doc 3")
|
||||
|
||||
# Mock embedded count query (1 doc has embedding)
|
||||
mock_embedded = [{"count": 1}]
|
||||
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["with_embeddings"] == 1
|
||||
assert stats["without_embeddings"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_title_extraction(tmp_path):
|
||||
"""Test DocumentationHandler extracts title from markdown heading."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test with heading
|
||||
doc_with_heading = tmp_path / "with_heading.md"
|
||||
doc_with_heading.write_text("# My Title\n\nContent here")
|
||||
title = handler._extract_doc_title(doc_with_heading)
|
||||
assert title == "My Title"
|
||||
|
||||
# Test without heading
|
||||
doc_without_heading = tmp_path / "no-heading.md"
|
||||
doc_without_heading.write_text("Just content, no heading")
|
||||
title = handler._extract_doc_title(doc_without_heading)
|
||||
assert title == "No Heading" # Uses filename
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_markdown_chunking(tmp_path):
|
||||
"""Test DocumentationHandler chunks markdown by headings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test document with multiple sections
|
||||
doc_with_sections = tmp_path / "sections.md"
|
||||
doc_with_sections.write_text(
|
||||
"# Document Title\n\n"
|
||||
"Intro paragraph.\n\n"
|
||||
"## Section One\n\n"
|
||||
"Content for section one.\n\n"
|
||||
"## Section Two\n\n"
|
||||
"Content for section two.\n"
|
||||
)
|
||||
sections = handler._chunk_markdown_by_headings(doc_with_sections)
|
||||
|
||||
# Should have 3 sections: intro (with doc title), section one, section two
|
||||
assert len(sections) == 3
|
||||
assert sections[0].title == "Document Title"
|
||||
assert sections[0].index == 0
|
||||
assert "Intro paragraph" in sections[0].content
|
||||
|
||||
assert sections[1].title == "Section One"
|
||||
assert sections[1].index == 1
|
||||
assert "Content for section one" in sections[1].content
|
||||
|
||||
assert sections[2].title == "Section Two"
|
||||
assert sections[2].index == 2
|
||||
assert "Content for section two" in sections[2].content
|
||||
|
||||
# Test document without headings
|
||||
doc_no_sections = tmp_path / "no-sections.md"
|
||||
doc_no_sections.write_text("Just plain content without any headings.")
|
||||
sections = handler._chunk_markdown_by_headings(doc_no_sections)
|
||||
assert len(sections) == 1
|
||||
assert sections[0].index == 0
|
||||
assert "Just plain content" in sections[0].content
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_section_content_ids():
|
||||
"""Test DocumentationHandler creates and parses section content IDs."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test making content ID
|
||||
content_id = handler._make_section_content_id("docs/guide.md", 2)
|
||||
assert content_id == "docs/guide.md::2"
|
||||
|
||||
# Test parsing content ID
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/guide.md::2")
|
||||
assert doc_path == "docs/guide.md"
|
||||
assert section_index == 2
|
||||
|
||||
# Test parsing legacy format (no section index)
|
||||
doc_path, section_index = handler._parse_section_content_id("docs/old-format.md")
|
||||
assert doc_path == "docs/old-format.md"
|
||||
assert section_index == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handlers_registry():
|
||||
"""Test all content types are registered."""
|
||||
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
|
||||
assert ContentType.BLOCK in CONTENT_HANDLERS
|
||||
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
|
||||
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_missing_attributes():
|
||||
"""Test BlockHandler gracefully handles blocks with missing attributes."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock block with minimal attributes
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Minimal Block"
|
||||
# No description, categories, or schema
|
||||
del mock_block_instance.description
|
||||
del mock_block_instance.categories
|
||||
del mock_block_instance.input_schema
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-minimal": mock_block_class}
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].searchable_text == "Minimal Block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_skips_failed_blocks():
|
||||
"""Test BlockHandler skips blocks that fail to instantiate."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock one good block and one bad block
|
||||
good_block = MagicMock()
|
||||
good_instance = MagicMock()
|
||||
good_instance.name = "Good Block"
|
||||
good_instance.description = "Works fine"
|
||||
good_instance.categories = []
|
||||
good_block.return_value = good_instance
|
||||
|
||||
bad_block = MagicMock()
|
||||
bad_block.side_effect = Exception("Instantiation failed")
|
||||
|
||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||
|
||||
with patch(
|
||||
"backend.data.block.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
# Should only get the good block
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "good-block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_missing_docs_directory():
|
||||
"""Test DocumentationHandler handles missing docs directory gracefully."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Mock _get_docs_root to return non-existent path
|
||||
fake_path = Path("/nonexistent/docs")
|
||||
with patch.object(handler, "_get_docs_root", return_value=fake_path):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
assert items == []
|
||||
|
||||
stats = await handler.get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
assert stats["without_embeddings"] == 0
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import fastapi
|
||||
import prisma.enums
|
||||
@@ -29,6 +29,8 @@ from backend.util.settings import Settings
|
||||
|
||||
from . import exceptions as store_exceptions
|
||||
from . import model as store_model
|
||||
from .embeddings import ensure_embedding
|
||||
from .hybrid_search import hybrid_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -49,54 +51,77 @@ async def get_store_agents(
|
||||
page_size: int = 20,
|
||||
) -> store_model.StoreAgentsResponse:
|
||||
"""
|
||||
Get PUBLIC store agents from the StoreAgent view
|
||||
Get PUBLIC store agents from the StoreAgent view.
|
||||
|
||||
Search behavior:
|
||||
- With search_query: Uses hybrid search (semantic + lexical)
|
||||
- Fallback: If embeddings unavailable, gracefully degrades to lexical-only
|
||||
- Rationale: User-facing endpoint prioritizes availability over accuracy
|
||||
|
||||
Note: Admin operations (approval) use fail-fast to prevent inconsistent state.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||
)
|
||||
|
||||
search_used_hybrid = False
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
agents: list[dict[str, Any]] = []
|
||||
total = 0
|
||||
total_pages = 0
|
||||
|
||||
try:
|
||||
# If search_query is provided, use hybrid search (embeddings + tsvector)
|
||||
if search_query:
|
||||
from backend.api.features.store.hybrid_search import hybrid_search
|
||||
# Try hybrid search combining semantic and lexical signals
|
||||
# Falls back to lexical-only if OpenAI unavailable (user-facing, high SLA)
|
||||
try:
|
||||
agents, total = await hybrid_search(
|
||||
query=search_query,
|
||||
featured=featured,
|
||||
creators=creators,
|
||||
category=category,
|
||||
sorted_by="relevance", # Use hybrid scoring for relevance
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
search_used_hybrid = True
|
||||
except Exception as e:
|
||||
# Log error but fall back to lexical search for better UX
|
||||
logger.error(
|
||||
f"Hybrid search failed (likely OpenAI unavailable), "
|
||||
f"falling back to lexical search: {e}"
|
||||
)
|
||||
# search_used_hybrid remains False, will use fallback path below
|
||||
|
||||
# Use hybrid search combining semantic and lexical signals
|
||||
agents, total = await hybrid_search(
|
||||
query=search_query,
|
||||
featured=featured,
|
||||
creators=creators,
|
||||
category=category,
|
||||
sorted_by="relevance", # Use hybrid scoring for relevance
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
# Convert hybrid search results (dict format) if hybrid succeeded
|
||||
if search_used_hybrid:
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
store_agent = store_model.StoreAgent(
|
||||
slug=agent["slug"],
|
||||
agent_name=agent["agent_name"],
|
||||
agent_image=(
|
||||
agent["agent_image"][0] if agent["agent_image"] else ""
|
||||
),
|
||||
creator=agent["creator_username"] or "Needs Profile",
|
||||
creator_avatar=agent["creator_avatar"] or "",
|
||||
sub_heading=agent["sub_heading"],
|
||||
description=agent["description"],
|
||||
runs=agent["runs"],
|
||||
rating=agent["rating"],
|
||||
)
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error parsing Store agent from hybrid search results: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert raw results to StoreAgent models
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
store_agent = store_model.StoreAgent(
|
||||
slug=agent["slug"],
|
||||
agent_name=agent["agent_name"],
|
||||
agent_image=(
|
||||
agent["agent_image"][0] if agent["agent_image"] else ""
|
||||
),
|
||||
creator=agent["creator_username"] or "Needs Profile",
|
||||
creator_avatar=agent["creator_avatar"] or "",
|
||||
sub_heading=agent["sub_heading"],
|
||||
description=agent["description"],
|
||||
runs=agent["runs"],
|
||||
rating=agent["rating"],
|
||||
)
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Store agent from search results: {e}")
|
||||
continue
|
||||
|
||||
else:
|
||||
# Non-search query path (original logic)
|
||||
if not search_used_hybrid:
|
||||
# Fallback path - use basic search or no search
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
@@ -105,6 +130,14 @@ async def get_store_agents(
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
# Add basic text search if search_query provided but hybrid failed
|
||||
if search_query:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"sub_heading": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
if sorted_by == "rating":
|
||||
order_by.append({"rating": "desc"})
|
||||
@@ -113,7 +146,7 @@ async def get_store_agents(
|
||||
elif sorted_by == "name":
|
||||
order_by.append({"agent_name": "asc"})
|
||||
|
||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
@@ -124,7 +157,7 @@ async def get_store_agents(
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
for agent in db_agents:
|
||||
try:
|
||||
# Create the StoreAgent object safely
|
||||
store_agent = store_model.StoreAgent(
|
||||
@@ -539,6 +572,7 @@ async def get_store_submissions(
|
||||
submission_models = []
|
||||
for sub in submissions:
|
||||
submission_model = store_model.StoreSubmission(
|
||||
listing_id=sub.listing_id,
|
||||
agent_id=sub.agent_id,
|
||||
agent_version=sub.agent_version,
|
||||
name=sub.name,
|
||||
@@ -592,35 +626,48 @@ async def delete_store_submission(
|
||||
submission_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a store listing submission as the submitting user.
|
||||
Delete a store submission version as the submitting user.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user
|
||||
submission_id: ID of the submission to be deleted
|
||||
submission_id: StoreListingVersion ID to delete
|
||||
|
||||
Returns:
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
bool: True if successfully deleted
|
||||
"""
|
||||
logger.debug(f"Deleting store submission {submission_id} for user {user_id}")
|
||||
|
||||
try:
|
||||
# Verify the submission belongs to this user
|
||||
submission = await prisma.models.StoreListing.prisma().find_first(
|
||||
where={"agentGraphId": submission_id, "owningUserId": user_id}
|
||||
# Find the submission version with ownership check
|
||||
version = await prisma.models.StoreListingVersion.prisma().find_first(
|
||||
where={"id": submission_id}, include={"StoreListing": True}
|
||||
)
|
||||
|
||||
if not submission:
|
||||
logger.warning(f"Submission not found for user {user_id}: {submission_id}")
|
||||
raise store_exceptions.SubmissionNotFoundError(
|
||||
f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}"
|
||||
if (
|
||||
not version
|
||||
or not version.StoreListing
|
||||
or version.StoreListing.owningUserId != user_id
|
||||
):
|
||||
raise store_exceptions.SubmissionNotFoundError("Submission not found")
|
||||
|
||||
# Prevent deletion of approved submissions
|
||||
if version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
|
||||
raise store_exceptions.InvalidOperationError(
|
||||
"Cannot delete approved submissions"
|
||||
)
|
||||
|
||||
# Delete the submission
|
||||
await prisma.models.StoreListing.prisma().delete(where={"id": submission.id})
|
||||
|
||||
logger.debug(
|
||||
f"Successfully deleted submission {submission_id} for user {user_id}"
|
||||
# Delete the version
|
||||
await prisma.models.StoreListingVersion.prisma().delete(
|
||||
where={"id": version.id}
|
||||
)
|
||||
|
||||
# Clean up empty listing if this was the last version
|
||||
remaining = await prisma.models.StoreListingVersion.prisma().count(
|
||||
where={"storeListingId": version.storeListingId}
|
||||
)
|
||||
if remaining == 0:
|
||||
await prisma.models.StoreListing.prisma().delete(
|
||||
where={"id": version.storeListingId}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -684,9 +731,15 @@ async def create_store_submission(
|
||||
logger.warning(
|
||||
f"Agent not found for user {user_id}: {agent_id} v{agent_version}"
|
||||
)
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
# Provide more user-friendly error message when agent_id is empty
|
||||
if not agent_id or agent_id.strip() == "":
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
"No agent selected. Please select an agent before submitting to the store."
|
||||
)
|
||||
else:
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
|
||||
# Check if listing already exists for this agent
|
||||
existing_listing = await prisma.models.StoreListing.prisma().find_first(
|
||||
@@ -758,6 +811,7 @@ async def create_store_submission(
|
||||
logger.debug(f"Created store listing for agent {agent_id}")
|
||||
# Return submission details
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=listing.id,
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
@@ -869,81 +923,56 @@ async def edit_store_submission(
|
||||
# Currently we are not allowing user to update the agent associated with a submission
|
||||
# If we allow it in future, then we need a check here to verify the agent belongs to this user.
|
||||
|
||||
# Check if we can edit this submission
|
||||
if current_version.submissionStatus == prisma.enums.SubmissionStatus.REJECTED:
|
||||
# Only allow editing of PENDING submissions
|
||||
if current_version.submissionStatus != prisma.enums.SubmissionStatus.PENDING:
|
||||
raise store_exceptions.InvalidOperationError(
|
||||
"Cannot edit a rejected submission"
|
||||
)
|
||||
|
||||
# For APPROVED submissions, we need to create a new version
|
||||
if current_version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
|
||||
# Create a new version for the existing listing
|
||||
return await create_store_version(
|
||||
user_id=user_id,
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
store_listing_id=current_version.storeListingId,
|
||||
name=name,
|
||||
video_url=video_url,
|
||||
agent_output_demo_url=agent_output_demo_url,
|
||||
image_urls=image_urls,
|
||||
description=description,
|
||||
sub_heading=sub_heading,
|
||||
categories=categories,
|
||||
changes_summary=changes_summary,
|
||||
recommended_schedule_cron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
f"Cannot edit a {current_version.submissionStatus.value.lower()} submission. Only pending submissions can be edited."
|
||||
)
|
||||
|
||||
# For PENDING submissions, we can update the existing version
|
||||
elif current_version.submissionStatus == prisma.enums.SubmissionStatus.PENDING:
|
||||
# Update the existing version
|
||||
updated_version = await prisma.models.StoreListingVersion.prisma().update(
|
||||
where={"id": store_listing_version_id},
|
||||
data=prisma.types.StoreListingVersionUpdateInput(
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
|
||||
)
|
||||
|
||||
if not updated_version:
|
||||
raise DatabaseError("Failed to update store listing version")
|
||||
return store_model.StoreSubmission(
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
# Update the existing version
|
||||
updated_version = await prisma.models.StoreListingVersion.prisma().update(
|
||||
where={"id": store_listing_version_id},
|
||||
data=prisma.types.StoreListingVersionUpdateInput(
|
||||
name=name,
|
||||
sub_heading=sub_heading,
|
||||
slug=current_version.StoreListing.slug,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
image_urls=image_urls,
|
||||
date_submitted=updated_version.submittedAt or updated_version.createdAt,
|
||||
status=updated_version.submissionStatus,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
store_listing_version_id=updated_version.id,
|
||||
changes_summary=changes_summary,
|
||||
video_url=video_url,
|
||||
categories=categories,
|
||||
version=updated_version.version,
|
||||
)
|
||||
subHeading=sub_heading,
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
raise store_exceptions.InvalidOperationError(
|
||||
f"Cannot edit submission with status: {current_version.submissionStatus}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
|
||||
)
|
||||
|
||||
if not updated_version:
|
||||
raise DatabaseError("Failed to update store listing version")
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=current_version.StoreListing.id,
|
||||
agent_id=current_version.agentGraphId,
|
||||
agent_version=current_version.agentGraphVersion,
|
||||
name=name,
|
||||
sub_heading=sub_heading,
|
||||
slug=current_version.StoreListing.slug,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
image_urls=image_urls,
|
||||
date_submitted=updated_version.submittedAt or updated_version.createdAt,
|
||||
status=updated_version.submissionStatus,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
store_listing_version_id=updated_version.id,
|
||||
changes_summary=changes_summary,
|
||||
video_url=video_url,
|
||||
categories=categories,
|
||||
version=updated_version.version,
|
||||
)
|
||||
|
||||
except (
|
||||
store_exceptions.SubmissionNotFoundError,
|
||||
@@ -1022,38 +1051,78 @@ async def create_store_version(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
|
||||
# Get the latest version number
|
||||
latest_version = listing.Versions[0] if listing.Versions else None
|
||||
|
||||
next_version = (latest_version.version + 1) if latest_version else 1
|
||||
|
||||
# Create a new version for the existing listing
|
||||
new_version = await prisma.models.StoreListingVersion.prisma().create(
|
||||
data=prisma.types.StoreListingVersionCreateInput(
|
||||
version=next_version,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(),
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
storeListingId=store_listing_id,
|
||||
# Check if there's already a PENDING submission for this agent (any version)
|
||||
existing_pending_submission = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_first(
|
||||
where=prisma.types.StoreListingVersionWhereInput(
|
||||
storeListingId=store_listing_id,
|
||||
agentGraphId=agent_id,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
isDeleted=False,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Handle existing pending submission and create new one atomically
|
||||
async with transaction() as tx:
|
||||
# Get the latest version number first
|
||||
latest_listing = await prisma.models.StoreListing.prisma(tx).find_first(
|
||||
where=prisma.types.StoreListingWhereInput(
|
||||
id=store_listing_id, owningUserId=user_id
|
||||
),
|
||||
include={"Versions": {"order_by": {"version": "desc"}, "take": 1}},
|
||||
)
|
||||
|
||||
if not latest_listing:
|
||||
raise store_exceptions.ListingNotFoundError(
|
||||
f"Store listing not found. User ID: {user_id}, Listing ID: {store_listing_id}"
|
||||
)
|
||||
|
||||
latest_version = (
|
||||
latest_listing.Versions[0] if latest_listing.Versions else None
|
||||
)
|
||||
next_version = (latest_version.version + 1) if latest_version else 1
|
||||
|
||||
# If there's an existing pending submission, delete it atomically before creating new one
|
||||
if existing_pending_submission:
|
||||
logger.info(
|
||||
f"Found existing PENDING submission for agent {agent_id} (was v{existing_pending_submission.agentGraphVersion}, now v{agent_version}), replacing existing submission instead of creating duplicate"
|
||||
)
|
||||
await prisma.models.StoreListingVersion.prisma(tx).delete(
|
||||
where={"id": existing_pending_submission.id}
|
||||
)
|
||||
logger.debug(
|
||||
f"Deleted existing pending submission {existing_pending_submission.id}"
|
||||
)
|
||||
|
||||
# Create a new version for the existing listing
|
||||
new_version = await prisma.models.StoreListingVersion.prisma(tx).create(
|
||||
data=prisma.types.StoreListingVersionCreateInput(
|
||||
version=next_version,
|
||||
agentGraphId=agent_id,
|
||||
agentGraphVersion=agent_version,
|
||||
name=name,
|
||||
videoUrl=video_url,
|
||||
agentOutputDemoUrl=agent_output_demo_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(),
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
storeListingId=store_listing_id,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Created new version for listing {store_listing_id} of agent {agent_id}"
|
||||
)
|
||||
# Return submission details
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=listing.id,
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
@@ -1466,7 +1535,7 @@ async def review_store_submission(
|
||||
)
|
||||
|
||||
# Update the AgentGraph with store listing data
|
||||
await prisma.models.AgentGraph.prisma().update(
|
||||
await prisma.models.AgentGraph.prisma(tx).update(
|
||||
where={
|
||||
"graphVersionId": {
|
||||
"id": store_listing_version.agentGraphId,
|
||||
@@ -1481,6 +1550,23 @@ async def review_store_submission(
|
||||
},
|
||||
)
|
||||
|
||||
# Generate embedding for approved listing (blocking - admin operation)
|
||||
# Inside transaction: if embedding fails, entire transaction rolls back
|
||||
embedding_success = await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=store_listing_version.name,
|
||||
description=store_listing_version.description,
|
||||
sub_heading=store_listing_version.subHeading,
|
||||
categories=store_listing_version.categories or [],
|
||||
tx=tx,
|
||||
)
|
||||
if not embedding_success:
|
||||
raise ValueError(
|
||||
f"Failed to generate embedding for listing {store_listing_version_id}. "
|
||||
"This is likely due to OpenAI API being unavailable. "
|
||||
"Please try again later or contact support if the issue persists."
|
||||
)
|
||||
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": store_listing_version.StoreListing.id},
|
||||
data={
|
||||
@@ -1489,24 +1575,6 @@ async def review_store_submission(
|
||||
},
|
||||
)
|
||||
|
||||
# Generate embedding for approved listing (non-blocking)
|
||||
try:
|
||||
from backend.api.features.store.embeddings import ensure_embedding
|
||||
|
||||
await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=store_listing_version.name,
|
||||
description=store_listing_version.description,
|
||||
sub_heading=store_listing_version.subHeading,
|
||||
categories=store_listing_version.categories or [],
|
||||
)
|
||||
except Exception as e:
|
||||
# Don't fail approval if embedding generation fails
|
||||
logger.warning(
|
||||
f"Failed to generate embedding for approved listing "
|
||||
f"{store_listing_version_id}: {e}"
|
||||
)
|
||||
|
||||
# If rejecting an approved agent, update the StoreListing accordingly
|
||||
if is_rejecting_approved:
|
||||
# Check if there are other approved versions
|
||||
@@ -1651,15 +1719,12 @@ async def review_store_submission(
|
||||
|
||||
# Convert to Pydantic model for consistency
|
||||
return store_model.StoreSubmission(
|
||||
listing_id=(submission.StoreListing.id if submission.StoreListing else ""),
|
||||
agent_id=submission.agentGraphId,
|
||||
agent_version=submission.agentGraphVersion,
|
||||
name=submission.name,
|
||||
sub_heading=submission.subHeading,
|
||||
slug=(
|
||||
submission.StoreListing.slug
|
||||
if hasattr(submission, "storeListing") and submission.StoreListing
|
||||
else ""
|
||||
),
|
||||
slug=(submission.StoreListing.slug if submission.StoreListing else ""),
|
||||
description=submission.description,
|
||||
instructions=submission.instructions,
|
||||
image_urls=submission.imageUrls or [],
|
||||
@@ -1761,9 +1826,7 @@ async def get_admin_listings_with_versions(
|
||||
where = prisma.types.StoreListingWhereInput(**where_dict)
|
||||
include = prisma.types.StoreListingInclude(
|
||||
Versions=prisma.types.FindManyStoreListingVersionArgsFromStoreListing(
|
||||
order_by=prisma.types._StoreListingVersion_version_OrderByInput(
|
||||
version="desc"
|
||||
)
|
||||
order_by={"version": "desc"}
|
||||
),
|
||||
OwningUser=True,
|
||||
)
|
||||
@@ -1788,6 +1851,7 @@ async def get_admin_listings_with_versions(
|
||||
# If we have versions, turn them into StoreSubmission models
|
||||
for version in listing.Versions or []:
|
||||
version_model = store_model.StoreSubmission(
|
||||
listing_id=listing.id,
|
||||
agent_id=version.agentGraphId,
|
||||
agent_version=version.agentGraphVersion,
|
||||
name=version.name,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user